feat: sanitize external identity provider user data (#681)

* feat: sanitize external user data

* remove config option to disable Sanitization: sanitize_external_user_data

* cleanup

---------

Co-authored-by: Christoph Haas <christoph.h@sprinternet.at>
This commit is contained in:
Mykhailo Roit
2026-05-18 23:28:27 +03:00
committed by GitHub
parent ff935a404e
commit 958dcb8fa9
24 changed files with 1545 additions and 50 deletions

View File

@@ -113,4 +113,4 @@ backend:
api_verify_tls: true
api_timeout: 30s
concurrency: 5
debug: false
debug: false

View File

@@ -8,6 +8,23 @@ To enable encryption, set the [`encryption_passphrase`](../configuration/overvie
> :warning: Important: Once encryption is enabled, it cannot be disabled, and the passphrase cannot be changed!
> Only new or updated records will be encrypted; existing data remains in plaintext until its next modified.
## External Identity Provider Data Sanitization
When users authenticate via LDAP, OIDC, or OAuth, WireGuard Portal sanitizes the field values received from the provider before storing them. This protects against several classes of attack that a compromised or misconfigured identity provider could introduce:
- **Unsafe control characters** — Unicode control and format characters, null bytes, and invalid UTF-8 bytes are stripped from external profile fields before they reach the Vue.js UI or email templates.
- **Email header injection** — carriage return and line feed characters in email fields are rejected entirely, and email fields must parse as plain email addresses.
- **Log injection** — unsafe control and format characters are stripped from all external profile fields and from sanitization log context.
- **Denial of service via oversized fields** — field lengths are capped (e.g., 256 runes for identifiers, 254 characters for email addresses).
- **Reserved identifier collision** — reserved user identifiers such as `"all"`, `"new"`, `"id"`, and internal system user identifiers are rejected.
- **Unsafe authorization groups** — OIDC/OAuth group claims are sanitized before group-based checks; groups changed by control/format stripping or truncation are dropped rather than repaired into allowed/admin matches.
Sanitization is always enabled and cannot be disabled.
When sanitization modifies or clears a field value, a `WARN` log entry is emitted with the provider name, provider type, and field name — but never the raw or sanitized value, to avoid leaking sensitive data into logs. This makes it straightforward to detect and investigate potentially malicious or misconfigured providers.
---
## UI and API Access
WireGuard Portal provides a web UI and a REST API for user interaction. It is important to secure these interfaces to prevent unauthorized access and data breaches.
@@ -21,4 +38,4 @@ A detailed explanation is available in the [Reverse Proxy](../getting-started/re
### Secure Authentication
To prevent unauthorized access, WireGuard Portal supports integrating with secure authentication providers such as LDAP, OAuth2, or Passkeys, see [Authentication](./authentication.md) for more details.
When possible, use centralized authentication and enforce multi-factor authentication (MFA) at the provider level for enhanced account security.
For local accounts, administrators should enforce strong password requirements.
For local accounts, administrators should enforce strong password requirements.

1
go.mod
View File

@@ -31,6 +31,7 @@ require (
gorm.io/driver/postgres v1.6.0
gorm.io/driver/sqlserver v1.6.3
gorm.io/gorm v1.31.1
pgregory.net/rapid v1.2.0
)
require (

2
go.sum
View File

@@ -432,5 +432,7 @@ modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk=
pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04=
sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs=
sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4=

View File

@@ -149,5 +149,9 @@ func (l LdapAuthenticator) ParseUserInfo(raw map[string]any) (*domain.Authentica
AdminInfoAvailable: adminInfoAvailable,
}
if err := userInfo.Sanitize("ldap", l.cfg.ProviderName); err != nil {
return nil, err
}
return userInfo, nil
}

View File

@@ -0,0 +1,98 @@
package auth
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
"github.com/h44z/wg-portal/internal/testutil"
)
// makeLdapAuthenticator creates a minimal LdapAuthenticator for testing ParseUserInfo.
func makeLdapAuthenticator() *LdapAuthenticator {
return &LdapAuthenticator{
cfg: &config.LdapProvider{
ProviderName: "test-ldap",
FieldMap: config.LdapFields{
BaseFields: config.BaseFields{
UserIdentifier: "uid",
Email: "mail",
Firstname: "givenName",
Lastname: "sn",
Phone: "telephoneNumber",
Department: "department",
},
GroupMembership: "", // no group membership check
},
},
}
}
// makeRawLdapMap builds a minimal raw LDAP attribute map for ParseUserInfo.
func makeRawLdapMap(uid, mail, givenName, sn, phone, department string) map[string]any {
return map[string]any{
"uid": uid,
"mail": mail,
"givenName": givenName,
"sn": sn,
"telephoneNumber": phone,
"department": department,
}
}
// Test: firstname contains \x00 → output firstname has no null byte,
// one WARN log entry with field: "firstname".
func TestLdapParseUserInfo_NullByteInFirstname(t *testing.T) {
auth := makeLdapAuthenticator()
raw := makeRawLdapMap("alice", "alice@example.com", "Ali\x00ce", "Smith", "", "")
restore := testutil.CaptureWarnLogs(t)
info, err := auth.ParseUserInfo(raw)
records := restore()
require.NoError(t, err)
assert.NotContains(t, info.Firstname, "\x00", "firstname should have null byte removed")
assert.Equal(t, "Alice", info.Firstname)
warnCount := testutil.CountWarnEntries(records)
assert.Equal(t, 1, warnCount, "expected exactly one WARN log entry")
rec, found := testutil.FindWarnWithField(records, "firstname")
assert.True(t, found, "expected WARN log entry with field=firstname")
if found {
assert.Equal(t, "WARN", rec["level"])
}
}
// Test: all fields clean → no WARN log entries emitted.
func TestLdapParseUserInfo_AllFieldsClean(t *testing.T) {
auth := makeLdapAuthenticator()
raw := makeRawLdapMap("alice", "alice@example.com", "Alice", "Smith", "+1 555-1234", "Engineering")
restore := testutil.CaptureWarnLogs(t)
info, err := auth.ParseUserInfo(raw)
records := restore()
require.NoError(t, err)
assert.Equal(t, domain.UserIdentifier("alice"), info.Identifier)
warnCount := testutil.CountWarnEntries(records)
assert.Equal(t, 0, warnCount, "expected no WARN log entries when all fields are clean")
}
// Test: identifier is "all" → returns ErrInvalidData.
func TestLdapParseUserInfo_IdentifierAll(t *testing.T) {
auth := makeLdapAuthenticator()
raw := makeRawLdapMap("all", "all@example.com", "Alice", "Smith", "", "")
restore := testutil.CaptureWarnLogs(t)
_, err := auth.ParseUserInfo(raw)
_ = restore()
require.Error(t, err)
assert.True(t, errors.Is(err, domain.ErrInvalidData), "expected ErrInvalidData when identifier is 'all'")
}

View File

@@ -155,5 +155,5 @@ func (p PlainOauthAuthenticator) GetUserInfo(
// ParseUserInfo parses the user information from the raw data.
func (p PlainOauthAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
return parseOauthUserInfo(p.userInfoMapping, p.userAdminMapping, raw)
return parseOauthUserInfo(p.userInfoMapping, p.userAdminMapping, raw, "oauth", p.name)
}

View File

@@ -194,5 +194,5 @@ func (o OidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token,
// ParseUserInfo parses the user info.
func (o OidcAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
return parseOauthUserInfo(o.userInfoMapping, o.userAdminMapping, raw)
return parseOauthUserInfo(o.userInfoMapping, o.userAdminMapping, raw, "oidc", o.name)
}

View File

@@ -13,6 +13,8 @@ func parseOauthUserInfo(
mapping config.OauthFields,
adminMapping *config.OauthAdminMapping,
raw map[string]any,
providerType string,
providerName string,
) (*domain.AuthenticatorUserInfo, error) {
var isAdmin bool
var adminInfoAvailable bool
@@ -27,18 +29,6 @@ func parseOauthUserInfo(
}
}
// next try to parse the user's groups
if !isAdmin && mapping.UserGroups != "" && adminMapping.AdminGroupRegex != "" {
adminInfoAvailable = true
re := adminMapping.GetAdminGroupRegex()
for _, group := range userGroups {
if re.MatchString(strings.TrimSpace(group)) {
isAdmin = true
break
}
}
}
userInfo := &domain.AuthenticatorUserInfo{
Identifier: domain.UserIdentifier(internal.MapDefaultString(raw, mapping.UserIdentifier, "")),
Email: internal.MapDefaultString(raw, mapping.Email, ""),
@@ -51,6 +41,24 @@ func parseOauthUserInfo(
AdminInfoAvailable: adminInfoAvailable,
}
if err := userInfo.Sanitize(providerType, providerName); err != nil {
return nil, err
}
// check admin group match after sanitization
if !isAdmin && mapping.UserGroups != "" && adminMapping.AdminGroupRegex != "" {
adminInfoAvailable = true
re := adminMapping.GetAdminGroupRegex()
for _, group := range userInfo.UserGroups {
if re.MatchString(group) {
isAdmin = true
break
}
}
userInfo.IsAdmin = isAdmin
userInfo.AdminInfoAvailable = adminInfoAvailable
}
return userInfo, nil
}

View File

@@ -0,0 +1,148 @@
package auth
import (
"errors"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
"github.com/h44z/wg-portal/internal/testutil"
)
// makeOauthFieldMapping returns a minimal OauthFields mapping for testing.
func makeOauthFieldMapping() config.OauthFields {
return config.OauthFields{
BaseFields: config.BaseFields{
UserIdentifier: "sub",
Email: "email",
Firstname: "given_name",
Lastname: "family_name",
Phone: "phone",
Department: "department",
},
}
}
// makeOauthRaw builds a minimal raw OAuth user info map.
func makeOauthRaw(sub, email, givenName, familyName, phone, department string) map[string]any {
return map[string]any{
"sub": sub,
"email": email,
"given_name": givenName,
"family_name": familyName,
"phone": phone,
"department": department,
}
}
// Test: email containing \r\n → output email is "",
// one WARN log entry with field: "email" and cleared indication.
func TestParseOauthUserInfo_CRLFInEmail(t *testing.T) {
mapping := makeOauthFieldMapping()
adminMapping := &config.OauthAdminMapping{}
raw := makeOauthRaw("user123", "user\r\n@example.com", "Alice", "Smith", "", "")
restore := testutil.CaptureWarnLogs(t)
info, err := parseOauthUserInfo(mapping, adminMapping, raw, "oauth", "test-provider")
records := restore()
require.NoError(t, err)
assert.Equal(t, "", info.Email, "email should be cleared when it contains CR/LF")
warnCount := testutil.CountWarnEntries(records)
assert.Equal(t, 1, warnCount, "expected exactly one WARN log entry")
rec, found := testutil.FindWarnWithField(records, "email")
assert.True(t, found, "expected WARN log entry with field=email")
if found {
msg, _ := rec["msg"].(string)
assert.Contains(t, msg, "cleared", "expected 'cleared' in log message when email is cleared")
}
}
// Test: two fields modified (email cleared, firstname truncated) →
// two separate WARN log entries.
func TestParseOauthUserInfo_TwoFieldsModified(t *testing.T) {
mapping := makeOauthFieldMapping()
adminMapping := &config.OauthAdminMapping{}
longFirstname := strings.Repeat("A", 200)
raw := makeOauthRaw("user123", "bad\r\nemail@example.com", longFirstname, "Smith", "", "")
restore := testutil.CaptureWarnLogs(t)
info, err := parseOauthUserInfo(mapping, adminMapping, raw, "oauth", "test-provider")
records := restore()
require.NoError(t, err)
assert.Equal(t, "", info.Email, "email should be cleared")
assert.Equal(t, 128, len([]rune(info.Firstname)), "firstname should be truncated to 128 runes")
warnCount := testutil.CountWarnEntries(records)
assert.Equal(t, 2, warnCount, "expected exactly two WARN log entries (one per modified field)")
_, emailFound := testutil.FindWarnWithField(records, "email")
assert.True(t, emailFound, "expected WARN log entry with field=email")
_, firstnameFound := testutil.FindWarnWithField(records, "firstname")
assert.True(t, firstnameFound, "expected WARN log entry with field=firstname")
}
// Test: identifier "all" → returns ErrInvalidData.
func TestParseOauthUserInfo_IdentifierAll(t *testing.T) {
mapping := makeOauthFieldMapping()
adminMapping := &config.OauthAdminMapping{}
raw := makeOauthRaw("all", "all@example.com", "Alice", "Smith", "", "")
restore := testutil.CaptureWarnLogs(t)
_, err := parseOauthUserInfo(mapping, adminMapping, raw, "oauth", "test-provider")
_ = restore()
require.Error(t, err)
assert.True(t, errors.Is(err, domain.ErrInvalidData), "expected ErrInvalidData when identifier is 'all'")
}
func TestParseOauthUserInfo_DropsModifiedGroupBeforeAdminMatch(t *testing.T) {
mapping := makeOauthFieldMapping()
mapping.UserGroups = "groups"
adminMapping := &config.OauthAdminMapping{
AdminGroupRegex: "^wgportal-admins$",
}
raw := makeOauthRaw("user123", "user@example.com", "Alice", "Smith", "", "")
raw["groups"] = []any{"wgportal-\u200badmins"}
restore := testutil.CaptureWarnLogs(t)
info, err := parseOauthUserInfo(mapping, adminMapping, raw, "oidc", "test-provider")
records := restore()
require.NoError(t, err)
require.NotNil(t, info)
assert.False(t, info.IsAdmin, "sanitization must not repair a modified group into an admin match")
assert.Empty(t, info.UserGroups)
rec, found := testutil.FindWarnWithField(records, "user_group")
assert.True(t, found, "expected WARN log entry with field=user_group")
if found {
assert.Equal(t, "oidc", rec["provider_type"])
}
}
func TestParseOauthUserInfo_AllowsWhitespaceOnlyGroupTrim(t *testing.T) {
mapping := makeOauthFieldMapping()
mapping.UserGroups = "groups"
adminMapping := &config.OauthAdminMapping{
AdminGroupRegex: "^wgportal-admins$",
}
raw := makeOauthRaw("user123", "user@example.com", "Alice", "Smith", "", "")
raw["groups"] = []any{" wgportal-admins "}
info, err := parseOauthUserInfo(mapping, adminMapping, raw, "oidc", "test-provider")
require.NoError(t, err)
require.NotNil(t, info)
assert.True(t, info.IsAdmin)
assert.Equal(t, []string{"wgportal-admins"}, info.UserGroups)
}

View File

@@ -43,7 +43,7 @@ func Test_parseOauthUserInfo_no_admin(t *testing.T) {
})
adminMapping := &config.OauthAdminMapping{}
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo)
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo, "oauth", "test-provider")
assert.NoError(t, err)
assert.False(t, info.IsAdmin)
assert.Equal(t, info.Firstname, "Test User")
@@ -90,7 +90,7 @@ func Test_parseOauthUserInfo_admin_group(t *testing.T) {
AdminGroupRegex: "^wgportal-admins@mydomain.net$",
}
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo)
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo, "oauth", "test-provider")
assert.NoError(t, err)
assert.True(t, info.IsAdmin)
assert.Equal(t, info.Firstname, "Test User")
@@ -132,7 +132,7 @@ func Test_parseOauthUserInfo_admin_value(t *testing.T) {
})
adminMapping := &config.OauthAdminMapping{} // test with default regex
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo)
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo, "oauth", "test-provider")
assert.NoError(t, err)
assert.True(t, info.IsAdmin)
assert.Equal(t, info.Firstname, "Test User")
@@ -175,7 +175,7 @@ func Test_parseOauthUserInfo_admin_value_custom(t *testing.T) {
AdminValueRegex: "^1$",
}
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo)
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo, "oauth", "test-provider")
assert.NoError(t, err)
assert.True(t, info.IsAdmin)
assert.Equal(t, info.Firstname, "Test User")

View File

@@ -0,0 +1,90 @@
package auth
import (
"bytes"
"encoding/json"
"log/slog"
"testing"
"github.com/stretchr/testify/require"
"pgregory.net/rapid"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
"github.com/h44z/wg-portal/internal/testutil"
)
// captureWarnLogsInline redirects the default slog logger to a buffer, calls fn,
// restores the original logger, and returns the captured log records.
func captureWarnLogsInline(fn func()) []map[string]any {
original := slog.Default()
var buf bytes.Buffer
handler := slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelWarn})
slog.SetDefault(slog.New(handler))
fn()
slog.SetDefault(original)
var records []map[string]any
decoder := json.NewDecoder(&buf)
for decoder.More() {
var rec map[string]any
if err := decoder.Decode(&rec); err == nil {
records = append(records, rec)
}
}
return records
}
// Property 7: Sanitization change logging completeness
func TestPropertySanitizationChangeLoggingCompleteness(t *testing.T) {
mapping := makeOauthFieldMapping()
adminMapping := &config.OauthAdminMapping{}
rapid.Check(t, func(t *rapid.T) {
sub := rapid.StringMatching(`[a-zA-Z0-9_@.-]{1,50}`).Draw(t, "sub")
email := rapid.String().Draw(t, "email")
firstname := rapid.String().Draw(t, "firstname")
lastname := rapid.String().Draw(t, "lastname")
phone := rapid.String().Draw(t, "phone")
department := rapid.String().Draw(t, "department")
if sub == "" {
sub = "testuser"
}
raw := makeOauthRaw(sub, email, firstname, lastname, phone, department)
// Count how many fields will actually change after sanitization
expectedChanges := 0
if domain.SanitizeIdentifier(sub, 256) != sub {
expectedChanges++
}
if domain.SanitizeEmail(email, 254) != email {
expectedChanges++
}
if domain.SanitizeString(firstname, 128) != firstname {
expectedChanges++
}
if domain.SanitizeString(lastname, 128) != lastname {
expectedChanges++
}
if domain.SanitizePhone(phone, 50) != phone {
expectedChanges++
}
if domain.SanitizeString(department, 128) != department {
expectedChanges++
}
var records []map[string]any
records = captureWarnLogsInline(func() {
_, _ = parseOauthUserInfo(mapping, adminMapping, raw, "oauth", "test-provider")
})
actualWarnCount := testutil.CountWarnEntries(records)
require.Equal(t, expectedChanges, actualWarnCount,
"number of WARN log entries (%d) must equal number of fields changed by sanitization (%d)",
actualWarnCount, expectedChanges)
})
}

View File

@@ -28,7 +28,7 @@ func convertRawLdapUser(
uid := domain.UserIdentifier(internal.MapDefaultString(rawUser, fields.UserIdentifier, ""))
return &domain.User{
user := &domain.User{
BaseModel: domain.BaseModel{
CreatedBy: domain.CtxSystemLdapSyncer,
UpdatedBy: domain.CtxSystemLdapSyncer,
@@ -49,10 +49,16 @@ func convertRawLdapUser(
Lastname: internal.MapDefaultString(rawUser, fields.Lastname, ""),
Phone: internal.MapDefaultString(rawUser, fields.Phone, ""),
Department: internal.MapDefaultString(rawUser, fields.Department, ""),
Notes: "",
Password: "",
Disabled: nil,
}, nil
}
if err := user.SanitizeExternalData("ldap", providerName); err != nil {
return nil, err
}
// Update authentication identifier after sanitization
user.Authentications[0].UserIdentifier = user.Identifier
return user, nil
}
func userChangedInLdap(dbUser, ldapUser *domain.User) bool {

View File

@@ -0,0 +1,136 @@
package users
import (
"errors"
"testing"
"github.com/go-ldap/ldap/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
"github.com/h44z/wg-portal/internal/testutil"
)
// makeTestLdapFields returns a minimal LdapFields config for testing.
func makeTestLdapFields() *config.LdapFields {
return &config.LdapFields{
BaseFields: config.BaseFields{
UserIdentifier: "uid",
Email: "mail",
Firstname: "givenName",
Lastname: "sn",
Phone: "telephoneNumber",
Department: "department",
},
GroupMembership: "memberOf",
}
}
// makeTestAdminGroupDN returns a parsed DN for testing (a non-matching group).
func makeTestAdminGroupDN(t *testing.T) *ldap.DN {
t.Helper()
dn, err := ldap.ParseDN("cn=admins,dc=example,dc=com")
require.NoError(t, err)
return dn
}
// makeRawLdapUser builds a raw LDAP user map for convertRawLdapUser.
func makeRawLdapUser(uid, mail, givenName, sn, phone, department string) map[string]any {
return map[string]any{
"uid": uid,
"mail": mail,
"givenName": givenName,
"sn": sn,
"telephoneNumber": phone,
"department": department,
"memberOf": [][]byte{}, // no group memberships
}
}
// Test: identifier "all" → returns ErrInvalidData,
// one WARN log entry with field: "identifier" and cleared indication.
func TestConvertRawLdapUser_IdentifierAll(t *testing.T) {
fields := makeTestLdapFields()
adminGroupDN := makeTestAdminGroupDN(t)
raw := makeRawLdapUser("all", "all@example.com", "Alice", "Smith", "", "")
restore := testutil.CaptureWarnLogs(t)
user, err := convertRawLdapUser("test-ldap", raw, fields, adminGroupDN)
records := restore()
require.Error(t, err)
assert.True(t, errors.Is(err, domain.ErrInvalidData), "expected ErrInvalidData when identifier is 'all'")
assert.Nil(t, user)
warnCount := testutil.CountWarnEntries(records)
assert.Equal(t, 1, warnCount, "expected exactly one WARN log entry")
rec, found := testutil.FindWarnWithField(records, "identifier")
assert.True(t, found, "expected WARN log entry with field=identifier")
if found {
msg, _ := rec["msg"].(string)
assert.Contains(t, msg, "cleared", "expected 'cleared' in log message when identifier is cleared")
}
}
// Test: firstname contains \x00 → output firstname has null byte removed,
// one WARN log entry with field: "firstname".
func TestConvertRawLdapUser_NullByteInFirstname(t *testing.T) {
fields := makeTestLdapFields()
adminGroupDN := makeTestAdminGroupDN(t)
raw := makeRawLdapUser("alice", "alice@example.com", "Ali\x00ce", "Smith", "", "")
restore := testutil.CaptureWarnLogs(t)
user, err := convertRawLdapUser("test-ldap", raw, fields, adminGroupDN)
records := restore()
require.NoError(t, err)
require.NotNil(t, user)
assert.NotContains(t, user.Firstname, "\x00", "firstname should have null byte removed")
assert.Equal(t, "Alice", user.Firstname)
warnCount := testutil.CountWarnEntries(records)
assert.Equal(t, 1, warnCount, "expected exactly one WARN log entry")
rec, found := testutil.FindWarnWithField(records, "firstname")
assert.True(t, found, "expected WARN log entry with field=firstname")
if found {
assert.Equal(t, "WARN", rec["level"])
}
}
// Test: all fields clean → no WARN log entries emitted.
func TestConvertRawLdapUser_AllFieldsClean(t *testing.T) {
fields := makeTestLdapFields()
adminGroupDN := makeTestAdminGroupDN(t)
raw := makeRawLdapUser("alice", "alice@example.com", "Alice", "Smith", "+1 555-1234", "Engineering")
restore := testutil.CaptureWarnLogs(t)
user, err := convertRawLdapUser("test-ldap", raw, fields, adminGroupDN)
records := restore()
require.NoError(t, err)
require.NotNil(t, user)
assert.Equal(t, domain.UserIdentifier("alice"), user.Identifier)
warnCount := testutil.CountWarnEntries(records)
assert.Equal(t, 0, warnCount, "expected no WARN log entries when all fields are clean")
}
func TestLdapUserIdentifier_NormalizesSyncComparisons(t *testing.T) {
raw := map[string]any{"uid": " alice\x00 "}
got := ldapUserIdentifier(raw, "uid")
assert.Equal(t, domain.UserIdentifier("alice"), got)
}
func TestLdapUserIdentifier_RejectsReservedIdentifier(t *testing.T) {
raw := map[string]any{"uid": " all "}
got := ldapUserIdentifier(raw, "uid")
assert.Empty(t, got)
}

View File

@@ -109,6 +109,11 @@ func (m Manager) updateLdapUsers(
for _, rawUser := range rawUsers {
user, err := convertRawLdapUser(provider.ProviderName, rawUser, fields, adminGroupDN)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
if errors.Is(err, domain.ErrInvalidData) {
slog.Warn("skipping LDAP user with invalid data after sanitization",
"raw-dn", rawUser["dn"], "error", err)
continue
}
return fmt.Errorf("failed to convert LDAP data for %v: %w", rawUser["dn"], err)
}
@@ -212,7 +217,7 @@ func (m Manager) disableMissingLdapUsers(
existsInLDAP := false
for _, rawUser := range rawUsers {
userId := domain.UserIdentifier(internal.MapDefaultString(rawUser, fields.UserIdentifier, ""))
userId := ldapUserIdentifier(rawUser, fields.UserIdentifier)
if user.Identifier == userId {
existsInLDAP = true
break
@@ -258,19 +263,19 @@ func (m Manager) updateInterfaceLdapFilters(
// Combined filter: user must match the provider's base SyncFilter AND the interface's LdapGroupFilter
combinedFilter := fmt.Sprintf("(&(%s)(%s))", provider.SyncFilter, groupFilter)
rawUsers, err := internal.LdapFindAllUsers(conn, provider.BaseDN, combinedFilter, &provider.FieldMap)
if err != nil {
slog.Error("failed to find users for interface filter",
"interface", ifaceId,
"provider", provider.ProviderName,
slog.Error("failed to find users for interface filter",
"interface", ifaceId,
"provider", provider.ProviderName,
"error", err)
continue
}
matchedUserIds := make([]domain.UserIdentifier, 0, len(rawUsers))
for _, rawUser := range rawUsers {
userId := domain.UserIdentifier(internal.MapDefaultString(rawUser, provider.FieldMap.UserIdentifier, ""))
userId := ldapUserIdentifier(rawUser, provider.FieldMap.UserIdentifier)
if userId != "" {
matchedUserIds = append(matchedUserIds, userId)
}
@@ -285,17 +290,26 @@ func (m Manager) updateInterfaceLdapFilters(
return i, nil
})
if err != nil {
slog.Error("failed to save interface ldap allowed users",
"interface", ifaceId,
"provider", provider.ProviderName,
slog.Error("failed to save interface ldap allowed users",
"interface", ifaceId,
"provider", provider.ProviderName,
"error", err)
} else {
slog.Debug("updated interface ldap allowed users",
"interface", ifaceId,
"provider", provider.ProviderName,
slog.Debug("updated interface ldap allowed users",
"interface", ifaceId,
"provider", provider.ProviderName,
"matched_count", len(matchedUserIds))
}
}
return nil
}
func ldapUserIdentifier(rawUser map[string]any, field string) domain.UserIdentifier {
identifier := internal.MapDefaultString(rawUser, field, "")
identifier = domain.SanitizeIdentifier(identifier, 256)
if identifier == "" {
return ""
}
return domain.UserIdentifier(identifier)
}

View File

@@ -365,19 +365,7 @@ func (m Manager) validateCreation(ctx context.Context, new *domain.User) error {
return fmt.Errorf("invalid user identifier: %w", domain.ErrInvalidData)
}
if new.Identifier == "all" { // the 'all' user identifier collides with the rest api routes
return fmt.Errorf("reserved user identifier: %w", domain.ErrInvalidData)
}
if new.Identifier == "new" { // the 'new' user identifier collides with the rest api routes
return fmt.Errorf("reserved user identifier: %w", domain.ErrInvalidData)
}
if new.Identifier == "id" { // the 'id' user identifier collides with the rest api routes
return fmt.Errorf("reserved user identifier: %w", domain.ErrInvalidData)
}
if new.Identifier == domain.CtxSystemAdminId || new.Identifier == domain.CtxUnknownUserId {
if domain.IsReservedUserIdentifier(new.Identifier) {
return fmt.Errorf("reserved user identifier: %w", domain.ErrInvalidData)
}

View File

@@ -1,5 +1,10 @@
package domain
import (
"fmt"
"strings"
)
type LoginProvider string
type LoginProviderInfo struct {
@@ -20,3 +25,52 @@ type AuthenticatorUserInfo struct {
IsAdmin bool
AdminInfoAvailable bool // true if the IsAdmin flag is valid
}
// Sanitize sanitizes all external identity provider fields in place.
// Returns ErrInvalidData if the identifier becomes empty after sanitization.
func (u *AuthenticatorUserInfo) Sanitize(providerType, providerName string) error {
identifier := string(u.Identifier)
LogSanitizeChange(providerType, providerName, "identifier", identifier,
func() string { return SanitizeIdentifier(identifier, 256) }, &identifier)
u.Identifier = UserIdentifier(identifier)
email := u.Email
LogSanitizeChange(providerType, providerName, "email", email,
func() string { return SanitizeEmail(email, 254) }, &u.Email)
LogSanitizeChange(providerType, providerName, "firstname", u.Firstname,
func() string { return SanitizeString(u.Firstname, 128) }, &u.Firstname)
LogSanitizeChange(providerType, providerName, "lastname", u.Lastname,
func() string { return SanitizeString(u.Lastname, 128) }, &u.Lastname)
LogSanitizeChange(providerType, providerName, "phone", u.Phone,
func() string { return SanitizePhone(u.Phone, 50) }, &u.Phone)
LogSanitizeChange(providerType, providerName, "department", u.Department,
func() string { return SanitizeString(u.Department, 128) }, &u.Department)
u.UserGroups = sanitizeGroups(providerType, providerName, u.UserGroups)
if u.Identifier == "" {
return fmt.Errorf("empty user identifier: %w", ErrInvalidData)
}
return nil
}
// sanitizeGroups sanitizes group names, dropping any that were modified by sanitization.
func sanitizeGroups(providerType, providerName string, rawGroups []string) []string {
if len(rawGroups) == 0 {
return rawGroups
}
groups := make([]string, 0, len(rawGroups))
for _, rawGroup := range rawGroups {
sanitized := rawGroup
LogSanitizeChange(providerType, providerName, "user_group", rawGroup,
func() string { return SanitizeString(rawGroup, 256) }, &sanitized)
if sanitized == "" || sanitized != strings.TrimSpace(rawGroup) {
continue
}
groups = append(groups, sanitized)
}
return groups
}

View File

@@ -0,0 +1,103 @@
package domain
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/h44z/wg-portal/internal/testutil"
)
func TestAuthenticatorUserInfo_Sanitize_NullByteInFirstname(t *testing.T) {
info := &AuthenticatorUserInfo{
Identifier: "alice",
Email: "alice@example.com",
Firstname: "Ali\x00ce",
Lastname: "Smith",
}
restore := testutil.CaptureWarnLogs(t)
err := info.Sanitize("ldap", "test-provider")
records := restore()
require.NoError(t, err)
assert.Equal(t, "Alice", info.Firstname)
warnCount := testutil.CountWarnEntries(records)
assert.Equal(t, 1, warnCount)
_, found := testutil.FindWarnWithField(records, "firstname")
assert.True(t, found)
}
func TestAuthenticatorUserInfo_Sanitize_AllFieldsClean(t *testing.T) {
info := &AuthenticatorUserInfo{
Identifier: "alice",
Email: "alice@example.com",
Firstname: "Alice",
Lastname: "Smith",
Phone: "+1 555-1234",
Department: "Engineering",
}
restore := testutil.CaptureWarnLogs(t)
err := info.Sanitize("ldap", "test-provider")
records := restore()
require.NoError(t, err)
assert.Equal(t, UserIdentifier("alice"), info.Identifier)
assert.Equal(t, 0, testutil.CountWarnEntries(records))
}
func TestAuthenticatorUserInfo_Sanitize_IdentifierAll(t *testing.T) {
info := &AuthenticatorUserInfo{
Identifier: "all",
Email: "all@example.com",
Firstname: "Alice",
Lastname: "Smith",
}
err := info.Sanitize("ldap", "test-provider")
require.Error(t, err)
assert.True(t, errors.Is(err, ErrInvalidData))
}
func TestAuthenticatorUserInfo_Sanitize_CRLFInEmail(t *testing.T) {
info := &AuthenticatorUserInfo{
Identifier: "user123",
Email: "user\r\n@example.com",
Firstname: "Alice",
Lastname: "Smith",
}
restore := testutil.CaptureWarnLogs(t)
err := info.Sanitize("oauth", "test-provider")
records := restore()
require.NoError(t, err)
assert.Equal(t, "", info.Email)
_, found := testutil.FindWarnWithField(records, "email")
assert.True(t, found)
}
func TestAuthenticatorUserInfo_Sanitize_GroupsWithZeroWidthChars(t *testing.T) {
info := &AuthenticatorUserInfo{
Identifier: "user123",
Email: "user@example.com",
UserGroups: []string{"wgportal-\u200badmins"},
}
restore := testutil.CaptureWarnLogs(t)
err := info.Sanitize("oidc", "test-provider")
records := restore()
require.NoError(t, err)
assert.Empty(t, info.UserGroups)
_, found := testutil.FindWarnWithField(records, "user_group")
assert.True(t, found)
}

151
internal/domain/sanitize.go Normal file
View File

@@ -0,0 +1,151 @@
package domain
import (
"log/slog"
"net/mail"
"strings"
"unicode"
"unicode/utf8"
"golang.org/x/text/unicode/norm"
)
// LogSanitizeChange applies sanitizeFn to raw, logs when the value changes, and writes
// the sanitized value to dest. Raw and sanitized values are intentionally omitted.
func LogSanitizeChange(
providerType string,
providerName string,
field string,
raw string,
sanitizeFn func() string,
dest *string,
) {
sanitized := sanitizeFn()
if sanitized != raw {
message := "sanitization modified field value from external provider"
if sanitized == "" {
message = "sanitization cleared field value from external provider"
}
slog.Warn(message,
"provider_type", SanitizeString(providerType, 64),
"provider", SanitizeString(providerName, 128),
"field", SanitizeString(field, 64),
)
}
*dest = sanitized
}
var reservedUserIdentifiers = map[string]struct{}{
"all": {},
"new": {},
"id": {},
CtxSystemAdminId: {},
CtxUnknownUserId: {},
CtxSystemLdapSyncer: {},
CtxSystemWgImporter: {},
CtxSystemV1Migrator: {},
CtxSystemDBMigrator: {},
}
// SanitizeString normalizes to NFC, trims leading and trailing whitespace, strips Unicode
// control and format characters, drops invalid UTF-8 bytes, and truncates the result to
// maxLen runes. If maxLen <= 0, returns "".
func SanitizeString(s string, maxLen int) string {
if maxLen <= 0 {
return ""
}
s = norm.NFC.String(strings.TrimSpace(s))
var b strings.Builder
b.Grow(len(s))
for len(s) > 0 {
r, size := utf8.DecodeRuneInString(s)
s = s[size:]
if r == utf8.RuneError && size == 1 {
continue
}
if !unicode.IsControl(r) && !unicode.Is(unicode.Cf, r) {
b.WriteRune(r)
}
}
s = b.String()
if utf8.RuneCountInString(s) > maxLen {
runes := []rune(s)
s = string(runes[:maxLen])
}
return strings.TrimSpace(s)
}
// SanitizeEmail applies SanitizeString first, then returns "" if the original s
// contains CR/LF or if the sanitized result is not a plain email address.
func SanitizeEmail(s string, maxLen int) string {
if strings.ContainsRune(s, '\r') || strings.ContainsRune(s, '\n') {
return ""
}
sanitized := SanitizeString(s, maxLen)
if sanitized == "" || strings.Count(sanitized, "@") != 1 {
return ""
}
addr, err := mail.ParseAddress(sanitized)
if err != nil || addr.Name != "" || addr.Address != sanitized {
return ""
}
return sanitized
}
// SanitizePhone applies SanitizeString first, then removes all characters not in the
// set [0-9+\-() .]. Returns "" if the result after filtering is empty.
func SanitizePhone(s string, maxLen int) string {
sanitized := SanitizeString(s, maxLen)
// Remove all characters not in [0-9+\-() .]
var b strings.Builder
b.Grow(len(sanitized))
for _, r := range sanitized {
if isAllowedPhoneRune(r) {
b.WriteRune(r)
}
}
result := strings.TrimSpace(b.String())
if result == "" {
return ""
}
return result
}
// isAllowedPhoneRune reports whether r is in the allowed phone character set [0-9+\-() .].
func isAllowedPhoneRune(r rune) bool {
switch {
case r >= '0' && r <= '9':
return true
case r == '+', r == '-', r == '(', r == ')', r == ' ', r == '.':
return true
default:
return false
}
}
// SanitizeIdentifier applies SanitizeString first, then returns "" if the result equals
// a reserved user identifier (case-sensitive, exact match).
func SanitizeIdentifier(s string, maxLen int) string {
sanitized := SanitizeString(s, maxLen)
if IsReservedUserIdentifier(UserIdentifier(sanitized)) {
return ""
}
return sanitized
}
func IsReservedUserIdentifier(identifier UserIdentifier) bool {
_, reserved := reservedUserIdentifiers[string(identifier)]
return reserved
}

View File

@@ -0,0 +1,503 @@
package domain
import (
"net/mail"
"strings"
"testing"
"unicode"
"unicode/utf8"
"pgregory.net/rapid"
)
func TestSanitizeString(t *testing.T) {
tests := []struct {
name string
input string
maxLen int
want string
}{
{
name: "null byte removed",
input: "\x00",
maxLen: 64,
want: "",
},
{
name: "CR removed",
input: "\r",
maxLen: 64,
want: "",
},
{
name: "LF removed",
input: "\n",
maxLen: 64,
want: "",
},
{
name: "tab removed",
input: "\t",
maxLen: 64,
want: "",
},
{
name: "leading and trailing whitespace trimmed",
input: " hello ",
maxLen: 64,
want: "hello",
},
{
name: "multi-byte UTF-8 truncation at rune boundary",
input: "héllo",
maxLen: 3,
want: "hél", // 3 runes, not 3 bytes
},
{
name: "empty input",
input: "",
maxLen: 64,
want: "",
},
{
name: "maxLen zero returns empty",
input: "hello",
maxLen: 0,
want: "",
},
{
name: "string longer than maxLen truncated",
input: "abcdefgh",
maxLen: 4,
want: "abcd",
},
{
name: "mixed control chars and normal chars",
input: "hel\x00lo\r\nworld",
maxLen: 64,
want: "helloworld",
},
{
name: "only whitespace returns empty",
input: " ",
maxLen: 64,
want: "",
},
{
name: "string exactly at maxLen not truncated",
input: "abc",
maxLen: 3,
want: "abc",
},
{
name: "negative maxLen returns empty",
input: "hello",
maxLen: -1,
want: "",
},
{
name: "DEL control removed",
input: "hel\x7flo",
maxLen: 64,
want: "hello",
},
{
name: "zero-width format character removed",
input: "ali\u200bce",
maxLen: 64,
want: "alice",
},
{
name: "invalid UTF-8 byte removed",
input: "a\xffb",
maxLen: 64,
want: "ab",
},
{
name: "unicode normalized to NFC",
input: "e\u0301",
maxLen: 64,
want: "é",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := SanitizeString(tc.input, tc.maxLen)
if got != tc.want {
t.Errorf("SanitizeString(%q, %d) = %q; want %q", tc.input, tc.maxLen, got, tc.want)
}
})
}
}
func TestSanitizeEmail(t *testing.T) {
tests := []struct {
name string
input string
maxLen int
want string
}{
{
name: "valid email passes through unchanged",
input: "user@example.com",
maxLen: 254,
want: "user@example.com",
},
{
name: "CR in email returns empty",
input: "user\r@example.com",
maxLen: 254,
want: "",
},
{
name: "LF in email returns empty",
input: "user\n@example.com",
maxLen: 254,
want: "",
},
{
name: "missing @ returns empty",
input: "userexample.com",
maxLen: 254,
want: "",
},
{
name: "whitespace-only returns empty",
input: " ",
maxLen: 254,
want: "",
},
{
name: "email with leading/trailing whitespace trimmed and returned",
input: " user@example.com ",
maxLen: 254,
want: "user@example.com",
},
{
name: "empty input returns empty",
input: "",
maxLen: 254,
want: "",
},
{
name: "display-name address rejected",
input: "User <user@example.com>",
maxLen: 254,
want: "",
},
{
name: "multiple at signs rejected",
input: "user@@example.com",
maxLen: 254,
want: "",
},
{
name: "invalid address rejected",
input: "user@",
maxLen: 254,
want: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := SanitizeEmail(tc.input, tc.maxLen)
if got != tc.want {
t.Errorf("SanitizeEmail(%q, %d) = %q; want %q", tc.input, tc.maxLen, got, tc.want)
}
})
}
}
func TestSanitizePhone(t *testing.T) {
tests := []struct {
name string
input string
maxLen int
want string
}{
{
name: "valid phone passes through unchanged",
input: "+1 (555) 123-4567",
maxLen: 50,
want: "+1 (555) 123-4567",
},
{
name: "non-allowed chars stripped",
input: "abc+1def",
maxLen: 50,
want: "+1",
},
{
name: "all-stripped input returns empty",
input: "abc",
maxLen: 50,
want: "",
},
{
name: "mixed allowed and non-allowed chars",
input: "+49 (0) 123-456.789",
maxLen: 50,
want: "+49 (0) 123-456.789",
},
{
name: "empty input returns empty",
input: "",
maxLen: 50,
want: "",
},
{
name: "only digits passes through",
input: "1234567890",
maxLen: 50,
want: "1234567890",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := SanitizePhone(tc.input, tc.maxLen)
if got != tc.want {
t.Errorf("SanitizePhone(%q, %d) = %q; want %q", tc.input, tc.maxLen, got, tc.want)
}
})
}
}
func TestSanitizeIdentifier(t *testing.T) {
tests := []struct {
name string
input string
maxLen int
want string
}{
{
name: "reserved value all returns empty",
input: "all",
maxLen: 256,
want: "",
},
{
name: "all with surrounding whitespace returns empty",
input: " all ",
maxLen: 256,
want: "",
},
{
name: "reserved value new returns empty",
input: "new",
maxLen: 256,
want: "",
},
{
name: "reserved value id returns empty",
input: "id",
maxLen: 256,
want: "",
},
{
name: "system admin identifier returns empty",
input: string(CtxSystemAdminId),
maxLen: 256,
want: "",
},
{
name: "unknown user identifier returns empty",
input: string(CtxUnknownUserId),
maxLen: 256,
want: "",
},
{
name: "LDAP syncer identifier returns empty",
input: string(CtxSystemLdapSyncer),
maxLen: 256,
want: "",
},
{
name: "ALL uppercase passes through (case-sensitive)",
input: "ALL",
maxLen: 256,
want: "ALL",
},
{
name: "valid email identifier passes through",
input: "alice@example.com",
maxLen: 256,
want: "alice@example.com",
},
{
name: "normal identifier passes through",
input: "alice",
maxLen: 256,
want: "alice",
},
{
name: "empty input returns empty",
input: "",
maxLen: 256,
want: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := SanitizeIdentifier(tc.input, tc.maxLen)
if got != tc.want {
t.Errorf("SanitizeIdentifier(%q, %d) = %q; want %q", tc.input, tc.maxLen, got, tc.want)
}
})
}
}
func TestSanitizeXSSPayload(t *testing.T) {
// XSS payload: null byte removed, angle brackets preserved
input := "<script>\x00alert(1)</script>"
want := "<script>alert(1)</script>"
got := SanitizeString(input, 256)
if got != want {
t.Errorf("SanitizeString(%q, 256) = %q; want %q", input, got, want)
}
}
// ---------------------------------------------------------------------------
// Property 1: SanitizeString output invariants
// ---------------------------------------------------------------------------
// Feature: external-identity-sanitization, Property 1: SanitizeString output is free of control characters and bounded in length
func TestPropertySanitizeStringOutputInvariants(t *testing.T) {
rapid.Check(t, func(t *rapid.T) {
s := rapid.String().Draw(t, "s")
maxLen := rapid.IntRange(0, 512).Draw(t, "maxLen")
result := SanitizeString(s, maxLen)
// No control or format runes in result
for _, r := range result {
if unicode.IsControl(r) || unicode.Is(unicode.Cf, r) {
t.Fatalf("result %q contains unsafe character %U", result, r)
}
}
if !utf8.ValidString(result) {
t.Fatalf("result %q is not valid UTF-8", result)
}
// No leading or trailing whitespace
if result != strings.TrimSpace(result) {
t.Fatalf("result %q has leading or trailing whitespace", result)
}
// Rune count <= maxLen
runeCount := utf8.RuneCountInString(result)
if runeCount > maxLen {
t.Fatalf("result %q has %d runes, exceeds maxLen %d", result, runeCount, maxLen)
}
})
}
// ---------------------------------------------------------------------------
// Property 2: SanitizeString is idempotent
// ---------------------------------------------------------------------------
// Feature: external-identity-sanitization, Property 2: SanitizeString is idempotent
func TestPropertySanitizeStringIdempotent(t *testing.T) {
rapid.Check(t, func(t *rapid.T) {
s := rapid.String().Draw(t, "s")
maxLen := rapid.IntRange(1, 512).Draw(t, "maxLen")
once := SanitizeString(s, maxLen)
twice := SanitizeString(once, maxLen)
if once != twice {
t.Fatalf("SanitizeString is not idempotent: once=%q, twice=%q (input=%q, maxLen=%d)",
once, twice, s, maxLen)
}
})
}
// ---------------------------------------------------------------------------
// Property 3: SanitizeEmail rejection rules
// ---------------------------------------------------------------------------
// Feature: external-identity-sanitization, Property 3: SanitizeEmail rejects strings without "@" or containing CR/LF
func TestPropertySanitizeEmailRejectionRules(t *testing.T) {
rapid.Check(t, func(t *rapid.T) {
s := rapid.String().Draw(t, "s")
maxLen := rapid.IntRange(1, 512).Draw(t, "maxLen")
result := SanitizeEmail(s, maxLen)
sanitized := SanitizeString(s, maxLen)
addr, parseErr := mail.ParseAddress(sanitized)
reject := strings.ContainsAny(s, "\r\n") ||
sanitized == "" ||
strings.Count(sanitized, "@") != 1 ||
parseErr != nil ||
addr.Name != "" ||
addr.Address != sanitized
if reject {
if result != "" {
t.Fatalf("SanitizeEmail(%q, %d) = %q; expected empty string (contains CR/LF or no @)",
s, maxLen, result)
}
}
})
}
// ---------------------------------------------------------------------------
// Property 4: SanitizePhone allowed character set
// ---------------------------------------------------------------------------
// isAllowedPhoneCharTest mirrors the internal isAllowedPhoneRune logic for test assertions.
func isAllowedPhoneCharTest(r rune) bool {
switch {
case r >= '0' && r <= '9':
return true
case r == '+', r == '-', r == '(', r == ')', r == ' ', r == '.':
return true
default:
return false
}
}
// Feature: external-identity-sanitization, Property 4: SanitizePhone output contains only allowed characters
func TestPropertySanitizePhoneAllowedChars(t *testing.T) {
rapid.Check(t, func(t *rapid.T) {
s := rapid.String().Draw(t, "s")
maxLen := rapid.IntRange(1, 512).Draw(t, "maxLen")
result := SanitizePhone(s, maxLen)
for _, r := range result {
if !isAllowedPhoneCharTest(r) {
t.Fatalf("SanitizePhone(%q, %d) = %q; contains disallowed rune %U (%c)",
s, maxLen, result, r, r)
}
}
})
}
// ---------------------------------------------------------------------------
// Property 5: SanitizeIdentifier rejects reserved identifiers
// ---------------------------------------------------------------------------
// Feature: external-identity-sanitization, Property 5: SanitizeIdentifier rejects reserved values
func TestPropertySanitizeIdentifierRejectsReservedValues(t *testing.T) {
rapid.Check(t, func(t *rapid.T) {
s := rapid.String().Draw(t, "s")
maxLen := rapid.IntRange(1, 512).Draw(t, "maxLen")
result := SanitizeIdentifier(s, maxLen)
sanitized := SanitizeString(s, maxLen)
_, reserved := reservedUserIdentifiers[sanitized]
if reserved {
if result != "" {
t.Fatalf("SanitizeIdentifier(%q, %d) = %q; expected empty string when sanitized is reserved",
s, maxLen, result)
}
} else {
if result != sanitized {
t.Fatalf("SanitizeIdentifier(%q, %d) = %q; expected %q (== SanitizeString result)",
s, maxLen, result, sanitized)
}
}
})
}

View File

@@ -282,6 +282,32 @@ func (u *User) CreateDefaultPeers() bool {
return true
}
// SanitizeExternalData sanitizes user profile fields received from an external identity provider.
// Returns ErrInvalidData if the identifier becomes empty after sanitization.
func (u *User) SanitizeExternalData(providerType, providerName string) error {
identifier := string(u.Identifier)
LogSanitizeChange(providerType, providerName, "identifier", identifier,
func() string { return SanitizeIdentifier(identifier, 256) }, &identifier)
u.Identifier = UserIdentifier(identifier)
LogSanitizeChange(providerType, providerName, "email", u.Email,
func() string { return SanitizeEmail(u.Email, 254) }, &u.Email)
LogSanitizeChange(providerType, providerName, "firstname", u.Firstname,
func() string { return SanitizeString(u.Firstname, 128) }, &u.Firstname)
LogSanitizeChange(providerType, providerName, "lastname", u.Lastname,
func() string { return SanitizeString(u.Lastname, 128) }, &u.Lastname)
LogSanitizeChange(providerType, providerName, "phone", u.Phone,
func() string { return SanitizePhone(u.Phone, 50) }, &u.Phone)
LogSanitizeChange(providerType, providerName, "department", u.Department,
func() string { return SanitizeString(u.Department, 128) }, &u.Department)
if u.Identifier == "" {
return fmt.Errorf("empty user identifier: %w", ErrInvalidData)
}
return nil
}
// region webauthn
func (u *User) WebAuthnID() []byte {

View File

@@ -0,0 +1,64 @@
package domain
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/h44z/wg-portal/internal/testutil"
)
func TestUser_SanitizeExternalData_NullByteInFirstname(t *testing.T) {
u := &User{
Identifier: "alice",
Email: "alice@example.com",
Firstname: "Ali\x00ce",
Lastname: "Smith",
}
restore := testutil.CaptureWarnLogs(t)
err := u.SanitizeExternalData("ldap", "test-provider")
records := restore()
require.NoError(t, err)
assert.Equal(t, "Alice", u.Firstname)
assert.Equal(t, 1, testutil.CountWarnEntries(records))
_, found := testutil.FindWarnWithField(records, "firstname")
assert.True(t, found)
}
func TestUser_SanitizeExternalData_IdentifierAll(t *testing.T) {
u := &User{
Identifier: "all",
Email: "all@example.com",
Firstname: "Alice",
Lastname: "Smith",
}
err := u.SanitizeExternalData("ldap", "test-provider")
require.Error(t, err)
assert.True(t, errors.Is(err, ErrInvalidData))
}
func TestUser_SanitizeExternalData_AllFieldsClean(t *testing.T) {
u := &User{
Identifier: "alice",
Email: "alice@example.com",
Firstname: "Alice",
Lastname: "Smith",
Phone: "+1 555-1234",
Department: "Engineering",
}
restore := testutil.CaptureWarnLogs(t)
err := u.SanitizeExternalData("ldap", "test-provider")
records := restore()
require.NoError(t, err)
assert.Equal(t, UserIdentifier("alice"), u.Identifier)
assert.Equal(t, 0, testutil.CountWarnEntries(records))
}

32
internal/sanitize/log.go Normal file
View File

@@ -0,0 +1,32 @@
package sanitize
import (
"log/slog"
"github.com/h44z/wg-portal/internal/domain"
)
// LogChange applies sanitizeFn to raw, logs when the value changes, and writes
// the sanitized value to dest. Raw and sanitized values are intentionally omitted.
func LogChange(
providerType string,
providerName string,
field string,
raw string,
sanitizeFn func() string,
dest *string,
) {
sanitized := sanitizeFn()
if sanitized != raw {
message := "sanitization modified field value from external provider"
if sanitized == "" {
message = "sanitization cleared field value from external provider"
}
slog.Warn(message,
"provider_type", domain.SanitizeString(providerType, 64),
"provider", domain.SanitizeString(providerName, 128),
"field", domain.SanitizeString(field, 64),
)
}
*dest = sanitized
}

View File

@@ -0,0 +1,50 @@
package testutil
import (
"bytes"
"encoding/json"
"log/slog"
"testing"
)
func CaptureWarnLogs(t *testing.T) (restore func() []map[string]any) {
t.Helper()
original := slog.Default()
var buf bytes.Buffer
handler := slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelWarn})
slog.SetDefault(slog.New(handler))
return func() []map[string]any {
slog.SetDefault(original)
var records []map[string]any
decoder := json.NewDecoder(&buf)
for decoder.More() {
var rec map[string]any
if err := decoder.Decode(&rec); err == nil {
records = append(records, rec)
}
}
return records
}
}
func CountWarnEntries(records []map[string]any) int {
count := 0
for _, r := range records {
if lvl, ok := r["level"].(string); ok && lvl == "WARN" {
count++
}
}
return count
}
func FindWarnWithField(records []map[string]any, fieldName string) (map[string]any, bool) {
for _, r := range records {
if lvl, ok := r["level"].(string); ok && lvl == "WARN" {
if f, ok := r["field"].(string); ok && f == fieldName {
return r, true
}
}
}
return nil, false
}