chore: replace gin with standard lib net/http

This commit is contained in:
Christoph Haas
2025-03-09 21:16:42 +01:00
parent 7473132932
commit 0206952182
58 changed files with 5302 additions and 1390 deletions

View File

@@ -0,0 +1,214 @@
package cors
import (
"net/http"
"slices"
"strconv"
"strings"
)
// Middleware is a type that creates a new CORS middleware. The CORS middleware
// adds Cross-Origin Resource Sharing headers to the response. This middleware should
// be used to allow cross-origin requests to your server.
type Middleware struct {
o options
varyHeaders string // precomputed Vary header
allOrigins bool // all origins are allowed
}
// New returns a new CORS middleware with the provided options.
func New(opts ...Option) *Middleware {
o := newOptions(opts...)
m := &Middleware{
o: o,
}
// set vary headers
if m.o.allowPrivateNetworks {
m.varyHeaders = "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network"
} else {
m.varyHeaders = "Origin, Access-Control-Request-Method, Access-Control-Request-Headers"
}
if len(m.o.allowedOrigins) == 1 && m.o.allowedOrigins[0] == "*" {
m.allOrigins = true
}
return m
}
// Handler returns the CORS middleware handler.
func (m *Middleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Handle preflight requests and stop the chain as some other
// middleware may not handle OPTIONS requests correctly.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#preflighted_requests
if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
m.handlePreflight(w, r)
w.WriteHeader(http.StatusNoContent) // always return 204 No Content
return
}
// handle normal CORS requests
m.handleNormal(w, r)
next.ServeHTTP(w, r) // execute the next handler
})
}
// region internal-helpers
// handlePreflight handles preflight requests. If the request was successful, this function will
// write the CORS headers and return. If the request was not successful, this function will
// not add any CORS headers and return - thus the CORS request is considered invalid.
func (m *Middleware) handlePreflight(w http.ResponseWriter, r *http.Request) {
// Always set Vary headers
// see https://github.com/rs/cors/issues/10,
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
w.Header().Add("Vary", m.varyHeaders)
// check origin
origin := r.Header.Get("Origin")
if origin == "" {
return // not a valid CORS request
}
if !m.originAllowed(origin) {
return
}
// check method
reqMethod := r.Header.Get("Access-Control-Request-Method")
if !m.methodAllowed(reqMethod) {
return
}
// check headers
reqHeaders := r.Header.Get("Access-Control-Request-Headers")
if !m.headersAllowed(reqHeaders) {
return
}
// set CORS headers for the successful preflight request
if m.allOrigins {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else {
w.Header().Set("Access-Control-Allow-Origin", origin) // return original origin
}
w.Header().Set("Access-Control-Allow-Methods", reqMethod)
if reqHeaders != "" {
// Spec says: Since the list of headers can be unbounded, simply returning supported headers
// from Access-Control-Request-Headers can be enough
w.Header().Set("Access-Control-Allow-Headers", reqHeaders)
}
if m.o.allowCredentials {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
if m.o.allowPrivateNetworks && r.Header.Get("Access-Control-Request-Private-Network") == "true" {
w.Header().Set("Access-Control-Allow-Private-Network", "true")
}
if m.o.maxAge > 0 {
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(m.o.maxAge))
}
}
// handleNormal handles normal CORS requests. If the request was successful, this function will
// write the CORS headers to the response. If the request was not successful, this function will
// not add any CORS headers to the response. In this case, the CORS request is considered invalid.
func (m *Middleware) handleNormal(w http.ResponseWriter, r *http.Request) {
// Always set Vary headers
// see https://github.com/rs/cors/issues/10,
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
w.Header().Add("Vary", "Origin")
// check origin
origin := r.Header.Get("Origin")
if origin == "" {
return // not a valid CORS request
}
if !m.originAllowed(origin) {
return
}
// check method
if !m.methodAllowed(r.Method) {
return
}
// set CORS headers for the successful CORS request
if m.allOrigins {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else {
w.Header().Set("Access-Control-Allow-Origin", origin) // return original origin
}
if len(m.o.exposedHeaders) > 0 {
w.Header().Set("Access-Control-Expose-Headers", strings.Join(m.o.exposedHeaders, ", "))
}
if m.o.allowCredentials {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
}
func (m *Middleware) originAllowed(origin string) bool {
if len(m.o.allowedOrigins) == 1 && m.o.allowedOrigins[0] == "*" {
return true // everything is allowed
}
// check simple origins
if slices.Contains(m.o.allowedOrigins, origin) {
return true
}
// check wildcard origins
for _, allowedOrigin := range m.o.allowedOriginPatterns {
if allowedOrigin.match(origin) {
return true
}
}
return false
}
func (m *Middleware) methodAllowed(method string) bool {
if method == http.MethodOptions {
return true // preflight request is always allowed
}
if len(m.o.allowedMethods) == 1 && m.o.allowedMethods[0] == "*" {
return true // everything is allowed
}
if slices.Contains(m.o.allowedMethods, method) {
return true
}
return false
}
func (m *Middleware) headersAllowed(headers string) bool {
if headers == "" {
return true // no headers are requested
}
if len(m.o.allowedHeaders) == 0 {
return false // no headers are allowed
}
if _, ok := m.o.allowedHeaders["*"]; ok {
return true // everything is allowed
}
// split headers by comma (according to definition, the headers are sorted and in lowercase)
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Request-Headers
for header := range strings.SplitSeq(headers, ",") {
if _, ok := m.o.allowedHeaders[strings.TrimSpace(header)]; !ok {
return false
}
}
return true
}
// endregion internal-helpers

View File

@@ -0,0 +1,101 @@
package cors
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestMiddleware_New(t *testing.T) {
m := New(WithAllowedOrigins("*"))
if len(m.varyHeaders) == 0 {
t.Errorf("expected vary headers to be populated, got %v", m.varyHeaders)
}
if !m.allOrigins {
t.Errorf("expected allOrigins to be true, got %v", m.allOrigins)
}
}
func TestMiddleware_Handler_normal(t *testing.T) {
m := New(WithAllowedOrigins("http://example.com"))
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
req.Header.Set("Origin", "http://example.com")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Result().StatusCode != http.StatusOK {
t.Errorf("expected status code 200, got %d", w.Result().StatusCode)
}
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
t.Errorf("expected Access-Control-Allow-Origin to be 'http://example.com', got %s",
w.Header().Get("Access-Control-Allow-Origin"))
}
}
func TestMiddleware_Handler_preflight(t *testing.T) {
m := New(WithAllowedOrigins("http://example.com"))
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodOptions, "http://example.com", nil)
req.Header.Set("Origin", "http://example.com")
req.Header.Set("Access-Control-Request-Method", http.MethodGet)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Result().StatusCode != http.StatusNoContent {
t.Errorf("expected status code 204, got %d", w.Result().StatusCode)
}
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
t.Errorf("expected Access-Control-Allow-Origin to be 'http://example.com', got %s",
w.Header().Get("Access-Control-Allow-Origin"))
}
}
func TestMiddleware_originAllowed(t *testing.T) {
m := New(WithAllowedOrigins("http://example.com"))
if !m.originAllowed("http://example.com") {
t.Errorf("expected origin 'http://example.com' to be allowed")
}
if m.originAllowed("http://notallowed.com") {
t.Errorf("expected origin 'http://notallowed.com' to be not allowed")
}
}
func TestMiddleware_methodAllowed(t *testing.T) {
m := New(WithAllowedMethods(http.MethodGet, http.MethodPost))
if !m.methodAllowed(http.MethodGet) {
t.Errorf("expected method 'GET' to be allowed")
}
if m.methodAllowed(http.MethodDelete) {
t.Errorf("expected method 'DELETE' to be not allowed")
}
}
func TestMiddleware_headersAllowed(t *testing.T) {
m := New(WithAllowedHeaders("Content-Type", "Authorization"))
if !m.headersAllowed("content-type, authorization") {
t.Errorf("expected headers 'Content-Type, Authorization' to be allowed")
}
if m.headersAllowed("x-custom-header") {
t.Errorf("expected header 'X-Custom-Header' to be not allowed")
}
}

View File

@@ -0,0 +1,133 @@
package cors
import (
"net/http"
"strings"
)
type void struct{}
// options is a struct that contains options for the CORS middleware.
// It uses the functional options pattern for flexible configuration.
type options struct {
allowedOrigins []string // origins without wildcards
allowedOriginPatterns []wildcard // origins with wildcards
allowedMethods []string
allowedHeaders map[string]void
exposedHeaders []string // these are in addition to the CORS-safelisted response headers
allowCredentials bool
allowPrivateNetworks bool
maxAge int
}
// Option is a type that is used to set options for the CORS middleware.
// It implements the functional options pattern.
type Option func(*options)
// WithAllowedOrigins sets the allowed origins for the CORS middleware.
// If the special "*" value is present in the list, all origins will be allowed.
// An origin may contain a wildcard (*) to replace 0 or more characters
// (i.e.: http://*.domain.com). Usage of wildcards implies a small performance penalty.
// Only one wildcard can be used per origin.
// By default, all origins are allowed (*).
func WithAllowedOrigins(origins ...string) Option {
return func(o *options) {
o.allowedOrigins = nil
o.allowedOriginPatterns = nil
for _, origin := range origins {
if len(origin) > 1 && strings.Contains(origin, "*") {
o.allowedOriginPatterns = append(
o.allowedOriginPatterns,
newWildcard(origin),
)
} else {
o.allowedOrigins = append(o.allowedOrigins, origin)
}
}
}
}
// WithAllowedMethods sets the allowed methods for the CORS middleware.
// By default, all methods are allowed (*).
func WithAllowedMethods(methods ...string) Option {
return func(o *options) {
o.allowedMethods = methods
}
}
// WithAllowedHeaders sets the allowed headers for the CORS middleware.
// By default, all headers are allowed (*).
func WithAllowedHeaders(headers ...string) Option {
return func(o *options) {
o.allowedHeaders = make(map[string]void)
for _, header := range headers {
// allowed headers are always checked in lowercase
o.allowedHeaders[strings.ToLower(header)] = void{}
}
}
}
// WithExposedHeaders sets the exposed headers for the CORS middleware.
// By default, no headers are exposed.
func WithExposedHeaders(headers ...string) Option {
return func(o *options) {
o.exposedHeaders = nil
for _, header := range headers {
o.exposedHeaders = append(o.exposedHeaders, http.CanonicalHeaderKey(header))
}
}
}
// WithAllowCredentials sets the allow credentials option for the CORS middleware.
// This setting indicates whether the request can include user credentials like
// cookies, HTTP authentication or client side SSL certificates.
// By default, credentials are not allowed.
func WithAllowCredentials(allow bool) Option {
return func(o *options) {
o.allowCredentials = allow
}
}
// WithAllowPrivateNetworks sets the allow private networks option for the CORS middleware.
// This setting indicates whether to accept cross-origin requests over a private network.
func WithAllowPrivateNetworks(allow bool) Option {
return func(o *options) {
o.allowPrivateNetworks = allow
}
}
// WithMaxAge sets the max age (in seconds) for the CORS middleware.
// The maximum age indicates how long (in seconds) the results of a preflight request
// can be cached. A value of 0 means that no Access-Control-Max-Age header is sent back,
// resulting in browsers using their default value (5s by spec).
// If you need to force a 0 max-age, set it to a negative value (ie: -1).
// By default, the max age is 7200 seconds.
func WithMaxAge(age int) Option {
return func(o *options) {
o.maxAge = age
}
}
// newOptions is a function that returns a new options struct with sane default values.
func newOptions(opts ...Option) options {
o := options{
allowedOrigins: []string{"*"},
allowedMethods: []string{
http.MethodHead, http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete,
},
allowedHeaders: map[string]void{"*": {}},
exposedHeaders: nil,
allowCredentials: false,
allowPrivateNetworks: false,
maxAge: 0,
}
for _, opt := range opts {
opt(&o)
}
return o
}

View File

@@ -0,0 +1,96 @@
package cors
import (
"maps"
"net/http"
"slices"
"testing"
)
func TestWithAllowedOrigins(t *testing.T) {
tests := []struct {
name string
origins []string
wantNormal []string
wantWildcard []wildcard
}{
{
name: "No origins",
origins: []string{},
wantNormal: nil,
wantWildcard: nil,
},
{
name: "Single origin",
origins: []string{"http://example.com"},
wantNormal: []string{"http://example.com"},
wantWildcard: nil,
},
{
name: "Wildcard origin",
origins: []string{"http://*.example.com"},
wantNormal: nil,
wantWildcard: []wildcard{newWildcard("http://*.example.com")},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := newOptions(WithAllowedOrigins(tt.origins...))
if !slices.Equal(o.allowedOrigins, tt.wantNormal) {
t.Errorf("got %v, want %v", o, tt.wantNormal)
}
if !slices.Equal(o.allowedOriginPatterns, tt.wantWildcard) {
t.Errorf("got %v, want %v", o, tt.wantWildcard)
}
})
}
}
func TestWithAllowedMethods(t *testing.T) {
methods := []string{http.MethodGet, http.MethodPost}
o := newOptions(WithAllowedMethods(methods...))
if !slices.Equal(o.allowedMethods, methods) {
t.Errorf("got %v, want %v", o.allowedMethods, methods)
}
}
func TestWithAllowedHeaders(t *testing.T) {
headers := []string{"Content-Type", "Authorization"}
o := newOptions(WithAllowedHeaders(headers...))
expectedHeaders := map[string]void{"content-type": {}, "authorization": {}}
if !maps.Equal(o.allowedHeaders, expectedHeaders) {
t.Errorf("got %v, want %v", o.allowedHeaders, expectedHeaders)
}
}
func TestWithExposedHeaders(t *testing.T) {
headers := []string{"X-Custom-Header"}
o := newOptions(WithExposedHeaders(headers...))
expectedHeaders := []string{http.CanonicalHeaderKey("X-Custom-Header")}
if !slices.Equal(o.exposedHeaders, expectedHeaders) {
t.Errorf("got %v, want %v", o.exposedHeaders, expectedHeaders)
}
}
func TestWithAllowCredentials(t *testing.T) {
o := newOptions(WithAllowCredentials(true))
if !o.allowCredentials {
t.Errorf("got %v, want %v", o.allowCredentials, true)
}
}
func TestWithAllowPrivateNetworks(t *testing.T) {
o := newOptions(WithAllowPrivateNetworks(true))
if !o.allowPrivateNetworks {
t.Errorf("got %v, want %v", o.allowPrivateNetworks, true)
}
}
func TestWithMaxAge(t *testing.T) {
maxAge := 3600
o := newOptions(WithMaxAge(maxAge))
if o.maxAge != maxAge {
t.Errorf("got %v, want %v", o.maxAge, maxAge)
}
}

View File

@@ -0,0 +1,33 @@
package cors
import "strings"
// wildcard is a type that represents a wildcard string.
// This type allows faster matching of strings with a wildcard
// in comparison to using regex.
type wildcard struct {
prefix string
suffix string
}
// match returns true if the string s has the prefix and suffix of the wildcard.
func (w wildcard) match(s string) bool {
return len(s) >= len(w.prefix)+len(w.suffix) &&
strings.HasPrefix(s, w.prefix) &&
strings.HasSuffix(s, w.suffix)
}
func newWildcard(s string) wildcard {
if i := strings.IndexByte(s, '*'); i >= 0 {
return wildcard{
prefix: s[:i],
suffix: s[i+1:],
}
}
// fallback, usually this case should not happen
return wildcard{
prefix: s,
suffix: "",
}
}

View File

@@ -0,0 +1,94 @@
package cors
import "testing"
func TestWildcardMatch(t *testing.T) {
tests := []struct {
name string
wildcard wildcard
input string
expected bool
}{
{
name: "Match with prefix and suffix",
wildcard: newWildcard("http://*.example.com"),
input: "http://sub.example.com",
expected: true,
},
{
name: "No match with different prefix",
wildcard: newWildcard("http://*.example.com"),
input: "https://sub.example.com",
expected: false,
},
{
name: "No match with different suffix",
wildcard: newWildcard("http://*.example.com"),
input: "http://sub.example.org",
expected: false,
},
{
name: "Match with empty suffix",
wildcard: newWildcard("http://*"),
input: "http://example.com",
expected: true,
},
{
name: "Match with empty prefix",
wildcard: newWildcard("*.example.com"),
input: "sub.example.com",
expected: true,
},
{
name: "No match with empty prefix and different suffix",
wildcard: newWildcard("*.example.com"),
input: "sub.example.org",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.wildcard.match(tt.input); got != tt.expected {
t.Errorf("wildcard.match(%s) = %v, want %v", tt.input, got, tt.expected)
}
})
}
}
func TestNewWildcard(t *testing.T) {
tests := []struct {
name string
input string
expected wildcard
}{
{
name: "Wildcard with prefix and suffix",
input: "http://*.example.com",
expected: wildcard{prefix: "http://", suffix: ".example.com"},
},
{
name: "Wildcard with empty suffix",
input: "http://*",
expected: wildcard{prefix: "http://", suffix: ""},
},
{
name: "Wildcard with empty prefix",
input: "*.example.com",
expected: wildcard{prefix: "", suffix: ".example.com"},
},
{
name: "No wildcard character",
input: "http://example.com",
expected: wildcard{prefix: "http://example.com", suffix: ""},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := newWildcard(tt.input); got != tt.expected {
t.Errorf("newWildcard(%s) = %v, want %v", tt.input, got, tt.expected)
}
})
}
}

View File

@@ -0,0 +1,137 @@
package csrf
import (
"context"
"net/http"
"slices"
)
// ContextValueIdentifier is the context value identifier for the CSRF token.
// The token is only stored in the context if the RefreshToken function was called before.
const ContextValueIdentifier = "_csrf_token"
// Middleware is a type that creates a new CSRF middleware. The CSRF middleware
// can be used to mitigate Cross-Site Request Forgery attacks.
type Middleware struct {
o options
}
// New returns a new CSRF middleware with the provided options.
func New(sessionReader SessionReader, sessionWriter SessionWriter, opts ...Option) *Middleware {
opts = append(opts, withSessionReader(sessionReader), withSessionWriter(sessionWriter))
o := newOptions(opts...)
m := &Middleware{
o: o,
}
checkForPRNG()
return m
}
// Handler returns the CSRF middleware handler. This middleware validates the CSRF token and calls the specified
// error handler if an invalid CSRF token was found.
func (m *Middleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if slices.Contains(m.o.ignoreMethods, r.Method) {
next.ServeHTTP(w, r) // skip CSRF check for ignored methods
return
}
// get the token from the request
token := m.o.tokenGetter(r)
storedToken := m.o.sessionGetter(r)
if !tokenEqual(token, storedToken) {
m.o.errCallback(w, r)
return
}
next.ServeHTTP(w, r) // execute the next handler
})
}
// RefreshToken generates a new CSRF Token and stores it in the session. The token is also passed to subsequent handlers
// via the context value ContextValueIdentifier.
func (m *Middleware) RefreshToken(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if GetToken(r.Context()) != "" {
// token already generated higher up in the chain
next.ServeHTTP(w, r)
return
}
// generate a new token
token := generateToken(m.o.tokenLength)
key := generateToken(m.o.tokenLength)
// mask the token
maskedToken := maskToken(token, key)
// store the encoded token in the session
encodedToken := encodeToken(maskedToken)
m.o.sessionWriter(r, encodedToken)
// pass the token down the chain via the context
r = r.WithContext(setToken(r.Context(), encodedToken))
next.ServeHTTP(w, r)
})
}
// region token-access
// GetToken retrieves the CSRF token from the given context. Ensure that the RefreshToken function was called before,
// otherwise, no token is populated in the context.
func GetToken(ctx context.Context) string {
token, ok := ctx.Value(ContextValueIdentifier).(string)
if !ok {
return ""
}
return token
}
// endregion token-access
// region internal-helpers
func setToken(ctx context.Context, token string) context.Context {
return context.WithValue(ctx, ContextValueIdentifier, token)
}
// defaultTokenGetter is the default token getter function for the CSRF middleware.
// It checks the request form values, URL query parameters, and headers for the CSRF token.
// The order of precedence is:
// 1. Header "X-CSRF-TOKEN"
// 2. Header "X-XSRF-TOKEN"
// 3. URL query parameter "_csrf"
// 4. Form value "_csrf"
func defaultTokenGetter(r *http.Request) string {
if t := r.Header.Get("X-CSRF-TOKEN"); len(t) > 0 {
return t
}
if t := r.Header.Get("X-XSRF-TOKEN"); len(t) > 0 {
return t
}
if t := r.URL.Query().Get("_csrf"); len(t) > 0 {
return t
}
if t := r.FormValue("_csrf"); len(t) > 0 {
return t
}
return ""
}
// defaultErrorHandler is the default error handler function for the CSRF middleware.
// It writes a 403 Forbidden response.
func defaultErrorHandler(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "CSRF token mismatch", http.StatusForbidden)
}
// endregion internal-helpers

View File

@@ -0,0 +1,251 @@
package csrf
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/h44z/wg-portal/internal/app/api/core/request"
)
func TestMiddleware_Handler(t *testing.T) {
sessionToken := "stored-token"
sessionReader := func(r *http.Request) string {
return sessionToken
}
sessionWriter := func(r *http.Request, token string) {
sessionToken = token
}
m := New(sessionReader, sessionWriter)
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
tests := []struct {
name string
method string
token string
wantStatus int
}{
{"ValidToken", "POST", "stored-token", http.StatusOK},
{"ValidToken2", "PUT", "stored-token", http.StatusOK},
{"ValidToken3", "GET", "stored-token", http.StatusOK},
{"InvalidToken", "POST", "invalid-token", http.StatusForbidden},
{"IgnoredMethod", "GET", "", http.StatusOK},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, "/", nil)
req.Header.Set("X-CSRF-TOKEN", tt.token)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if status := rr.Code; status != tt.wantStatus {
t.Errorf("Handler() status = %d, want %d", status, tt.wantStatus)
}
})
}
}
func TestMiddleware_RefreshToken(t *testing.T) {
sessionToken := ""
sessionReader := func(r *http.Request) string {
return sessionToken
}
sessionWriter := func(r *http.Request, token string) {
sessionToken = token
}
m := New(sessionReader, sessionWriter)
handler := m.RefreshToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := GetToken(r.Context())
if token == "" {
t.Errorf("RefreshToken() did not set token in context")
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("POST", "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusOK {
t.Errorf("RefreshToken() status = %d, want %d", status, http.StatusOK)
}
if sessionToken == "" {
t.Errorf("RefreshToken() did not set token in session")
}
}
func TestMiddleware_RefreshToken_chained(t *testing.T) {
sessionToken := ""
tokenWrites := 0
sessionReader := func(r *http.Request) string {
return sessionToken
}
sessionWriter := func(r *http.Request, token string) {
sessionToken = token
tokenWrites++
}
m := New(sessionReader, sessionWriter)
handler := m.RefreshToken(m.RefreshToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := GetToken(r.Context())
if token == "" {
t.Errorf("RefreshToken() did not set token in context")
}
w.WriteHeader(http.StatusOK)
})))
req := httptest.NewRequest("POST", "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusOK {
t.Errorf("RefreshToken() status = %d, want %d", status, http.StatusOK)
}
if sessionToken == "" {
t.Errorf("RefreshToken() did not set token in session")
}
if tokenWrites != 1 {
t.Errorf("RefreshToken() wrote token to session more than once: %d", tokenWrites)
}
}
func TestMiddleware_RefreshToken_Handler(t *testing.T) {
sessionToken := ""
sessionReader := func(r *http.Request) string {
return sessionToken
}
sessionWriter := func(r *http.Request, token string) {
sessionToken = token
}
m := New(sessionReader, sessionWriter)
// simulate two requests: first one GET request with the RefreshToken handler, the next one is a PUT request with
// the token from the first request added as X-CSRF-TOKEN header
// first request
retrievedToken := ""
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
handler := m.RefreshToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
retrievedToken = GetToken(r.Context())
if retrievedToken == "" {
t.Errorf("RefreshToken() did not set token in context")
}
w.WriteHeader(http.StatusAccepted)
}))
handler.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusAccepted {
t.Errorf("Handler() status = %d, want %d", status, http.StatusAccepted)
}
if retrievedToken == "" {
t.Errorf("no token retrieved")
}
if retrievedToken != sessionToken {
t.Errorf("token in context does not match token in session")
}
// second request
req = httptest.NewRequest("PUT", "/", nil)
req.Header.Set("X-CSRF-TOKEN", retrievedToken)
rr = httptest.NewRecorder()
handler = m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
handler.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusOK {
t.Errorf("Handler() status = %d, want %d", status, http.StatusOK)
}
}
func TestMiddleware_Handler_FormBody(t *testing.T) {
sessionToken := "stored-token"
sessionReader := func(r *http.Request) string {
return sessionToken
}
sessionWriter := func(r *http.Request, token string) {
sessionToken = token
}
m := New(sessionReader, sessionWriter)
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bodyData, err := request.BodyString(r)
if err != nil {
t.Errorf("Handler() error = %v, want nil", err)
}
// ensure that the body is empty - ParseForm() should have been called before by the CSRF middleware
if bodyData != "" {
t.Errorf("Handler() bodyData = %s, want empty", bodyData)
}
if r.FormValue("_csrf") != "stored-token" {
t.Errorf("Handler() _csrf = %s, want %s", r.FormValue("_csrf"), "stored-token")
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("POST", "/", nil)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Form = make(map[string][]string)
req.Form.Add("_csrf", "stored-token")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusOK {
t.Errorf("Handler() status = %d, want %d", status, http.StatusOK)
}
}
func TestMiddleware_Handler_FormBodyAvailable(t *testing.T) {
sessionToken := "stored-token"
sessionReader := func(r *http.Request) string {
return sessionToken
}
sessionWriter := func(r *http.Request, token string) {
sessionToken = token
}
m := New(sessionReader, sessionWriter)
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bodyData, err := request.BodyString(r)
if err != nil {
t.Errorf("Handler() error = %v, want nil", err)
}
// ensure that the body is not empty, as the CSRF middleware should not have read the body
if bodyData != "the original body" {
t.Errorf("Handler() bodyData = %s, want %s", bodyData, "the original body")
}
// check if the token is available in the form values (from query parameters)
if r.FormValue("_csrf") != "stored-token" {
t.Errorf("Handler() _csrf = %s, want %s", r.FormValue("_csrf"), "stored-token")
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("POST", "/?_csrf=stored-token", nil)
req.Header.Set("Content-Type", "text/plain")
req.Body = io.NopCloser(strings.NewReader("the original body"))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusOK {
t.Errorf("Handler() status = %d, want %d", status, http.StatusOK)
}
}

View File

@@ -0,0 +1,88 @@
package csrf
import "net/http"
type SessionReader func(r *http.Request) string
type SessionWriter func(r *http.Request, token string)
// options is a struct that contains options for the CSRF middleware.
// It uses the functional options pattern for flexible configuration.
type options struct {
tokenLength int
ignoreMethods []string
errCallbackOverride bool
errCallback func(w http.ResponseWriter, r *http.Request)
tokenGetterOverride bool
tokenGetter func(r *http.Request) string
sessionGetter SessionReader
sessionWriter SessionWriter
}
// Option is a type that is used to set options for the CSRF middleware.
// It implements the functional options pattern.
type Option func(*options)
// WithTokenLength is a method that sets the token length for the CSRF middleware.
// The default value is 32.
func WithTokenLength(length int) Option {
return func(o *options) {
o.tokenLength = length
}
}
// WithErrorCallback is a method that sets the error callback function for the CSRF middleware.
// The error callback function is called when the CSRF token is invalid.
// The default behavior is to write a 403 Forbidden response.
func WithErrorCallback(fn func(w http.ResponseWriter, r *http.Request)) Option {
return func(o *options) {
o.errCallback = fn
o.errCallbackOverride = true
}
}
// WithTokenGetter is a method that sets the token getter function for the CSRF middleware.
// The token getter function is called to get the CSRF token from the request.
// The default behavior is to get the token from the "X-CSRF-Token" header.
func WithTokenGetter(fn func(r *http.Request) string) Option {
return func(o *options) {
o.tokenGetter = fn
o.tokenGetterOverride = true
}
}
// withSessionReader is a method that sets the session reader function for the CSRF middleware.
// The session reader function is called to get the CSRF token from the session.
func withSessionReader(fn SessionReader) Option {
return func(o *options) {
o.sessionGetter = fn
}
}
// withSessionWriter is a method that sets the session writer function for the CSRF middleware.
// The session writer function is called to write the CSRF token to the session.
func withSessionWriter(fn SessionWriter) Option {
return func(o *options) {
o.sessionWriter = fn
}
}
// newOptions is a function that returns a new options struct with sane default values.
func newOptions(opts ...Option) options {
o := options{
tokenLength: 32,
ignoreMethods: []string{"GET", "HEAD", "OPTIONS"},
errCallbackOverride: false,
errCallback: defaultErrorHandler,
tokenGetterOverride: false,
tokenGetter: defaultTokenGetter,
}
for _, opt := range opts {
opt(&o)
}
return o
}

View File

@@ -0,0 +1,75 @@
package csrf
import (
"net/http"
"testing"
)
func TestWithTokenLength(t *testing.T) {
o := newOptions(WithTokenLength(64))
if o.tokenLength != 64 {
t.Errorf("WithTokenLength() = %d, want %d", o.tokenLength, 64)
}
}
func TestWithErrorCallback(t *testing.T) {
callback := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
}
o := newOptions(WithErrorCallback(callback))
if !o.errCallbackOverride {
t.Errorf("WithErrorCallback() did not set errCallbackOverride to true")
}
if o.errCallback == nil {
t.Errorf("WithErrorCallback() did not set errCallback")
}
}
func TestWithTokenGetter(t *testing.T) {
getter := func(r *http.Request) string {
return "test-token"
}
o := newOptions(WithTokenGetter(getter))
if !o.tokenGetterOverride {
t.Errorf("WithTokenGetter() did not set tokenGetterOverride to true")
}
if o.tokenGetter == nil {
t.Errorf("WithTokenGetter() did not set tokenGetter")
}
}
func TestWithSessionReader(t *testing.T) {
reader := func(r *http.Request) string {
return "session-token"
}
o := newOptions(withSessionReader(reader))
if o.sessionGetter == nil {
t.Errorf("withSessionReader() did not set sessionGetter")
}
}
func TestWithSessionWriter(t *testing.T) {
writer := func(r *http.Request, token string) {
// do nothing
}
o := newOptions(withSessionWriter(writer))
if o.sessionWriter == nil {
t.Errorf("withSessionWriter() did not set sessionWriter")
}
}
func TestNewOptionsDefaults(t *testing.T) {
o := newOptions()
if o.tokenLength != 32 {
t.Errorf("newOptions() default tokenLength = %d, want %d", o.tokenLength, 32)
}
if len(o.ignoreMethods) != 3 {
t.Errorf("newOptions() default ignoreMethods length = %d, want %d", len(o.ignoreMethods), 3)
}
if o.errCallback == nil {
t.Errorf("newOptions() default errCallback is nil")
}
if o.tokenGetter == nil {
t.Errorf("newOptions() default tokenGetter is nil")
}
}

View File

@@ -0,0 +1,90 @@
package csrf
import (
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"slices"
)
// checkForPRNG is a function that checks if a cryptographically secure PRNG is available.
// If it is not available, the function panics.
func checkForPRNG() {
buf := make([]byte, 1)
_, err := io.ReadFull(rand.Reader, buf)
if err != nil {
panic(fmt.Sprintf("crypto/rand is unavailable: Read() failed with %#v", err))
}
}
// generateToken is a function that generates a secure random CSRF token.
func generateToken(length int) []byte {
bytes := make([]byte, length)
if _, err := io.ReadFull(rand.Reader, bytes); err != nil {
panic(err)
}
return bytes
}
// encodeToken is a function that encodes a token to a base64 string.
func encodeToken(token []byte) string {
return base64.URLEncoding.EncodeToString(token)
}
// decodeToken is a function that decodes a base64 string to a token.
func decodeToken(token string) ([]byte, error) {
return base64.URLEncoding.DecodeString(token)
}
// maskToken is a function that masks a token with a given key.
// The returned byte slice contains the key + the masked token.
// The key needs to have the same length as the token, otherwise the function panics.
// So the resulting slice has a length of len(token) * 2.
func maskToken(token, key []byte) []byte {
if len(token) != len(key) {
panic("token and key must have the same length")
}
// masked contains the key in the first half and the XOR masked token in the second half
tokenLength := len(token)
masked := make([]byte, tokenLength*2)
for i := 0; i < len(token); i++ {
masked[i] = key[i]
masked[i+tokenLength] = token[i] ^ key[i] // XOR mask
}
return masked
}
// unmaskToken is a function that unmask a token which contains the key in the first half.
// The returned byte slice contains the unmasked token, it has exactly half the length of the input slice.
func unmaskToken(masked []byte) []byte {
tokenLength := len(masked) / 2
token := make([]byte, tokenLength)
for i := 0; i < tokenLength; i++ {
token[i] = masked[i] ^ masked[i+tokenLength] // XOR unmask
}
return token
}
// tokenEqual is a function that compares two tokens for equality.
func tokenEqual(a, b string) bool {
decodedA, err := decodeToken(a)
if err != nil {
return false
}
decodedB, err := decodeToken(b)
if err != nil {
return false
}
unmaskedA := unmaskToken(decodedA)
unmaskedB := unmaskToken(decodedB)
return slices.Equal(unmaskedA, unmaskedB)
}

View File

@@ -0,0 +1,81 @@
package csrf
import (
"encoding/base64"
"testing"
)
func TestCheckForPRNG(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Errorf("checkForPRNG() panicked: %v", r)
}
}()
checkForPRNG()
}
func TestGenerateToken(t *testing.T) {
length := 32
token := generateToken(length)
if len(token) != length {
t.Errorf("generateToken() returned token of length %d, expected %d", len(token), length)
}
}
func TestEncodeToken(t *testing.T) {
token := []byte("testtoken")
encoded := encodeToken(token)
expected := base64.URLEncoding.EncodeToString(token)
if encoded != expected {
t.Errorf("encodeToken() = %v, want %v", encoded, expected)
}
}
func TestDecodeToken(t *testing.T) {
token := "dGVzdHRva2Vu"
expected := []byte("testtoken")
decoded, err := decodeToken(token)
if err != nil {
t.Errorf("decodeToken() error = %v", err)
}
if string(decoded) != string(expected) {
t.Errorf("decodeToken() = %v, want %v", decoded, expected)
}
}
func TestMaskToken(t *testing.T) {
token := []byte("testtoken")
key := []byte("keykeykey")
masked := maskToken(token, key)
if len(masked) != len(token)*2 {
t.Errorf("maskToken() returned masked token of length %d, expected %d", len(masked), len(token)*2)
}
}
func TestUnmaskToken(t *testing.T) {
token := []byte("testtoken")
key := []byte("keykeykey")
masked := maskToken(token, key)
unmasked := unmaskToken(masked)
if string(unmasked) != string(token) {
t.Errorf("unmaskToken() = %v, want %v", unmasked, token)
}
}
func TestTokenEqual(t *testing.T) {
tokenA := encodeToken(maskToken([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}))
tokenB := encodeToken(maskToken([]byte{0x01, 0x02, 0x03}, []byte{0x04, 0x05, 0x06}))
if !tokenEqual(tokenA, tokenB) {
t.Errorf("tokenEqual() = false, want true")
}
tokenC := encodeToken(maskToken([]byte{0x01, 0x02, 0x03}, []byte{0x07, 0x08, 0x09}))
if !tokenEqual(tokenA, tokenC) {
t.Errorf("tokenEqual() = false, want true")
}
tokenD := encodeToken(maskToken([]byte{0x09, 0x02, 0x03}, []byte{0x04, 0x05, 0x06}))
if tokenEqual(tokenA, tokenD) {
t.Errorf("tokenEqual() = true, want false")
}
}

View File

@@ -0,0 +1,199 @@
package logging
import (
"fmt"
"log/slog"
"net/http"
"strings"
"time"
)
// LogLevel is an enumeration of the different log levels.
type LogLevel int
const (
LogLevelDebug LogLevel = iota
LogLevelInfo
LogLevelWarn
LogLevelError
)
// Logger is an interface that defines the methods that a logger must implement.
// This allows the logging middleware to be used with different logging libraries.
type Logger interface {
// Debugf logs a message at debug level.
Debugf(format string, args ...any)
// Infof logs a message at info level.
Infof(format string, args ...any)
// Warnf logs a message at warn level.
Warnf(format string, args ...any)
// Errorf logs a message at error level.
Errorf(format string, args ...any)
}
// Middleware is a type that creates a new logging middleware. The logging middleware
// logs information about each request.
type Middleware struct {
o options
}
// New returns a new logging middleware with the provided options.
func New(opts ...Option) *Middleware {
o := newOptions(opts...)
m := &Middleware{
o: o,
}
return m
}
// Handler returns the logging middleware handler.
func (m *Middleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ww := newWriterWrapper(w)
start := time.Now()
defer func() {
info := m.extractInfoMap(r, start, ww)
if m.o.logger == nil {
msg, args := m.buildSlogMessageAndArguments(info)
m.logMsg(msg, args...)
} else {
msg := m.buildNormalLogMessage(info)
m.logMsg(msg)
}
}()
next.ServeHTTP(ww, r)
})
}
func (m *Middleware) extractInfoMap(r *http.Request, start time.Time, ww *writerWrapper) map[string]any {
info := make(map[string]any)
info["method"] = r.Method
info["path"] = r.URL.Path
info["protocol"] = r.Proto
info["clientIP"] = r.Header.Get("X-Forwarded-For")
if info["clientIP"] == "" {
// If the X-Forwarded-For header is not set, use the remote address without the port number.
lastColonIndex := strings.LastIndex(r.RemoteAddr, ":")
switch lastColonIndex {
case -1:
info["clientIP"] = r.RemoteAddr
default:
info["clientIP"] = r.RemoteAddr[:lastColonIndex]
}
}
info["userAgent"] = r.UserAgent()
info["referer"] = r.Header.Get("Referer")
info["duration"] = time.Since(start).String()
info["status"] = ww.StatusCode
info["dataLength"] = ww.WrittenBytes
if m.o.headerRequestIdKey != "" {
info["headerRequestId"] = r.Header.Get(m.o.headerRequestIdKey)
}
if m.o.contextRequestIdKey != "" {
info["contextRequestId"], _ = r.Context().Value(m.o.contextRequestIdKey).(string)
}
return info
}
func (m *Middleware) buildNormalLogMessage(info map[string]any) string {
switch {
case info["headerRequestId"] != nil && info["contextRequestId"] != nil:
return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s - rid=%s ctx=%s",
info["method"], info["path"], info["protocol"],
info["status"], info["dataLength"],
info["duration"],
info["clientIP"], info["userAgent"], info["referer"],
info["headerRequestId"], info["contextRequestId"])
case info["headerRequestId"] != nil:
return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s - rid=%s",
info["method"], info["path"], info["protocol"],
info["status"], info["dataLength"],
info["duration"],
info["clientIP"], info["userAgent"], info["referer"],
info["headerRequestId"])
case info["contextRequestId"] != nil:
return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s - ctx=%s",
info["method"], info["path"], info["protocol"],
info["status"], info["dataLength"],
info["duration"],
info["clientIP"], info["userAgent"], info["referer"],
info["contextRequestId"])
default:
return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s",
info["method"], info["path"], info["protocol"],
info["status"], info["dataLength"],
info["duration"],
info["clientIP"], info["userAgent"], info["referer"])
}
}
func (m *Middleware) buildSlogMessageAndArguments(info map[string]any) (message string, args []any) {
message = fmt.Sprintf("%s %s", info["method"], info["path"])
// Use a fixed order for the keys, so that the message is always the same.
// Skip method and path as they are already in the message.
keys := []string{
"protocol",
"status",
"dataLength",
"duration",
"clientIP",
"userAgent",
"referer",
"headerRequestId",
"contextRequestId",
}
for _, k := range keys {
if v, ok := info[k]; ok {
args = append(args, k, v) // only add key, value if it exists
}
}
return
}
func (m *Middleware) addPrefix(message string) string {
if m.o.prefix != "" {
return m.o.prefix + " " + message
}
return message
}
func (m *Middleware) logMsg(message string, args ...any) {
message = m.addPrefix(message)
if m.o.logger != nil {
switch m.o.logLevel {
case LogLevelDebug:
m.o.logger.Debugf(message, args...)
case LogLevelInfo:
m.o.logger.Infof(message, args...)
case LogLevelWarn:
m.o.logger.Warnf(message, args...)
case LogLevelError:
m.o.logger.Errorf(message, args...)
default:
m.o.logger.Infof(message, args...)
}
} else {
switch m.o.logLevel {
case LogLevelDebug:
slog.Debug(message, args...)
case LogLevelInfo:
slog.Info(message, args...)
case LogLevelWarn:
slog.Warn(message, args...)
case LogLevelError:
slog.Error(message, args...)
default:
slog.Info(message, args...)
}
}
}

View File

@@ -0,0 +1,148 @@
package logging
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
type mockLogger struct {
messages []string
}
func (m *mockLogger) Debugf(format string, _ ...any) {
m.messages = append(m.messages, "DEBUG: "+format)
}
func (m *mockLogger) Infof(format string, _ ...any) {
m.messages = append(m.messages, "INFO: "+format)
}
func (m *mockLogger) Warnf(format string, _ ...any) {
m.messages = append(m.messages, "WARN: "+format)
}
func (m *mockLogger) Errorf(format string, _ ...any) {
m.messages = append(m.messages, "ERROR: "+format)
}
func TestMiddleware_Normal(t *testing.T) {
logger := &mockLogger{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte("Hello, World!"))
})
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
rr := httptest.NewRecorder()
middleware.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusTeapot {
t.Errorf("expected status code to be %v, got %v", http.StatusTeapot, status)
}
expected := "Hello, World!"
if rr.Body.String() != expected {
t.Errorf("expected response body to be %v, got %v", expected, rr.Body.String())
}
if len(logger.messages) == 0 {
t.Errorf("expected log messages, got none")
}
if len(logger.messages) != 0 && !strings.Contains(logger.messages[0], "ERROR: GET /foo") {
t.Errorf("expected log message to contain request info, got %v", logger.messages[0])
}
}
func TestMiddleware_Extended(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte("Hello, World!"))
})
middleware := New(WithContextRequestIdKey("requestId"), WithHeaderRequestIdKey("X-Request-Id")).
Handler(handler)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
rr := httptest.NewRecorder()
middleware.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusTeapot {
t.Errorf("expected status code to be %v, got %v", http.StatusTeapot, status)
}
expected := "Hello, World!"
if rr.Body.String() != expected {
t.Errorf("expected response body to be %v, got %v", expected, rr.Body.String())
}
}
func TestMiddleware_Logger_remoteAddr(t *testing.T) {
logger := &mockLogger{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte("Hello, World!"))
})
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
req.RemoteAddr = "xhamster.com:1234"
rr := httptest.NewRecorder()
middleware.ServeHTTP(rr, req)
}
func TestMiddleware_Logger_remoteAddrNoPort(t *testing.T) {
logger := &mockLogger{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte("Hello, World!"))
})
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
req.RemoteAddr = "xhamster.com"
rr := httptest.NewRecorder()
middleware.ServeHTTP(rr, req)
}
func TestMiddleware_Logger_remoteAddrV6(t *testing.T) {
logger := &mockLogger{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte("Hello, World!"))
})
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
req.RemoteAddr = "[::1]:4711"
rr := httptest.NewRecorder()
middleware.ServeHTTP(rr, req)
}
func TestMiddleware_Logger_remoteAddrV6NoPort(t *testing.T) {
logger := &mockLogger{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte("Hello, World!"))
})
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
req.RemoteAddr = "[::1]"
rr := httptest.NewRecorder()
middleware.ServeHTTP(rr, req)
}

View File

@@ -0,0 +1,80 @@
package logging
// options is a struct that contains options for the logging middleware.
// It uses the functional options pattern for flexible configuration.
type options struct {
logLevel LogLevel
logger Logger
prefix string
contextRequestIdKey string
headerRequestIdKey string
}
// Option is a type that is used to set options for the logging middleware.
// It implements the functional options pattern.
type Option func(*options)
// WithLevel is a method that sets the log level for the logging middleware.
// Possible values are LogLevelDebug, LogLevelInfo, LogLevelWarn, and LogLevelError.
// The default value is LogLevelInfo.
func WithLevel(level LogLevel) Option {
return func(o *options) {
o.logLevel = level
}
}
// WithPrefix is a method that sets the prefix for the logging middleware.
// If a prefix is set, it will be prepended to each log message. A space will
// be added between the prefix and the log message.
// The default value is an empty string.
func WithPrefix(prefix string) Option {
return func(o *options) {
o.prefix = prefix
}
}
// WithContextRequestIdKey is a method that sets the key for the request ID in the
// request context. If a key is set, the logging middleware will use this key to
// retrieve the request ID from the request context.
// The default value is an empty string, meaning the request ID will not be logged.
func WithContextRequestIdKey(key string) Option {
return func(o *options) {
o.contextRequestIdKey = key
}
}
// WithHeaderRequestIdKey is a method that sets the key for the request ID in the
// request headers. If a key is set, the logging middleware will use this key to
// retrieve the request ID from the request headers.
// The default value is an empty string, meaning the request ID will not be logged.
func WithHeaderRequestIdKey(key string) Option {
return func(o *options) {
o.headerRequestIdKey = key
}
}
// WithLogger is a method that sets the logger for the logging middleware.
// If a logger is set, the logging middleware will use this logger to log messages.
// The default logger is the structured slog logger.
func WithLogger(logger Logger) Option {
return func(o *options) {
o.logger = logger
}
}
// newOptions is a function that returns a new options struct with sane default values.
func newOptions(opts ...Option) options {
o := options{
logLevel: LogLevelInfo,
logger: nil,
prefix: "",
contextRequestIdKey: "",
}
for _, opt := range opts {
opt(&o)
}
return o
}

View File

@@ -0,0 +1,88 @@
package logging
import (
"testing"
)
func TestWithLevel(t *testing.T) {
// table test to check all possible log levels
levels := []LogLevel{
LogLevelDebug,
LogLevelInfo,
LogLevelWarn,
LogLevelError,
}
for _, level := range levels {
opt := WithLevel(level)
o := newOptions(opt)
if o.logLevel != level {
t.Errorf("expected log level to be %v, got %v", level, o.logLevel)
}
}
}
func TestWithPrefix(t *testing.T) {
prefix := "TEST"
opt := WithPrefix(prefix)
o := newOptions(opt)
if o.prefix != prefix {
t.Errorf("expected prefix to be %v, got %v", prefix, o.prefix)
}
}
func TestWithContextRequestIdKey(t *testing.T) {
key := "contextKey"
opt := WithContextRequestIdKey(key)
o := newOptions(opt)
if o.contextRequestIdKey != key {
t.Errorf("expected contextRequestIdKey to be %v, got %v", key, o.contextRequestIdKey)
}
}
func TestWithHeaderRequestIdKey(t *testing.T) {
key := "headerKey"
opt := WithHeaderRequestIdKey(key)
o := newOptions(opt)
if o.headerRequestIdKey != key {
t.Errorf("expected headerRequestIdKey to be %v, got %v", key, o.headerRequestIdKey)
}
}
func TestWithLogger(t *testing.T) {
logger := &mockLogger{}
opt := WithLogger(logger)
o := newOptions(opt)
if o.logger != logger {
t.Errorf("expected logger to be %v, got %v", logger, o.logger)
}
}
func TestDefaults(t *testing.T) {
o := newOptions()
if o.logLevel != LogLevelInfo {
t.Errorf("expected log level to be %v, got %v", LogLevelInfo, o.logLevel)
}
if o.logger != nil {
t.Errorf("expected logger to be nil, got %v", o.logger)
}
if o.prefix != "" {
t.Errorf("expected prefix to be empty, got %v", o.prefix)
}
if o.contextRequestIdKey != "" {
t.Errorf("expected contextRequestIdKey to be empty, got %v", o.contextRequestIdKey)
}
if o.headerRequestIdKey != "" {
t.Errorf("expected headerRequestIdKey to be empty, got %v", o.headerRequestIdKey)
}
}

View File

@@ -0,0 +1,45 @@
package logging
import (
"net/http"
)
// writerWrapper wraps a http.ResponseWriter and tracks the number of bytes written to it.
// It also tracks the http response code passed to the WriteHeader func of
// the ResponseWriter.
type writerWrapper struct {
http.ResponseWriter
// StatusCode is the last http response code passed to the WriteHeader func of
// the ResponseWriter. If no such call is made, a default code of http.StatusOK
// is assumed instead.
StatusCode int
// WrittenBytes is the number of bytes successfully written by the Write or
// ReadFrom function of the ResponseWriter. ResponseWriters may also write
// data to their underlaying connection directly (e.g. headers), but those
// are not tracked. Therefor the number of Written bytes will usually match
// the size of the response body.
WrittenBytes int64
}
// WriteHeader wraps the WriteHeader method of the ResponseWriter and tracks the
// http response code passed to it.
func (w *writerWrapper) WriteHeader(code int) {
w.StatusCode = code
w.ResponseWriter.WriteHeader(code)
}
// Write wraps the Write method of the ResponseWriter and tracks the number of bytes
// written to it.
func (w *writerWrapper) Write(data []byte) (int, error) {
n, err := w.ResponseWriter.Write(data)
w.WrittenBytes += int64(n)
return n, err
}
// newWriterWrapper returns a new writerWrapper that wraps the given http.ResponseWriter.
// It initializes the StatusCode to http.StatusOK.
func newWriterWrapper(w http.ResponseWriter) *writerWrapper {
return &writerWrapper{ResponseWriter: w, StatusCode: http.StatusOK}
}

View File

@@ -0,0 +1,85 @@
package logging
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestWriterWrapper_WriteHeader(t *testing.T) {
rr := httptest.NewRecorder()
ww := newWriterWrapper(rr)
ww.WriteHeader(http.StatusNotFound)
if ww.StatusCode != http.StatusNotFound {
t.Errorf("expected status code to be %v, got %v", http.StatusNotFound, ww.StatusCode)
}
if rr.Code != http.StatusNotFound {
t.Errorf("expected recorder status code to be %v, got %v", http.StatusNotFound, rr.Code)
}
}
func TestWriterWrapper_Write(t *testing.T) {
rr := httptest.NewRecorder()
ww := newWriterWrapper(rr)
data := []byte("Hello, World!")
n, err := ww.Write(data)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if n != len(data) {
t.Errorf("expected written bytes to be %v, got %v", len(data), n)
}
if ww.WrittenBytes != int64(len(data)) {
t.Errorf("expected WrittenBytes to be %v, got %v", len(data), ww.WrittenBytes)
}
if rr.Body.String() != string(data) {
t.Errorf("expected response body to be %v, got %v", string(data), rr.Body.String())
}
}
func TestWriterWrapper_WriteWithHeaders(t *testing.T) {
rr := httptest.NewRecorder()
ww := newWriterWrapper(rr)
data := []byte("Hello, World!")
n, err := ww.Write(data)
ww.Header().Set("Content-Type", "text/plain")
ww.Header().Set("X-Some-Header", "some-value")
ww.WriteHeader(http.StatusTeapot)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if n != len(data) {
t.Errorf("expected written bytes to be %v, got %v", len(data), n)
}
if ww.WrittenBytes != int64(len(data)) {
t.Errorf("expected WrittenBytes to be %v, got %v", len(data), ww.WrittenBytes)
}
if rr.Body.String() != string(data) {
t.Errorf("expected response body to be %v, got %v", string(data), rr.Body.String())
}
if ww.StatusCode != http.StatusTeapot {
t.Errorf("expected status code to be %v, got %v", http.StatusTeapot, ww.StatusCode)
}
}
func TestNewWriterWrapper(t *testing.T) {
rr := httptest.NewRecorder()
ww := newWriterWrapper(rr)
if ww.StatusCode != http.StatusOK {
t.Errorf("expected initial status code to be %v, got %v", http.StatusOK, ww.StatusCode)
}
if ww.WrittenBytes != 0 {
t.Errorf("expected initial WrittenBytes to be %v, got %v", 0, ww.WrittenBytes)
}
if ww.ResponseWriter != rr {
t.Errorf("expected ResponseWriter to be %v, got %v", rr, ww.ResponseWriter)
}
}

View File

@@ -0,0 +1,133 @@
package recovery
import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"os"
"runtime/debug"
"strings"
)
// Logger is an interface that defines the methods that a logger must implement.
// This allows the logging middleware to be used with different logging libraries.
type Logger interface {
// Errorf logs a message at error level.
Errorf(format string, args ...any)
}
// Middleware is a type that creates a new recovery middleware. The recovery middleware
// recovers from panics and returns an Internal Server Error response. This middleware should
// be the first middleware in the middleware chain, so that it can recover from panics in other
// middlewares.
type Middleware struct {
o options
}
// New returns a new recovery middleware with the provided options.
func New(opts ...Option) *Middleware {
o := newOptions(opts...)
m := &Middleware{
o: o,
}
return m
}
// Handler returns the recovery middleware handler.
func (m *Middleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
stack := debug.Stack()
realErr, ok := err.(error)
if !ok {
realErr = fmt.Errorf("%v", err)
}
// Check for a broken connection, as it is not really a
// condition that warrants a panic stack trace.
brokenPipe := isBrokenPipeError(realErr)
// Log the error and stack trace
if m.o.logCallback != nil {
m.o.logCallback(realErr, stack, brokenPipe)
}
switch {
case brokenPipe && m.o.brokenPipeCallback != nil:
m.o.brokenPipeCallback(realErr, stack, w, r)
case !brokenPipe && m.o.errCallback != nil:
m.o.errCallback(realErr, stack, w, r)
default:
// no callback set, simply recover and do nothing...
}
}
}()
next.ServeHTTP(w, r)
})
}
func addPrefix(o options, message string) string {
if o.defaultLogPrefix != "" {
return o.defaultLogPrefix + " " + message
}
return message
}
// defaultErrCallback is the default error callback function for the recovery middleware.
// It writes a JSON response with an Internal Server Error status code. If the exposeStackTrace option is
// enabled, the stack trace is included in the response.
func getDefaultErrCallback(o options) func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {
return func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {
responseBody := map[string]interface{}{
"error": "Internal Server Error",
}
if o.exposeStackTrace && len(stack) > 0 {
responseBody["stack"] = string(stack)
}
jsonBody, _ := json.Marshal(responseBody)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write(jsonBody)
}
}
// getDefaultLogCallback is the default log callback function for the recovery middleware.
// It logs the error and stack trace using the structured slog logger or the provided logger in Error level.
func getDefaultLogCallback(o options) func(error, []byte, bool) {
return func(err error, stack []byte, brokenPipe bool) {
if brokenPipe {
return // by default, ignore broken pipe errors
}
switch {
case o.useSlog:
slog.Error(addPrefix(o, err.Error()), "stack", string(stack))
case o.logger != nil:
o.logger.Errorf(fmt.Sprintf("%s; stacktrace=%s", addPrefix(o, err.Error()), string(stack)))
default:
// no logger set, do nothing...
}
}
}
func isBrokenPipeError(err error) bool {
var syscallErr *os.SyscallError
if errors.As(err, &syscallErr) {
errMsg := strings.ToLower(syscallErr.Err.Error())
if strings.Contains(errMsg, "broken pipe") ||
strings.Contains(errMsg, "connection reset by peer") {
return true
}
}
return false
}

View File

@@ -0,0 +1,149 @@
package recovery
import (
"errors"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
)
type mockLogger struct{}
func (m *mockLogger) Errorf(_ string, _ ...any) {}
func TestMiddleware(t *testing.T) {
tests := []struct {
name string
options []Option
panicSimulator func()
expectedStatus int
expectedBody string
expectStack bool
}{
{
name: "default behavior",
options: []Option{},
panicSimulator: func() {
panic(errors.New("test panic"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: `{"error":"Internal Server Error"}`,
},
{
name: "custom error callback",
options: []Option{
WithErrCallback(func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
w.Write([]byte("custom error"))
}),
},
panicSimulator: func() {
panic(errors.New("test panic"))
},
expectedStatus: http.StatusTeapot,
expectedBody: "custom error",
},
{
name: "broken pipe error",
options: []Option{
WithBrokenPipeCallback(func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte("broken pipe"))
}),
},
panicSimulator: func() {
panic(&os.SyscallError{Err: errors.New("broken pipe")})
},
expectedStatus: http.StatusServiceUnavailable,
expectedBody: "broken pipe",
},
{
name: "default callback broken pipe error",
options: nil,
panicSimulator: func() {
panic(&os.SyscallError{Err: errors.New("broken pipe")})
},
expectedStatus: http.StatusOK,
expectedBody: "",
},
{
name: "default callback normal error",
options: nil,
panicSimulator: func() {
panic("something went wrong")
},
expectedStatus: http.StatusInternalServerError,
expectedBody: "{\"error\":\"Internal Server Error\"}",
},
{
name: "default callback with stack trace",
options: []Option{
WithExposeStackTrace(true),
},
panicSimulator: func() {
panic("something went wrong")
},
expectedStatus: http.StatusInternalServerError,
expectedBody: "\"stack\":",
expectStack: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := New(tt.options...).Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tt.panicSimulator()
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != tt.expectedStatus {
t.Errorf("expected status %v, got %v", tt.expectedStatus, rr.Code)
}
if !tt.expectStack && rr.Body.String() != tt.expectedBody {
t.Errorf("expected body %v, got %v", tt.expectedBody, rr.Body.String())
}
if tt.expectStack && !strings.Contains(rr.Body.String(), tt.expectedBody) {
t.Errorf("expected body to contain %v, got %v", tt.expectedBody, rr.Body.String())
}
})
}
}
func TestIsBrokenPipeError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "broken pipe error",
err: &os.SyscallError{Err: errors.New("broken pipe")},
expected: true,
},
{
name: "connection reset by peer error",
err: &os.SyscallError{Err: errors.New("connection reset by peer")},
expected: true,
},
{
name: "other error",
err: errors.New("other error"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isBrokenPipeError(tt.err)
if result != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, result)
}
})
}
}

View File

@@ -0,0 +1,129 @@
package recovery
import "net/http"
// options is a struct that contains options for the recovery middleware.
// It uses the functional options pattern for flexible configuration.
type options struct {
logger Logger
useSlog bool
errCallbackOverride bool
errCallback func(err error, stack []byte, w http.ResponseWriter, r *http.Request)
brokenPipeCallbackOverride bool
brokenPipeCallback func(err error, stack []byte, w http.ResponseWriter, r *http.Request)
exposeStackTrace bool
defaultLogPrefix string
logCallbackOverride bool
logCallback func(err error, stack []byte, brokenPipe bool)
}
// Option is a type that is used to set options for the recovery middleware.
// It implements the functional options pattern.
type Option func(*options)
// WithErrCallback sets the error callback function for the recovery middleware.
// The error callback function is called when a panic is recovered by the middleware.
// This function completely overrides the default behavior of the middleware. It is the
// responsibility of the user to handle the error and write a response to the client.
//
// Ensure that this function does not panic, as it will be called in a deferred function!
func WithErrCallback(fn func(err error, stack []byte, w http.ResponseWriter, r *http.Request)) Option {
return func(o *options) {
o.errCallback = fn
o.errCallbackOverride = true
}
}
// WithBrokenPipeCallback sets the broken pipe callback function for the recovery middleware.
// The broken pipe callback function is called when a broken pipe error is recovered by the middleware.
// This function completely overrides the default behavior of the middleware. It is the responsibility
// of the user to handle the error and write a response to the client.
//
// Ensure that this function does not panic, as it will be called in a deferred function!
func WithBrokenPipeCallback(fn func(err error, stack []byte, w http.ResponseWriter, r *http.Request)) Option {
return func(o *options) {
o.brokenPipeCallback = fn
o.brokenPipeCallbackOverride = true
}
}
// WithLogCallback sets the log callback function for the recovery middleware.
// The log callback function is called when a panic is recovered by the middleware.
// This function allows the user to log the error and stack trace. The default behavior is to log
// the error and stack trace in Error level.
// This function completely overrides the default behavior of the middleware.
//
// Ensure that this function does not panic, as it will be called in a deferred function!
func WithLogCallback(fn func(err error, stack []byte, brokenPipe bool)) Option {
return func(o *options) {
o.logCallback = fn
o.logCallbackOverride = true
}
}
// WithLogger is a method that sets the logger for the logging middleware.
// If a logger is set, the logging middleware will use this logger to log messages.
// The default logger is the structured slog logger, see WithSlog.
func WithLogger(logger Logger) Option {
return func(o *options) {
o.logger = logger
}
}
// WithSlog is a method that sets whether the recovery middleware should use the structured slog logger.
// If set to true, the middleware will use the structured slog logger. If set to false, the middleware
// will not use any logger unless one is explicitly set with the WithLogger option.
// The default value is true.
func WithSlog(useSlog bool) Option {
return func(o *options) {
o.useSlog = useSlog
}
}
// WithDefaultLogPrefix is a method that sets the default log prefix for the recovery middleware.
// If a default log prefix is set and the default log callback is used, the prefix will be prepended
// to each log message. A space will be added between the prefix and the log message.
// The default value is an empty string.
func WithDefaultLogPrefix(defaultLogPrefix string) Option {
return func(o *options) {
o.defaultLogPrefix = defaultLogPrefix
}
}
// WithExposeStackTrace is a method that sets whether the stack trace should be exposed in the response.
// If set to true, the stack trace will be included in the response body. If set to false, the stack trace
// will not be included in the response body. This only applies to the default error callback.
// The default value is false.
func WithExposeStackTrace(exposeStackTrace bool) Option {
return func(o *options) {
o.exposeStackTrace = exposeStackTrace
}
}
// newOptions is a function that returns a new options struct with sane default values.
func newOptions(opts ...Option) options {
o := options{
logger: nil,
useSlog: true,
errCallback: nil,
brokenPipeCallback: nil, // by default, ignore broken pipe errors
exposeStackTrace: false,
defaultLogPrefix: "",
logCallback: nil,
}
for _, opt := range opts {
opt(&o)
}
if o.errCallback == nil && !o.errCallbackOverride {
o.errCallback = getDefaultErrCallback(o)
}
if o.logCallback == nil && !o.logCallbackOverride {
o.logCallback = getDefaultLogCallback(o)
}
return o
}

View File

@@ -0,0 +1,100 @@
package recovery
import (
"net/http"
"testing"
)
func TestWithErrCallback(t *testing.T) {
callback := func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {}
opt := WithErrCallback(callback)
o := newOptions(opt)
if o.errCallback == nil {
t.Errorf("expected errCallback to be set, got nil")
}
}
func TestWithBrokenPipeCallback(t *testing.T) {
callback := func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {}
opt := WithBrokenPipeCallback(callback)
o := newOptions(opt)
if o.brokenPipeCallback == nil {
t.Errorf("expected brokenPipeCallback to be set, got nil")
}
}
func TestWithLogCallback(t *testing.T) {
callback := func(err error, stack []byte, brokenPipe bool) {}
opt := WithLogCallback(callback)
o := newOptions(opt)
if o.logCallback == nil {
t.Errorf("expected logCallback to be set, got nil")
}
}
func TestWithLogger(t *testing.T) {
logger := &mockLogger{}
opt := WithLogger(logger)
o := newOptions(opt)
if o.logger != logger {
t.Errorf("expected logger to be %v, got %v", logger, o.logger)
}
}
func TestWithSlog(t *testing.T) {
opt := WithSlog(false)
o := newOptions(opt)
if o.useSlog != false {
t.Errorf("expected useSlog to be false, got %v", o.useSlog)
}
}
func TestWithDefaultLogPrefix(t *testing.T) {
prefix := "PREFIX"
opt := WithDefaultLogPrefix(prefix)
o := newOptions(opt)
if o.defaultLogPrefix != prefix {
t.Errorf("expected defaultLogPrefix to be %v, got %v", prefix, o.defaultLogPrefix)
}
}
func TestWithExposeStackTrace(t *testing.T) {
opt := WithExposeStackTrace(true)
o := newOptions(opt)
if o.exposeStackTrace != true {
t.Errorf("expected exposeStackTrace to be true, got %v", o.exposeStackTrace)
}
}
func TestNewOptionsDefaults(t *testing.T) {
o := newOptions()
if o.logger != nil {
t.Errorf("expected logger to be nil, got %v", o.logger)
}
if o.useSlog != true {
t.Errorf("expected useSlog to be true, got %v", o.useSlog)
}
if o.errCallback == nil {
t.Errorf("expected errCallback to be set, got nil")
}
if o.brokenPipeCallback != nil {
t.Errorf("expected brokenPipeCallback to be nil, got %T", o.brokenPipeCallback)
}
if o.exposeStackTrace != false {
t.Errorf("expected exposeStackTrace to be false, got %T", o.exposeStackTrace)
}
if o.defaultLogPrefix != "" {
t.Errorf("expected defaultLogPrefix to be empty, got %T", o.defaultLogPrefix)
}
if o.logCallback == nil {
t.Errorf("expected logCallback to be set, got nil")
}
}

View File

@@ -0,0 +1,69 @@
package tracing
import (
"context"
"math/rand"
"net/http"
)
// Middleware is a type that creates a new tracing middleware. The tracing middleware
// can be used to trace requests based on a request ID header or parameter.
type Middleware struct {
o options
seededRand *rand.Rand
}
// New returns a new CORS middleware with the provided options.
func New(opts ...Option) *Middleware {
o := newOptions(opts...)
m := &Middleware{
o: o,
seededRand: rand.New(rand.NewSource(o.generateSeed)),
}
return m
}
// Handler returns the tracing middleware handler.
func (m *Middleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var reqId string
// read upstream header und re-use it
if m.o.upstreamReqIdHeader != "" {
reqId = r.Header.Get(m.o.upstreamReqIdHeader)
}
// generate new id
if reqId == "" && m.o.generateLength > 0 {
reqId = m.generateRandomId()
}
// set response header
if m.o.headerIdentifier != "" {
w.Header().Set(m.o.headerIdentifier, reqId)
}
// set context value
if m.o.contextIdentifier != "" {
ctx := context.WithValue(r.Context(), m.o.contextIdentifier, reqId)
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r) // execute the next handler
})
}
// region internal-helpers
func (m *Middleware) generateRandomId() string {
b := make([]byte, m.o.generateLength)
for i := range b {
b[i] = m.o.generateCharset[m.seededRand.Intn(len(m.o.generateCharset))]
}
return string(b)
}
// endregion internal-helpers

View File

@@ -0,0 +1,118 @@
package tracing
import (
"net/http"
"net/http/httptest"
"testing"
)
const defaultLength = 8
const upstreamHeaderValue = "upstream-id"
func TestMiddleware_Handler_WithUpstreamHeader(t *testing.T) {
m := New(WithUpstreamHeader("X-Upstream-Id"))
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqId := r.Header.Get("X-Upstream-Id")
if reqId != upstreamHeaderValue {
t.Errorf("expected upstream request id to be 'upstream-id', got %s", reqId)
}
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Upstream-Id", upstreamHeaderValue)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Header().Get("X-Request-Id") != upstreamHeaderValue {
t.Errorf("expected X-Request-Id header to be set in the response")
}
}
func TestMiddleware_Handler_GenerateNewId(t *testing.T) {
idLen := 18
m := New(WithIdLength(idLen))
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqId := w.Header().Get("X-Request-Id")
if len(reqId) != 18 {
t.Errorf("expected generated request id length to be %d, got %d", idLen, len(reqId))
}
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Header().Get("X-Request-Id") == "" || len(rr.Header().Get("X-Request-Id")) != idLen {
t.Errorf("expected X-Request-Id header to be set in the response")
}
}
func TestMiddleware_Handler_SetContextValue(t *testing.T) {
m := New()
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqId := r.Context().Value("RequestId").(string)
if reqId == "" || len(reqId) != defaultLength {
t.Errorf("expected context request id to be set, got empty string")
}
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
}
func TestMiddleware_Handler_SetCustomContextValue(t *testing.T) {
m := New(WithContextIdentifier("Custom-Id"))
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqId := r.Context().Value("Custom-Id").(string)
if reqId == "" || len(reqId) != defaultLength {
t.Errorf("expected context request id to be set, got empty string")
}
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
}
func TestMiddleware_Handler_NoIdGenerated(t *testing.T) {
m := New(WithIdLength(0))
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqId := w.Header().Get("X-Request-Id")
if reqId != "" {
t.Errorf("expected no request id to be generated, got %s", reqId)
}
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
}
func TestMiddleware_Handler_NoIdHeaderSet(t *testing.T) {
m := New(WithHeaderIdentifier(""))
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqId := w.Header().Get("X-Request-Id")
if reqId != "" {
t.Errorf("expected no request id to be generated, got %s", reqId)
}
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
}
func TestMiddleware_Handler_NoIdContextSet(t *testing.T) {
m := New(WithHeaderIdentifier(""))
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqId := r.Context().Value("Request-Id")
if reqId != nil {
t.Errorf("expected no context request id to be set, got %v", reqId)
}
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
}

View File

@@ -0,0 +1,85 @@
package tracing
import "time"
// options is a struct that contains options for the tracing middleware.
// It uses the functional options pattern for flexible configuration.
type options struct {
upstreamReqIdHeader string
headerIdentifier string
contextIdentifier string
generateLength int
generateCharset string
generateSeed int64
}
// Option is a type that is used to set options for the tracing middleware.
// It implements the functional options pattern.
type Option func(*options)
// WithIdSeed sets the seed for the random request id.
// If no seed is provided, the current timestamp is used.
func WithIdSeed(seed int64) Option {
return func(o *options) {
o.generateSeed = seed
}
}
// WithIdCharset sets the charset that is used to generate a random request id.
// By default, upper-case letters and numbers are used.
func WithIdCharset(charset string) Option {
return func(o *options) {
o.generateCharset = charset
}
}
// WithIdLength specifies the length of generated random ids.
// By default, a length of 8 is used. If the length is 0, no request id will be generated.
func WithIdLength(len int) Option {
return func(o *options) {
o.generateLength = len
}
}
// WithHeaderIdentifier specifies the header name for the request id that is added to the response headers.
// If the identifier is empty, the request id will not be added to the response headers.
func WithHeaderIdentifier(identifier string) Option {
return func(o *options) {
o.headerIdentifier = identifier
}
}
// WithUpstreamHeader sets the upstream header name, that should be used to fetch the request id.
// If no upstream header is found, a random id will be generated if the id-length parameter is set to a value > 0.
func WithUpstreamHeader(header string) Option {
return func(o *options) {
o.upstreamReqIdHeader = header
}
}
// WithContextIdentifier specifies the value-key for the request id that is added to the request context.
// If the identifier is empty, the request id will not be added to the context.
// If the request id is added to the context, it can be retrieved with:
// `id := r.Context().Value(THE-IDENTIFIER).(string)`
func WithContextIdentifier(identifier string) Option {
return func(o *options) {
o.contextIdentifier = identifier
}
}
// newOptions is a function that returns a new options struct with sane default values.
func newOptions(opts ...Option) options {
o := options{
headerIdentifier: "X-Request-Id",
contextIdentifier: "RequestId",
generateSeed: time.Now().UnixNano(),
generateCharset: "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789",
generateLength: 8,
}
for _, opt := range opts {
opt(&o)
}
return o
}

View File

@@ -0,0 +1,75 @@
package tracing
import (
"testing"
)
func TestWithIdSeed(t *testing.T) {
o := newOptions(WithIdSeed(12345))
if o.generateSeed != 12345 {
t.Errorf("expected generateSeed to be 12345, got %d", o.generateSeed)
}
}
func TestWithIdCharset(t *testing.T) {
o := newOptions(WithIdCharset("abc123"))
if o.generateCharset != "abc123" {
t.Errorf("expected generateCharset to be 'abc123', got %s", o.generateCharset)
}
}
func TestWithIdLength(t *testing.T) {
o := newOptions(WithIdLength(16))
if o.generateLength != 16 {
t.Errorf("expected generateLength to be 16, got %d", o.generateLength)
}
}
func TestWithHeaderIdentifier(t *testing.T) {
o := newOptions(WithHeaderIdentifier("X-Custom-Id"))
if o.headerIdentifier != "X-Custom-Id" {
t.Errorf("expected headerIdentifier to be 'X-Custom-Id', got %s", o.headerIdentifier)
}
}
func TestWithUpstreamHeader(t *testing.T) {
o := newOptions(WithUpstreamHeader("X-Upstream-Id"))
if o.upstreamReqIdHeader != "X-Upstream-Id" {
t.Errorf("expected upstreamReqIdHeader to be 'X-Upstream-Id', got %s", o.upstreamReqIdHeader)
}
}
func TestWithContextIdentifier(t *testing.T) {
o := newOptions(WithContextIdentifier("Request-Id"))
if o.contextIdentifier != "Request-Id" {
t.Errorf("expected contextIdentifier to be 'Request-Id', got %s", o.contextIdentifier)
}
}
func TestDefaults(t *testing.T) {
o := newOptions()
if o.generateLength != 8 {
t.Errorf("expected generateLength to be 8, got %d", o.generateLength)
}
if o.generateCharset != "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" {
t.Errorf("expected generateCharset to be 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', got %s", o.generateCharset)
}
if o.generateSeed == 0 {
t.Errorf("expected generateSeed to be non-zero")
}
if o.headerIdentifier != "X-Request-Id" {
t.Errorf("expected headerIdentifier to be 'X-Request-Id', got %s", o.headerIdentifier)
}
if o.upstreamReqIdHeader != "" {
t.Errorf("expected upstreamReqIdHeader to be empty, got %s", o.upstreamReqIdHeader)
}
if o.contextIdentifier != "RequestId" {
t.Errorf("expected contextIdentifier to be 'RequestId', got %s", o.contextIdentifier)
}
}

View File

@@ -0,0 +1,259 @@
// Package request provides functions to extract parameters from the request.
package request
import (
"encoding/json"
"io"
"net"
"net/http"
"net/textproto"
"slices"
"strings"
)
const CheckPrivateProxy = "PRIVATE"
// PathRaw returns the value of the named path parameter.
func PathRaw(r *http.Request, name string) string {
return r.PathValue(name)
}
// Path returns the value of the named path parameter.
// The return value is trimmed of leading and trailing whitespace.
func Path(r *http.Request, name string) string {
return strings.TrimSpace(PathRaw(r, name))
}
// PathDefault returns the value of the named path parameter.
// If the parameter is empty, it returns the default value.
// The return value is trimmed of leading and trailing whitespace.
func PathDefault(r *http.Request, name string, defaultValue string) string {
value := r.PathValue(name)
if value == "" {
return defaultValue
}
return Path(r, name)
}
// QueryRaw returns the value of the named query parameter.
func QueryRaw(r *http.Request, name string) string {
return r.URL.Query().Get(name)
}
// Query returns the value of the named query parameter.
// The return value is trimmed of leading and trailing whitespace.
func Query(r *http.Request, name string) string {
return strings.TrimSpace(QueryRaw(r, name))
}
// QueryDefault returns the value of the named query parameter.
// If the parameter is empty, it returns the default value.
// The return value is trimmed of leading and trailing whitespace.
func QueryDefault(r *http.Request, name string, defaultValue string) string {
if !r.URL.Query().Has(name) {
return defaultValue
}
return Query(r, name)
}
// QuerySlice returns the value of the named query parameter.
// All slice values are trimmed of leading and trailing whitespace.
func QuerySlice(r *http.Request, name string) []string {
values, ok := r.URL.Query()[name]
if !ok {
return nil
}
result := make([]string, len(values))
for i, value := range values {
result[i] = strings.TrimSpace(value)
}
return result
}
// QuerySliceDefault returns the value of the named query parameter.
// If the parameter is empty, it returns the default value.
// All slice values are trimmed of leading and trailing whitespace.
func QuerySliceDefault(r *http.Request, name string, defaultValue []string) []string {
if !r.URL.Query().Has(name) {
return defaultValue
}
return QuerySlice(r, name)
}
// FragmentRaw returns the value of the named fragment parameter.
func FragmentRaw(r *http.Request) string {
return r.URL.Fragment
}
// Fragment returns the value of the named fragment parameter.
// The return value is trimmed of leading and trailing whitespace.
func Fragment(r *http.Request) string {
return strings.TrimSpace(FragmentRaw(r))
}
// FragmentDefault returns the value of the named fragment parameter.
// If the parameter is empty, it returns the default value.
// The return value is trimmed of leading and trailing whitespace.
func FragmentDefault(r *http.Request, defaultValue string) string {
if r.URL.Fragment == "" {
return defaultValue
}
return Fragment(r)
}
// FormRaw returns the value of the named form parameter.
func FormRaw(r *http.Request, name string) string {
return r.FormValue(name)
}
// Form returns the value of the named form parameter.
// The return value is trimmed of leading and trailing whitespace.
func Form(r *http.Request, name string) string {
return strings.TrimSpace(FormRaw(r, name))
}
// DefaultForm returns the value of the named form parameter.
// If the parameter is not set, it returns the default value.
// The return value is trimmed of leading and trailing whitespace.
func DefaultForm(r *http.Request, name, defaultValue string) string {
err := r.ParseForm()
if err != nil {
return defaultValue
}
if !r.Form.Has(name) {
return defaultValue
}
return Form(r, name)
}
// HeaderRaw returns the value of the named header.
func HeaderRaw(r *http.Request, name string) string {
return r.Header.Get(name)
}
// Header returns the value of the named header.
// The return value is trimmed of leading and trailing whitespace.
func Header(r *http.Request, name string) string {
return strings.TrimSpace(HeaderRaw(r, name))
}
// HeaderDefault returns the value of the named header.
// If the header is not set, it returns the default value.
// The return value is trimmed of leading and trailing whitespace.
func HeaderDefault(r *http.Request, name, defaultValue string) string {
if _, ok := textproto.MIMEHeader(r.Header)[name]; !ok {
return defaultValue
}
return Header(r, name)
}
// Cookie returns the value of the named cookie.
// The return value is trimmed of leading and trailing whitespace.
func Cookie(r *http.Request, name string) string {
cookie, err := r.Cookie(name)
if err != nil {
return ""
}
return strings.TrimSpace(cookie.Value)
}
// CookieDefault returns the value of the named cookie.
// If the cookie is not set, it returns the default value.
// The return value is trimmed of leading and trailing whitespace.
func CookieDefault(r *http.Request, name, defaultValue string) string {
cookie, err := r.Cookie(name)
if err != nil {
return defaultValue
}
return strings.TrimSpace(cookie.Value)
}
// ClientIp returns the client IP address.
//
// As the request may come from a proxy, the function checks the
// X-Real-Ip and X-Forwarded-For headers to get the real client IP
// if the request IP matches one of the allowed proxy IPs.
// If the special proxy value CheckPrivateProxy ("PRIVATE") is passed, the function will
// also check the header if the request IP is a private IP address.
func ClientIp(r *http.Request, allowedProxyIp ...string) string {
ipStr, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
switch {
case err != nil && strings.Contains(err.Error(), "missing port in address"):
ipStr = strings.TrimSpace(r.RemoteAddr)
case err != nil:
ipStr = ""
}
IP := net.ParseIP(ipStr)
if IP == nil {
return ""
}
isProxiedRequest := false
if len(allowedProxyIp) > 0 {
if slices.Contains(allowedProxyIp, IP.String()) {
isProxiedRequest = true
}
if IP.IsPrivate() && slices.Contains(allowedProxyIp, CheckPrivateProxy) {
isProxiedRequest = true
}
}
if isProxiedRequest {
realClientIP := r.Header.Get("X-Real-Ip")
if realClientIP == "" {
realClientIP = r.Header.Get("X-Forwarded-For")
}
if realClientIP != "" {
realIpStr, _, err := net.SplitHostPort(strings.TrimSpace(realClientIP))
switch {
case err != nil && strings.Contains(err.Error(), "missing port in address"):
realIpStr = realClientIP
case err != nil:
realIpStr = ipStr
}
realIP := net.ParseIP(realIpStr)
if realIP == nil {
return IP.String()
}
return realIP.String()
}
}
return IP.String()
}
// BodyJson decodes the JSON value from the request body into the target.
// The target must be a pointer to a struct or slice.
// The function returns an error if the JSON value could not be decoded.
// The body reader is closed after reading.
func BodyJson(r *http.Request, target any) error {
defer func() {
_ = r.Body.Close()
}()
return json.NewDecoder(r.Body).Decode(target)
}
// BodyString returns the request body as a string.
// The content is read and returned as is, without any processing.
// The body is assumed to be UTF-8 encoded.
func BodyString(r *http.Request) (string, error) {
defer func() {
_ = r.Body.Close()
}()
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
return "", err
}
return string(bodyBytes), nil
}

View File

@@ -0,0 +1,221 @@
package request
import (
"io"
"net/http"
"net/url"
"slices"
"strings"
"testing"
)
func TestPath(t *testing.T) {
r := &http.Request{URL: &url.URL{Path: "/test/sample"}}
r.SetPathValue("first", "test")
if got := Path(r, "first"); got != "test" {
t.Errorf("Path() = %v, want %v", got, "test")
}
}
func TestDefaultPath(t *testing.T) {
r := &http.Request{URL: &url.URL{Path: "/"}}
if got := PathDefault(r, "test", "default"); got != "default" {
t.Errorf("PathDefault() = %v, want %v", got, "default")
}
}
func TestDefaultPath_noDefault(t *testing.T) {
r := &http.Request{URL: &url.URL{Path: "/"}}
r.SetPathValue("first", "test")
if got := PathDefault(r, "first", "test"); got != "test" {
t.Errorf("PathDefault() = %v, want %v", got, "test")
}
}
func TestQuery(t *testing.T) {
r := &http.Request{URL: &url.URL{RawQuery: "name=value"}}
if got := Query(r, "name"); got != "value" {
t.Errorf("Query() = %v, want %v", got, "value")
}
}
func TestDefaultQuery(t *testing.T) {
r := &http.Request{URL: &url.URL{RawQuery: ""}}
if got := QueryDefault(r, "name", "default"); got != "default" {
t.Errorf("QueryDefault() = %v, want %v", got, "default")
}
}
func TestQuerySlice(t *testing.T) {
r := &http.Request{URL: &url.URL{RawQuery: "name=value1 &name=value2"}}
expected := []string{"value1", "value2"}
if got := QuerySlice(r, "name"); !slices.Equal(got, expected) {
t.Errorf("QuerySlice() = %v, want %v", got, expected)
}
}
func TestQuerySlice_empty(t *testing.T) {
r := &http.Request{URL: &url.URL{RawQuery: "name=value1&name=value2"}}
if got := QuerySlice(r, "nix"); !slices.Equal(got, nil) {
t.Errorf("QuerySlice() = %v, want %v", got, nil)
}
}
func TestDefaultQuerySlice(t *testing.T) {
r := &http.Request{URL: &url.URL{RawQuery: ""}}
defaultValue := []string{"default1", "default2"}
if got := QuerySliceDefault(r, "name", defaultValue); !slices.Equal(got, defaultValue) {
t.Errorf("QuerySliceDefault() = %v, want %v", got, defaultValue)
}
}
func TestFragment(t *testing.T) {
r := &http.Request{URL: &url.URL{Fragment: "section"}}
if got := Fragment(r); got != "section" {
t.Errorf("Fragment() = %v, want %v", got, "section")
}
}
func TestDefaultFragment(t *testing.T) {
r := &http.Request{URL: &url.URL{Fragment: ""}}
if got := FragmentDefault(r, "default"); got != "default" {
t.Errorf("FragmentDefault() = %v, want %v", got, "default")
}
}
func TestForm(t *testing.T) {
r := &http.Request{Form: url.Values{"name": {"value"}}}
if got := Form(r, "name"); got != "value" {
t.Errorf("Form() = %v, want %v", got, "value")
}
}
func TestDefaultForm(t *testing.T) {
r := &http.Request{Form: url.Values{}}
if got := DefaultForm(r, "name", "default"); got != "default" {
t.Errorf("DefaultForm() = %v, want %v", got, "default")
}
}
func TestHeader(t *testing.T) {
r := &http.Request{Header: http.Header{"X-Test-Header": {"value"}}}
if got := Header(r, "X-Test-Header"); got != "value" {
t.Errorf("Header() = %v, want %v", got, "value")
}
}
func TestDefaultHeader(t *testing.T) {
r := &http.Request{Header: http.Header{}}
if got := HeaderDefault(r, "X-Test-Header", "default"); got != "default" {
t.Errorf("HeaderDefault() = %v, want %v", got, "default")
}
}
func TestCookie(t *testing.T) {
r := &http.Request{Header: http.Header{"Cookie": {"name=value"}}}
if got := Cookie(r, "name"); got != "value" {
t.Errorf("Cookie() = %v, want %v", got, "value")
}
}
func TestDefaultCookie(t *testing.T) {
r := &http.Request{Header: http.Header{}}
if got := CookieDefault(r, "name", "default"); got != "default" {
t.Errorf("CookieDefault() = %v, want %v", got, "default")
}
}
func TestClientIp(t *testing.T) {
r := &http.Request{RemoteAddr: "192.168.1.1:12345"}
if got := ClientIp(r); got != "192.168.1.1" {
t.Errorf("ClientIp() = %v, want %v", got, "192.168.1.1")
}
}
func TestClientIp_invalid(t *testing.T) {
r := &http.Request{RemoteAddr: "was_isn_des"}
if got := ClientIp(r); got != "" {
t.Errorf("ClientIp() = %v, want %v", got, "")
}
}
func TestClientIp_ignore_header(t *testing.T) {
r := &http.Request{RemoteAddr: "192.168.1.1:12345", Header: http.Header{"X-Forwarded-For": {"123.45.67.1"}}}
if got := ClientIp(r); got != "192.168.1.1" {
t.Errorf("ClientIp() = %v, want %v", got, "192.168.1.1")
}
}
func TestClientIp_header1(t *testing.T) {
r := &http.Request{RemoteAddr: "192.168.1.1:12345", Header: http.Header{"X-Forwarded-For": {"123.45.67.1"}}}
if got := ClientIp(r, CheckPrivateProxy); got != "123.45.67.1" {
t.Errorf("ClientIp() = %v, want %v", got, "123.45.67.1")
}
}
func TestClientIp_header2(t *testing.T) {
r := &http.Request{RemoteAddr: "192.168.1.1:12345", Header: http.Header{"X-Real-Ip": {"123.45.67.1"}}}
if got := ClientIp(r, CheckPrivateProxy); got != "123.45.67.1" {
t.Errorf("ClientIp() = %v, want %v", got, "123.45.67.1")
}
}
func TestClientIp_header3(t *testing.T) {
r := &http.Request{RemoteAddr: "1.1.1.1:12345", Header: http.Header{"X-Real-Ip": {"123.45.67.1"}}}
if got := ClientIp(r, "1.1.1.1"); got != "123.45.67.1" {
t.Errorf("ClientIp() = %v, want %v", got, "123.45.67.1")
}
}
func TestClientIp_header4(t *testing.T) {
r := &http.Request{RemoteAddr: "8.8.8.8:12345", Header: http.Header{"X-Real-Ip": {"123.45.67.1"}}}
if got := ClientIp(r, "1.1.1.1"); got != "8.8.8.8" {
t.Errorf("ClientIp() = %v, want %v", got, "8.8.8.8")
}
}
func TestClientIp_header_invalid(t *testing.T) {
r := &http.Request{RemoteAddr: "1.1.1.1:12345", Header: http.Header{"X-Real-Ip": {"so-sicher-nit"}}}
if got := ClientIp(r, "1.1.1.1"); got != "1.1.1.1" {
t.Errorf("ClientIp() = %v, want %v", got, "1.1.1.1")
}
}
func TestBodyJson(t *testing.T) {
type TestStruct struct {
Name string `json:"name"`
Value int `json:"value"`
}
jsonStr := `{"name": "test", "value": 123}`
r := &http.Request{
Body: io.NopCloser(strings.NewReader(jsonStr)),
}
var result TestStruct
err := BodyJson(r, &result)
if err != nil {
t.Fatalf("BodyJson() error = %v", err)
}
expected := TestStruct{Name: "test", Value: 123}
if result != expected {
t.Errorf("BodyJson() = %v, want %v", result, expected)
}
}
func TestBodyString(t *testing.T) {
bodyStr := "test body content"
r := &http.Request{
Body: io.NopCloser(strings.NewReader(bodyStr)),
}
result, err := BodyString(r)
if err != nil {
t.Fatalf("BodyString() error = %v", err)
}
if result != bodyStr {
t.Errorf("BodyString() = %v, want %v", result, bodyStr)
}
}

View File

@@ -0,0 +1,100 @@
// Package respond provides a set of utility functions to help with the HTTP response handling.
package respond
import (
"encoding/json"
"io"
"net/http"
"strconv"
)
// Status writes a response with the given status code.
// The response will not contain any data.
func Status(w http.ResponseWriter, code int) {
w.WriteHeader(code)
}
// String writes a plain text response with the given status code and data.
// The Content-Type header is set to text/plain with a charset of utf-8.
func String(w http.ResponseWriter, code int, data string) {
w.Header().Set("Content-Type", "text/plain;charset=utf-8")
w.WriteHeader(code)
_, _ = w.Write([]byte(data))
}
// JSON writes a JSON response with the given status code and data.
// If data is nil, the response will null. The status code is set to the given code.
// The Content-Type header is set to application/json.
// If the given data is not JSON serializable, the response will not contain any data.
// All encoding errors are silently ignored.
func JSON(w http.ResponseWriter, code int, data any) {
w.Header().Set("Content-Type", "application/json")
// if no data was given, simply return null
if data == nil {
w.WriteHeader(code)
_, _ = w.Write([]byte("null"))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
_ = json.NewEncoder(w).Encode(data)
}
// Data writes a response with the given status code, content type, and data.
// If no content type is provided, it is detected from the data.
func Data(w http.ResponseWriter, code int, contentType string, data []byte) {
if contentType == "" {
contentType = http.DetectContentType(data) // ensure content type is set
}
w.Header().Set("Content-Type", contentType)
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
w.WriteHeader(code)
_, _ = w.Write(data)
}
// Reader writes a response with the given status code, content type, and data.
// The content length is optional, it is only set if the given length is greater than 0.
func Reader(w http.ResponseWriter, code int, contentType string, contentLength int, data io.Reader) {
w.Header().Set("Content-Type", contentType)
if contentLength > 0 {
w.Header().Set("Content-Length", strconv.Itoa(contentLength))
}
w.WriteHeader(code)
_, _ = io.Copy(w, data)
}
// Attachment writes a response with the given status code, content type, filename, and data.
// If no content type is provided, it is detected from the data.
func Attachment(w http.ResponseWriter, code int, filename, contentType string, data []byte) {
w.Header().Set("Content-Disposition", "attachment; filename="+filename)
Data(w, code, contentType, data)
}
// AttachmentReader writes a response with the given status code, content type, filename, content length, and data.
// The content length is optional, it is only set if the given length is greater than 0.
func AttachmentReader(
w http.ResponseWriter,
code int,
filename, contentType string,
contentLength int,
data io.Reader,
) {
w.Header().Set("Content-Disposition", "attachment; filename="+filename)
Reader(w, code, contentType, contentLength, data)
}
// Redirect writes a response with the given status code and redirects to the given URL.
// The redirect url will always be an absolute URL. If the given URL is relative,
// the original request URL is used as the base.
func Redirect(w http.ResponseWriter, r *http.Request, code int, url string) {
http.Redirect(w, r, url, code)
}

View File

@@ -0,0 +1,273 @@
package respond
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strconv"
"testing"
)
func TestStatus(t *testing.T) {
rec := httptest.NewRecorder()
Status(rec, http.StatusNoContent)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
t.Errorf("expected status %d, got %d", http.StatusNoContent, res.StatusCode)
}
body, _ := io.ReadAll(res.Body)
if len(body) != 0 {
t.Errorf("expected no body, got %s", body)
}
}
func TestString(t *testing.T) {
rec := httptest.NewRecorder()
String(rec, http.StatusOK, "Hello, World!")
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain;charset=utf-8" {
t.Errorf("expected content type %s, got %s", "text/plain;charset=utf-8", contentType)
}
body, _ := io.ReadAll(res.Body)
if string(body) != "Hello, World!" {
t.Errorf("expected body %s, got %s", "Hello, World!", string(body))
}
}
func TestJSON(t *testing.T) {
rec := httptest.NewRecorder()
data := map[string]string{"hello": "world"}
JSON(rec, http.StatusOK, data)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentType := res.Header.Get("Content-Type"); contentType != "application/json" {
t.Errorf("expected content type %s, got %s", "application/json", contentType)
}
var body map[string]string
_ = json.NewDecoder(res.Body).Decode(&body)
if body["hello"] != "world" {
t.Errorf("expected body %v, got %v", data, body)
}
}
func TestJSON_empty(t *testing.T) {
rec := httptest.NewRecorder()
JSON(rec, http.StatusOK, nil)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentType := res.Header.Get("Content-Type"); contentType != "application/json" {
t.Errorf("expected content type %s, got %s", "application/json", contentType)
}
body, _ := io.ReadAll(res.Body)
if string(body) != "null" {
t.Errorf("expected body %s, got %s", "null", body)
}
}
func TestData(t *testing.T) {
rec := httptest.NewRecorder()
data := []byte("Hello, World!")
Data(rec, http.StatusOK, "text/plain", data)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" {
t.Errorf("expected content type %s, got %s", "text/plain", contentType)
}
body, _ := io.ReadAll(res.Body)
if !bytes.Equal(body, data) {
t.Errorf("expected body %s, got %s", data, body)
}
}
func TestData_noContentType(t *testing.T) {
rec := httptest.NewRecorder()
data := []byte{0x1, 0x2, 0x3, 0x4, 0x5}
Data(rec, http.StatusOK, "", data)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentType := res.Header.Get("Content-Type"); contentType != "application/octet-stream" {
t.Errorf("expected content type %s, got %s", "application/octet-stream", contentType)
}
body, _ := io.ReadAll(res.Body)
if !bytes.Equal(body, data) {
t.Errorf("expected body %s, got %s", data, body)
}
}
func TestReader(t *testing.T) {
rec := httptest.NewRecorder()
data := []byte("Hello, World!")
reader := bytes.NewBufferString(string(data))
Reader(rec, http.StatusOK, "text/plain", len(data), reader)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" {
t.Errorf("expected content type %s, got %s", "text/plain", contentType)
}
if contentLength := res.Header.Get("Content-Length"); contentLength != strconv.Itoa(len(data)) {
t.Errorf("expected content length %d, got %s", len(data), contentLength)
}
body, _ := io.ReadAll(res.Body)
if string(body) != "Hello, World!" {
t.Errorf("expected body %s, got %s", "Hello, World!", string(body))
}
}
func TestReader_unknownLength(t *testing.T) {
rec := httptest.NewRecorder()
data := bytes.NewBufferString("Hello, World!")
Reader(rec, http.StatusOK, "text/plain", 0, data)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" {
t.Errorf("expected content type %s, got %s", "text/plain", contentType)
}
if contentLength := res.Header.Get("Content-Length"); contentLength != "" {
t.Errorf("expected no content length, got %s", contentLength)
}
body, _ := io.ReadAll(res.Body)
if string(body) != "Hello, World!" {
t.Errorf("expected body %s, got %s", "Hello, World!", string(body))
}
}
func TestAttachment(t *testing.T) {
rec := httptest.NewRecorder()
data := []byte("Hello, World!")
Attachment(rec, http.StatusOK, "example.txt", "text/plain", data)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentDisposition := res.Header.Get("Content-Disposition"); contentDisposition != "attachment; filename=example.txt" {
t.Errorf("expected content disposition %s, got %s", "attachment; filename=example.txt", contentDisposition)
}
body, _ := io.ReadAll(res.Body)
if !bytes.Equal(body, data) {
t.Errorf("expected body %s, got %s", data, body)
}
}
func TestAttachmentReader(t *testing.T) {
rec := httptest.NewRecorder()
data := bytes.NewBufferString("Hello, World!")
AttachmentReader(rec, http.StatusOK, "example.txt", "text/plain", data.Len(), data)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentDisposition := res.Header.Get("Content-Disposition"); contentDisposition != "attachment; filename=example.txt" {
t.Errorf("expected content disposition %s, got %s", "attachment; filename=example.txt", contentDisposition)
}
body, _ := io.ReadAll(res.Body)
if string(body) != "Hello, World!" {
t.Errorf("expected body %s, got %s", "Hello, World!", string(body))
}
}
func TestRedirect(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/old", nil)
url := "http://example.com/new"
Redirect(rec, req, http.StatusMovedPermanently, url)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusMovedPermanently {
t.Errorf("expected status %d, got %d", http.StatusMovedPermanently, res.StatusCode)
}
if location := res.Header.Get("Location"); location != url {
t.Errorf("expected location %s, got %s", url, location)
}
}
func TestRedirect_relative(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/old/dir", nil)
url := "newlocation/sub"
want := "/old/newlocation/sub"
Redirect(rec, req, http.StatusMovedPermanently, url)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusMovedPermanently {
t.Errorf("expected status %d, got %d", http.StatusMovedPermanently, res.StatusCode)
}
if location := res.Header.Get("Location"); location != want {
t.Errorf("expected location %s, got %s", want, location)
}
}

View File

@@ -0,0 +1,46 @@
package respond
import (
"fmt"
"io"
"net/http"
)
// TplData is a map of template data. This is a convenience type for passing data to templates.
type TplData map[string]any
// TemplateInstance is an interface that wraps the ExecuteTemplate method.
// It is implemented by the html/template and text/template packages.
type TemplateInstance interface {
// ExecuteTemplate executes a template with the given name and data.
ExecuteTemplate(wr io.Writer, name string, data any) error
}
// TemplateRenderer is a renderer that uses a template instance to render HTML or Text templates.
type TemplateRenderer struct {
t TemplateInstance
}
// NewTemplateRenderer creates a new HTML or Text template renderer with the given template instance.
func NewTemplateRenderer(t TemplateInstance) *TemplateRenderer {
return &TemplateRenderer{t: t}
}
// Render renders a template with the given name and data.
// If rendering fails, it will panic with an error.
func (r *TemplateRenderer) Render(w http.ResponseWriter, code int, name, contentType string, data any) {
w.Header().Set("Content-Type", contentType)
w.WriteHeader(code)
err := r.t.ExecuteTemplate(w, name, data)
if err != nil {
panic(fmt.Errorf("error rendering template %s: %v", name, err))
}
}
// HTML renders a template with the given name and data. It is a convenience method for Render.
// The content type is set to "text/html" and the encoding to "utf-8".
// If rendering fails, it will panic with an error.
func (r *TemplateRenderer) HTML(w http.ResponseWriter, code int, name string, data any) {
r.Render(w, code, name, "text/html;charset=utf-8", data)
}

View File

@@ -0,0 +1,67 @@
package respond
import (
"html/template"
"io"
"net/http"
"net/http/httptest"
"testing"
)
type mockTemplate struct {
tmpl *template.Template
}
func (m *mockTemplate) ExecuteTemplate(wr io.Writer, name string, data any) error {
return m.tmpl.ExecuteTemplate(wr, name, data)
}
func TestTemplateRenderer_Render(t *testing.T) {
tmpl := template.Must(template.New("test").Parse(`{{define "test"}}Hello, {{.}}!{{end}}`))
renderer := NewTemplateRenderer(&mockTemplate{tmpl: tmpl})
rec := httptest.NewRecorder()
renderer.Render(rec, http.StatusOK, "test", "text/plain", "World")
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" {
t.Errorf("expected content type %s, got %s", "text/plain", contentType)
}
body, _ := io.ReadAll(res.Body)
expectedBody := "Hello, World!"
if string(body) != expectedBody {
t.Errorf("expected body %s, got %s", expectedBody, string(body))
}
}
func TestTemplateRenderer_HTML(t *testing.T) {
tmpl := template.Must(template.New("test").Parse(`{{define "test"}}<p>Hello, {{.}}!</p>{{end}}`))
renderer := NewTemplateRenderer(&mockTemplate{tmpl: tmpl})
rec := httptest.NewRecorder()
renderer.HTML(rec, http.StatusOK, "test", "World")
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentType := res.Header.Get("Content-Type"); contentType != "text/html;charset=utf-8" {
t.Errorf("expected content type %s, got %s", "text/html;charset=utf-8", contentType)
}
body, _ := io.ReadAll(res.Body)
expectedBody := "<p>Hello, World!</p>"
if string(body) != expectedBody {
t.Errorf("expected body %s, got %s", expectedBody, string(body))
}
}

View File

@@ -2,27 +2,25 @@ package core
import (
"context"
"encoding/base64"
"fmt"
"html/template"
"io"
"io/fs"
"log/slog"
"math/rand"
"net/http"
"os"
"time"
"github.com/gin-gonic/gin"
"github.com/go-pkgz/routegroup"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/app/api/core/middleware/cors"
"github.com/h44z/wg-portal/internal/app/api/core/middleware/logging"
"github.com/h44z/wg-portal/internal/app/api/core/middleware/recovery"
"github.com/h44z/wg-portal/internal/app/api/core/middleware/tracing"
"github.com/h44z/wg-portal/internal/app/api/core/respond"
"github.com/h44z/wg-portal/internal/config"
)
var (
random = rand.New(rand.NewSource(time.Now().UTC().UnixNano()))
)
const (
RequestIDKey = "X-Request-ID"
)
@@ -30,19 +28,21 @@ const (
type ApiVersion string
type HandlerName string
type GroupSetupFn func(group *gin.RouterGroup)
type GroupSetupFn func(group *routegroup.Bundle)
type ApiEndpointSetupFunc func() (ApiVersion, GroupSetupFn)
type Server struct {
cfg *config.Config
server *gin.Engine
versions map[ApiVersion]*gin.RouterGroup
server *routegroup.Bundle
tpl *respond.TemplateRenderer
versions map[ApiVersion]*routegroup.Bundle
}
func NewServer(cfg *config.Config, endpoints ...ApiEndpointSetupFunc) (*Server, error) {
s := &Server{
cfg: cfg,
cfg: cfg,
server: routegroup.New(http.NewServeMux()),
}
hostname, err := os.Hostname()
@@ -51,69 +51,39 @@ func NewServer(cfg *config.Config, endpoints ...ApiEndpointSetupFunc) (*Server,
}
hostname += ", version " + internal.Version
// Setup http server
gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard
s.server = gin.New()
s.server.Use(recovery.New().Handler)
if cfg.Web.RequestLogging {
if cfg.Advanced.LogLevel == "trace" {
gin.SetMode(gin.DebugMode)
}
s.server.Use(func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
raw := c.Request.URL.RawQuery
s.server.Use(logging.New(logging.WithLevel(logging.LogLevelDebug)).Handler)
c.Next()
if raw != "" {
path = path + "?" + raw
}
latency := time.Since(start)
status := c.Writer.Status()
clientIP := c.ClientIP()
method := c.Request.Method
errorMsg := c.Errors.ByType(gin.ErrorTypePrivate).String()
slog.Debug("HTTP Request",
"status", status,
"latency", latency,
"client", clientIP,
"method", method,
"path", path,
"error", errorMsg,
)
}
s.server.Use(cors.New().Handler)
s.server.Use(tracing.New(
tracing.WithContextIdentifier(RequestIDKey),
tracing.WithHeaderIdentifier(RequestIDKey),
).Handler)
if cfg.Web.ExposeHostInfo {
s.server.Use(func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Served-By", hostname)
handler.ServeHTTP(w, r)
})
})
}
s.server.Use(gin.Recovery()).Use(func(c *gin.Context) {
c.Writer.Header().Set("X-Served-By", hostname)
c.Next()
}).Use(func(c *gin.Context) {
xRequestID := uuid(16)
c.Request.Header.Set(RequestIDKey, xRequestID)
c.Set(RequestIDKey, xRequestID)
c.Next()
})
// Setup templates
templates := template.Must(template.New("").Funcs(s.server.FuncMap).ParseFS(apiTemplates, "assets/tpl/*.gohtml"))
s.server.SetHTMLTemplate(templates)
s.tpl = respond.NewTemplateRenderer(
template.Must(template.New("").ParseFS(apiTemplates, "assets/tpl/*.gohtml")),
)
// Serve static files
imgFs := http.FS(fsMust(fs.Sub(apiStatics, "assets/img")))
s.server.StaticFS("/css", http.FS(fsMust(fs.Sub(apiStatics, "assets/css"))))
s.server.StaticFS("/js", http.FS(fsMust(fs.Sub(apiStatics, "assets/js"))))
s.server.StaticFS("/img", imgFs)
s.server.StaticFS("/fonts", http.FS(fsMust(fs.Sub(apiStatics, "assets/fonts"))))
s.server.StaticFS("/doc", http.FS(fsMust(fs.Sub(apiStatics, "assets/doc"))))
s.server.HandleFiles("/css", http.FS(fsMust(fs.Sub(apiStatics, "assets/css"))))
s.server.HandleFiles("/js", http.FS(fsMust(fs.Sub(apiStatics, "assets/js"))))
s.server.HandleFiles("/img", imgFs)
s.server.HandleFiles("/fonts", http.FS(fsMust(fs.Sub(apiStatics, "assets/fonts"))))
s.server.HandleFiles("/doc", http.FS(fsMust(fs.Sub(apiStatics, "assets/doc"))))
// Setup routes
s.server.UseRawPath = true
s.server.UnescapePathValues = true
s.setupRoutes(endpoints...)
s.setupFrontendRoutes()
@@ -136,9 +106,7 @@ func (s *Server) Run(ctx context.Context, listenAddress string) {
err = srv.ListenAndServe()
}
if err != nil {
slog.Info("web service exited",
"address", listenAddress,
"error", err)
slog.Info("web service exited", "address", listenAddress, "error", err)
cancelFn()
}
}()
@@ -157,18 +125,18 @@ func (s *Server) Run(ctx context.Context, listenAddress string) {
}
func (s *Server) setupRoutes(endpoints ...ApiEndpointSetupFunc) {
s.server.GET("/api", s.landingPage)
s.versions = make(map[ApiVersion]*gin.RouterGroup)
s.server.HandleFunc("GET /api", s.landingPage)
s.versions = make(map[ApiVersion]*routegroup.Bundle)
for _, setupFunc := range endpoints {
version, groupSetupFn := setupFunc()
if _, ok := s.versions[version]; !ok {
s.versions[version] = s.server.Group(fmt.Sprintf("/api/%s", version))
s.versions[version] = s.server.Mount(fmt.Sprintf("/api/%s", version))
// OpenAPI documentation (via RapiDoc)
s.versions[version].GET("/swagger/index.html", s.rapiDocHandler(version)) // Deprecated: old link
s.versions[version].GET("/doc.html", s.rapiDocHandler(version))
s.versions[version].HandleFunc("GET /swagger/index.html", s.rapiDocHandler(version)) // Deprecated: old link
s.versions[version].HandleFunc("GET /doc.html", s.rapiDocHandler(version))
groupSetupFn(s.versions[version])
}
@@ -177,25 +145,27 @@ func (s *Server) setupRoutes(endpoints ...ApiEndpointSetupFunc) {
func (s *Server) setupFrontendRoutes() {
// Serve static files
s.server.GET("/", func(c *gin.Context) {
c.Redirect(http.StatusMovedPermanently, "/app")
s.server.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
respond.Redirect(w, r, http.StatusMovedPermanently, "/app")
})
s.server.GET("/favicon.ico", func(c *gin.Context) {
c.Redirect(http.StatusMovedPermanently, "/app/favicon.ico")
s.server.HandleFunc("/favicon.ico", func(w http.ResponseWriter, r *http.Request) {
respond.Redirect(w, r, http.StatusMovedPermanently, "/app/favicon.ico")
})
s.server.StaticFS("/app", http.FS(fsMust(fs.Sub(frontendStatics, "frontend-dist"))))
s.server.HandleFiles("/app", http.FS(fsMust(fs.Sub(frontendStatics, "frontend-dist"))))
}
func (s *Server) landingPage(c *gin.Context) {
c.HTML(http.StatusOK, "index.gohtml", gin.H{
func (s *Server) landingPage(w http.ResponseWriter, _ *http.Request) {
s.tpl.HTML(w, http.StatusOK, "index.gohtml", respond.TplData{
"Version": internal.Version,
"Year": time.Now().Year(),
})
}
func (s *Server) rapiDocHandler(version ApiVersion) gin.HandlerFunc {
return func(c *gin.Context) {
c.HTML(http.StatusOK, "rapidoc.gohtml", gin.H{
func (s *Server) rapiDocHandler(version ApiVersion) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
s.tpl.HTML(w, http.StatusOK, "rapidoc.gohtml", respond.TplData{
"RapiDocSource": "/js/rapidoc-min.js",
"ApiSpecUrl": fmt.Sprintf("/doc/%s_swagger.yaml", version),
"Version": internal.Version,
@@ -210,9 +180,3 @@ func fsMust(f fs.FS, err error) fs.FS {
}
return f
}
func uuid(len int) string {
bytes := make([]byte, len)
random.Read(bytes)
return base64.StdEncoding.EncodeToString(bytes)[:len]
}