diff --git a/config.yml.sample b/config.yml.sample index 7033456..3be9ce5 100644 --- a/config.yml.sample +++ b/config.yml.sample @@ -113,4 +113,4 @@ backend: api_verify_tls: true api_timeout: 30s concurrency: 5 - debug: false \ No newline at end of file + debug: false diff --git a/docs/documentation/usage/security.md b/docs/documentation/usage/security.md index 08024b2..224a8b4 100644 --- a/docs/documentation/usage/security.md +++ b/docs/documentation/usage/security.md @@ -8,6 +8,23 @@ To enable encryption, set the [`encryption_passphrase`](../configuration/overvie > :warning: Important: Once encryption is enabled, it cannot be disabled, and the passphrase cannot be changed! > Only new or updated records will be encrypted; existing data remains in plaintext until it’s next modified. +## External Identity Provider Data Sanitization + +When users authenticate via LDAP, OIDC, or OAuth, WireGuard Portal sanitizes the field values received from the provider before storing them. This protects against several classes of attack that a compromised or misconfigured identity provider could introduce: + +- **Unsafe control characters** — Unicode control and format characters, null bytes, and invalid UTF-8 bytes are stripped from external profile fields before they reach the Vue.js UI or email templates. +- **Email header injection** — carriage return and line feed characters in email fields are rejected entirely, and email fields must parse as plain email addresses. +- **Log injection** — unsafe control and format characters are stripped from all external profile fields and from sanitization log context. +- **Denial of service via oversized fields** — field lengths are capped (e.g., 256 runes for identifiers, 254 characters for email addresses). +- **Reserved identifier collision** — reserved user identifiers such as `"all"`, `"new"`, `"id"`, and internal system user identifiers are rejected. +- **Unsafe authorization groups** — OIDC/OAuth group claims are sanitized before group-based checks; groups changed by control/format stripping or truncation are dropped rather than repaired into allowed/admin matches. + +Sanitization is always enabled and cannot be disabled. + +When sanitization modifies or clears a field value, a `WARN` log entry is emitted with the provider name, provider type, and field name — but never the raw or sanitized value, to avoid leaking sensitive data into logs. This makes it straightforward to detect and investigate potentially malicious or misconfigured providers. + +--- + ## UI and API Access WireGuard Portal provides a web UI and a REST API for user interaction. It is important to secure these interfaces to prevent unauthorized access and data breaches. @@ -21,4 +38,4 @@ A detailed explanation is available in the [Reverse Proxy](../getting-started/re ### Secure Authentication To prevent unauthorized access, WireGuard Portal supports integrating with secure authentication providers such as LDAP, OAuth2, or Passkeys, see [Authentication](./authentication.md) for more details. When possible, use centralized authentication and enforce multi-factor authentication (MFA) at the provider level for enhanced account security. -For local accounts, administrators should enforce strong password requirements. \ No newline at end of file +For local accounts, administrators should enforce strong password requirements. diff --git a/go.mod b/go.mod index 76ece8d..c88340c 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( gorm.io/driver/postgres v1.6.0 gorm.io/driver/sqlserver v1.6.3 gorm.io/gorm v1.31.1 + pgregory.net/rapid v1.2.0 ) require ( diff --git a/go.sum b/go.sum index ec35c76..808d6ee 100644 --- a/go.sum +++ b/go.sum @@ -432,5 +432,7 @@ modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= +pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk= +pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04= sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs= sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4= diff --git a/internal/app/auth/auth_ldap.go b/internal/app/auth/auth_ldap.go index ddcca83..95a16c9 100644 --- a/internal/app/auth/auth_ldap.go +++ b/internal/app/auth/auth_ldap.go @@ -149,5 +149,9 @@ func (l LdapAuthenticator) ParseUserInfo(raw map[string]any) (*domain.Authentica AdminInfoAvailable: adminInfoAvailable, } + if err := userInfo.Sanitize("ldap", l.cfg.ProviderName); err != nil { + return nil, err + } + return userInfo, nil } diff --git a/internal/app/auth/auth_ldap_sanitize_test.go b/internal/app/auth/auth_ldap_sanitize_test.go new file mode 100644 index 0000000..55daad1 --- /dev/null +++ b/internal/app/auth/auth_ldap_sanitize_test.go @@ -0,0 +1,98 @@ +package auth + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/domain" + "github.com/h44z/wg-portal/internal/testutil" +) + +// makeLdapAuthenticator creates a minimal LdapAuthenticator for testing ParseUserInfo. +func makeLdapAuthenticator() *LdapAuthenticator { + return &LdapAuthenticator{ + cfg: &config.LdapProvider{ + ProviderName: "test-ldap", + FieldMap: config.LdapFields{ + BaseFields: config.BaseFields{ + UserIdentifier: "uid", + Email: "mail", + Firstname: "givenName", + Lastname: "sn", + Phone: "telephoneNumber", + Department: "department", + }, + GroupMembership: "", // no group membership check + }, + }, + } +} + +// makeRawLdapMap builds a minimal raw LDAP attribute map for ParseUserInfo. +func makeRawLdapMap(uid, mail, givenName, sn, phone, department string) map[string]any { + return map[string]any{ + "uid": uid, + "mail": mail, + "givenName": givenName, + "sn": sn, + "telephoneNumber": phone, + "department": department, + } +} + +// Test: firstname contains \x00 → output firstname has no null byte, +// one WARN log entry with field: "firstname". +func TestLdapParseUserInfo_NullByteInFirstname(t *testing.T) { + auth := makeLdapAuthenticator() + raw := makeRawLdapMap("alice", "alice@example.com", "Ali\x00ce", "Smith", "", "") + + restore := testutil.CaptureWarnLogs(t) + info, err := auth.ParseUserInfo(raw) + records := restore() + + require.NoError(t, err) + assert.NotContains(t, info.Firstname, "\x00", "firstname should have null byte removed") + assert.Equal(t, "Alice", info.Firstname) + + warnCount := testutil.CountWarnEntries(records) + assert.Equal(t, 1, warnCount, "expected exactly one WARN log entry") + + rec, found := testutil.FindWarnWithField(records, "firstname") + assert.True(t, found, "expected WARN log entry with field=firstname") + if found { + assert.Equal(t, "WARN", rec["level"]) + } +} + +// Test: all fields clean → no WARN log entries emitted. +func TestLdapParseUserInfo_AllFieldsClean(t *testing.T) { + auth := makeLdapAuthenticator() + raw := makeRawLdapMap("alice", "alice@example.com", "Alice", "Smith", "+1 555-1234", "Engineering") + + restore := testutil.CaptureWarnLogs(t) + info, err := auth.ParseUserInfo(raw) + records := restore() + + require.NoError(t, err) + assert.Equal(t, domain.UserIdentifier("alice"), info.Identifier) + + warnCount := testutil.CountWarnEntries(records) + assert.Equal(t, 0, warnCount, "expected no WARN log entries when all fields are clean") +} + +// Test: identifier is "all" → returns ErrInvalidData. +func TestLdapParseUserInfo_IdentifierAll(t *testing.T) { + auth := makeLdapAuthenticator() + raw := makeRawLdapMap("all", "all@example.com", "Alice", "Smith", "", "") + + restore := testutil.CaptureWarnLogs(t) + _, err := auth.ParseUserInfo(raw) + _ = restore() + + require.Error(t, err) + assert.True(t, errors.Is(err, domain.ErrInvalidData), "expected ErrInvalidData when identifier is 'all'") +} diff --git a/internal/app/auth/auth_oauth.go b/internal/app/auth/auth_oauth.go index 37f6e39..55eb7e4 100644 --- a/internal/app/auth/auth_oauth.go +++ b/internal/app/auth/auth_oauth.go @@ -155,5 +155,5 @@ func (p PlainOauthAuthenticator) GetUserInfo( // ParseUserInfo parses the user information from the raw data. func (p PlainOauthAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) { - return parseOauthUserInfo(p.userInfoMapping, p.userAdminMapping, raw) + return parseOauthUserInfo(p.userInfoMapping, p.userAdminMapping, raw, "oauth", p.name) } diff --git a/internal/app/auth/auth_oidc.go b/internal/app/auth/auth_oidc.go index ec53fdf..5bcdbc7 100644 --- a/internal/app/auth/auth_oidc.go +++ b/internal/app/auth/auth_oidc.go @@ -194,5 +194,5 @@ func (o OidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token, // ParseUserInfo parses the user info. func (o OidcAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) { - return parseOauthUserInfo(o.userInfoMapping, o.userAdminMapping, raw) + return parseOauthUserInfo(o.userInfoMapping, o.userAdminMapping, raw, "oidc", o.name) } diff --git a/internal/app/auth/oauth_common.go b/internal/app/auth/oauth_common.go index 855823c..586be15 100644 --- a/internal/app/auth/oauth_common.go +++ b/internal/app/auth/oauth_common.go @@ -13,6 +13,8 @@ func parseOauthUserInfo( mapping config.OauthFields, adminMapping *config.OauthAdminMapping, raw map[string]any, + providerType string, + providerName string, ) (*domain.AuthenticatorUserInfo, error) { var isAdmin bool var adminInfoAvailable bool @@ -27,18 +29,6 @@ func parseOauthUserInfo( } } - // next try to parse the user's groups - if !isAdmin && mapping.UserGroups != "" && adminMapping.AdminGroupRegex != "" { - adminInfoAvailable = true - re := adminMapping.GetAdminGroupRegex() - for _, group := range userGroups { - if re.MatchString(strings.TrimSpace(group)) { - isAdmin = true - break - } - } - } - userInfo := &domain.AuthenticatorUserInfo{ Identifier: domain.UserIdentifier(internal.MapDefaultString(raw, mapping.UserIdentifier, "")), Email: internal.MapDefaultString(raw, mapping.Email, ""), @@ -51,6 +41,24 @@ func parseOauthUserInfo( AdminInfoAvailable: adminInfoAvailable, } + if err := userInfo.Sanitize(providerType, providerName); err != nil { + return nil, err + } + + // check admin group match after sanitization + if !isAdmin && mapping.UserGroups != "" && adminMapping.AdminGroupRegex != "" { + adminInfoAvailable = true + re := adminMapping.GetAdminGroupRegex() + for _, group := range userInfo.UserGroups { + if re.MatchString(group) { + isAdmin = true + break + } + } + userInfo.IsAdmin = isAdmin + userInfo.AdminInfoAvailable = adminInfoAvailable + } + return userInfo, nil } diff --git a/internal/app/auth/oauth_common_sanitize_test.go b/internal/app/auth/oauth_common_sanitize_test.go new file mode 100644 index 0000000..bb7a48b --- /dev/null +++ b/internal/app/auth/oauth_common_sanitize_test.go @@ -0,0 +1,148 @@ +package auth + +import ( + "errors" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/domain" + "github.com/h44z/wg-portal/internal/testutil" +) + +// makeOauthFieldMapping returns a minimal OauthFields mapping for testing. +func makeOauthFieldMapping() config.OauthFields { + return config.OauthFields{ + BaseFields: config.BaseFields{ + UserIdentifier: "sub", + Email: "email", + Firstname: "given_name", + Lastname: "family_name", + Phone: "phone", + Department: "department", + }, + } +} + +// makeOauthRaw builds a minimal raw OAuth user info map. +func makeOauthRaw(sub, email, givenName, familyName, phone, department string) map[string]any { + return map[string]any{ + "sub": sub, + "email": email, + "given_name": givenName, + "family_name": familyName, + "phone": phone, + "department": department, + } +} + +// Test: email containing \r\n → output email is "", +// one WARN log entry with field: "email" and cleared indication. +func TestParseOauthUserInfo_CRLFInEmail(t *testing.T) { + mapping := makeOauthFieldMapping() + adminMapping := &config.OauthAdminMapping{} + raw := makeOauthRaw("user123", "user\r\n@example.com", "Alice", "Smith", "", "") + + restore := testutil.CaptureWarnLogs(t) + info, err := parseOauthUserInfo(mapping, adminMapping, raw, "oauth", "test-provider") + records := restore() + + require.NoError(t, err) + assert.Equal(t, "", info.Email, "email should be cleared when it contains CR/LF") + + warnCount := testutil.CountWarnEntries(records) + assert.Equal(t, 1, warnCount, "expected exactly one WARN log entry") + + rec, found := testutil.FindWarnWithField(records, "email") + assert.True(t, found, "expected WARN log entry with field=email") + if found { + msg, _ := rec["msg"].(string) + assert.Contains(t, msg, "cleared", "expected 'cleared' in log message when email is cleared") + } +} + +// Test: two fields modified (email cleared, firstname truncated) → +// two separate WARN log entries. +func TestParseOauthUserInfo_TwoFieldsModified(t *testing.T) { + mapping := makeOauthFieldMapping() + adminMapping := &config.OauthAdminMapping{} + + longFirstname := strings.Repeat("A", 200) + raw := makeOauthRaw("user123", "bad\r\nemail@example.com", longFirstname, "Smith", "", "") + + restore := testutil.CaptureWarnLogs(t) + info, err := parseOauthUserInfo(mapping, adminMapping, raw, "oauth", "test-provider") + records := restore() + + require.NoError(t, err) + assert.Equal(t, "", info.Email, "email should be cleared") + assert.Equal(t, 128, len([]rune(info.Firstname)), "firstname should be truncated to 128 runes") + + warnCount := testutil.CountWarnEntries(records) + assert.Equal(t, 2, warnCount, "expected exactly two WARN log entries (one per modified field)") + + _, emailFound := testutil.FindWarnWithField(records, "email") + assert.True(t, emailFound, "expected WARN log entry with field=email") + + _, firstnameFound := testutil.FindWarnWithField(records, "firstname") + assert.True(t, firstnameFound, "expected WARN log entry with field=firstname") +} + +// Test: identifier "all" → returns ErrInvalidData. +func TestParseOauthUserInfo_IdentifierAll(t *testing.T) { + mapping := makeOauthFieldMapping() + adminMapping := &config.OauthAdminMapping{} + raw := makeOauthRaw("all", "all@example.com", "Alice", "Smith", "", "") + + restore := testutil.CaptureWarnLogs(t) + _, err := parseOauthUserInfo(mapping, adminMapping, raw, "oauth", "test-provider") + _ = restore() + + require.Error(t, err) + assert.True(t, errors.Is(err, domain.ErrInvalidData), "expected ErrInvalidData when identifier is 'all'") +} + +func TestParseOauthUserInfo_DropsModifiedGroupBeforeAdminMatch(t *testing.T) { + mapping := makeOauthFieldMapping() + mapping.UserGroups = "groups" + adminMapping := &config.OauthAdminMapping{ + AdminGroupRegex: "^wgportal-admins$", + } + raw := makeOauthRaw("user123", "user@example.com", "Alice", "Smith", "", "") + raw["groups"] = []any{"wgportal-\u200badmins"} + + restore := testutil.CaptureWarnLogs(t) + info, err := parseOauthUserInfo(mapping, adminMapping, raw, "oidc", "test-provider") + records := restore() + + require.NoError(t, err) + require.NotNil(t, info) + assert.False(t, info.IsAdmin, "sanitization must not repair a modified group into an admin match") + assert.Empty(t, info.UserGroups) + + rec, found := testutil.FindWarnWithField(records, "user_group") + assert.True(t, found, "expected WARN log entry with field=user_group") + if found { + assert.Equal(t, "oidc", rec["provider_type"]) + } +} + +func TestParseOauthUserInfo_AllowsWhitespaceOnlyGroupTrim(t *testing.T) { + mapping := makeOauthFieldMapping() + mapping.UserGroups = "groups" + adminMapping := &config.OauthAdminMapping{ + AdminGroupRegex: "^wgportal-admins$", + } + raw := makeOauthRaw("user123", "user@example.com", "Alice", "Smith", "", "") + raw["groups"] = []any{" wgportal-admins "} + + info, err := parseOauthUserInfo(mapping, adminMapping, raw, "oidc", "test-provider") + + require.NoError(t, err) + require.NotNil(t, info) + assert.True(t, info.IsAdmin) + assert.Equal(t, []string{"wgportal-admins"}, info.UserGroups) +} diff --git a/internal/app/auth/oauth_common_test.go b/internal/app/auth/oauth_common_test.go index 8c08210..d3992e3 100644 --- a/internal/app/auth/oauth_common_test.go +++ b/internal/app/auth/oauth_common_test.go @@ -43,7 +43,7 @@ func Test_parseOauthUserInfo_no_admin(t *testing.T) { }) adminMapping := &config.OauthAdminMapping{} - info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo) + info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo, "oauth", "test-provider") assert.NoError(t, err) assert.False(t, info.IsAdmin) assert.Equal(t, info.Firstname, "Test User") @@ -90,7 +90,7 @@ func Test_parseOauthUserInfo_admin_group(t *testing.T) { AdminGroupRegex: "^wgportal-admins@mydomain.net$", } - info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo) + info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo, "oauth", "test-provider") assert.NoError(t, err) assert.True(t, info.IsAdmin) assert.Equal(t, info.Firstname, "Test User") @@ -132,7 +132,7 @@ func Test_parseOauthUserInfo_admin_value(t *testing.T) { }) adminMapping := &config.OauthAdminMapping{} // test with default regex - info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo) + info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo, "oauth", "test-provider") assert.NoError(t, err) assert.True(t, info.IsAdmin) assert.Equal(t, info.Firstname, "Test User") @@ -175,7 +175,7 @@ func Test_parseOauthUserInfo_admin_value_custom(t *testing.T) { AdminValueRegex: "^1$", } - info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo) + info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo, "oauth", "test-provider") assert.NoError(t, err) assert.True(t, info.IsAdmin) assert.Equal(t, info.Firstname, "Test User") diff --git a/internal/app/auth/sanitize_log_test.go b/internal/app/auth/sanitize_log_test.go new file mode 100644 index 0000000..d64ba8f --- /dev/null +++ b/internal/app/auth/sanitize_log_test.go @@ -0,0 +1,90 @@ +package auth + +import ( + "bytes" + "encoding/json" + "log/slog" + "testing" + + "github.com/stretchr/testify/require" + "pgregory.net/rapid" + + "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/domain" + "github.com/h44z/wg-portal/internal/testutil" +) + +// captureWarnLogsInline redirects the default slog logger to a buffer, calls fn, +// restores the original logger, and returns the captured log records. +func captureWarnLogsInline(fn func()) []map[string]any { + original := slog.Default() + var buf bytes.Buffer + handler := slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelWarn}) + slog.SetDefault(slog.New(handler)) + + fn() + + slog.SetDefault(original) + + var records []map[string]any + decoder := json.NewDecoder(&buf) + for decoder.More() { + var rec map[string]any + if err := decoder.Decode(&rec); err == nil { + records = append(records, rec) + } + } + return records +} + +// Property 7: Sanitization change logging completeness +func TestPropertySanitizationChangeLoggingCompleteness(t *testing.T) { + mapping := makeOauthFieldMapping() + adminMapping := &config.OauthAdminMapping{} + + rapid.Check(t, func(t *rapid.T) { + sub := rapid.StringMatching(`[a-zA-Z0-9_@.-]{1,50}`).Draw(t, "sub") + email := rapid.String().Draw(t, "email") + firstname := rapid.String().Draw(t, "firstname") + lastname := rapid.String().Draw(t, "lastname") + phone := rapid.String().Draw(t, "phone") + department := rapid.String().Draw(t, "department") + + if sub == "" { + sub = "testuser" + } + + raw := makeOauthRaw(sub, email, firstname, lastname, phone, department) + + // Count how many fields will actually change after sanitization + expectedChanges := 0 + if domain.SanitizeIdentifier(sub, 256) != sub { + expectedChanges++ + } + if domain.SanitizeEmail(email, 254) != email { + expectedChanges++ + } + if domain.SanitizeString(firstname, 128) != firstname { + expectedChanges++ + } + if domain.SanitizeString(lastname, 128) != lastname { + expectedChanges++ + } + if domain.SanitizePhone(phone, 50) != phone { + expectedChanges++ + } + if domain.SanitizeString(department, 128) != department { + expectedChanges++ + } + + var records []map[string]any + records = captureWarnLogsInline(func() { + _, _ = parseOauthUserInfo(mapping, adminMapping, raw, "oauth", "test-provider") + }) + + actualWarnCount := testutil.CountWarnEntries(records) + require.Equal(t, expectedChanges, actualWarnCount, + "number of WARN log entries (%d) must equal number of fields changed by sanitization (%d)", + actualWarnCount, expectedChanges) + }) +} diff --git a/internal/app/users/ldap_helper.go b/internal/app/users/ldap_helper.go index c202110..51430cb 100644 --- a/internal/app/users/ldap_helper.go +++ b/internal/app/users/ldap_helper.go @@ -28,7 +28,7 @@ func convertRawLdapUser( uid := domain.UserIdentifier(internal.MapDefaultString(rawUser, fields.UserIdentifier, "")) - return &domain.User{ + user := &domain.User{ BaseModel: domain.BaseModel{ CreatedBy: domain.CtxSystemLdapSyncer, UpdatedBy: domain.CtxSystemLdapSyncer, @@ -49,10 +49,16 @@ func convertRawLdapUser( Lastname: internal.MapDefaultString(rawUser, fields.Lastname, ""), Phone: internal.MapDefaultString(rawUser, fields.Phone, ""), Department: internal.MapDefaultString(rawUser, fields.Department, ""), - Notes: "", - Password: "", - Disabled: nil, - }, nil + } + + if err := user.SanitizeExternalData("ldap", providerName); err != nil { + return nil, err + } + + // Update authentication identifier after sanitization + user.Authentications[0].UserIdentifier = user.Identifier + + return user, nil } func userChangedInLdap(dbUser, ldapUser *domain.User) bool { diff --git a/internal/app/users/ldap_helper_sanitize_test.go b/internal/app/users/ldap_helper_sanitize_test.go new file mode 100644 index 0000000..8333992 --- /dev/null +++ b/internal/app/users/ldap_helper_sanitize_test.go @@ -0,0 +1,136 @@ +package users + +import ( + "errors" + "testing" + + "github.com/go-ldap/ldap/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/domain" + "github.com/h44z/wg-portal/internal/testutil" +) + +// makeTestLdapFields returns a minimal LdapFields config for testing. +func makeTestLdapFields() *config.LdapFields { + return &config.LdapFields{ + BaseFields: config.BaseFields{ + UserIdentifier: "uid", + Email: "mail", + Firstname: "givenName", + Lastname: "sn", + Phone: "telephoneNumber", + Department: "department", + }, + GroupMembership: "memberOf", + } +} + +// makeTestAdminGroupDN returns a parsed DN for testing (a non-matching group). +func makeTestAdminGroupDN(t *testing.T) *ldap.DN { + t.Helper() + dn, err := ldap.ParseDN("cn=admins,dc=example,dc=com") + require.NoError(t, err) + return dn +} + +// makeRawLdapUser builds a raw LDAP user map for convertRawLdapUser. +func makeRawLdapUser(uid, mail, givenName, sn, phone, department string) map[string]any { + return map[string]any{ + "uid": uid, + "mail": mail, + "givenName": givenName, + "sn": sn, + "telephoneNumber": phone, + "department": department, + "memberOf": [][]byte{}, // no group memberships + } +} + +// Test: identifier "all" → returns ErrInvalidData, +// one WARN log entry with field: "identifier" and cleared indication. +func TestConvertRawLdapUser_IdentifierAll(t *testing.T) { + fields := makeTestLdapFields() + adminGroupDN := makeTestAdminGroupDN(t) + raw := makeRawLdapUser("all", "all@example.com", "Alice", "Smith", "", "") + + restore := testutil.CaptureWarnLogs(t) + user, err := convertRawLdapUser("test-ldap", raw, fields, adminGroupDN) + records := restore() + + require.Error(t, err) + assert.True(t, errors.Is(err, domain.ErrInvalidData), "expected ErrInvalidData when identifier is 'all'") + assert.Nil(t, user) + + warnCount := testutil.CountWarnEntries(records) + assert.Equal(t, 1, warnCount, "expected exactly one WARN log entry") + + rec, found := testutil.FindWarnWithField(records, "identifier") + assert.True(t, found, "expected WARN log entry with field=identifier") + if found { + msg, _ := rec["msg"].(string) + assert.Contains(t, msg, "cleared", "expected 'cleared' in log message when identifier is cleared") + } +} + +// Test: firstname contains \x00 → output firstname has null byte removed, +// one WARN log entry with field: "firstname". +func TestConvertRawLdapUser_NullByteInFirstname(t *testing.T) { + fields := makeTestLdapFields() + adminGroupDN := makeTestAdminGroupDN(t) + raw := makeRawLdapUser("alice", "alice@example.com", "Ali\x00ce", "Smith", "", "") + + restore := testutil.CaptureWarnLogs(t) + user, err := convertRawLdapUser("test-ldap", raw, fields, adminGroupDN) + records := restore() + + require.NoError(t, err) + require.NotNil(t, user) + assert.NotContains(t, user.Firstname, "\x00", "firstname should have null byte removed") + assert.Equal(t, "Alice", user.Firstname) + + warnCount := testutil.CountWarnEntries(records) + assert.Equal(t, 1, warnCount, "expected exactly one WARN log entry") + + rec, found := testutil.FindWarnWithField(records, "firstname") + assert.True(t, found, "expected WARN log entry with field=firstname") + if found { + assert.Equal(t, "WARN", rec["level"]) + } +} + +// Test: all fields clean → no WARN log entries emitted. +func TestConvertRawLdapUser_AllFieldsClean(t *testing.T) { + fields := makeTestLdapFields() + adminGroupDN := makeTestAdminGroupDN(t) + raw := makeRawLdapUser("alice", "alice@example.com", "Alice", "Smith", "+1 555-1234", "Engineering") + + restore := testutil.CaptureWarnLogs(t) + user, err := convertRawLdapUser("test-ldap", raw, fields, adminGroupDN) + records := restore() + + require.NoError(t, err) + require.NotNil(t, user) + assert.Equal(t, domain.UserIdentifier("alice"), user.Identifier) + + warnCount := testutil.CountWarnEntries(records) + assert.Equal(t, 0, warnCount, "expected no WARN log entries when all fields are clean") +} + +func TestLdapUserIdentifier_NormalizesSyncComparisons(t *testing.T) { + raw := map[string]any{"uid": " alice\x00 "} + + got := ldapUserIdentifier(raw, "uid") + + assert.Equal(t, domain.UserIdentifier("alice"), got) +} + +func TestLdapUserIdentifier_RejectsReservedIdentifier(t *testing.T) { + raw := map[string]any{"uid": " all "} + + got := ldapUserIdentifier(raw, "uid") + + assert.Empty(t, got) +} diff --git a/internal/app/users/ldap_sync.go b/internal/app/users/ldap_sync.go index cc0fff9..56cc299 100644 --- a/internal/app/users/ldap_sync.go +++ b/internal/app/users/ldap_sync.go @@ -109,6 +109,11 @@ func (m Manager) updateLdapUsers( for _, rawUser := range rawUsers { user, err := convertRawLdapUser(provider.ProviderName, rawUser, fields, adminGroupDN) if err != nil && !errors.Is(err, domain.ErrNotFound) { + if errors.Is(err, domain.ErrInvalidData) { + slog.Warn("skipping LDAP user with invalid data after sanitization", + "raw-dn", rawUser["dn"], "error", err) + continue + } return fmt.Errorf("failed to convert LDAP data for %v: %w", rawUser["dn"], err) } @@ -212,7 +217,7 @@ func (m Manager) disableMissingLdapUsers( existsInLDAP := false for _, rawUser := range rawUsers { - userId := domain.UserIdentifier(internal.MapDefaultString(rawUser, fields.UserIdentifier, "")) + userId := ldapUserIdentifier(rawUser, fields.UserIdentifier) if user.Identifier == userId { existsInLDAP = true break @@ -258,19 +263,19 @@ func (m Manager) updateInterfaceLdapFilters( // Combined filter: user must match the provider's base SyncFilter AND the interface's LdapGroupFilter combinedFilter := fmt.Sprintf("(&(%s)(%s))", provider.SyncFilter, groupFilter) - + rawUsers, err := internal.LdapFindAllUsers(conn, provider.BaseDN, combinedFilter, &provider.FieldMap) if err != nil { - slog.Error("failed to find users for interface filter", - "interface", ifaceId, - "provider", provider.ProviderName, + slog.Error("failed to find users for interface filter", + "interface", ifaceId, + "provider", provider.ProviderName, "error", err) continue } matchedUserIds := make([]domain.UserIdentifier, 0, len(rawUsers)) for _, rawUser := range rawUsers { - userId := domain.UserIdentifier(internal.MapDefaultString(rawUser, provider.FieldMap.UserIdentifier, "")) + userId := ldapUserIdentifier(rawUser, provider.FieldMap.UserIdentifier) if userId != "" { matchedUserIds = append(matchedUserIds, userId) } @@ -285,17 +290,26 @@ func (m Manager) updateInterfaceLdapFilters( return i, nil }) if err != nil { - slog.Error("failed to save interface ldap allowed users", - "interface", ifaceId, - "provider", provider.ProviderName, + slog.Error("failed to save interface ldap allowed users", + "interface", ifaceId, + "provider", provider.ProviderName, "error", err) } else { - slog.Debug("updated interface ldap allowed users", - "interface", ifaceId, - "provider", provider.ProviderName, + slog.Debug("updated interface ldap allowed users", + "interface", ifaceId, + "provider", provider.ProviderName, "matched_count", len(matchedUserIds)) } } return nil } + +func ldapUserIdentifier(rawUser map[string]any, field string) domain.UserIdentifier { + identifier := internal.MapDefaultString(rawUser, field, "") + identifier = domain.SanitizeIdentifier(identifier, 256) + if identifier == "" { + return "" + } + return domain.UserIdentifier(identifier) +} diff --git a/internal/app/users/user_manager.go b/internal/app/users/user_manager.go index bb2d3a6..06dbc11 100644 --- a/internal/app/users/user_manager.go +++ b/internal/app/users/user_manager.go @@ -365,19 +365,7 @@ func (m Manager) validateCreation(ctx context.Context, new *domain.User) error { return fmt.Errorf("invalid user identifier: %w", domain.ErrInvalidData) } - if new.Identifier == "all" { // the 'all' user identifier collides with the rest api routes - return fmt.Errorf("reserved user identifier: %w", domain.ErrInvalidData) - } - - if new.Identifier == "new" { // the 'new' user identifier collides with the rest api routes - return fmt.Errorf("reserved user identifier: %w", domain.ErrInvalidData) - } - - if new.Identifier == "id" { // the 'id' user identifier collides with the rest api routes - return fmt.Errorf("reserved user identifier: %w", domain.ErrInvalidData) - } - - if new.Identifier == domain.CtxSystemAdminId || new.Identifier == domain.CtxUnknownUserId { + if domain.IsReservedUserIdentifier(new.Identifier) { return fmt.Errorf("reserved user identifier: %w", domain.ErrInvalidData) } diff --git a/internal/domain/auth.go b/internal/domain/auth.go index 02853b7..a12e328 100644 --- a/internal/domain/auth.go +++ b/internal/domain/auth.go @@ -1,5 +1,10 @@ package domain +import ( + "fmt" + "strings" +) + type LoginProvider string type LoginProviderInfo struct { @@ -20,3 +25,52 @@ type AuthenticatorUserInfo struct { IsAdmin bool AdminInfoAvailable bool // true if the IsAdmin flag is valid } + +// Sanitize sanitizes all external identity provider fields in place. +// Returns ErrInvalidData if the identifier becomes empty after sanitization. +func (u *AuthenticatorUserInfo) Sanitize(providerType, providerName string) error { + identifier := string(u.Identifier) + LogSanitizeChange(providerType, providerName, "identifier", identifier, + func() string { return SanitizeIdentifier(identifier, 256) }, &identifier) + u.Identifier = UserIdentifier(identifier) + + email := u.Email + LogSanitizeChange(providerType, providerName, "email", email, + func() string { return SanitizeEmail(email, 254) }, &u.Email) + LogSanitizeChange(providerType, providerName, "firstname", u.Firstname, + func() string { return SanitizeString(u.Firstname, 128) }, &u.Firstname) + LogSanitizeChange(providerType, providerName, "lastname", u.Lastname, + func() string { return SanitizeString(u.Lastname, 128) }, &u.Lastname) + LogSanitizeChange(providerType, providerName, "phone", u.Phone, + func() string { return SanitizePhone(u.Phone, 50) }, &u.Phone) + LogSanitizeChange(providerType, providerName, "department", u.Department, + func() string { return SanitizeString(u.Department, 128) }, &u.Department) + + u.UserGroups = sanitizeGroups(providerType, providerName, u.UserGroups) + + if u.Identifier == "" { + return fmt.Errorf("empty user identifier: %w", ErrInvalidData) + } + + return nil +} + +// sanitizeGroups sanitizes group names, dropping any that were modified by sanitization. +func sanitizeGroups(providerType, providerName string, rawGroups []string) []string { + if len(rawGroups) == 0 { + return rawGroups + } + + groups := make([]string, 0, len(rawGroups)) + for _, rawGroup := range rawGroups { + sanitized := rawGroup + LogSanitizeChange(providerType, providerName, "user_group", rawGroup, + func() string { return SanitizeString(rawGroup, 256) }, &sanitized) + if sanitized == "" || sanitized != strings.TrimSpace(rawGroup) { + continue + } + groups = append(groups, sanitized) + } + + return groups +} diff --git a/internal/domain/auth_sanitize_test.go b/internal/domain/auth_sanitize_test.go new file mode 100644 index 0000000..5098f35 --- /dev/null +++ b/internal/domain/auth_sanitize_test.go @@ -0,0 +1,103 @@ +package domain + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/h44z/wg-portal/internal/testutil" +) + +func TestAuthenticatorUserInfo_Sanitize_NullByteInFirstname(t *testing.T) { + info := &AuthenticatorUserInfo{ + Identifier: "alice", + Email: "alice@example.com", + Firstname: "Ali\x00ce", + Lastname: "Smith", + } + + restore := testutil.CaptureWarnLogs(t) + err := info.Sanitize("ldap", "test-provider") + records := restore() + + require.NoError(t, err) + assert.Equal(t, "Alice", info.Firstname) + + warnCount := testutil.CountWarnEntries(records) + assert.Equal(t, 1, warnCount) + + _, found := testutil.FindWarnWithField(records, "firstname") + assert.True(t, found) +} + +func TestAuthenticatorUserInfo_Sanitize_AllFieldsClean(t *testing.T) { + info := &AuthenticatorUserInfo{ + Identifier: "alice", + Email: "alice@example.com", + Firstname: "Alice", + Lastname: "Smith", + Phone: "+1 555-1234", + Department: "Engineering", + } + + restore := testutil.CaptureWarnLogs(t) + err := info.Sanitize("ldap", "test-provider") + records := restore() + + require.NoError(t, err) + assert.Equal(t, UserIdentifier("alice"), info.Identifier) + assert.Equal(t, 0, testutil.CountWarnEntries(records)) +} + +func TestAuthenticatorUserInfo_Sanitize_IdentifierAll(t *testing.T) { + info := &AuthenticatorUserInfo{ + Identifier: "all", + Email: "all@example.com", + Firstname: "Alice", + Lastname: "Smith", + } + + err := info.Sanitize("ldap", "test-provider") + + require.Error(t, err) + assert.True(t, errors.Is(err, ErrInvalidData)) +} + +func TestAuthenticatorUserInfo_Sanitize_CRLFInEmail(t *testing.T) { + info := &AuthenticatorUserInfo{ + Identifier: "user123", + Email: "user\r\n@example.com", + Firstname: "Alice", + Lastname: "Smith", + } + + restore := testutil.CaptureWarnLogs(t) + err := info.Sanitize("oauth", "test-provider") + records := restore() + + require.NoError(t, err) + assert.Equal(t, "", info.Email) + + _, found := testutil.FindWarnWithField(records, "email") + assert.True(t, found) +} + +func TestAuthenticatorUserInfo_Sanitize_GroupsWithZeroWidthChars(t *testing.T) { + info := &AuthenticatorUserInfo{ + Identifier: "user123", + Email: "user@example.com", + UserGroups: []string{"wgportal-\u200badmins"}, + } + + restore := testutil.CaptureWarnLogs(t) + err := info.Sanitize("oidc", "test-provider") + records := restore() + + require.NoError(t, err) + assert.Empty(t, info.UserGroups) + + _, found := testutil.FindWarnWithField(records, "user_group") + assert.True(t, found) +} diff --git a/internal/domain/sanitize.go b/internal/domain/sanitize.go new file mode 100644 index 0000000..9999a60 --- /dev/null +++ b/internal/domain/sanitize.go @@ -0,0 +1,151 @@ +package domain + +import ( + "log/slog" + "net/mail" + "strings" + "unicode" + "unicode/utf8" + + "golang.org/x/text/unicode/norm" +) + +// LogSanitizeChange applies sanitizeFn to raw, logs when the value changes, and writes +// the sanitized value to dest. Raw and sanitized values are intentionally omitted. +func LogSanitizeChange( + providerType string, + providerName string, + field string, + raw string, + sanitizeFn func() string, + dest *string, +) { + sanitized := sanitizeFn() + if sanitized != raw { + message := "sanitization modified field value from external provider" + if sanitized == "" { + message = "sanitization cleared field value from external provider" + } + slog.Warn(message, + "provider_type", SanitizeString(providerType, 64), + "provider", SanitizeString(providerName, 128), + "field", SanitizeString(field, 64), + ) + } + *dest = sanitized +} + +var reservedUserIdentifiers = map[string]struct{}{ + "all": {}, + "new": {}, + "id": {}, + CtxSystemAdminId: {}, + CtxUnknownUserId: {}, + CtxSystemLdapSyncer: {}, + CtxSystemWgImporter: {}, + CtxSystemV1Migrator: {}, + CtxSystemDBMigrator: {}, +} + +// SanitizeString normalizes to NFC, trims leading and trailing whitespace, strips Unicode +// control and format characters, drops invalid UTF-8 bytes, and truncates the result to +// maxLen runes. If maxLen <= 0, returns "". +func SanitizeString(s string, maxLen int) string { + if maxLen <= 0 { + return "" + } + + s = norm.NFC.String(strings.TrimSpace(s)) + + var b strings.Builder + b.Grow(len(s)) + for len(s) > 0 { + r, size := utf8.DecodeRuneInString(s) + s = s[size:] + if r == utf8.RuneError && size == 1 { + continue + } + if !unicode.IsControl(r) && !unicode.Is(unicode.Cf, r) { + b.WriteRune(r) + } + } + s = b.String() + + if utf8.RuneCountInString(s) > maxLen { + runes := []rune(s) + s = string(runes[:maxLen]) + } + + return strings.TrimSpace(s) +} + +// SanitizeEmail applies SanitizeString first, then returns "" if the original s +// contains CR/LF or if the sanitized result is not a plain email address. +func SanitizeEmail(s string, maxLen int) string { + if strings.ContainsRune(s, '\r') || strings.ContainsRune(s, '\n') { + return "" + } + + sanitized := SanitizeString(s, maxLen) + + if sanitized == "" || strings.Count(sanitized, "@") != 1 { + return "" + } + addr, err := mail.ParseAddress(sanitized) + if err != nil || addr.Name != "" || addr.Address != sanitized { + return "" + } + + return sanitized +} + +// SanitizePhone applies SanitizeString first, then removes all characters not in the +// set [0-9+\-() .]. Returns "" if the result after filtering is empty. +func SanitizePhone(s string, maxLen int) string { + sanitized := SanitizeString(s, maxLen) + + // Remove all characters not in [0-9+\-() .] + var b strings.Builder + b.Grow(len(sanitized)) + for _, r := range sanitized { + if isAllowedPhoneRune(r) { + b.WriteRune(r) + } + } + result := strings.TrimSpace(b.String()) + + if result == "" { + return "" + } + + return result +} + +// isAllowedPhoneRune reports whether r is in the allowed phone character set [0-9+\-() .]. +func isAllowedPhoneRune(r rune) bool { + switch { + case r >= '0' && r <= '9': + return true + case r == '+', r == '-', r == '(', r == ')', r == ' ', r == '.': + return true + default: + return false + } +} + +// SanitizeIdentifier applies SanitizeString first, then returns "" if the result equals +// a reserved user identifier (case-sensitive, exact match). +func SanitizeIdentifier(s string, maxLen int) string { + sanitized := SanitizeString(s, maxLen) + + if IsReservedUserIdentifier(UserIdentifier(sanitized)) { + return "" + } + + return sanitized +} + +func IsReservedUserIdentifier(identifier UserIdentifier) bool { + _, reserved := reservedUserIdentifiers[string(identifier)] + return reserved +} diff --git a/internal/domain/sanitize_test.go b/internal/domain/sanitize_test.go new file mode 100644 index 0000000..d3e5b6a --- /dev/null +++ b/internal/domain/sanitize_test.go @@ -0,0 +1,503 @@ +package domain + +import ( + "net/mail" + "strings" + "testing" + "unicode" + "unicode/utf8" + + "pgregory.net/rapid" +) + +func TestSanitizeString(t *testing.T) { + tests := []struct { + name string + input string + maxLen int + want string + }{ + { + name: "null byte removed", + input: "\x00", + maxLen: 64, + want: "", + }, + { + name: "CR removed", + input: "\r", + maxLen: 64, + want: "", + }, + { + name: "LF removed", + input: "\n", + maxLen: 64, + want: "", + }, + { + name: "tab removed", + input: "\t", + maxLen: 64, + want: "", + }, + { + name: "leading and trailing whitespace trimmed", + input: " hello ", + maxLen: 64, + want: "hello", + }, + { + name: "multi-byte UTF-8 truncation at rune boundary", + input: "héllo", + maxLen: 3, + want: "hél", // 3 runes, not 3 bytes + }, + { + name: "empty input", + input: "", + maxLen: 64, + want: "", + }, + { + name: "maxLen zero returns empty", + input: "hello", + maxLen: 0, + want: "", + }, + { + name: "string longer than maxLen truncated", + input: "abcdefgh", + maxLen: 4, + want: "abcd", + }, + { + name: "mixed control chars and normal chars", + input: "hel\x00lo\r\nworld", + maxLen: 64, + want: "helloworld", + }, + { + name: "only whitespace returns empty", + input: " ", + maxLen: 64, + want: "", + }, + { + name: "string exactly at maxLen not truncated", + input: "abc", + maxLen: 3, + want: "abc", + }, + { + name: "negative maxLen returns empty", + input: "hello", + maxLen: -1, + want: "", + }, + { + name: "DEL control removed", + input: "hel\x7flo", + maxLen: 64, + want: "hello", + }, + { + name: "zero-width format character removed", + input: "ali\u200bce", + maxLen: 64, + want: "alice", + }, + { + name: "invalid UTF-8 byte removed", + input: "a\xffb", + maxLen: 64, + want: "ab", + }, + { + name: "unicode normalized to NFC", + input: "e\u0301", + maxLen: 64, + want: "é", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := SanitizeString(tc.input, tc.maxLen) + if got != tc.want { + t.Errorf("SanitizeString(%q, %d) = %q; want %q", tc.input, tc.maxLen, got, tc.want) + } + }) + } +} + +func TestSanitizeEmail(t *testing.T) { + tests := []struct { + name string + input string + maxLen int + want string + }{ + { + name: "valid email passes through unchanged", + input: "user@example.com", + maxLen: 254, + want: "user@example.com", + }, + { + name: "CR in email returns empty", + input: "user\r@example.com", + maxLen: 254, + want: "", + }, + { + name: "LF in email returns empty", + input: "user\n@example.com", + maxLen: 254, + want: "", + }, + { + name: "missing @ returns empty", + input: "userexample.com", + maxLen: 254, + want: "", + }, + { + name: "whitespace-only returns empty", + input: " ", + maxLen: 254, + want: "", + }, + { + name: "email with leading/trailing whitespace trimmed and returned", + input: " user@example.com ", + maxLen: 254, + want: "user@example.com", + }, + { + name: "empty input returns empty", + input: "", + maxLen: 254, + want: "", + }, + { + name: "display-name address rejected", + input: "User ", + maxLen: 254, + want: "", + }, + { + name: "multiple at signs rejected", + input: "user@@example.com", + maxLen: 254, + want: "", + }, + { + name: "invalid address rejected", + input: "user@", + maxLen: 254, + want: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := SanitizeEmail(tc.input, tc.maxLen) + if got != tc.want { + t.Errorf("SanitizeEmail(%q, %d) = %q; want %q", tc.input, tc.maxLen, got, tc.want) + } + }) + } +} + +func TestSanitizePhone(t *testing.T) { + tests := []struct { + name string + input string + maxLen int + want string + }{ + { + name: "valid phone passes through unchanged", + input: "+1 (555) 123-4567", + maxLen: 50, + want: "+1 (555) 123-4567", + }, + { + name: "non-allowed chars stripped", + input: "abc+1def", + maxLen: 50, + want: "+1", + }, + { + name: "all-stripped input returns empty", + input: "abc", + maxLen: 50, + want: "", + }, + { + name: "mixed allowed and non-allowed chars", + input: "+49 (0) 123-456.789", + maxLen: 50, + want: "+49 (0) 123-456.789", + }, + { + name: "empty input returns empty", + input: "", + maxLen: 50, + want: "", + }, + { + name: "only digits passes through", + input: "1234567890", + maxLen: 50, + want: "1234567890", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := SanitizePhone(tc.input, tc.maxLen) + if got != tc.want { + t.Errorf("SanitizePhone(%q, %d) = %q; want %q", tc.input, tc.maxLen, got, tc.want) + } + }) + } +} + +func TestSanitizeIdentifier(t *testing.T) { + tests := []struct { + name string + input string + maxLen int + want string + }{ + { + name: "reserved value all returns empty", + input: "all", + maxLen: 256, + want: "", + }, + { + name: "all with surrounding whitespace returns empty", + input: " all ", + maxLen: 256, + want: "", + }, + { + name: "reserved value new returns empty", + input: "new", + maxLen: 256, + want: "", + }, + { + name: "reserved value id returns empty", + input: "id", + maxLen: 256, + want: "", + }, + { + name: "system admin identifier returns empty", + input: string(CtxSystemAdminId), + maxLen: 256, + want: "", + }, + { + name: "unknown user identifier returns empty", + input: string(CtxUnknownUserId), + maxLen: 256, + want: "", + }, + { + name: "LDAP syncer identifier returns empty", + input: string(CtxSystemLdapSyncer), + maxLen: 256, + want: "", + }, + { + name: "ALL uppercase passes through (case-sensitive)", + input: "ALL", + maxLen: 256, + want: "ALL", + }, + { + name: "valid email identifier passes through", + input: "alice@example.com", + maxLen: 256, + want: "alice@example.com", + }, + { + name: "normal identifier passes through", + input: "alice", + maxLen: 256, + want: "alice", + }, + { + name: "empty input returns empty", + input: "", + maxLen: 256, + want: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := SanitizeIdentifier(tc.input, tc.maxLen) + if got != tc.want { + t.Errorf("SanitizeIdentifier(%q, %d) = %q; want %q", tc.input, tc.maxLen, got, tc.want) + } + }) + } +} + +func TestSanitizeXSSPayload(t *testing.T) { + // XSS payload: null byte removed, angle brackets preserved + input := "" + want := "" + got := SanitizeString(input, 256) + if got != want { + t.Errorf("SanitizeString(%q, 256) = %q; want %q", input, got, want) + } +} + +// --------------------------------------------------------------------------- +// Property 1: SanitizeString output invariants +// --------------------------------------------------------------------------- + +// Feature: external-identity-sanitization, Property 1: SanitizeString output is free of control characters and bounded in length +func TestPropertySanitizeStringOutputInvariants(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + s := rapid.String().Draw(t, "s") + maxLen := rapid.IntRange(0, 512).Draw(t, "maxLen") + result := SanitizeString(s, maxLen) + + // No control or format runes in result + for _, r := range result { + if unicode.IsControl(r) || unicode.Is(unicode.Cf, r) { + t.Fatalf("result %q contains unsafe character %U", result, r) + } + } + + if !utf8.ValidString(result) { + t.Fatalf("result %q is not valid UTF-8", result) + } + + // No leading or trailing whitespace + if result != strings.TrimSpace(result) { + t.Fatalf("result %q has leading or trailing whitespace", result) + } + + // Rune count <= maxLen + runeCount := utf8.RuneCountInString(result) + if runeCount > maxLen { + t.Fatalf("result %q has %d runes, exceeds maxLen %d", result, runeCount, maxLen) + } + }) +} + +// --------------------------------------------------------------------------- +// Property 2: SanitizeString is idempotent +// --------------------------------------------------------------------------- + +// Feature: external-identity-sanitization, Property 2: SanitizeString is idempotent +func TestPropertySanitizeStringIdempotent(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + s := rapid.String().Draw(t, "s") + maxLen := rapid.IntRange(1, 512).Draw(t, "maxLen") + + once := SanitizeString(s, maxLen) + twice := SanitizeString(once, maxLen) + + if once != twice { + t.Fatalf("SanitizeString is not idempotent: once=%q, twice=%q (input=%q, maxLen=%d)", + once, twice, s, maxLen) + } + }) +} + +// --------------------------------------------------------------------------- +// Property 3: SanitizeEmail rejection rules +// --------------------------------------------------------------------------- + +// Feature: external-identity-sanitization, Property 3: SanitizeEmail rejects strings without "@" or containing CR/LF +func TestPropertySanitizeEmailRejectionRules(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + s := rapid.String().Draw(t, "s") + maxLen := rapid.IntRange(1, 512).Draw(t, "maxLen") + result := SanitizeEmail(s, maxLen) + + sanitized := SanitizeString(s, maxLen) + addr, parseErr := mail.ParseAddress(sanitized) + reject := strings.ContainsAny(s, "\r\n") || + sanitized == "" || + strings.Count(sanitized, "@") != 1 || + parseErr != nil || + addr.Name != "" || + addr.Address != sanitized + if reject { + if result != "" { + t.Fatalf("SanitizeEmail(%q, %d) = %q; expected empty string (contains CR/LF or no @)", + s, maxLen, result) + } + } + }) +} + +// --------------------------------------------------------------------------- +// Property 4: SanitizePhone allowed character set +// --------------------------------------------------------------------------- + +// isAllowedPhoneCharTest mirrors the internal isAllowedPhoneRune logic for test assertions. +func isAllowedPhoneCharTest(r rune) bool { + switch { + case r >= '0' && r <= '9': + return true + case r == '+', r == '-', r == '(', r == ')', r == ' ', r == '.': + return true + default: + return false + } +} + +// Feature: external-identity-sanitization, Property 4: SanitizePhone output contains only allowed characters +func TestPropertySanitizePhoneAllowedChars(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + s := rapid.String().Draw(t, "s") + maxLen := rapid.IntRange(1, 512).Draw(t, "maxLen") + result := SanitizePhone(s, maxLen) + + for _, r := range result { + if !isAllowedPhoneCharTest(r) { + t.Fatalf("SanitizePhone(%q, %d) = %q; contains disallowed rune %U (%c)", + s, maxLen, result, r, r) + } + } + }) +} + +// --------------------------------------------------------------------------- +// Property 5: SanitizeIdentifier rejects reserved identifiers +// --------------------------------------------------------------------------- + +// Feature: external-identity-sanitization, Property 5: SanitizeIdentifier rejects reserved values +func TestPropertySanitizeIdentifierRejectsReservedValues(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + s := rapid.String().Draw(t, "s") + maxLen := rapid.IntRange(1, 512).Draw(t, "maxLen") + result := SanitizeIdentifier(s, maxLen) + sanitized := SanitizeString(s, maxLen) + + _, reserved := reservedUserIdentifiers[sanitized] + if reserved { + if result != "" { + t.Fatalf("SanitizeIdentifier(%q, %d) = %q; expected empty string when sanitized is reserved", + s, maxLen, result) + } + } else { + if result != sanitized { + t.Fatalf("SanitizeIdentifier(%q, %d) = %q; expected %q (== SanitizeString result)", + s, maxLen, result, sanitized) + } + } + }) +} diff --git a/internal/domain/user.go b/internal/domain/user.go index 39fc982..4eed5f2 100644 --- a/internal/domain/user.go +++ b/internal/domain/user.go @@ -282,6 +282,32 @@ func (u *User) CreateDefaultPeers() bool { return true } +// SanitizeExternalData sanitizes user profile fields received from an external identity provider. +// Returns ErrInvalidData if the identifier becomes empty after sanitization. +func (u *User) SanitizeExternalData(providerType, providerName string) error { + identifier := string(u.Identifier) + LogSanitizeChange(providerType, providerName, "identifier", identifier, + func() string { return SanitizeIdentifier(identifier, 256) }, &identifier) + u.Identifier = UserIdentifier(identifier) + + LogSanitizeChange(providerType, providerName, "email", u.Email, + func() string { return SanitizeEmail(u.Email, 254) }, &u.Email) + LogSanitizeChange(providerType, providerName, "firstname", u.Firstname, + func() string { return SanitizeString(u.Firstname, 128) }, &u.Firstname) + LogSanitizeChange(providerType, providerName, "lastname", u.Lastname, + func() string { return SanitizeString(u.Lastname, 128) }, &u.Lastname) + LogSanitizeChange(providerType, providerName, "phone", u.Phone, + func() string { return SanitizePhone(u.Phone, 50) }, &u.Phone) + LogSanitizeChange(providerType, providerName, "department", u.Department, + func() string { return SanitizeString(u.Department, 128) }, &u.Department) + + if u.Identifier == "" { + return fmt.Errorf("empty user identifier: %w", ErrInvalidData) + } + + return nil +} + // region webauthn func (u *User) WebAuthnID() []byte { diff --git a/internal/domain/user_sanitize_test.go b/internal/domain/user_sanitize_test.go new file mode 100644 index 0000000..0d5f8a5 --- /dev/null +++ b/internal/domain/user_sanitize_test.go @@ -0,0 +1,64 @@ +package domain + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/h44z/wg-portal/internal/testutil" +) + +func TestUser_SanitizeExternalData_NullByteInFirstname(t *testing.T) { + u := &User{ + Identifier: "alice", + Email: "alice@example.com", + Firstname: "Ali\x00ce", + Lastname: "Smith", + } + + restore := testutil.CaptureWarnLogs(t) + err := u.SanitizeExternalData("ldap", "test-provider") + records := restore() + + require.NoError(t, err) + assert.Equal(t, "Alice", u.Firstname) + assert.Equal(t, 1, testutil.CountWarnEntries(records)) + + _, found := testutil.FindWarnWithField(records, "firstname") + assert.True(t, found) +} + +func TestUser_SanitizeExternalData_IdentifierAll(t *testing.T) { + u := &User{ + Identifier: "all", + Email: "all@example.com", + Firstname: "Alice", + Lastname: "Smith", + } + + err := u.SanitizeExternalData("ldap", "test-provider") + + require.Error(t, err) + assert.True(t, errors.Is(err, ErrInvalidData)) +} + +func TestUser_SanitizeExternalData_AllFieldsClean(t *testing.T) { + u := &User{ + Identifier: "alice", + Email: "alice@example.com", + Firstname: "Alice", + Lastname: "Smith", + Phone: "+1 555-1234", + Department: "Engineering", + } + + restore := testutil.CaptureWarnLogs(t) + err := u.SanitizeExternalData("ldap", "test-provider") + records := restore() + + require.NoError(t, err) + assert.Equal(t, UserIdentifier("alice"), u.Identifier) + assert.Equal(t, 0, testutil.CountWarnEntries(records)) +} diff --git a/internal/sanitize/log.go b/internal/sanitize/log.go new file mode 100644 index 0000000..51e4d9b --- /dev/null +++ b/internal/sanitize/log.go @@ -0,0 +1,32 @@ +package sanitize + +import ( + "log/slog" + + "github.com/h44z/wg-portal/internal/domain" +) + +// LogChange applies sanitizeFn to raw, logs when the value changes, and writes +// the sanitized value to dest. Raw and sanitized values are intentionally omitted. +func LogChange( + providerType string, + providerName string, + field string, + raw string, + sanitizeFn func() string, + dest *string, +) { + sanitized := sanitizeFn() + if sanitized != raw { + message := "sanitization modified field value from external provider" + if sanitized == "" { + message = "sanitization cleared field value from external provider" + } + slog.Warn(message, + "provider_type", domain.SanitizeString(providerType, 64), + "provider", domain.SanitizeString(providerName, 128), + "field", domain.SanitizeString(field, 64), + ) + } + *dest = sanitized +} diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go new file mode 100644 index 0000000..c34c37a --- /dev/null +++ b/internal/testutil/testutil.go @@ -0,0 +1,50 @@ +package testutil + +import ( + "bytes" + "encoding/json" + "log/slog" + "testing" +) + +func CaptureWarnLogs(t *testing.T) (restore func() []map[string]any) { + t.Helper() + original := slog.Default() + var buf bytes.Buffer + handler := slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelWarn}) + slog.SetDefault(slog.New(handler)) + + return func() []map[string]any { + slog.SetDefault(original) + var records []map[string]any + decoder := json.NewDecoder(&buf) + for decoder.More() { + var rec map[string]any + if err := decoder.Decode(&rec); err == nil { + records = append(records, rec) + } + } + return records + } +} + +func CountWarnEntries(records []map[string]any) int { + count := 0 + for _, r := range records { + if lvl, ok := r["level"].(string); ok && lvl == "WARN" { + count++ + } + } + return count +} + +func FindWarnWithField(records []map[string]any, fieldName string) (map[string]any, bool) { + for _, r := range records { + if lvl, ok := r["level"].(string); ok && lvl == "WARN" { + if f, ok := r["field"].(string); ok && f == fieldName { + return r, true + } + } + } + return nil, false +}