mirror of
https://github.com/h44z/wg-portal.git
synced 2026-05-28 08:56:17 +00:00
feat: sanitize external identity provider user data (#681)
* feat: sanitize external user data * remove config option to disable Sanitization: sanitize_external_user_data * cleanup --------- Co-authored-by: Christoph Haas <christoph.h@sprinternet.at>
This commit is contained in:
@@ -149,5 +149,9 @@ func (l LdapAuthenticator) ParseUserInfo(raw map[string]any) (*domain.Authentica
|
||||
AdminInfoAvailable: adminInfoAvailable,
|
||||
}
|
||||
|
||||
if err := userInfo.Sanitize("ldap", l.cfg.ProviderName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
98
internal/app/auth/auth_ldap_sanitize_test.go
Normal file
98
internal/app/auth/auth_ldap_sanitize_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
"github.com/h44z/wg-portal/internal/testutil"
|
||||
)
|
||||
|
||||
// makeLdapAuthenticator creates a minimal LdapAuthenticator for testing ParseUserInfo.
|
||||
func makeLdapAuthenticator() *LdapAuthenticator {
|
||||
return &LdapAuthenticator{
|
||||
cfg: &config.LdapProvider{
|
||||
ProviderName: "test-ldap",
|
||||
FieldMap: config.LdapFields{
|
||||
BaseFields: config.BaseFields{
|
||||
UserIdentifier: "uid",
|
||||
Email: "mail",
|
||||
Firstname: "givenName",
|
||||
Lastname: "sn",
|
||||
Phone: "telephoneNumber",
|
||||
Department: "department",
|
||||
},
|
||||
GroupMembership: "", // no group membership check
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// makeRawLdapMap builds a minimal raw LDAP attribute map for ParseUserInfo.
|
||||
func makeRawLdapMap(uid, mail, givenName, sn, phone, department string) map[string]any {
|
||||
return map[string]any{
|
||||
"uid": uid,
|
||||
"mail": mail,
|
||||
"givenName": givenName,
|
||||
"sn": sn,
|
||||
"telephoneNumber": phone,
|
||||
"department": department,
|
||||
}
|
||||
}
|
||||
|
||||
// Test: firstname contains \x00 → output firstname has no null byte,
|
||||
// one WARN log entry with field: "firstname".
|
||||
func TestLdapParseUserInfo_NullByteInFirstname(t *testing.T) {
|
||||
auth := makeLdapAuthenticator()
|
||||
raw := makeRawLdapMap("alice", "alice@example.com", "Ali\x00ce", "Smith", "", "")
|
||||
|
||||
restore := testutil.CaptureWarnLogs(t)
|
||||
info, err := auth.ParseUserInfo(raw)
|
||||
records := restore()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, info.Firstname, "\x00", "firstname should have null byte removed")
|
||||
assert.Equal(t, "Alice", info.Firstname)
|
||||
|
||||
warnCount := testutil.CountWarnEntries(records)
|
||||
assert.Equal(t, 1, warnCount, "expected exactly one WARN log entry")
|
||||
|
||||
rec, found := testutil.FindWarnWithField(records, "firstname")
|
||||
assert.True(t, found, "expected WARN log entry with field=firstname")
|
||||
if found {
|
||||
assert.Equal(t, "WARN", rec["level"])
|
||||
}
|
||||
}
|
||||
|
||||
// Test: all fields clean → no WARN log entries emitted.
|
||||
func TestLdapParseUserInfo_AllFieldsClean(t *testing.T) {
|
||||
auth := makeLdapAuthenticator()
|
||||
raw := makeRawLdapMap("alice", "alice@example.com", "Alice", "Smith", "+1 555-1234", "Engineering")
|
||||
|
||||
restore := testutil.CaptureWarnLogs(t)
|
||||
info, err := auth.ParseUserInfo(raw)
|
||||
records := restore()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, domain.UserIdentifier("alice"), info.Identifier)
|
||||
|
||||
warnCount := testutil.CountWarnEntries(records)
|
||||
assert.Equal(t, 0, warnCount, "expected no WARN log entries when all fields are clean")
|
||||
}
|
||||
|
||||
// Test: identifier is "all" → returns ErrInvalidData.
|
||||
func TestLdapParseUserInfo_IdentifierAll(t *testing.T) {
|
||||
auth := makeLdapAuthenticator()
|
||||
raw := makeRawLdapMap("all", "all@example.com", "Alice", "Smith", "", "")
|
||||
|
||||
restore := testutil.CaptureWarnLogs(t)
|
||||
_, err := auth.ParseUserInfo(raw)
|
||||
_ = restore()
|
||||
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, domain.ErrInvalidData), "expected ErrInvalidData when identifier is 'all'")
|
||||
}
|
||||
@@ -155,5 +155,5 @@ func (p PlainOauthAuthenticator) GetUserInfo(
|
||||
|
||||
// ParseUserInfo parses the user information from the raw data.
|
||||
func (p PlainOauthAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
|
||||
return parseOauthUserInfo(p.userInfoMapping, p.userAdminMapping, raw)
|
||||
return parseOauthUserInfo(p.userInfoMapping, p.userAdminMapping, raw, "oauth", p.name)
|
||||
}
|
||||
|
||||
@@ -194,5 +194,5 @@ func (o OidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token,
|
||||
|
||||
// ParseUserInfo parses the user info.
|
||||
func (o OidcAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
|
||||
return parseOauthUserInfo(o.userInfoMapping, o.userAdminMapping, raw)
|
||||
return parseOauthUserInfo(o.userInfoMapping, o.userAdminMapping, raw, "oidc", o.name)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@ func parseOauthUserInfo(
|
||||
mapping config.OauthFields,
|
||||
adminMapping *config.OauthAdminMapping,
|
||||
raw map[string]any,
|
||||
providerType string,
|
||||
providerName string,
|
||||
) (*domain.AuthenticatorUserInfo, error) {
|
||||
var isAdmin bool
|
||||
var adminInfoAvailable bool
|
||||
@@ -27,18 +29,6 @@ func parseOauthUserInfo(
|
||||
}
|
||||
}
|
||||
|
||||
// next try to parse the user's groups
|
||||
if !isAdmin && mapping.UserGroups != "" && adminMapping.AdminGroupRegex != "" {
|
||||
adminInfoAvailable = true
|
||||
re := adminMapping.GetAdminGroupRegex()
|
||||
for _, group := range userGroups {
|
||||
if re.MatchString(strings.TrimSpace(group)) {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
userInfo := &domain.AuthenticatorUserInfo{
|
||||
Identifier: domain.UserIdentifier(internal.MapDefaultString(raw, mapping.UserIdentifier, "")),
|
||||
Email: internal.MapDefaultString(raw, mapping.Email, ""),
|
||||
@@ -51,6 +41,24 @@ func parseOauthUserInfo(
|
||||
AdminInfoAvailable: adminInfoAvailable,
|
||||
}
|
||||
|
||||
if err := userInfo.Sanitize(providerType, providerName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// check admin group match after sanitization
|
||||
if !isAdmin && mapping.UserGroups != "" && adminMapping.AdminGroupRegex != "" {
|
||||
adminInfoAvailable = true
|
||||
re := adminMapping.GetAdminGroupRegex()
|
||||
for _, group := range userInfo.UserGroups {
|
||||
if re.MatchString(group) {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
userInfo.IsAdmin = isAdmin
|
||||
userInfo.AdminInfoAvailable = adminInfoAvailable
|
||||
}
|
||||
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
|
||||
148
internal/app/auth/oauth_common_sanitize_test.go
Normal file
148
internal/app/auth/oauth_common_sanitize_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
"github.com/h44z/wg-portal/internal/testutil"
|
||||
)
|
||||
|
||||
// makeOauthFieldMapping returns a minimal OauthFields mapping for testing.
|
||||
func makeOauthFieldMapping() config.OauthFields {
|
||||
return config.OauthFields{
|
||||
BaseFields: config.BaseFields{
|
||||
UserIdentifier: "sub",
|
||||
Email: "email",
|
||||
Firstname: "given_name",
|
||||
Lastname: "family_name",
|
||||
Phone: "phone",
|
||||
Department: "department",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// makeOauthRaw builds a minimal raw OAuth user info map.
|
||||
func makeOauthRaw(sub, email, givenName, familyName, phone, department string) map[string]any {
|
||||
return map[string]any{
|
||||
"sub": sub,
|
||||
"email": email,
|
||||
"given_name": givenName,
|
||||
"family_name": familyName,
|
||||
"phone": phone,
|
||||
"department": department,
|
||||
}
|
||||
}
|
||||
|
||||
// Test: email containing \r\n → output email is "",
|
||||
// one WARN log entry with field: "email" and cleared indication.
|
||||
func TestParseOauthUserInfo_CRLFInEmail(t *testing.T) {
|
||||
mapping := makeOauthFieldMapping()
|
||||
adminMapping := &config.OauthAdminMapping{}
|
||||
raw := makeOauthRaw("user123", "user\r\n@example.com", "Alice", "Smith", "", "")
|
||||
|
||||
restore := testutil.CaptureWarnLogs(t)
|
||||
info, err := parseOauthUserInfo(mapping, adminMapping, raw, "oauth", "test-provider")
|
||||
records := restore()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "", info.Email, "email should be cleared when it contains CR/LF")
|
||||
|
||||
warnCount := testutil.CountWarnEntries(records)
|
||||
assert.Equal(t, 1, warnCount, "expected exactly one WARN log entry")
|
||||
|
||||
rec, found := testutil.FindWarnWithField(records, "email")
|
||||
assert.True(t, found, "expected WARN log entry with field=email")
|
||||
if found {
|
||||
msg, _ := rec["msg"].(string)
|
||||
assert.Contains(t, msg, "cleared", "expected 'cleared' in log message when email is cleared")
|
||||
}
|
||||
}
|
||||
|
||||
// Test: two fields modified (email cleared, firstname truncated) →
|
||||
// two separate WARN log entries.
|
||||
func TestParseOauthUserInfo_TwoFieldsModified(t *testing.T) {
|
||||
mapping := makeOauthFieldMapping()
|
||||
adminMapping := &config.OauthAdminMapping{}
|
||||
|
||||
longFirstname := strings.Repeat("A", 200)
|
||||
raw := makeOauthRaw("user123", "bad\r\nemail@example.com", longFirstname, "Smith", "", "")
|
||||
|
||||
restore := testutil.CaptureWarnLogs(t)
|
||||
info, err := parseOauthUserInfo(mapping, adminMapping, raw, "oauth", "test-provider")
|
||||
records := restore()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "", info.Email, "email should be cleared")
|
||||
assert.Equal(t, 128, len([]rune(info.Firstname)), "firstname should be truncated to 128 runes")
|
||||
|
||||
warnCount := testutil.CountWarnEntries(records)
|
||||
assert.Equal(t, 2, warnCount, "expected exactly two WARN log entries (one per modified field)")
|
||||
|
||||
_, emailFound := testutil.FindWarnWithField(records, "email")
|
||||
assert.True(t, emailFound, "expected WARN log entry with field=email")
|
||||
|
||||
_, firstnameFound := testutil.FindWarnWithField(records, "firstname")
|
||||
assert.True(t, firstnameFound, "expected WARN log entry with field=firstname")
|
||||
}
|
||||
|
||||
// Test: identifier "all" → returns ErrInvalidData.
|
||||
func TestParseOauthUserInfo_IdentifierAll(t *testing.T) {
|
||||
mapping := makeOauthFieldMapping()
|
||||
adminMapping := &config.OauthAdminMapping{}
|
||||
raw := makeOauthRaw("all", "all@example.com", "Alice", "Smith", "", "")
|
||||
|
||||
restore := testutil.CaptureWarnLogs(t)
|
||||
_, err := parseOauthUserInfo(mapping, adminMapping, raw, "oauth", "test-provider")
|
||||
_ = restore()
|
||||
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, domain.ErrInvalidData), "expected ErrInvalidData when identifier is 'all'")
|
||||
}
|
||||
|
||||
func TestParseOauthUserInfo_DropsModifiedGroupBeforeAdminMatch(t *testing.T) {
|
||||
mapping := makeOauthFieldMapping()
|
||||
mapping.UserGroups = "groups"
|
||||
adminMapping := &config.OauthAdminMapping{
|
||||
AdminGroupRegex: "^wgportal-admins$",
|
||||
}
|
||||
raw := makeOauthRaw("user123", "user@example.com", "Alice", "Smith", "", "")
|
||||
raw["groups"] = []any{"wgportal-\u200badmins"}
|
||||
|
||||
restore := testutil.CaptureWarnLogs(t)
|
||||
info, err := parseOauthUserInfo(mapping, adminMapping, raw, "oidc", "test-provider")
|
||||
records := restore()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, info)
|
||||
assert.False(t, info.IsAdmin, "sanitization must not repair a modified group into an admin match")
|
||||
assert.Empty(t, info.UserGroups)
|
||||
|
||||
rec, found := testutil.FindWarnWithField(records, "user_group")
|
||||
assert.True(t, found, "expected WARN log entry with field=user_group")
|
||||
if found {
|
||||
assert.Equal(t, "oidc", rec["provider_type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOauthUserInfo_AllowsWhitespaceOnlyGroupTrim(t *testing.T) {
|
||||
mapping := makeOauthFieldMapping()
|
||||
mapping.UserGroups = "groups"
|
||||
adminMapping := &config.OauthAdminMapping{
|
||||
AdminGroupRegex: "^wgportal-admins$",
|
||||
}
|
||||
raw := makeOauthRaw("user123", "user@example.com", "Alice", "Smith", "", "")
|
||||
raw["groups"] = []any{" wgportal-admins "}
|
||||
|
||||
info, err := parseOauthUserInfo(mapping, adminMapping, raw, "oidc", "test-provider")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, info)
|
||||
assert.True(t, info.IsAdmin)
|
||||
assert.Equal(t, []string{"wgportal-admins"}, info.UserGroups)
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func Test_parseOauthUserInfo_no_admin(t *testing.T) {
|
||||
})
|
||||
adminMapping := &config.OauthAdminMapping{}
|
||||
|
||||
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo)
|
||||
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo, "oauth", "test-provider")
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, info.IsAdmin)
|
||||
assert.Equal(t, info.Firstname, "Test User")
|
||||
@@ -90,7 +90,7 @@ func Test_parseOauthUserInfo_admin_group(t *testing.T) {
|
||||
AdminGroupRegex: "^wgportal-admins@mydomain.net$",
|
||||
}
|
||||
|
||||
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo)
|
||||
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo, "oauth", "test-provider")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, info.IsAdmin)
|
||||
assert.Equal(t, info.Firstname, "Test User")
|
||||
@@ -132,7 +132,7 @@ func Test_parseOauthUserInfo_admin_value(t *testing.T) {
|
||||
})
|
||||
adminMapping := &config.OauthAdminMapping{} // test with default regex
|
||||
|
||||
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo)
|
||||
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo, "oauth", "test-provider")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, info.IsAdmin)
|
||||
assert.Equal(t, info.Firstname, "Test User")
|
||||
@@ -175,7 +175,7 @@ func Test_parseOauthUserInfo_admin_value_custom(t *testing.T) {
|
||||
AdminValueRegex: "^1$",
|
||||
}
|
||||
|
||||
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo)
|
||||
info, err := parseOauthUserInfo(fieldMapping, adminMapping, userInfo, "oauth", "test-provider")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, info.IsAdmin)
|
||||
assert.Equal(t, info.Firstname, "Test User")
|
||||
|
||||
90
internal/app/auth/sanitize_log_test.go
Normal file
90
internal/app/auth/sanitize_log_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"pgregory.net/rapid"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
"github.com/h44z/wg-portal/internal/testutil"
|
||||
)
|
||||
|
||||
// captureWarnLogsInline redirects the default slog logger to a buffer, calls fn,
|
||||
// restores the original logger, and returns the captured log records.
|
||||
func captureWarnLogsInline(fn func()) []map[string]any {
|
||||
original := slog.Default()
|
||||
var buf bytes.Buffer
|
||||
handler := slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelWarn})
|
||||
slog.SetDefault(slog.New(handler))
|
||||
|
||||
fn()
|
||||
|
||||
slog.SetDefault(original)
|
||||
|
||||
var records []map[string]any
|
||||
decoder := json.NewDecoder(&buf)
|
||||
for decoder.More() {
|
||||
var rec map[string]any
|
||||
if err := decoder.Decode(&rec); err == nil {
|
||||
records = append(records, rec)
|
||||
}
|
||||
}
|
||||
return records
|
||||
}
|
||||
|
||||
// Property 7: Sanitization change logging completeness
|
||||
func TestPropertySanitizationChangeLoggingCompleteness(t *testing.T) {
|
||||
mapping := makeOauthFieldMapping()
|
||||
adminMapping := &config.OauthAdminMapping{}
|
||||
|
||||
rapid.Check(t, func(t *rapid.T) {
|
||||
sub := rapid.StringMatching(`[a-zA-Z0-9_@.-]{1,50}`).Draw(t, "sub")
|
||||
email := rapid.String().Draw(t, "email")
|
||||
firstname := rapid.String().Draw(t, "firstname")
|
||||
lastname := rapid.String().Draw(t, "lastname")
|
||||
phone := rapid.String().Draw(t, "phone")
|
||||
department := rapid.String().Draw(t, "department")
|
||||
|
||||
if sub == "" {
|
||||
sub = "testuser"
|
||||
}
|
||||
|
||||
raw := makeOauthRaw(sub, email, firstname, lastname, phone, department)
|
||||
|
||||
// Count how many fields will actually change after sanitization
|
||||
expectedChanges := 0
|
||||
if domain.SanitizeIdentifier(sub, 256) != sub {
|
||||
expectedChanges++
|
||||
}
|
||||
if domain.SanitizeEmail(email, 254) != email {
|
||||
expectedChanges++
|
||||
}
|
||||
if domain.SanitizeString(firstname, 128) != firstname {
|
||||
expectedChanges++
|
||||
}
|
||||
if domain.SanitizeString(lastname, 128) != lastname {
|
||||
expectedChanges++
|
||||
}
|
||||
if domain.SanitizePhone(phone, 50) != phone {
|
||||
expectedChanges++
|
||||
}
|
||||
if domain.SanitizeString(department, 128) != department {
|
||||
expectedChanges++
|
||||
}
|
||||
|
||||
var records []map[string]any
|
||||
records = captureWarnLogsInline(func() {
|
||||
_, _ = parseOauthUserInfo(mapping, adminMapping, raw, "oauth", "test-provider")
|
||||
})
|
||||
|
||||
actualWarnCount := testutil.CountWarnEntries(records)
|
||||
require.Equal(t, expectedChanges, actualWarnCount,
|
||||
"number of WARN log entries (%d) must equal number of fields changed by sanitization (%d)",
|
||||
actualWarnCount, expectedChanges)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user