wg-portal/internal/app/gorm_encryption.go

202 lines
5.2 KiB
Go

package app
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"fmt"
"reflect"
"strings"
"gorm.io/gorm/schema"
"github.com/h44z/wg-portal/internal/domain"
)
// GormEncryptedStringSerializer is a GORM serializer that encrypts and decrypts string values using AES256.
// It is used to store sensitive information in the database securely.
// If the serializer encounters a value that is not a string, it will return an error.
type GormEncryptedStringSerializer struct {
useEncryption bool
keyPhrase string
prefix string
}
// NewGormEncryptedStringSerializer creates a new GormEncryptedStringSerializer.
// It needs to be registered with GORM to be used:
// schema.RegisterSerializer("encstr", gormEncryptedStringSerializerInstance)
// You can then use it in your model like this:
//
// EncryptedField string `gorm:"serializer:encstr"`
func NewGormEncryptedStringSerializer(keyPhrase string) GormEncryptedStringSerializer {
return GormEncryptedStringSerializer{
useEncryption: keyPhrase != "",
keyPhrase: keyPhrase,
prefix: "WG_ENC_",
}
}
// Scan implements the GORM serializer interface. It decrypts the value after reading it from the database.
func (s GormEncryptedStringSerializer) Scan(
ctx context.Context,
field *schema.Field,
dst reflect.Value,
dbValue any,
) (err error) {
var dbStringValue string
if dbValue != nil {
switch v := dbValue.(type) {
case []byte:
dbStringValue = string(v)
case string:
dbStringValue = v
default:
return fmt.Errorf("unsupported type %T for encrypted field %s", dbValue, field.Name)
}
}
if !s.useEncryption {
field.ReflectValueOf(ctx, dst).SetString(dbStringValue) // keep the original value
return nil
}
if !strings.HasPrefix(dbStringValue, s.prefix) {
field.ReflectValueOf(ctx, dst).SetString(dbStringValue) // keep the original value
return nil
}
encryptedString := strings.TrimPrefix(dbStringValue, s.prefix)
decryptedString, err := DecryptAES256(encryptedString, s.keyPhrase)
if err != nil {
return fmt.Errorf("failed to decrypt value for field %s: %w", field.Name, err)
}
field.ReflectValueOf(ctx, dst).SetString(decryptedString)
return
}
// Value implements the GORM serializer interface. It encrypts the value before storing it in the database.
func (s GormEncryptedStringSerializer) Value(
_ context.Context,
_ *schema.Field,
_ reflect.Value,
fieldValue any,
) (any, error) {
if fieldValue == nil {
return nil, nil
}
if !s.useEncryption {
return fieldValue, nil // keep the original value
}
switch v := fieldValue.(type) {
case string:
if v == "" {
return "", nil // empty string, no need to encrypt
}
encryptedString, err := EncryptAES256(v, s.keyPhrase)
if err != nil {
return nil, err
}
return s.prefix + encryptedString, nil
case domain.PreSharedKey:
if v == "" {
return "", nil // empty string, no need to encrypt
}
encryptedString, err := EncryptAES256(string(v), s.keyPhrase)
if err != nil {
return nil, err
}
return s.prefix + encryptedString, nil
default:
return nil, fmt.Errorf("encryption only supports string values, got %T", fieldValue)
}
}
// EncryptAES256 encrypts the given plaintext with the given key using AES256 in CBC mode with PKCS7 padding
func EncryptAES256(plaintext, key string) (string, error) {
if len(plaintext) == 0 {
return "", fmt.Errorf("plaintext must not be empty")
}
if len(key) == 0 {
return "", fmt.Errorf("key must not be empty")
}
key = trimEncKey(key)
iv := key[:aes.BlockSize]
block, err := aes.NewCipher([]byte(key))
if err != nil {
return "", err
}
plain := []byte(plaintext)
plain = pkcs7Padding(plain, aes.BlockSize)
ciphertext := make([]byte, len(plain))
mode := cipher.NewCBCEncrypter(block, []byte(iv))
mode.CryptBlocks(ciphertext, plain)
b64String := base64.StdEncoding.EncodeToString(ciphertext)
return b64String, nil
}
// DecryptAES256 decrypts the given ciphertext with the given key using AES256 in CBC mode with PKCS7 padding
func DecryptAES256(encrypted, key string) (string, error) {
if len(encrypted) == 0 {
return "", fmt.Errorf("ciphertext must not be empty")
}
if len(key) == 0 {
return "", fmt.Errorf("key must not be empty")
}
key = trimEncKey(key)
iv := key[:aes.BlockSize]
ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
if err != nil {
return "", err
}
if len(ciphertext)%aes.BlockSize != 0 {
return "", fmt.Errorf("invalid ciphertext length, must be a multiple of %d", aes.BlockSize)
}
block, err := aes.NewCipher([]byte(key))
if err != nil {
return "", err
}
mode := cipher.NewCBCDecrypter(block, []byte(iv))
mode.CryptBlocks(ciphertext, ciphertext)
ciphertext = pkcs7UnPadding(ciphertext)
return string(ciphertext), nil
}
func pkcs7Padding(ciphertext []byte, blockSize int) []byte {
padding := blockSize - len(ciphertext)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padtext...)
}
func pkcs7UnPadding(src []byte) []byte {
length := len(src)
unpadding := int(src[length-1])
return src[:(length - unpadding)]
}
func trimEncKey(key string) string {
if len(key) > 32 {
return key[:32]
}
if len(key) < 32 {
key = key + strings.Repeat("0", 32-len(key))
}
return key
}