mirror of
https://github.com/h44z/wg-portal.git
synced 2026-05-28 08:56:17 +00:00
feat: sanitize external identity provider user data (#681)
* feat: sanitize external user data * remove config option to disable Sanitization: sanitize_external_user_data * cleanup --------- Co-authored-by: Christoph Haas <christoph.h@sprinternet.at>
This commit is contained in:
@@ -1,5 +1,10 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type LoginProvider string
|
||||
|
||||
type LoginProviderInfo struct {
|
||||
@@ -20,3 +25,52 @@ type AuthenticatorUserInfo struct {
|
||||
IsAdmin bool
|
||||
AdminInfoAvailable bool // true if the IsAdmin flag is valid
|
||||
}
|
||||
|
||||
// Sanitize sanitizes all external identity provider fields in place.
|
||||
// Returns ErrInvalidData if the identifier becomes empty after sanitization.
|
||||
func (u *AuthenticatorUserInfo) Sanitize(providerType, providerName string) error {
|
||||
identifier := string(u.Identifier)
|
||||
LogSanitizeChange(providerType, providerName, "identifier", identifier,
|
||||
func() string { return SanitizeIdentifier(identifier, 256) }, &identifier)
|
||||
u.Identifier = UserIdentifier(identifier)
|
||||
|
||||
email := u.Email
|
||||
LogSanitizeChange(providerType, providerName, "email", email,
|
||||
func() string { return SanitizeEmail(email, 254) }, &u.Email)
|
||||
LogSanitizeChange(providerType, providerName, "firstname", u.Firstname,
|
||||
func() string { return SanitizeString(u.Firstname, 128) }, &u.Firstname)
|
||||
LogSanitizeChange(providerType, providerName, "lastname", u.Lastname,
|
||||
func() string { return SanitizeString(u.Lastname, 128) }, &u.Lastname)
|
||||
LogSanitizeChange(providerType, providerName, "phone", u.Phone,
|
||||
func() string { return SanitizePhone(u.Phone, 50) }, &u.Phone)
|
||||
LogSanitizeChange(providerType, providerName, "department", u.Department,
|
||||
func() string { return SanitizeString(u.Department, 128) }, &u.Department)
|
||||
|
||||
u.UserGroups = sanitizeGroups(providerType, providerName, u.UserGroups)
|
||||
|
||||
if u.Identifier == "" {
|
||||
return fmt.Errorf("empty user identifier: %w", ErrInvalidData)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sanitizeGroups sanitizes group names, dropping any that were modified by sanitization.
|
||||
func sanitizeGroups(providerType, providerName string, rawGroups []string) []string {
|
||||
if len(rawGroups) == 0 {
|
||||
return rawGroups
|
||||
}
|
||||
|
||||
groups := make([]string, 0, len(rawGroups))
|
||||
for _, rawGroup := range rawGroups {
|
||||
sanitized := rawGroup
|
||||
LogSanitizeChange(providerType, providerName, "user_group", rawGroup,
|
||||
func() string { return SanitizeString(rawGroup, 256) }, &sanitized)
|
||||
if sanitized == "" || sanitized != strings.TrimSpace(rawGroup) {
|
||||
continue
|
||||
}
|
||||
groups = append(groups, sanitized)
|
||||
}
|
||||
|
||||
return groups
|
||||
}
|
||||
|
||||
103
internal/domain/auth_sanitize_test.go
Normal file
103
internal/domain/auth_sanitize_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/testutil"
|
||||
)
|
||||
|
||||
func TestAuthenticatorUserInfo_Sanitize_NullByteInFirstname(t *testing.T) {
|
||||
info := &AuthenticatorUserInfo{
|
||||
Identifier: "alice",
|
||||
Email: "alice@example.com",
|
||||
Firstname: "Ali\x00ce",
|
||||
Lastname: "Smith",
|
||||
}
|
||||
|
||||
restore := testutil.CaptureWarnLogs(t)
|
||||
err := info.Sanitize("ldap", "test-provider")
|
||||
records := restore()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Alice", info.Firstname)
|
||||
|
||||
warnCount := testutil.CountWarnEntries(records)
|
||||
assert.Equal(t, 1, warnCount)
|
||||
|
||||
_, found := testutil.FindWarnWithField(records, "firstname")
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestAuthenticatorUserInfo_Sanitize_AllFieldsClean(t *testing.T) {
|
||||
info := &AuthenticatorUserInfo{
|
||||
Identifier: "alice",
|
||||
Email: "alice@example.com",
|
||||
Firstname: "Alice",
|
||||
Lastname: "Smith",
|
||||
Phone: "+1 555-1234",
|
||||
Department: "Engineering",
|
||||
}
|
||||
|
||||
restore := testutil.CaptureWarnLogs(t)
|
||||
err := info.Sanitize("ldap", "test-provider")
|
||||
records := restore()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, UserIdentifier("alice"), info.Identifier)
|
||||
assert.Equal(t, 0, testutil.CountWarnEntries(records))
|
||||
}
|
||||
|
||||
func TestAuthenticatorUserInfo_Sanitize_IdentifierAll(t *testing.T) {
|
||||
info := &AuthenticatorUserInfo{
|
||||
Identifier: "all",
|
||||
Email: "all@example.com",
|
||||
Firstname: "Alice",
|
||||
Lastname: "Smith",
|
||||
}
|
||||
|
||||
err := info.Sanitize("ldap", "test-provider")
|
||||
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrInvalidData))
|
||||
}
|
||||
|
||||
func TestAuthenticatorUserInfo_Sanitize_CRLFInEmail(t *testing.T) {
|
||||
info := &AuthenticatorUserInfo{
|
||||
Identifier: "user123",
|
||||
Email: "user\r\n@example.com",
|
||||
Firstname: "Alice",
|
||||
Lastname: "Smith",
|
||||
}
|
||||
|
||||
restore := testutil.CaptureWarnLogs(t)
|
||||
err := info.Sanitize("oauth", "test-provider")
|
||||
records := restore()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "", info.Email)
|
||||
|
||||
_, found := testutil.FindWarnWithField(records, "email")
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestAuthenticatorUserInfo_Sanitize_GroupsWithZeroWidthChars(t *testing.T) {
|
||||
info := &AuthenticatorUserInfo{
|
||||
Identifier: "user123",
|
||||
Email: "user@example.com",
|
||||
UserGroups: []string{"wgportal-\u200badmins"},
|
||||
}
|
||||
|
||||
restore := testutil.CaptureWarnLogs(t)
|
||||
err := info.Sanitize("oidc", "test-provider")
|
||||
records := restore()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, info.UserGroups)
|
||||
|
||||
_, found := testutil.FindWarnWithField(records, "user_group")
|
||||
assert.True(t, found)
|
||||
}
|
||||
151
internal/domain/sanitize.go
Normal file
151
internal/domain/sanitize.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"golang.org/x/text/unicode/norm"
|
||||
)
|
||||
|
||||
// LogSanitizeChange applies sanitizeFn to raw, logs when the value changes, and writes
|
||||
// the sanitized value to dest. Raw and sanitized values are intentionally omitted.
|
||||
func LogSanitizeChange(
|
||||
providerType string,
|
||||
providerName string,
|
||||
field string,
|
||||
raw string,
|
||||
sanitizeFn func() string,
|
||||
dest *string,
|
||||
) {
|
||||
sanitized := sanitizeFn()
|
||||
if sanitized != raw {
|
||||
message := "sanitization modified field value from external provider"
|
||||
if sanitized == "" {
|
||||
message = "sanitization cleared field value from external provider"
|
||||
}
|
||||
slog.Warn(message,
|
||||
"provider_type", SanitizeString(providerType, 64),
|
||||
"provider", SanitizeString(providerName, 128),
|
||||
"field", SanitizeString(field, 64),
|
||||
)
|
||||
}
|
||||
*dest = sanitized
|
||||
}
|
||||
|
||||
var reservedUserIdentifiers = map[string]struct{}{
|
||||
"all": {},
|
||||
"new": {},
|
||||
"id": {},
|
||||
CtxSystemAdminId: {},
|
||||
CtxUnknownUserId: {},
|
||||
CtxSystemLdapSyncer: {},
|
||||
CtxSystemWgImporter: {},
|
||||
CtxSystemV1Migrator: {},
|
||||
CtxSystemDBMigrator: {},
|
||||
}
|
||||
|
||||
// SanitizeString normalizes to NFC, trims leading and trailing whitespace, strips Unicode
|
||||
// control and format characters, drops invalid UTF-8 bytes, and truncates the result to
|
||||
// maxLen runes. If maxLen <= 0, returns "".
|
||||
func SanitizeString(s string, maxLen int) string {
|
||||
if maxLen <= 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
s = norm.NFC.String(strings.TrimSpace(s))
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
for len(s) > 0 {
|
||||
r, size := utf8.DecodeRuneInString(s)
|
||||
s = s[size:]
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
continue
|
||||
}
|
||||
if !unicode.IsControl(r) && !unicode.Is(unicode.Cf, r) {
|
||||
b.WriteRune(r)
|
||||
}
|
||||
}
|
||||
s = b.String()
|
||||
|
||||
if utf8.RuneCountInString(s) > maxLen {
|
||||
runes := []rune(s)
|
||||
s = string(runes[:maxLen])
|
||||
}
|
||||
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
// SanitizeEmail applies SanitizeString first, then returns "" if the original s
|
||||
// contains CR/LF or if the sanitized result is not a plain email address.
|
||||
func SanitizeEmail(s string, maxLen int) string {
|
||||
if strings.ContainsRune(s, '\r') || strings.ContainsRune(s, '\n') {
|
||||
return ""
|
||||
}
|
||||
|
||||
sanitized := SanitizeString(s, maxLen)
|
||||
|
||||
if sanitized == "" || strings.Count(sanitized, "@") != 1 {
|
||||
return ""
|
||||
}
|
||||
addr, err := mail.ParseAddress(sanitized)
|
||||
if err != nil || addr.Name != "" || addr.Address != sanitized {
|
||||
return ""
|
||||
}
|
||||
|
||||
return sanitized
|
||||
}
|
||||
|
||||
// SanitizePhone applies SanitizeString first, then removes all characters not in the
|
||||
// set [0-9+\-() .]. Returns "" if the result after filtering is empty.
|
||||
func SanitizePhone(s string, maxLen int) string {
|
||||
sanitized := SanitizeString(s, maxLen)
|
||||
|
||||
// Remove all characters not in [0-9+\-() .]
|
||||
var b strings.Builder
|
||||
b.Grow(len(sanitized))
|
||||
for _, r := range sanitized {
|
||||
if isAllowedPhoneRune(r) {
|
||||
b.WriteRune(r)
|
||||
}
|
||||
}
|
||||
result := strings.TrimSpace(b.String())
|
||||
|
||||
if result == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// isAllowedPhoneRune reports whether r is in the allowed phone character set [0-9+\-() .].
|
||||
func isAllowedPhoneRune(r rune) bool {
|
||||
switch {
|
||||
case r >= '0' && r <= '9':
|
||||
return true
|
||||
case r == '+', r == '-', r == '(', r == ')', r == ' ', r == '.':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// SanitizeIdentifier applies SanitizeString first, then returns "" if the result equals
|
||||
// a reserved user identifier (case-sensitive, exact match).
|
||||
func SanitizeIdentifier(s string, maxLen int) string {
|
||||
sanitized := SanitizeString(s, maxLen)
|
||||
|
||||
if IsReservedUserIdentifier(UserIdentifier(sanitized)) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return sanitized
|
||||
}
|
||||
|
||||
func IsReservedUserIdentifier(identifier UserIdentifier) bool {
|
||||
_, reserved := reservedUserIdentifiers[string(identifier)]
|
||||
return reserved
|
||||
}
|
||||
503
internal/domain/sanitize_test.go
Normal file
503
internal/domain/sanitize_test.go
Normal file
@@ -0,0 +1,503 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"net/mail"
|
||||
"strings"
|
||||
"testing"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"pgregory.net/rapid"
|
||||
)
|
||||
|
||||
func TestSanitizeString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "null byte removed",
|
||||
input: "\x00",
|
||||
maxLen: 64,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "CR removed",
|
||||
input: "\r",
|
||||
maxLen: 64,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "LF removed",
|
||||
input: "\n",
|
||||
maxLen: 64,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "tab removed",
|
||||
input: "\t",
|
||||
maxLen: 64,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "leading and trailing whitespace trimmed",
|
||||
input: " hello ",
|
||||
maxLen: 64,
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
name: "multi-byte UTF-8 truncation at rune boundary",
|
||||
input: "héllo",
|
||||
maxLen: 3,
|
||||
want: "hél", // 3 runes, not 3 bytes
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: "",
|
||||
maxLen: 64,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "maxLen zero returns empty",
|
||||
input: "hello",
|
||||
maxLen: 0,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "string longer than maxLen truncated",
|
||||
input: "abcdefgh",
|
||||
maxLen: 4,
|
||||
want: "abcd",
|
||||
},
|
||||
{
|
||||
name: "mixed control chars and normal chars",
|
||||
input: "hel\x00lo\r\nworld",
|
||||
maxLen: 64,
|
||||
want: "helloworld",
|
||||
},
|
||||
{
|
||||
name: "only whitespace returns empty",
|
||||
input: " ",
|
||||
maxLen: 64,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "string exactly at maxLen not truncated",
|
||||
input: "abc",
|
||||
maxLen: 3,
|
||||
want: "abc",
|
||||
},
|
||||
{
|
||||
name: "negative maxLen returns empty",
|
||||
input: "hello",
|
||||
maxLen: -1,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "DEL control removed",
|
||||
input: "hel\x7flo",
|
||||
maxLen: 64,
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
name: "zero-width format character removed",
|
||||
input: "ali\u200bce",
|
||||
maxLen: 64,
|
||||
want: "alice",
|
||||
},
|
||||
{
|
||||
name: "invalid UTF-8 byte removed",
|
||||
input: "a\xffb",
|
||||
maxLen: 64,
|
||||
want: "ab",
|
||||
},
|
||||
{
|
||||
name: "unicode normalized to NFC",
|
||||
input: "e\u0301",
|
||||
maxLen: 64,
|
||||
want: "é",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := SanitizeString(tc.input, tc.maxLen)
|
||||
if got != tc.want {
|
||||
t.Errorf("SanitizeString(%q, %d) = %q; want %q", tc.input, tc.maxLen, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "valid email passes through unchanged",
|
||||
input: "user@example.com",
|
||||
maxLen: 254,
|
||||
want: "user@example.com",
|
||||
},
|
||||
{
|
||||
name: "CR in email returns empty",
|
||||
input: "user\r@example.com",
|
||||
maxLen: 254,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "LF in email returns empty",
|
||||
input: "user\n@example.com",
|
||||
maxLen: 254,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "missing @ returns empty",
|
||||
input: "userexample.com",
|
||||
maxLen: 254,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "whitespace-only returns empty",
|
||||
input: " ",
|
||||
maxLen: 254,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "email with leading/trailing whitespace trimmed and returned",
|
||||
input: " user@example.com ",
|
||||
maxLen: 254,
|
||||
want: "user@example.com",
|
||||
},
|
||||
{
|
||||
name: "empty input returns empty",
|
||||
input: "",
|
||||
maxLen: 254,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "display-name address rejected",
|
||||
input: "User <user@example.com>",
|
||||
maxLen: 254,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "multiple at signs rejected",
|
||||
input: "user@@example.com",
|
||||
maxLen: 254,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "invalid address rejected",
|
||||
input: "user@",
|
||||
maxLen: 254,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := SanitizeEmail(tc.input, tc.maxLen)
|
||||
if got != tc.want {
|
||||
t.Errorf("SanitizeEmail(%q, %d) = %q; want %q", tc.input, tc.maxLen, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizePhone(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "valid phone passes through unchanged",
|
||||
input: "+1 (555) 123-4567",
|
||||
maxLen: 50,
|
||||
want: "+1 (555) 123-4567",
|
||||
},
|
||||
{
|
||||
name: "non-allowed chars stripped",
|
||||
input: "abc+1def",
|
||||
maxLen: 50,
|
||||
want: "+1",
|
||||
},
|
||||
{
|
||||
name: "all-stripped input returns empty",
|
||||
input: "abc",
|
||||
maxLen: 50,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "mixed allowed and non-allowed chars",
|
||||
input: "+49 (0) 123-456.789",
|
||||
maxLen: 50,
|
||||
want: "+49 (0) 123-456.789",
|
||||
},
|
||||
{
|
||||
name: "empty input returns empty",
|
||||
input: "",
|
||||
maxLen: 50,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "only digits passes through",
|
||||
input: "1234567890",
|
||||
maxLen: 50,
|
||||
want: "1234567890",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := SanitizePhone(tc.input, tc.maxLen)
|
||||
if got != tc.want {
|
||||
t.Errorf("SanitizePhone(%q, %d) = %q; want %q", tc.input, tc.maxLen, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeIdentifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "reserved value all returns empty",
|
||||
input: "all",
|
||||
maxLen: 256,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "all with surrounding whitespace returns empty",
|
||||
input: " all ",
|
||||
maxLen: 256,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "reserved value new returns empty",
|
||||
input: "new",
|
||||
maxLen: 256,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "reserved value id returns empty",
|
||||
input: "id",
|
||||
maxLen: 256,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "system admin identifier returns empty",
|
||||
input: string(CtxSystemAdminId),
|
||||
maxLen: 256,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "unknown user identifier returns empty",
|
||||
input: string(CtxUnknownUserId),
|
||||
maxLen: 256,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "LDAP syncer identifier returns empty",
|
||||
input: string(CtxSystemLdapSyncer),
|
||||
maxLen: 256,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "ALL uppercase passes through (case-sensitive)",
|
||||
input: "ALL",
|
||||
maxLen: 256,
|
||||
want: "ALL",
|
||||
},
|
||||
{
|
||||
name: "valid email identifier passes through",
|
||||
input: "alice@example.com",
|
||||
maxLen: 256,
|
||||
want: "alice@example.com",
|
||||
},
|
||||
{
|
||||
name: "normal identifier passes through",
|
||||
input: "alice",
|
||||
maxLen: 256,
|
||||
want: "alice",
|
||||
},
|
||||
{
|
||||
name: "empty input returns empty",
|
||||
input: "",
|
||||
maxLen: 256,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := SanitizeIdentifier(tc.input, tc.maxLen)
|
||||
if got != tc.want {
|
||||
t.Errorf("SanitizeIdentifier(%q, %d) = %q; want %q", tc.input, tc.maxLen, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeXSSPayload(t *testing.T) {
|
||||
// XSS payload: null byte removed, angle brackets preserved
|
||||
input := "<script>\x00alert(1)</script>"
|
||||
want := "<script>alert(1)</script>"
|
||||
got := SanitizeString(input, 256)
|
||||
if got != want {
|
||||
t.Errorf("SanitizeString(%q, 256) = %q; want %q", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Property 1: SanitizeString output invariants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Feature: external-identity-sanitization, Property 1: SanitizeString output is free of control characters and bounded in length
|
||||
func TestPropertySanitizeStringOutputInvariants(t *testing.T) {
|
||||
rapid.Check(t, func(t *rapid.T) {
|
||||
s := rapid.String().Draw(t, "s")
|
||||
maxLen := rapid.IntRange(0, 512).Draw(t, "maxLen")
|
||||
result := SanitizeString(s, maxLen)
|
||||
|
||||
// No control or format runes in result
|
||||
for _, r := range result {
|
||||
if unicode.IsControl(r) || unicode.Is(unicode.Cf, r) {
|
||||
t.Fatalf("result %q contains unsafe character %U", result, r)
|
||||
}
|
||||
}
|
||||
|
||||
if !utf8.ValidString(result) {
|
||||
t.Fatalf("result %q is not valid UTF-8", result)
|
||||
}
|
||||
|
||||
// No leading or trailing whitespace
|
||||
if result != strings.TrimSpace(result) {
|
||||
t.Fatalf("result %q has leading or trailing whitespace", result)
|
||||
}
|
||||
|
||||
// Rune count <= maxLen
|
||||
runeCount := utf8.RuneCountInString(result)
|
||||
if runeCount > maxLen {
|
||||
t.Fatalf("result %q has %d runes, exceeds maxLen %d", result, runeCount, maxLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Property 2: SanitizeString is idempotent
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Feature: external-identity-sanitization, Property 2: SanitizeString is idempotent
|
||||
func TestPropertySanitizeStringIdempotent(t *testing.T) {
|
||||
rapid.Check(t, func(t *rapid.T) {
|
||||
s := rapid.String().Draw(t, "s")
|
||||
maxLen := rapid.IntRange(1, 512).Draw(t, "maxLen")
|
||||
|
||||
once := SanitizeString(s, maxLen)
|
||||
twice := SanitizeString(once, maxLen)
|
||||
|
||||
if once != twice {
|
||||
t.Fatalf("SanitizeString is not idempotent: once=%q, twice=%q (input=%q, maxLen=%d)",
|
||||
once, twice, s, maxLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Property 3: SanitizeEmail rejection rules
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Feature: external-identity-sanitization, Property 3: SanitizeEmail rejects strings without "@" or containing CR/LF
|
||||
func TestPropertySanitizeEmailRejectionRules(t *testing.T) {
|
||||
rapid.Check(t, func(t *rapid.T) {
|
||||
s := rapid.String().Draw(t, "s")
|
||||
maxLen := rapid.IntRange(1, 512).Draw(t, "maxLen")
|
||||
result := SanitizeEmail(s, maxLen)
|
||||
|
||||
sanitized := SanitizeString(s, maxLen)
|
||||
addr, parseErr := mail.ParseAddress(sanitized)
|
||||
reject := strings.ContainsAny(s, "\r\n") ||
|
||||
sanitized == "" ||
|
||||
strings.Count(sanitized, "@") != 1 ||
|
||||
parseErr != nil ||
|
||||
addr.Name != "" ||
|
||||
addr.Address != sanitized
|
||||
if reject {
|
||||
if result != "" {
|
||||
t.Fatalf("SanitizeEmail(%q, %d) = %q; expected empty string (contains CR/LF or no @)",
|
||||
s, maxLen, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Property 4: SanitizePhone allowed character set
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// isAllowedPhoneCharTest mirrors the internal isAllowedPhoneRune logic for test assertions.
|
||||
func isAllowedPhoneCharTest(r rune) bool {
|
||||
switch {
|
||||
case r >= '0' && r <= '9':
|
||||
return true
|
||||
case r == '+', r == '-', r == '(', r == ')', r == ' ', r == '.':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Feature: external-identity-sanitization, Property 4: SanitizePhone output contains only allowed characters
|
||||
func TestPropertySanitizePhoneAllowedChars(t *testing.T) {
|
||||
rapid.Check(t, func(t *rapid.T) {
|
||||
s := rapid.String().Draw(t, "s")
|
||||
maxLen := rapid.IntRange(1, 512).Draw(t, "maxLen")
|
||||
result := SanitizePhone(s, maxLen)
|
||||
|
||||
for _, r := range result {
|
||||
if !isAllowedPhoneCharTest(r) {
|
||||
t.Fatalf("SanitizePhone(%q, %d) = %q; contains disallowed rune %U (%c)",
|
||||
s, maxLen, result, r, r)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Property 5: SanitizeIdentifier rejects reserved identifiers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Feature: external-identity-sanitization, Property 5: SanitizeIdentifier rejects reserved values
|
||||
func TestPropertySanitizeIdentifierRejectsReservedValues(t *testing.T) {
|
||||
rapid.Check(t, func(t *rapid.T) {
|
||||
s := rapid.String().Draw(t, "s")
|
||||
maxLen := rapid.IntRange(1, 512).Draw(t, "maxLen")
|
||||
result := SanitizeIdentifier(s, maxLen)
|
||||
sanitized := SanitizeString(s, maxLen)
|
||||
|
||||
_, reserved := reservedUserIdentifiers[sanitized]
|
||||
if reserved {
|
||||
if result != "" {
|
||||
t.Fatalf("SanitizeIdentifier(%q, %d) = %q; expected empty string when sanitized is reserved",
|
||||
s, maxLen, result)
|
||||
}
|
||||
} else {
|
||||
if result != sanitized {
|
||||
t.Fatalf("SanitizeIdentifier(%q, %d) = %q; expected %q (== SanitizeString result)",
|
||||
s, maxLen, result, sanitized)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -282,6 +282,32 @@ func (u *User) CreateDefaultPeers() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// SanitizeExternalData sanitizes user profile fields received from an external identity provider.
|
||||
// Returns ErrInvalidData if the identifier becomes empty after sanitization.
|
||||
func (u *User) SanitizeExternalData(providerType, providerName string) error {
|
||||
identifier := string(u.Identifier)
|
||||
LogSanitizeChange(providerType, providerName, "identifier", identifier,
|
||||
func() string { return SanitizeIdentifier(identifier, 256) }, &identifier)
|
||||
u.Identifier = UserIdentifier(identifier)
|
||||
|
||||
LogSanitizeChange(providerType, providerName, "email", u.Email,
|
||||
func() string { return SanitizeEmail(u.Email, 254) }, &u.Email)
|
||||
LogSanitizeChange(providerType, providerName, "firstname", u.Firstname,
|
||||
func() string { return SanitizeString(u.Firstname, 128) }, &u.Firstname)
|
||||
LogSanitizeChange(providerType, providerName, "lastname", u.Lastname,
|
||||
func() string { return SanitizeString(u.Lastname, 128) }, &u.Lastname)
|
||||
LogSanitizeChange(providerType, providerName, "phone", u.Phone,
|
||||
func() string { return SanitizePhone(u.Phone, 50) }, &u.Phone)
|
||||
LogSanitizeChange(providerType, providerName, "department", u.Department,
|
||||
func() string { return SanitizeString(u.Department, 128) }, &u.Department)
|
||||
|
||||
if u.Identifier == "" {
|
||||
return fmt.Errorf("empty user identifier: %w", ErrInvalidData)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// region webauthn
|
||||
|
||||
func (u *User) WebAuthnID() []byte {
|
||||
|
||||
64
internal/domain/user_sanitize_test.go
Normal file
64
internal/domain/user_sanitize_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/testutil"
|
||||
)
|
||||
|
||||
func TestUser_SanitizeExternalData_NullByteInFirstname(t *testing.T) {
|
||||
u := &User{
|
||||
Identifier: "alice",
|
||||
Email: "alice@example.com",
|
||||
Firstname: "Ali\x00ce",
|
||||
Lastname: "Smith",
|
||||
}
|
||||
|
||||
restore := testutil.CaptureWarnLogs(t)
|
||||
err := u.SanitizeExternalData("ldap", "test-provider")
|
||||
records := restore()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Alice", u.Firstname)
|
||||
assert.Equal(t, 1, testutil.CountWarnEntries(records))
|
||||
|
||||
_, found := testutil.FindWarnWithField(records, "firstname")
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestUser_SanitizeExternalData_IdentifierAll(t *testing.T) {
|
||||
u := &User{
|
||||
Identifier: "all",
|
||||
Email: "all@example.com",
|
||||
Firstname: "Alice",
|
||||
Lastname: "Smith",
|
||||
}
|
||||
|
||||
err := u.SanitizeExternalData("ldap", "test-provider")
|
||||
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrInvalidData))
|
||||
}
|
||||
|
||||
func TestUser_SanitizeExternalData_AllFieldsClean(t *testing.T) {
|
||||
u := &User{
|
||||
Identifier: "alice",
|
||||
Email: "alice@example.com",
|
||||
Firstname: "Alice",
|
||||
Lastname: "Smith",
|
||||
Phone: "+1 555-1234",
|
||||
Department: "Engineering",
|
||||
}
|
||||
|
||||
restore := testutil.CaptureWarnLogs(t)
|
||||
err := u.SanitizeExternalData("ldap", "test-provider")
|
||||
records := restore()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, UserIdentifier("alice"), u.Identifier)
|
||||
assert.Equal(t, 0, testutil.CountWarnEntries(records))
|
||||
}
|
||||
Reference in New Issue
Block a user