mirror of
https://github.com/h44z/wg-portal.git
synced 2025-09-14 06:51:15 +00:00
chore: replace gin with standard lib net/http
This commit is contained in:
214
internal/app/api/core/middleware/cors/middleware.go
Normal file
214
internal/app/api/core/middleware/cors/middleware.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package cors
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Middleware is a type that creates a new CORS middleware. The CORS middleware
|
||||
// adds Cross-Origin Resource Sharing headers to the response. This middleware should
|
||||
// be used to allow cross-origin requests to your server.
|
||||
type Middleware struct {
|
||||
o options
|
||||
|
||||
varyHeaders string // precomputed Vary header
|
||||
allOrigins bool // all origins are allowed
|
||||
}
|
||||
|
||||
// New returns a new CORS middleware with the provided options.
|
||||
func New(opts ...Option) *Middleware {
|
||||
o := newOptions(opts...)
|
||||
|
||||
m := &Middleware{
|
||||
o: o,
|
||||
}
|
||||
|
||||
// set vary headers
|
||||
if m.o.allowPrivateNetworks {
|
||||
m.varyHeaders = "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network"
|
||||
} else {
|
||||
m.varyHeaders = "Origin, Access-Control-Request-Method, Access-Control-Request-Headers"
|
||||
}
|
||||
|
||||
if len(m.o.allowedOrigins) == 1 && m.o.allowedOrigins[0] == "*" {
|
||||
m.allOrigins = true
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Handler returns the CORS middleware handler.
|
||||
func (m *Middleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Handle preflight requests and stop the chain as some other
|
||||
// middleware may not handle OPTIONS requests correctly.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#preflighted_requests
|
||||
if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
|
||||
m.handlePreflight(w, r)
|
||||
w.WriteHeader(http.StatusNoContent) // always return 204 No Content
|
||||
return
|
||||
}
|
||||
|
||||
// handle normal CORS requests
|
||||
m.handleNormal(w, r)
|
||||
next.ServeHTTP(w, r) // execute the next handler
|
||||
})
|
||||
}
|
||||
|
||||
// region internal-helpers
|
||||
|
||||
// handlePreflight handles preflight requests. If the request was successful, this function will
|
||||
// write the CORS headers and return. If the request was not successful, this function will
|
||||
// not add any CORS headers and return - thus the CORS request is considered invalid.
|
||||
func (m *Middleware) handlePreflight(w http.ResponseWriter, r *http.Request) {
|
||||
// Always set Vary headers
|
||||
// see https://github.com/rs/cors/issues/10,
|
||||
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
|
||||
w.Header().Add("Vary", m.varyHeaders)
|
||||
|
||||
// check origin
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return // not a valid CORS request
|
||||
}
|
||||
|
||||
if !m.originAllowed(origin) {
|
||||
return
|
||||
}
|
||||
|
||||
// check method
|
||||
reqMethod := r.Header.Get("Access-Control-Request-Method")
|
||||
if !m.methodAllowed(reqMethod) {
|
||||
return
|
||||
}
|
||||
|
||||
// check headers
|
||||
reqHeaders := r.Header.Get("Access-Control-Request-Headers")
|
||||
if !m.headersAllowed(reqHeaders) {
|
||||
return
|
||||
}
|
||||
|
||||
// set CORS headers for the successful preflight request
|
||||
if m.allOrigins {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
} else {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin) // return original origin
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Methods", reqMethod)
|
||||
if reqHeaders != "" {
|
||||
// Spec says: Since the list of headers can be unbounded, simply returning supported headers
|
||||
// from Access-Control-Request-Headers can be enough
|
||||
w.Header().Set("Access-Control-Allow-Headers", reqHeaders)
|
||||
}
|
||||
if m.o.allowCredentials {
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
if m.o.allowPrivateNetworks && r.Header.Get("Access-Control-Request-Private-Network") == "true" {
|
||||
w.Header().Set("Access-Control-Allow-Private-Network", "true")
|
||||
}
|
||||
if m.o.maxAge > 0 {
|
||||
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(m.o.maxAge))
|
||||
}
|
||||
}
|
||||
|
||||
// handleNormal handles normal CORS requests. If the request was successful, this function will
|
||||
// write the CORS headers to the response. If the request was not successful, this function will
|
||||
// not add any CORS headers to the response. In this case, the CORS request is considered invalid.
|
||||
func (m *Middleware) handleNormal(w http.ResponseWriter, r *http.Request) {
|
||||
// Always set Vary headers
|
||||
// see https://github.com/rs/cors/issues/10,
|
||||
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
|
||||
w.Header().Add("Vary", "Origin")
|
||||
|
||||
// check origin
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return // not a valid CORS request
|
||||
}
|
||||
|
||||
if !m.originAllowed(origin) {
|
||||
return
|
||||
}
|
||||
|
||||
// check method
|
||||
if !m.methodAllowed(r.Method) {
|
||||
return
|
||||
}
|
||||
|
||||
// set CORS headers for the successful CORS request
|
||||
if m.allOrigins {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
} else {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin) // return original origin
|
||||
}
|
||||
if len(m.o.exposedHeaders) > 0 {
|
||||
w.Header().Set("Access-Control-Expose-Headers", strings.Join(m.o.exposedHeaders, ", "))
|
||||
}
|
||||
if m.o.allowCredentials {
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) originAllowed(origin string) bool {
|
||||
if len(m.o.allowedOrigins) == 1 && m.o.allowedOrigins[0] == "*" {
|
||||
return true // everything is allowed
|
||||
}
|
||||
|
||||
// check simple origins
|
||||
if slices.Contains(m.o.allowedOrigins, origin) {
|
||||
return true
|
||||
}
|
||||
|
||||
// check wildcard origins
|
||||
for _, allowedOrigin := range m.o.allowedOriginPatterns {
|
||||
if allowedOrigin.match(origin) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Middleware) methodAllowed(method string) bool {
|
||||
if method == http.MethodOptions {
|
||||
return true // preflight request is always allowed
|
||||
}
|
||||
|
||||
if len(m.o.allowedMethods) == 1 && m.o.allowedMethods[0] == "*" {
|
||||
return true // everything is allowed
|
||||
}
|
||||
|
||||
if slices.Contains(m.o.allowedMethods, method) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Middleware) headersAllowed(headers string) bool {
|
||||
if headers == "" {
|
||||
return true // no headers are requested
|
||||
}
|
||||
|
||||
if len(m.o.allowedHeaders) == 0 {
|
||||
return false // no headers are allowed
|
||||
}
|
||||
|
||||
if _, ok := m.o.allowedHeaders["*"]; ok {
|
||||
return true // everything is allowed
|
||||
}
|
||||
|
||||
// split headers by comma (according to definition, the headers are sorted and in lowercase)
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Request-Headers
|
||||
for header := range strings.SplitSeq(headers, ",") {
|
||||
if _, ok := m.o.allowedHeaders[strings.TrimSpace(header)]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// endregion internal-helpers
|
101
internal/app/api/core/middleware/cors/middleware_test.go
Normal file
101
internal/app/api/core/middleware/cors/middleware_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package cors
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMiddleware_New(t *testing.T) {
|
||||
m := New(WithAllowedOrigins("*"))
|
||||
|
||||
if len(m.varyHeaders) == 0 {
|
||||
t.Errorf("expected vary headers to be populated, got %v", m.varyHeaders)
|
||||
}
|
||||
if !m.allOrigins {
|
||||
t.Errorf("expected allOrigins to be true, got %v", m.allOrigins)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_normal(t *testing.T) {
|
||||
m := New(WithAllowedOrigins("http://example.com"))
|
||||
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Result().StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status code 200, got %d", w.Result().StatusCode)
|
||||
}
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
|
||||
t.Errorf("expected Access-Control-Allow-Origin to be 'http://example.com', got %s",
|
||||
w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_preflight(t *testing.T) {
|
||||
m := New(WithAllowedOrigins("http://example.com"))
|
||||
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "http://example.com", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
req.Header.Set("Access-Control-Request-Method", http.MethodGet)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Result().StatusCode != http.StatusNoContent {
|
||||
t.Errorf("expected status code 204, got %d", w.Result().StatusCode)
|
||||
}
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
|
||||
t.Errorf("expected Access-Control-Allow-Origin to be 'http://example.com', got %s",
|
||||
w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_originAllowed(t *testing.T) {
|
||||
m := New(WithAllowedOrigins("http://example.com"))
|
||||
|
||||
if !m.originAllowed("http://example.com") {
|
||||
t.Errorf("expected origin 'http://example.com' to be allowed")
|
||||
}
|
||||
|
||||
if m.originAllowed("http://notallowed.com") {
|
||||
t.Errorf("expected origin 'http://notallowed.com' to be not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_methodAllowed(t *testing.T) {
|
||||
m := New(WithAllowedMethods(http.MethodGet, http.MethodPost))
|
||||
|
||||
if !m.methodAllowed(http.MethodGet) {
|
||||
t.Errorf("expected method 'GET' to be allowed")
|
||||
}
|
||||
|
||||
if m.methodAllowed(http.MethodDelete) {
|
||||
t.Errorf("expected method 'DELETE' to be not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_headersAllowed(t *testing.T) {
|
||||
m := New(WithAllowedHeaders("Content-Type", "Authorization"))
|
||||
|
||||
if !m.headersAllowed("content-type, authorization") {
|
||||
t.Errorf("expected headers 'Content-Type, Authorization' to be allowed")
|
||||
}
|
||||
|
||||
if m.headersAllowed("x-custom-header") {
|
||||
t.Errorf("expected header 'X-Custom-Header' to be not allowed")
|
||||
}
|
||||
}
|
133
internal/app/api/core/middleware/cors/options.go
Normal file
133
internal/app/api/core/middleware/cors/options.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package cors
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type void struct{}
|
||||
|
||||
// options is a struct that contains options for the CORS middleware.
|
||||
// It uses the functional options pattern for flexible configuration.
|
||||
type options struct {
|
||||
allowedOrigins []string // origins without wildcards
|
||||
allowedOriginPatterns []wildcard // origins with wildcards
|
||||
allowedMethods []string
|
||||
allowedHeaders map[string]void
|
||||
exposedHeaders []string // these are in addition to the CORS-safelisted response headers
|
||||
allowCredentials bool
|
||||
allowPrivateNetworks bool
|
||||
maxAge int
|
||||
}
|
||||
|
||||
// Option is a type that is used to set options for the CORS middleware.
|
||||
// It implements the functional options pattern.
|
||||
type Option func(*options)
|
||||
|
||||
// WithAllowedOrigins sets the allowed origins for the CORS middleware.
|
||||
// If the special "*" value is present in the list, all origins will be allowed.
|
||||
// An origin may contain a wildcard (*) to replace 0 or more characters
|
||||
// (i.e.: http://*.domain.com). Usage of wildcards implies a small performance penalty.
|
||||
// Only one wildcard can be used per origin.
|
||||
// By default, all origins are allowed (*).
|
||||
func WithAllowedOrigins(origins ...string) Option {
|
||||
return func(o *options) {
|
||||
o.allowedOrigins = nil
|
||||
o.allowedOriginPatterns = nil
|
||||
|
||||
for _, origin := range origins {
|
||||
if len(origin) > 1 && strings.Contains(origin, "*") {
|
||||
o.allowedOriginPatterns = append(
|
||||
o.allowedOriginPatterns,
|
||||
newWildcard(origin),
|
||||
)
|
||||
} else {
|
||||
o.allowedOrigins = append(o.allowedOrigins, origin)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithAllowedMethods sets the allowed methods for the CORS middleware.
|
||||
// By default, all methods are allowed (*).
|
||||
func WithAllowedMethods(methods ...string) Option {
|
||||
return func(o *options) {
|
||||
o.allowedMethods = methods
|
||||
}
|
||||
}
|
||||
|
||||
// WithAllowedHeaders sets the allowed headers for the CORS middleware.
|
||||
// By default, all headers are allowed (*).
|
||||
func WithAllowedHeaders(headers ...string) Option {
|
||||
return func(o *options) {
|
||||
o.allowedHeaders = make(map[string]void)
|
||||
|
||||
for _, header := range headers {
|
||||
// allowed headers are always checked in lowercase
|
||||
o.allowedHeaders[strings.ToLower(header)] = void{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithExposedHeaders sets the exposed headers for the CORS middleware.
|
||||
// By default, no headers are exposed.
|
||||
func WithExposedHeaders(headers ...string) Option {
|
||||
return func(o *options) {
|
||||
o.exposedHeaders = nil
|
||||
|
||||
for _, header := range headers {
|
||||
o.exposedHeaders = append(o.exposedHeaders, http.CanonicalHeaderKey(header))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithAllowCredentials sets the allow credentials option for the CORS middleware.
|
||||
// This setting indicates whether the request can include user credentials like
|
||||
// cookies, HTTP authentication or client side SSL certificates.
|
||||
// By default, credentials are not allowed.
|
||||
func WithAllowCredentials(allow bool) Option {
|
||||
return func(o *options) {
|
||||
o.allowCredentials = allow
|
||||
}
|
||||
}
|
||||
|
||||
// WithAllowPrivateNetworks sets the allow private networks option for the CORS middleware.
|
||||
// This setting indicates whether to accept cross-origin requests over a private network.
|
||||
func WithAllowPrivateNetworks(allow bool) Option {
|
||||
return func(o *options) {
|
||||
o.allowPrivateNetworks = allow
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxAge sets the max age (in seconds) for the CORS middleware.
|
||||
// The maximum age indicates how long (in seconds) the results of a preflight request
|
||||
// can be cached. A value of 0 means that no Access-Control-Max-Age header is sent back,
|
||||
// resulting in browsers using their default value (5s by spec).
|
||||
// If you need to force a 0 max-age, set it to a negative value (ie: -1).
|
||||
// By default, the max age is 7200 seconds.
|
||||
func WithMaxAge(age int) Option {
|
||||
return func(o *options) {
|
||||
o.maxAge = age
|
||||
}
|
||||
}
|
||||
|
||||
// newOptions is a function that returns a new options struct with sane default values.
|
||||
func newOptions(opts ...Option) options {
|
||||
o := options{
|
||||
allowedOrigins: []string{"*"},
|
||||
allowedMethods: []string{
|
||||
http.MethodHead, http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete,
|
||||
},
|
||||
allowedHeaders: map[string]void{"*": {}},
|
||||
exposedHeaders: nil,
|
||||
allowCredentials: false,
|
||||
allowPrivateNetworks: false,
|
||||
maxAge: 0,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&o)
|
||||
}
|
||||
|
||||
return o
|
||||
}
|
96
internal/app/api/core/middleware/cors/options_test.go
Normal file
96
internal/app/api/core/middleware/cors/options_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package cors
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"net/http"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWithAllowedOrigins(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
origins []string
|
||||
wantNormal []string
|
||||
wantWildcard []wildcard
|
||||
}{
|
||||
{
|
||||
name: "No origins",
|
||||
origins: []string{},
|
||||
wantNormal: nil,
|
||||
wantWildcard: nil,
|
||||
},
|
||||
{
|
||||
name: "Single origin",
|
||||
origins: []string{"http://example.com"},
|
||||
wantNormal: []string{"http://example.com"},
|
||||
wantWildcard: nil,
|
||||
},
|
||||
{
|
||||
name: "Wildcard origin",
|
||||
origins: []string{"http://*.example.com"},
|
||||
wantNormal: nil,
|
||||
wantWildcard: []wildcard{newWildcard("http://*.example.com")},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
o := newOptions(WithAllowedOrigins(tt.origins...))
|
||||
if !slices.Equal(o.allowedOrigins, tt.wantNormal) {
|
||||
t.Errorf("got %v, want %v", o, tt.wantNormal)
|
||||
}
|
||||
if !slices.Equal(o.allowedOriginPatterns, tt.wantWildcard) {
|
||||
t.Errorf("got %v, want %v", o, tt.wantWildcard)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithAllowedMethods(t *testing.T) {
|
||||
methods := []string{http.MethodGet, http.MethodPost}
|
||||
o := newOptions(WithAllowedMethods(methods...))
|
||||
if !slices.Equal(o.allowedMethods, methods) {
|
||||
t.Errorf("got %v, want %v", o.allowedMethods, methods)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithAllowedHeaders(t *testing.T) {
|
||||
headers := []string{"Content-Type", "Authorization"}
|
||||
o := newOptions(WithAllowedHeaders(headers...))
|
||||
expectedHeaders := map[string]void{"content-type": {}, "authorization": {}}
|
||||
if !maps.Equal(o.allowedHeaders, expectedHeaders) {
|
||||
t.Errorf("got %v, want %v", o.allowedHeaders, expectedHeaders)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithExposedHeaders(t *testing.T) {
|
||||
headers := []string{"X-Custom-Header"}
|
||||
o := newOptions(WithExposedHeaders(headers...))
|
||||
expectedHeaders := []string{http.CanonicalHeaderKey("X-Custom-Header")}
|
||||
if !slices.Equal(o.exposedHeaders, expectedHeaders) {
|
||||
t.Errorf("got %v, want %v", o.exposedHeaders, expectedHeaders)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithAllowCredentials(t *testing.T) {
|
||||
o := newOptions(WithAllowCredentials(true))
|
||||
if !o.allowCredentials {
|
||||
t.Errorf("got %v, want %v", o.allowCredentials, true)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithAllowPrivateNetworks(t *testing.T) {
|
||||
o := newOptions(WithAllowPrivateNetworks(true))
|
||||
if !o.allowPrivateNetworks {
|
||||
t.Errorf("got %v, want %v", o.allowPrivateNetworks, true)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithMaxAge(t *testing.T) {
|
||||
maxAge := 3600
|
||||
o := newOptions(WithMaxAge(maxAge))
|
||||
if o.maxAge != maxAge {
|
||||
t.Errorf("got %v, want %v", o.maxAge, maxAge)
|
||||
}
|
||||
}
|
33
internal/app/api/core/middleware/cors/wildcard.go
Normal file
33
internal/app/api/core/middleware/cors/wildcard.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package cors
|
||||
|
||||
import "strings"
|
||||
|
||||
// wildcard is a type that represents a wildcard string.
|
||||
// This type allows faster matching of strings with a wildcard
|
||||
// in comparison to using regex.
|
||||
type wildcard struct {
|
||||
prefix string
|
||||
suffix string
|
||||
}
|
||||
|
||||
// match returns true if the string s has the prefix and suffix of the wildcard.
|
||||
func (w wildcard) match(s string) bool {
|
||||
return len(s) >= len(w.prefix)+len(w.suffix) &&
|
||||
strings.HasPrefix(s, w.prefix) &&
|
||||
strings.HasSuffix(s, w.suffix)
|
||||
}
|
||||
|
||||
func newWildcard(s string) wildcard {
|
||||
if i := strings.IndexByte(s, '*'); i >= 0 {
|
||||
return wildcard{
|
||||
prefix: s[:i],
|
||||
suffix: s[i+1:],
|
||||
}
|
||||
}
|
||||
|
||||
// fallback, usually this case should not happen
|
||||
return wildcard{
|
||||
prefix: s,
|
||||
suffix: "",
|
||||
}
|
||||
}
|
94
internal/app/api/core/middleware/cors/wildcard_test.go
Normal file
94
internal/app/api/core/middleware/cors/wildcard_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package cors
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestWildcardMatch(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wildcard wildcard
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Match with prefix and suffix",
|
||||
wildcard: newWildcard("http://*.example.com"),
|
||||
input: "http://sub.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "No match with different prefix",
|
||||
wildcard: newWildcard("http://*.example.com"),
|
||||
input: "https://sub.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "No match with different suffix",
|
||||
wildcard: newWildcard("http://*.example.com"),
|
||||
input: "http://sub.example.org",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Match with empty suffix",
|
||||
wildcard: newWildcard("http://*"),
|
||||
input: "http://example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Match with empty prefix",
|
||||
wildcard: newWildcard("*.example.com"),
|
||||
input: "sub.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "No match with empty prefix and different suffix",
|
||||
wildcard: newWildcard("*.example.com"),
|
||||
input: "sub.example.org",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.wildcard.match(tt.input); got != tt.expected {
|
||||
t.Errorf("wildcard.match(%s) = %v, want %v", tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWildcard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected wildcard
|
||||
}{
|
||||
{
|
||||
name: "Wildcard with prefix and suffix",
|
||||
input: "http://*.example.com",
|
||||
expected: wildcard{prefix: "http://", suffix: ".example.com"},
|
||||
},
|
||||
{
|
||||
name: "Wildcard with empty suffix",
|
||||
input: "http://*",
|
||||
expected: wildcard{prefix: "http://", suffix: ""},
|
||||
},
|
||||
{
|
||||
name: "Wildcard with empty prefix",
|
||||
input: "*.example.com",
|
||||
expected: wildcard{prefix: "", suffix: ".example.com"},
|
||||
},
|
||||
{
|
||||
name: "No wildcard character",
|
||||
input: "http://example.com",
|
||||
expected: wildcard{prefix: "http://example.com", suffix: ""},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := newWildcard(tt.input); got != tt.expected {
|
||||
t.Errorf("newWildcard(%s) = %v, want %v", tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
137
internal/app/api/core/middleware/csrf/middleware.go
Normal file
137
internal/app/api/core/middleware/csrf/middleware.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package csrf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"slices"
|
||||
)
|
||||
|
||||
// ContextValueIdentifier is the context value identifier for the CSRF token.
|
||||
// The token is only stored in the context if the RefreshToken function was called before.
|
||||
const ContextValueIdentifier = "_csrf_token"
|
||||
|
||||
// Middleware is a type that creates a new CSRF middleware. The CSRF middleware
|
||||
// can be used to mitigate Cross-Site Request Forgery attacks.
|
||||
type Middleware struct {
|
||||
o options
|
||||
}
|
||||
|
||||
// New returns a new CSRF middleware with the provided options.
|
||||
func New(sessionReader SessionReader, sessionWriter SessionWriter, opts ...Option) *Middleware {
|
||||
opts = append(opts, withSessionReader(sessionReader), withSessionWriter(sessionWriter))
|
||||
o := newOptions(opts...)
|
||||
|
||||
m := &Middleware{
|
||||
o: o,
|
||||
}
|
||||
|
||||
checkForPRNG()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Handler returns the CSRF middleware handler. This middleware validates the CSRF token and calls the specified
|
||||
// error handler if an invalid CSRF token was found.
|
||||
func (m *Middleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if slices.Contains(m.o.ignoreMethods, r.Method) {
|
||||
next.ServeHTTP(w, r) // skip CSRF check for ignored methods
|
||||
return
|
||||
}
|
||||
|
||||
// get the token from the request
|
||||
token := m.o.tokenGetter(r)
|
||||
storedToken := m.o.sessionGetter(r)
|
||||
|
||||
if !tokenEqual(token, storedToken) {
|
||||
m.o.errCallback(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r) // execute the next handler
|
||||
})
|
||||
}
|
||||
|
||||
// RefreshToken generates a new CSRF Token and stores it in the session. The token is also passed to subsequent handlers
|
||||
// via the context value ContextValueIdentifier.
|
||||
func (m *Middleware) RefreshToken(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if GetToken(r.Context()) != "" {
|
||||
// token already generated higher up in the chain
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// generate a new token
|
||||
token := generateToken(m.o.tokenLength)
|
||||
key := generateToken(m.o.tokenLength)
|
||||
|
||||
// mask the token
|
||||
maskedToken := maskToken(token, key)
|
||||
|
||||
// store the encoded token in the session
|
||||
encodedToken := encodeToken(maskedToken)
|
||||
m.o.sessionWriter(r, encodedToken)
|
||||
|
||||
// pass the token down the chain via the context
|
||||
r = r.WithContext(setToken(r.Context(), encodedToken))
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// region token-access
|
||||
|
||||
// GetToken retrieves the CSRF token from the given context. Ensure that the RefreshToken function was called before,
|
||||
// otherwise, no token is populated in the context.
|
||||
func GetToken(ctx context.Context) string {
|
||||
token, ok := ctx.Value(ContextValueIdentifier).(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
// endregion token-access
|
||||
|
||||
// region internal-helpers
|
||||
|
||||
func setToken(ctx context.Context, token string) context.Context {
|
||||
return context.WithValue(ctx, ContextValueIdentifier, token)
|
||||
}
|
||||
|
||||
// defaultTokenGetter is the default token getter function for the CSRF middleware.
|
||||
// It checks the request form values, URL query parameters, and headers for the CSRF token.
|
||||
// The order of precedence is:
|
||||
// 1. Header "X-CSRF-TOKEN"
|
||||
// 2. Header "X-XSRF-TOKEN"
|
||||
// 3. URL query parameter "_csrf"
|
||||
// 4. Form value "_csrf"
|
||||
func defaultTokenGetter(r *http.Request) string {
|
||||
if t := r.Header.Get("X-CSRF-TOKEN"); len(t) > 0 {
|
||||
return t
|
||||
}
|
||||
|
||||
if t := r.Header.Get("X-XSRF-TOKEN"); len(t) > 0 {
|
||||
return t
|
||||
}
|
||||
|
||||
if t := r.URL.Query().Get("_csrf"); len(t) > 0 {
|
||||
return t
|
||||
}
|
||||
|
||||
if t := r.FormValue("_csrf"); len(t) > 0 {
|
||||
return t
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// defaultErrorHandler is the default error handler function for the CSRF middleware.
|
||||
// It writes a 403 Forbidden response.
|
||||
func defaultErrorHandler(w http.ResponseWriter, _ *http.Request) {
|
||||
http.Error(w, "CSRF token mismatch", http.StatusForbidden)
|
||||
}
|
||||
|
||||
// endregion internal-helpers
|
251
internal/app/api/core/middleware/csrf/middleware_test.go
Normal file
251
internal/app/api/core/middleware/csrf/middleware_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package csrf
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/request"
|
||||
)
|
||||
|
||||
func TestMiddleware_Handler(t *testing.T) {
|
||||
sessionToken := "stored-token"
|
||||
sessionReader := func(r *http.Request) string {
|
||||
return sessionToken
|
||||
}
|
||||
sessionWriter := func(r *http.Request, token string) {
|
||||
sessionToken = token
|
||||
}
|
||||
m := New(sessionReader, sessionWriter)
|
||||
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{"ValidToken", "POST", "stored-token", http.StatusOK},
|
||||
{"ValidToken2", "PUT", "stored-token", http.StatusOK},
|
||||
{"ValidToken3", "GET", "stored-token", http.StatusOK},
|
||||
{"InvalidToken", "POST", "invalid-token", http.StatusForbidden},
|
||||
{"IgnoredMethod", "GET", "", http.StatusOK},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tt.method, "/", nil)
|
||||
req.Header.Set("X-CSRF-TOKEN", tt.token)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if status := rr.Code; status != tt.wantStatus {
|
||||
t.Errorf("Handler() status = %d, want %d", status, tt.wantStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_RefreshToken(t *testing.T) {
|
||||
sessionToken := ""
|
||||
sessionReader := func(r *http.Request) string {
|
||||
return sessionToken
|
||||
}
|
||||
sessionWriter := func(r *http.Request, token string) {
|
||||
sessionToken = token
|
||||
}
|
||||
m := New(sessionReader, sessionWriter)
|
||||
|
||||
handler := m.RefreshToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := GetToken(r.Context())
|
||||
if token == "" {
|
||||
t.Errorf("RefreshToken() did not set token in context")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("POST", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("RefreshToken() status = %d, want %d", status, http.StatusOK)
|
||||
}
|
||||
|
||||
if sessionToken == "" {
|
||||
t.Errorf("RefreshToken() did not set token in session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_RefreshToken_chained(t *testing.T) {
|
||||
sessionToken := ""
|
||||
tokenWrites := 0
|
||||
sessionReader := func(r *http.Request) string {
|
||||
return sessionToken
|
||||
}
|
||||
sessionWriter := func(r *http.Request, token string) {
|
||||
sessionToken = token
|
||||
tokenWrites++
|
||||
}
|
||||
m := New(sessionReader, sessionWriter)
|
||||
|
||||
handler := m.RefreshToken(m.RefreshToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := GetToken(r.Context())
|
||||
if token == "" {
|
||||
t.Errorf("RefreshToken() did not set token in context")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})))
|
||||
|
||||
req := httptest.NewRequest("POST", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("RefreshToken() status = %d, want %d", status, http.StatusOK)
|
||||
}
|
||||
|
||||
if sessionToken == "" {
|
||||
t.Errorf("RefreshToken() did not set token in session")
|
||||
}
|
||||
|
||||
if tokenWrites != 1 {
|
||||
t.Errorf("RefreshToken() wrote token to session more than once: %d", tokenWrites)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_RefreshToken_Handler(t *testing.T) {
|
||||
sessionToken := ""
|
||||
sessionReader := func(r *http.Request) string {
|
||||
return sessionToken
|
||||
}
|
||||
sessionWriter := func(r *http.Request, token string) {
|
||||
sessionToken = token
|
||||
}
|
||||
m := New(sessionReader, sessionWriter)
|
||||
|
||||
// simulate two requests: first one GET request with the RefreshToken handler, the next one is a PUT request with
|
||||
// the token from the first request added as X-CSRF-TOKEN header
|
||||
|
||||
// first request
|
||||
retrievedToken := ""
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler := m.RefreshToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
retrievedToken = GetToken(r.Context())
|
||||
if retrievedToken == "" {
|
||||
t.Errorf("RefreshToken() did not set token in context")
|
||||
}
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
}))
|
||||
handler.ServeHTTP(rr, req)
|
||||
if status := rr.Code; status != http.StatusAccepted {
|
||||
t.Errorf("Handler() status = %d, want %d", status, http.StatusAccepted)
|
||||
}
|
||||
if retrievedToken == "" {
|
||||
t.Errorf("no token retrieved")
|
||||
}
|
||||
if retrievedToken != sessionToken {
|
||||
t.Errorf("token in context does not match token in session")
|
||||
}
|
||||
|
||||
// second request
|
||||
req = httptest.NewRequest("PUT", "/", nil)
|
||||
req.Header.Set("X-CSRF-TOKEN", retrievedToken)
|
||||
rr = httptest.NewRecorder()
|
||||
handler = m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
handler.ServeHTTP(rr, req)
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("Handler() status = %d, want %d", status, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_FormBody(t *testing.T) {
|
||||
sessionToken := "stored-token"
|
||||
sessionReader := func(r *http.Request) string {
|
||||
return sessionToken
|
||||
}
|
||||
sessionWriter := func(r *http.Request, token string) {
|
||||
sessionToken = token
|
||||
}
|
||||
m := New(sessionReader, sessionWriter)
|
||||
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
bodyData, err := request.BodyString(r)
|
||||
if err != nil {
|
||||
t.Errorf("Handler() error = %v, want nil", err)
|
||||
}
|
||||
// ensure that the body is empty - ParseForm() should have been called before by the CSRF middleware
|
||||
if bodyData != "" {
|
||||
t.Errorf("Handler() bodyData = %s, want empty", bodyData)
|
||||
}
|
||||
|
||||
if r.FormValue("_csrf") != "stored-token" {
|
||||
t.Errorf("Handler() _csrf = %s, want %s", r.FormValue("_csrf"), "stored-token")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("POST", "/", nil)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Form = make(map[string][]string)
|
||||
req.Form.Add("_csrf", "stored-token")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("Handler() status = %d, want %d", status, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_FormBodyAvailable(t *testing.T) {
|
||||
sessionToken := "stored-token"
|
||||
sessionReader := func(r *http.Request) string {
|
||||
return sessionToken
|
||||
}
|
||||
sessionWriter := func(r *http.Request, token string) {
|
||||
sessionToken = token
|
||||
}
|
||||
m := New(sessionReader, sessionWriter)
|
||||
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
bodyData, err := request.BodyString(r)
|
||||
if err != nil {
|
||||
t.Errorf("Handler() error = %v, want nil", err)
|
||||
}
|
||||
// ensure that the body is not empty, as the CSRF middleware should not have read the body
|
||||
if bodyData != "the original body" {
|
||||
t.Errorf("Handler() bodyData = %s, want %s", bodyData, "the original body")
|
||||
}
|
||||
|
||||
// check if the token is available in the form values (from query parameters)
|
||||
if r.FormValue("_csrf") != "stored-token" {
|
||||
t.Errorf("Handler() _csrf = %s, want %s", r.FormValue("_csrf"), "stored-token")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("POST", "/?_csrf=stored-token", nil)
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
req.Body = io.NopCloser(strings.NewReader("the original body"))
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("Handler() status = %d, want %d", status, http.StatusOK)
|
||||
}
|
||||
}
|
88
internal/app/api/core/middleware/csrf/options.go
Normal file
88
internal/app/api/core/middleware/csrf/options.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package csrf
|
||||
|
||||
import "net/http"
|
||||
|
||||
type SessionReader func(r *http.Request) string
|
||||
type SessionWriter func(r *http.Request, token string)
|
||||
|
||||
// options is a struct that contains options for the CSRF middleware.
|
||||
// It uses the functional options pattern for flexible configuration.
|
||||
type options struct {
|
||||
tokenLength int
|
||||
ignoreMethods []string
|
||||
|
||||
errCallbackOverride bool
|
||||
errCallback func(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
tokenGetterOverride bool
|
||||
tokenGetter func(r *http.Request) string
|
||||
|
||||
sessionGetter SessionReader
|
||||
sessionWriter SessionWriter
|
||||
}
|
||||
|
||||
// Option is a type that is used to set options for the CSRF middleware.
|
||||
// It implements the functional options pattern.
|
||||
type Option func(*options)
|
||||
|
||||
// WithTokenLength is a method that sets the token length for the CSRF middleware.
|
||||
// The default value is 32.
|
||||
func WithTokenLength(length int) Option {
|
||||
return func(o *options) {
|
||||
o.tokenLength = length
|
||||
}
|
||||
}
|
||||
|
||||
// WithErrorCallback is a method that sets the error callback function for the CSRF middleware.
|
||||
// The error callback function is called when the CSRF token is invalid.
|
||||
// The default behavior is to write a 403 Forbidden response.
|
||||
func WithErrorCallback(fn func(w http.ResponseWriter, r *http.Request)) Option {
|
||||
return func(o *options) {
|
||||
o.errCallback = fn
|
||||
o.errCallbackOverride = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithTokenGetter is a method that sets the token getter function for the CSRF middleware.
|
||||
// The token getter function is called to get the CSRF token from the request.
|
||||
// The default behavior is to get the token from the "X-CSRF-Token" header.
|
||||
func WithTokenGetter(fn func(r *http.Request) string) Option {
|
||||
return func(o *options) {
|
||||
o.tokenGetter = fn
|
||||
o.tokenGetterOverride = true
|
||||
}
|
||||
}
|
||||
|
||||
// withSessionReader is a method that sets the session reader function for the CSRF middleware.
|
||||
// The session reader function is called to get the CSRF token from the session.
|
||||
func withSessionReader(fn SessionReader) Option {
|
||||
return func(o *options) {
|
||||
o.sessionGetter = fn
|
||||
}
|
||||
}
|
||||
|
||||
// withSessionWriter is a method that sets the session writer function for the CSRF middleware.
|
||||
// The session writer function is called to write the CSRF token to the session.
|
||||
func withSessionWriter(fn SessionWriter) Option {
|
||||
return func(o *options) {
|
||||
o.sessionWriter = fn
|
||||
}
|
||||
}
|
||||
|
||||
// newOptions is a function that returns a new options struct with sane default values.
|
||||
func newOptions(opts ...Option) options {
|
||||
o := options{
|
||||
tokenLength: 32,
|
||||
ignoreMethods: []string{"GET", "HEAD", "OPTIONS"},
|
||||
errCallbackOverride: false,
|
||||
errCallback: defaultErrorHandler,
|
||||
tokenGetterOverride: false,
|
||||
tokenGetter: defaultTokenGetter,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&o)
|
||||
}
|
||||
|
||||
return o
|
||||
}
|
75
internal/app/api/core/middleware/csrf/options_test.go
Normal file
75
internal/app/api/core/middleware/csrf/options_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package csrf
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWithTokenLength(t *testing.T) {
|
||||
o := newOptions(WithTokenLength(64))
|
||||
if o.tokenLength != 64 {
|
||||
t.Errorf("WithTokenLength() = %d, want %d", o.tokenLength, 64)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithErrorCallback(t *testing.T) {
|
||||
callback := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
}
|
||||
o := newOptions(WithErrorCallback(callback))
|
||||
if !o.errCallbackOverride {
|
||||
t.Errorf("WithErrorCallback() did not set errCallbackOverride to true")
|
||||
}
|
||||
if o.errCallback == nil {
|
||||
t.Errorf("WithErrorCallback() did not set errCallback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTokenGetter(t *testing.T) {
|
||||
getter := func(r *http.Request) string {
|
||||
return "test-token"
|
||||
}
|
||||
o := newOptions(WithTokenGetter(getter))
|
||||
if !o.tokenGetterOverride {
|
||||
t.Errorf("WithTokenGetter() did not set tokenGetterOverride to true")
|
||||
}
|
||||
if o.tokenGetter == nil {
|
||||
t.Errorf("WithTokenGetter() did not set tokenGetter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithSessionReader(t *testing.T) {
|
||||
reader := func(r *http.Request) string {
|
||||
return "session-token"
|
||||
}
|
||||
o := newOptions(withSessionReader(reader))
|
||||
if o.sessionGetter == nil {
|
||||
t.Errorf("withSessionReader() did not set sessionGetter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithSessionWriter(t *testing.T) {
|
||||
writer := func(r *http.Request, token string) {
|
||||
// do nothing
|
||||
}
|
||||
o := newOptions(withSessionWriter(writer))
|
||||
if o.sessionWriter == nil {
|
||||
t.Errorf("withSessionWriter() did not set sessionWriter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewOptionsDefaults(t *testing.T) {
|
||||
o := newOptions()
|
||||
if o.tokenLength != 32 {
|
||||
t.Errorf("newOptions() default tokenLength = %d, want %d", o.tokenLength, 32)
|
||||
}
|
||||
if len(o.ignoreMethods) != 3 {
|
||||
t.Errorf("newOptions() default ignoreMethods length = %d, want %d", len(o.ignoreMethods), 3)
|
||||
}
|
||||
if o.errCallback == nil {
|
||||
t.Errorf("newOptions() default errCallback is nil")
|
||||
}
|
||||
if o.tokenGetter == nil {
|
||||
t.Errorf("newOptions() default tokenGetter is nil")
|
||||
}
|
||||
}
|
90
internal/app/api/core/middleware/csrf/token.go
Normal file
90
internal/app/api/core/middleware/csrf/token.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package csrf
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
)
|
||||
|
||||
// checkForPRNG is a function that checks if a cryptographically secure PRNG is available.
|
||||
// If it is not available, the function panics.
|
||||
func checkForPRNG() {
|
||||
buf := make([]byte, 1)
|
||||
_, err := io.ReadFull(rand.Reader, buf)
|
||||
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand is unavailable: Read() failed with %#v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// generateToken is a function that generates a secure random CSRF token.
|
||||
func generateToken(length int) []byte {
|
||||
bytes := make([]byte, length)
|
||||
|
||||
if _, err := io.ReadFull(rand.Reader, bytes); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return bytes
|
||||
}
|
||||
|
||||
// encodeToken is a function that encodes a token to a base64 string.
|
||||
func encodeToken(token []byte) string {
|
||||
return base64.URLEncoding.EncodeToString(token)
|
||||
}
|
||||
|
||||
// decodeToken is a function that decodes a base64 string to a token.
|
||||
func decodeToken(token string) ([]byte, error) {
|
||||
return base64.URLEncoding.DecodeString(token)
|
||||
}
|
||||
|
||||
// maskToken is a function that masks a token with a given key.
|
||||
// The returned byte slice contains the key + the masked token.
|
||||
// The key needs to have the same length as the token, otherwise the function panics.
|
||||
// So the resulting slice has a length of len(token) * 2.
|
||||
func maskToken(token, key []byte) []byte {
|
||||
if len(token) != len(key) {
|
||||
panic("token and key must have the same length")
|
||||
}
|
||||
|
||||
// masked contains the key in the first half and the XOR masked token in the second half
|
||||
tokenLength := len(token)
|
||||
masked := make([]byte, tokenLength*2)
|
||||
for i := 0; i < len(token); i++ {
|
||||
masked[i] = key[i]
|
||||
masked[i+tokenLength] = token[i] ^ key[i] // XOR mask
|
||||
}
|
||||
|
||||
return masked
|
||||
}
|
||||
|
||||
// unmaskToken is a function that unmask a token which contains the key in the first half.
|
||||
// The returned byte slice contains the unmasked token, it has exactly half the length of the input slice.
|
||||
func unmaskToken(masked []byte) []byte {
|
||||
tokenLength := len(masked) / 2
|
||||
token := make([]byte, tokenLength)
|
||||
for i := 0; i < tokenLength; i++ {
|
||||
token[i] = masked[i] ^ masked[i+tokenLength] // XOR unmask
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
// tokenEqual is a function that compares two tokens for equality.
|
||||
func tokenEqual(a, b string) bool {
|
||||
decodedA, err := decodeToken(a)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
decodedB, err := decodeToken(b)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
unmaskedA := unmaskToken(decodedA)
|
||||
unmaskedB := unmaskToken(decodedB)
|
||||
|
||||
return slices.Equal(unmaskedA, unmaskedB)
|
||||
}
|
81
internal/app/api/core/middleware/csrf/token_test.go
Normal file
81
internal/app/api/core/middleware/csrf/token_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package csrf
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCheckForPRNG(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("checkForPRNG() panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
checkForPRNG()
|
||||
}
|
||||
|
||||
func TestGenerateToken(t *testing.T) {
|
||||
length := 32
|
||||
token := generateToken(length)
|
||||
if len(token) != length {
|
||||
t.Errorf("generateToken() returned token of length %d, expected %d", len(token), length)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeToken(t *testing.T) {
|
||||
token := []byte("testtoken")
|
||||
encoded := encodeToken(token)
|
||||
expected := base64.URLEncoding.EncodeToString(token)
|
||||
if encoded != expected {
|
||||
t.Errorf("encodeToken() = %v, want %v", encoded, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToken(t *testing.T) {
|
||||
token := "dGVzdHRva2Vu"
|
||||
expected := []byte("testtoken")
|
||||
decoded, err := decodeToken(token)
|
||||
if err != nil {
|
||||
t.Errorf("decodeToken() error = %v", err)
|
||||
}
|
||||
if string(decoded) != string(expected) {
|
||||
t.Errorf("decodeToken() = %v, want %v", decoded, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaskToken(t *testing.T) {
|
||||
token := []byte("testtoken")
|
||||
key := []byte("keykeykey")
|
||||
masked := maskToken(token, key)
|
||||
if len(masked) != len(token)*2 {
|
||||
t.Errorf("maskToken() returned masked token of length %d, expected %d", len(masked), len(token)*2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmaskToken(t *testing.T) {
|
||||
token := []byte("testtoken")
|
||||
key := []byte("keykeykey")
|
||||
masked := maskToken(token, key)
|
||||
unmasked := unmaskToken(masked)
|
||||
if string(unmasked) != string(token) {
|
||||
t.Errorf("unmaskToken() = %v, want %v", unmasked, token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenEqual(t *testing.T) {
|
||||
tokenA := encodeToken(maskToken([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}))
|
||||
tokenB := encodeToken(maskToken([]byte{0x01, 0x02, 0x03}, []byte{0x04, 0x05, 0x06}))
|
||||
if !tokenEqual(tokenA, tokenB) {
|
||||
t.Errorf("tokenEqual() = false, want true")
|
||||
}
|
||||
|
||||
tokenC := encodeToken(maskToken([]byte{0x01, 0x02, 0x03}, []byte{0x07, 0x08, 0x09}))
|
||||
if !tokenEqual(tokenA, tokenC) {
|
||||
t.Errorf("tokenEqual() = false, want true")
|
||||
}
|
||||
|
||||
tokenD := encodeToken(maskToken([]byte{0x09, 0x02, 0x03}, []byte{0x04, 0x05, 0x06}))
|
||||
if tokenEqual(tokenA, tokenD) {
|
||||
t.Errorf("tokenEqual() = true, want false")
|
||||
}
|
||||
}
|
199
internal/app/api/core/middleware/logging/middleware.go
Normal file
199
internal/app/api/core/middleware/logging/middleware.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LogLevel is an enumeration of the different log levels.
|
||||
type LogLevel int
|
||||
|
||||
const (
|
||||
LogLevelDebug LogLevel = iota
|
||||
LogLevelInfo
|
||||
LogLevelWarn
|
||||
LogLevelError
|
||||
)
|
||||
|
||||
// Logger is an interface that defines the methods that a logger must implement.
|
||||
// This allows the logging middleware to be used with different logging libraries.
|
||||
type Logger interface {
|
||||
// Debugf logs a message at debug level.
|
||||
Debugf(format string, args ...any)
|
||||
// Infof logs a message at info level.
|
||||
Infof(format string, args ...any)
|
||||
// Warnf logs a message at warn level.
|
||||
Warnf(format string, args ...any)
|
||||
// Errorf logs a message at error level.
|
||||
Errorf(format string, args ...any)
|
||||
}
|
||||
|
||||
// Middleware is a type that creates a new logging middleware. The logging middleware
|
||||
// logs information about each request.
|
||||
type Middleware struct {
|
||||
o options
|
||||
}
|
||||
|
||||
// New returns a new logging middleware with the provided options.
|
||||
func New(opts ...Option) *Middleware {
|
||||
o := newOptions(opts...)
|
||||
|
||||
m := &Middleware{
|
||||
o: o,
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Handler returns the logging middleware handler.
|
||||
func (m *Middleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ww := newWriterWrapper(w)
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
info := m.extractInfoMap(r, start, ww)
|
||||
|
||||
if m.o.logger == nil {
|
||||
msg, args := m.buildSlogMessageAndArguments(info)
|
||||
m.logMsg(msg, args...)
|
||||
} else {
|
||||
msg := m.buildNormalLogMessage(info)
|
||||
m.logMsg(msg)
|
||||
}
|
||||
}()
|
||||
|
||||
next.ServeHTTP(ww, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Middleware) extractInfoMap(r *http.Request, start time.Time, ww *writerWrapper) map[string]any {
|
||||
info := make(map[string]any)
|
||||
|
||||
info["method"] = r.Method
|
||||
info["path"] = r.URL.Path
|
||||
info["protocol"] = r.Proto
|
||||
info["clientIP"] = r.Header.Get("X-Forwarded-For")
|
||||
if info["clientIP"] == "" {
|
||||
// If the X-Forwarded-For header is not set, use the remote address without the port number.
|
||||
lastColonIndex := strings.LastIndex(r.RemoteAddr, ":")
|
||||
switch lastColonIndex {
|
||||
case -1:
|
||||
info["clientIP"] = r.RemoteAddr
|
||||
default:
|
||||
info["clientIP"] = r.RemoteAddr[:lastColonIndex]
|
||||
}
|
||||
}
|
||||
info["userAgent"] = r.UserAgent()
|
||||
info["referer"] = r.Header.Get("Referer")
|
||||
info["duration"] = time.Since(start).String()
|
||||
info["status"] = ww.StatusCode
|
||||
info["dataLength"] = ww.WrittenBytes
|
||||
|
||||
if m.o.headerRequestIdKey != "" {
|
||||
info["headerRequestId"] = r.Header.Get(m.o.headerRequestIdKey)
|
||||
}
|
||||
if m.o.contextRequestIdKey != "" {
|
||||
info["contextRequestId"], _ = r.Context().Value(m.o.contextRequestIdKey).(string)
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
func (m *Middleware) buildNormalLogMessage(info map[string]any) string {
|
||||
switch {
|
||||
case info["headerRequestId"] != nil && info["contextRequestId"] != nil:
|
||||
return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s - rid=%s ctx=%s",
|
||||
info["method"], info["path"], info["protocol"],
|
||||
info["status"], info["dataLength"],
|
||||
info["duration"],
|
||||
info["clientIP"], info["userAgent"], info["referer"],
|
||||
info["headerRequestId"], info["contextRequestId"])
|
||||
case info["headerRequestId"] != nil:
|
||||
return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s - rid=%s",
|
||||
info["method"], info["path"], info["protocol"],
|
||||
info["status"], info["dataLength"],
|
||||
info["duration"],
|
||||
info["clientIP"], info["userAgent"], info["referer"],
|
||||
info["headerRequestId"])
|
||||
case info["contextRequestId"] != nil:
|
||||
return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s - ctx=%s",
|
||||
info["method"], info["path"], info["protocol"],
|
||||
info["status"], info["dataLength"],
|
||||
info["duration"],
|
||||
info["clientIP"], info["userAgent"], info["referer"],
|
||||
info["contextRequestId"])
|
||||
default:
|
||||
return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s",
|
||||
info["method"], info["path"], info["protocol"],
|
||||
info["status"], info["dataLength"],
|
||||
info["duration"],
|
||||
info["clientIP"], info["userAgent"], info["referer"])
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) buildSlogMessageAndArguments(info map[string]any) (message string, args []any) {
|
||||
message = fmt.Sprintf("%s %s", info["method"], info["path"])
|
||||
|
||||
// Use a fixed order for the keys, so that the message is always the same.
|
||||
// Skip method and path as they are already in the message.
|
||||
keys := []string{
|
||||
"protocol",
|
||||
"status",
|
||||
"dataLength",
|
||||
"duration",
|
||||
"clientIP",
|
||||
"userAgent",
|
||||
"referer",
|
||||
"headerRequestId",
|
||||
"contextRequestId",
|
||||
}
|
||||
for _, k := range keys {
|
||||
if v, ok := info[k]; ok {
|
||||
args = append(args, k, v) // only add key, value if it exists
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Middleware) addPrefix(message string) string {
|
||||
if m.o.prefix != "" {
|
||||
return m.o.prefix + " " + message
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
func (m *Middleware) logMsg(message string, args ...any) {
|
||||
message = m.addPrefix(message)
|
||||
|
||||
if m.o.logger != nil {
|
||||
switch m.o.logLevel {
|
||||
case LogLevelDebug:
|
||||
m.o.logger.Debugf(message, args...)
|
||||
case LogLevelInfo:
|
||||
m.o.logger.Infof(message, args...)
|
||||
case LogLevelWarn:
|
||||
m.o.logger.Warnf(message, args...)
|
||||
case LogLevelError:
|
||||
m.o.logger.Errorf(message, args...)
|
||||
default:
|
||||
m.o.logger.Infof(message, args...)
|
||||
}
|
||||
} else {
|
||||
switch m.o.logLevel {
|
||||
case LogLevelDebug:
|
||||
slog.Debug(message, args...)
|
||||
case LogLevelInfo:
|
||||
slog.Info(message, args...)
|
||||
case LogLevelWarn:
|
||||
slog.Warn(message, args...)
|
||||
case LogLevelError:
|
||||
slog.Error(message, args...)
|
||||
default:
|
||||
slog.Info(message, args...)
|
||||
}
|
||||
}
|
||||
}
|
148
internal/app/api/core/middleware/logging/middleware_test.go
Normal file
148
internal/app/api/core/middleware/logging/middleware_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockLogger struct {
|
||||
messages []string
|
||||
}
|
||||
|
||||
func (m *mockLogger) Debugf(format string, _ ...any) {
|
||||
m.messages = append(m.messages, "DEBUG: "+format)
|
||||
}
|
||||
func (m *mockLogger) Infof(format string, _ ...any) {
|
||||
m.messages = append(m.messages, "INFO: "+format)
|
||||
}
|
||||
func (m *mockLogger) Warnf(format string, _ ...any) {
|
||||
m.messages = append(m.messages, "WARN: "+format)
|
||||
}
|
||||
func (m *mockLogger) Errorf(format string, _ ...any) {
|
||||
m.messages = append(m.messages, "ERROR: "+format)
|
||||
}
|
||||
|
||||
func TestMiddleware_Normal(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
_, _ = w.Write([]byte("Hello, World!"))
|
||||
})
|
||||
|
||||
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(rr, req)
|
||||
|
||||
if status := rr.Code; status != http.StatusTeapot {
|
||||
t.Errorf("expected status code to be %v, got %v", http.StatusTeapot, status)
|
||||
}
|
||||
|
||||
expected := "Hello, World!"
|
||||
if rr.Body.String() != expected {
|
||||
t.Errorf("expected response body to be %v, got %v", expected, rr.Body.String())
|
||||
}
|
||||
|
||||
if len(logger.messages) == 0 {
|
||||
t.Errorf("expected log messages, got none")
|
||||
}
|
||||
|
||||
if len(logger.messages) != 0 && !strings.Contains(logger.messages[0], "ERROR: GET /foo") {
|
||||
t.Errorf("expected log message to contain request info, got %v", logger.messages[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_Extended(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
_, _ = w.Write([]byte("Hello, World!"))
|
||||
})
|
||||
|
||||
middleware := New(WithContextRequestIdKey("requestId"), WithHeaderRequestIdKey("X-Request-Id")).
|
||||
Handler(handler)
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(rr, req)
|
||||
|
||||
if status := rr.Code; status != http.StatusTeapot {
|
||||
t.Errorf("expected status code to be %v, got %v", http.StatusTeapot, status)
|
||||
}
|
||||
|
||||
expected := "Hello, World!"
|
||||
if rr.Body.String() != expected {
|
||||
t.Errorf("expected response body to be %v, got %v", expected, rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_Logger_remoteAddr(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
_, _ = w.Write([]byte("Hello, World!"))
|
||||
})
|
||||
|
||||
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
req.RemoteAddr = "xhamster.com:1234"
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(rr, req)
|
||||
|
||||
}
|
||||
|
||||
func TestMiddleware_Logger_remoteAddrNoPort(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
_, _ = w.Write([]byte("Hello, World!"))
|
||||
})
|
||||
|
||||
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
req.RemoteAddr = "xhamster.com"
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(rr, req)
|
||||
|
||||
}
|
||||
|
||||
func TestMiddleware_Logger_remoteAddrV6(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
_, _ = w.Write([]byte("Hello, World!"))
|
||||
})
|
||||
|
||||
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
req.RemoteAddr = "[::1]:4711"
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(rr, req)
|
||||
|
||||
}
|
||||
|
||||
func TestMiddleware_Logger_remoteAddrV6NoPort(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
_, _ = w.Write([]byte("Hello, World!"))
|
||||
})
|
||||
|
||||
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
req.RemoteAddr = "[::1]"
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(rr, req)
|
||||
|
||||
}
|
80
internal/app/api/core/middleware/logging/options.go
Normal file
80
internal/app/api/core/middleware/logging/options.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package logging
|
||||
|
||||
// options is a struct that contains options for the logging middleware.
|
||||
// It uses the functional options pattern for flexible configuration.
|
||||
type options struct {
|
||||
logLevel LogLevel
|
||||
logger Logger
|
||||
prefix string
|
||||
|
||||
contextRequestIdKey string
|
||||
headerRequestIdKey string
|
||||
}
|
||||
|
||||
// Option is a type that is used to set options for the logging middleware.
|
||||
// It implements the functional options pattern.
|
||||
type Option func(*options)
|
||||
|
||||
// WithLevel is a method that sets the log level for the logging middleware.
|
||||
// Possible values are LogLevelDebug, LogLevelInfo, LogLevelWarn, and LogLevelError.
|
||||
// The default value is LogLevelInfo.
|
||||
func WithLevel(level LogLevel) Option {
|
||||
return func(o *options) {
|
||||
o.logLevel = level
|
||||
}
|
||||
}
|
||||
|
||||
// WithPrefix is a method that sets the prefix for the logging middleware.
|
||||
// If a prefix is set, it will be prepended to each log message. A space will
|
||||
// be added between the prefix and the log message.
|
||||
// The default value is an empty string.
|
||||
func WithPrefix(prefix string) Option {
|
||||
return func(o *options) {
|
||||
o.prefix = prefix
|
||||
}
|
||||
}
|
||||
|
||||
// WithContextRequestIdKey is a method that sets the key for the request ID in the
|
||||
// request context. If a key is set, the logging middleware will use this key to
|
||||
// retrieve the request ID from the request context.
|
||||
// The default value is an empty string, meaning the request ID will not be logged.
|
||||
func WithContextRequestIdKey(key string) Option {
|
||||
return func(o *options) {
|
||||
o.contextRequestIdKey = key
|
||||
}
|
||||
}
|
||||
|
||||
// WithHeaderRequestIdKey is a method that sets the key for the request ID in the
|
||||
// request headers. If a key is set, the logging middleware will use this key to
|
||||
// retrieve the request ID from the request headers.
|
||||
// The default value is an empty string, meaning the request ID will not be logged.
|
||||
func WithHeaderRequestIdKey(key string) Option {
|
||||
return func(o *options) {
|
||||
o.headerRequestIdKey = key
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger is a method that sets the logger for the logging middleware.
|
||||
// If a logger is set, the logging middleware will use this logger to log messages.
|
||||
// The default logger is the structured slog logger.
|
||||
func WithLogger(logger Logger) Option {
|
||||
return func(o *options) {
|
||||
o.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// newOptions is a function that returns a new options struct with sane default values.
|
||||
func newOptions(opts ...Option) options {
|
||||
o := options{
|
||||
logLevel: LogLevelInfo,
|
||||
logger: nil,
|
||||
prefix: "",
|
||||
contextRequestIdKey: "",
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&o)
|
||||
}
|
||||
|
||||
return o
|
||||
}
|
88
internal/app/api/core/middleware/logging/options_test.go
Normal file
88
internal/app/api/core/middleware/logging/options_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWithLevel(t *testing.T) {
|
||||
// table test to check all possible log levels
|
||||
levels := []LogLevel{
|
||||
LogLevelDebug,
|
||||
LogLevelInfo,
|
||||
LogLevelWarn,
|
||||
LogLevelError,
|
||||
}
|
||||
|
||||
for _, level := range levels {
|
||||
opt := WithLevel(level)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.logLevel != level {
|
||||
t.Errorf("expected log level to be %v, got %v", level, o.logLevel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithPrefix(t *testing.T) {
|
||||
prefix := "TEST"
|
||||
opt := WithPrefix(prefix)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.prefix != prefix {
|
||||
t.Errorf("expected prefix to be %v, got %v", prefix, o.prefix)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithContextRequestIdKey(t *testing.T) {
|
||||
key := "contextKey"
|
||||
opt := WithContextRequestIdKey(key)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.contextRequestIdKey != key {
|
||||
t.Errorf("expected contextRequestIdKey to be %v, got %v", key, o.contextRequestIdKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithHeaderRequestIdKey(t *testing.T) {
|
||||
key := "headerKey"
|
||||
opt := WithHeaderRequestIdKey(key)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.headerRequestIdKey != key {
|
||||
t.Errorf("expected headerRequestIdKey to be %v, got %v", key, o.headerRequestIdKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithLogger(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
opt := WithLogger(logger)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.logger != logger {
|
||||
t.Errorf("expected logger to be %v, got %v", logger, o.logger)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaults(t *testing.T) {
|
||||
o := newOptions()
|
||||
|
||||
if o.logLevel != LogLevelInfo {
|
||||
t.Errorf("expected log level to be %v, got %v", LogLevelInfo, o.logLevel)
|
||||
}
|
||||
|
||||
if o.logger != nil {
|
||||
t.Errorf("expected logger to be nil, got %v", o.logger)
|
||||
}
|
||||
|
||||
if o.prefix != "" {
|
||||
t.Errorf("expected prefix to be empty, got %v", o.prefix)
|
||||
}
|
||||
|
||||
if o.contextRequestIdKey != "" {
|
||||
t.Errorf("expected contextRequestIdKey to be empty, got %v", o.contextRequestIdKey)
|
||||
}
|
||||
|
||||
if o.headerRequestIdKey != "" {
|
||||
t.Errorf("expected headerRequestIdKey to be empty, got %v", o.headerRequestIdKey)
|
||||
}
|
||||
}
|
45
internal/app/api/core/middleware/logging/writer.go
Normal file
45
internal/app/api/core/middleware/logging/writer.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// writerWrapper wraps a http.ResponseWriter and tracks the number of bytes written to it.
|
||||
// It also tracks the http response code passed to the WriteHeader func of
|
||||
// the ResponseWriter.
|
||||
type writerWrapper struct {
|
||||
http.ResponseWriter
|
||||
|
||||
// StatusCode is the last http response code passed to the WriteHeader func of
|
||||
// the ResponseWriter. If no such call is made, a default code of http.StatusOK
|
||||
// is assumed instead.
|
||||
StatusCode int
|
||||
|
||||
// WrittenBytes is the number of bytes successfully written by the Write or
|
||||
// ReadFrom function of the ResponseWriter. ResponseWriters may also write
|
||||
// data to their underlaying connection directly (e.g. headers), but those
|
||||
// are not tracked. Therefor the number of Written bytes will usually match
|
||||
// the size of the response body.
|
||||
WrittenBytes int64
|
||||
}
|
||||
|
||||
// WriteHeader wraps the WriteHeader method of the ResponseWriter and tracks the
|
||||
// http response code passed to it.
|
||||
func (w *writerWrapper) WriteHeader(code int) {
|
||||
w.StatusCode = code
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// Write wraps the Write method of the ResponseWriter and tracks the number of bytes
|
||||
// written to it.
|
||||
func (w *writerWrapper) Write(data []byte) (int, error) {
|
||||
n, err := w.ResponseWriter.Write(data)
|
||||
w.WrittenBytes += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// newWriterWrapper returns a new writerWrapper that wraps the given http.ResponseWriter.
|
||||
// It initializes the StatusCode to http.StatusOK.
|
||||
func newWriterWrapper(w http.ResponseWriter) *writerWrapper {
|
||||
return &writerWrapper{ResponseWriter: w, StatusCode: http.StatusOK}
|
||||
}
|
85
internal/app/api/core/middleware/logging/writer_test.go
Normal file
85
internal/app/api/core/middleware/logging/writer_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWriterWrapper_WriteHeader(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
ww := newWriterWrapper(rr)
|
||||
|
||||
ww.WriteHeader(http.StatusNotFound)
|
||||
|
||||
if ww.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("expected status code to be %v, got %v", http.StatusNotFound, ww.StatusCode)
|
||||
}
|
||||
if rr.Code != http.StatusNotFound {
|
||||
t.Errorf("expected recorder status code to be %v, got %v", http.StatusNotFound, rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriterWrapper_Write(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
ww := newWriterWrapper(rr)
|
||||
|
||||
data := []byte("Hello, World!")
|
||||
n, err := ww.Write(data)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("expected written bytes to be %v, got %v", len(data), n)
|
||||
}
|
||||
if ww.WrittenBytes != int64(len(data)) {
|
||||
t.Errorf("expected WrittenBytes to be %v, got %v", len(data), ww.WrittenBytes)
|
||||
}
|
||||
if rr.Body.String() != string(data) {
|
||||
t.Errorf("expected response body to be %v, got %v", string(data), rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriterWrapper_WriteWithHeaders(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
ww := newWriterWrapper(rr)
|
||||
|
||||
data := []byte("Hello, World!")
|
||||
n, err := ww.Write(data)
|
||||
|
||||
ww.Header().Set("Content-Type", "text/plain")
|
||||
ww.Header().Set("X-Some-Header", "some-value")
|
||||
ww.WriteHeader(http.StatusTeapot)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("expected written bytes to be %v, got %v", len(data), n)
|
||||
}
|
||||
if ww.WrittenBytes != int64(len(data)) {
|
||||
t.Errorf("expected WrittenBytes to be %v, got %v", len(data), ww.WrittenBytes)
|
||||
}
|
||||
if rr.Body.String() != string(data) {
|
||||
t.Errorf("expected response body to be %v, got %v", string(data), rr.Body.String())
|
||||
}
|
||||
if ww.StatusCode != http.StatusTeapot {
|
||||
t.Errorf("expected status code to be %v, got %v", http.StatusTeapot, ww.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWriterWrapper(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
ww := newWriterWrapper(rr)
|
||||
|
||||
if ww.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected initial status code to be %v, got %v", http.StatusOK, ww.StatusCode)
|
||||
}
|
||||
if ww.WrittenBytes != 0 {
|
||||
t.Errorf("expected initial WrittenBytes to be %v, got %v", 0, ww.WrittenBytes)
|
||||
}
|
||||
if ww.ResponseWriter != rr {
|
||||
t.Errorf("expected ResponseWriter to be %v, got %v", rr, ww.ResponseWriter)
|
||||
}
|
||||
}
|
133
internal/app/api/core/middleware/recovery/middleware.go
Normal file
133
internal/app/api/core/middleware/recovery/middleware.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package recovery
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Logger is an interface that defines the methods that a logger must implement.
|
||||
// This allows the logging middleware to be used with different logging libraries.
|
||||
type Logger interface {
|
||||
// Errorf logs a message at error level.
|
||||
Errorf(format string, args ...any)
|
||||
}
|
||||
|
||||
// Middleware is a type that creates a new recovery middleware. The recovery middleware
|
||||
// recovers from panics and returns an Internal Server Error response. This middleware should
|
||||
// be the first middleware in the middleware chain, so that it can recover from panics in other
|
||||
// middlewares.
|
||||
type Middleware struct {
|
||||
o options
|
||||
}
|
||||
|
||||
// New returns a new recovery middleware with the provided options.
|
||||
func New(opts ...Option) *Middleware {
|
||||
o := newOptions(opts...)
|
||||
|
||||
m := &Middleware{
|
||||
o: o,
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Handler returns the recovery middleware handler.
|
||||
func (m *Middleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
stack := debug.Stack()
|
||||
|
||||
realErr, ok := err.(error)
|
||||
if !ok {
|
||||
realErr = fmt.Errorf("%v", err)
|
||||
}
|
||||
|
||||
// Check for a broken connection, as it is not really a
|
||||
// condition that warrants a panic stack trace.
|
||||
brokenPipe := isBrokenPipeError(realErr)
|
||||
|
||||
// Log the error and stack trace
|
||||
if m.o.logCallback != nil {
|
||||
m.o.logCallback(realErr, stack, brokenPipe)
|
||||
}
|
||||
|
||||
switch {
|
||||
case brokenPipe && m.o.brokenPipeCallback != nil:
|
||||
m.o.brokenPipeCallback(realErr, stack, w, r)
|
||||
case !brokenPipe && m.o.errCallback != nil:
|
||||
m.o.errCallback(realErr, stack, w, r)
|
||||
default:
|
||||
// no callback set, simply recover and do nothing...
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func addPrefix(o options, message string) string {
|
||||
if o.defaultLogPrefix != "" {
|
||||
return o.defaultLogPrefix + " " + message
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
// defaultErrCallback is the default error callback function for the recovery middleware.
|
||||
// It writes a JSON response with an Internal Server Error status code. If the exposeStackTrace option is
|
||||
// enabled, the stack trace is included in the response.
|
||||
func getDefaultErrCallback(o options) func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {
|
||||
return func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {
|
||||
responseBody := map[string]interface{}{
|
||||
"error": "Internal Server Error",
|
||||
}
|
||||
if o.exposeStackTrace && len(stack) > 0 {
|
||||
responseBody["stack"] = string(stack)
|
||||
}
|
||||
|
||||
jsonBody, _ := json.Marshal(responseBody)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write(jsonBody)
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultLogCallback is the default log callback function for the recovery middleware.
|
||||
// It logs the error and stack trace using the structured slog logger or the provided logger in Error level.
|
||||
func getDefaultLogCallback(o options) func(error, []byte, bool) {
|
||||
return func(err error, stack []byte, brokenPipe bool) {
|
||||
if brokenPipe {
|
||||
return // by default, ignore broken pipe errors
|
||||
}
|
||||
|
||||
switch {
|
||||
case o.useSlog:
|
||||
slog.Error(addPrefix(o, err.Error()), "stack", string(stack))
|
||||
case o.logger != nil:
|
||||
o.logger.Errorf(fmt.Sprintf("%s; stacktrace=%s", addPrefix(o, err.Error()), string(stack)))
|
||||
default:
|
||||
// no logger set, do nothing...
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isBrokenPipeError(err error) bool {
|
||||
var syscallErr *os.SyscallError
|
||||
if errors.As(err, &syscallErr) {
|
||||
errMsg := strings.ToLower(syscallErr.Err.Error())
|
||||
if strings.Contains(errMsg, "broken pipe") ||
|
||||
strings.Contains(errMsg, "connection reset by peer") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
149
internal/app/api/core/middleware/recovery/middleware_test.go
Normal file
149
internal/app/api/core/middleware/recovery/middleware_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package recovery
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockLogger struct{}
|
||||
|
||||
func (m *mockLogger) Errorf(_ string, _ ...any) {}
|
||||
|
||||
func TestMiddleware(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
options []Option
|
||||
panicSimulator func()
|
||||
expectedStatus int
|
||||
expectedBody string
|
||||
expectStack bool
|
||||
}{
|
||||
{
|
||||
name: "default behavior",
|
||||
options: []Option{},
|
||||
panicSimulator: func() {
|
||||
panic(errors.New("test panic"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: `{"error":"Internal Server Error"}`,
|
||||
},
|
||||
{
|
||||
name: "custom error callback",
|
||||
options: []Option{
|
||||
WithErrCallback(func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
w.Write([]byte("custom error"))
|
||||
}),
|
||||
},
|
||||
panicSimulator: func() {
|
||||
panic(errors.New("test panic"))
|
||||
},
|
||||
expectedStatus: http.StatusTeapot,
|
||||
expectedBody: "custom error",
|
||||
},
|
||||
{
|
||||
name: "broken pipe error",
|
||||
options: []Option{
|
||||
WithBrokenPipeCallback(func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
w.Write([]byte("broken pipe"))
|
||||
}),
|
||||
},
|
||||
panicSimulator: func() {
|
||||
panic(&os.SyscallError{Err: errors.New("broken pipe")})
|
||||
},
|
||||
expectedStatus: http.StatusServiceUnavailable,
|
||||
expectedBody: "broken pipe",
|
||||
},
|
||||
{
|
||||
name: "default callback broken pipe error",
|
||||
options: nil,
|
||||
panicSimulator: func() {
|
||||
panic(&os.SyscallError{Err: errors.New("broken pipe")})
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: "",
|
||||
},
|
||||
{
|
||||
name: "default callback normal error",
|
||||
options: nil,
|
||||
panicSimulator: func() {
|
||||
panic("something went wrong")
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: "{\"error\":\"Internal Server Error\"}",
|
||||
},
|
||||
{
|
||||
name: "default callback with stack trace",
|
||||
options: []Option{
|
||||
WithExposeStackTrace(true),
|
||||
},
|
||||
panicSimulator: func() {
|
||||
panic("something went wrong")
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: "\"stack\":",
|
||||
expectStack: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
handler := New(tt.options...).Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tt.panicSimulator()
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != tt.expectedStatus {
|
||||
t.Errorf("expected status %v, got %v", tt.expectedStatus, rr.Code)
|
||||
}
|
||||
if !tt.expectStack && rr.Body.String() != tt.expectedBody {
|
||||
t.Errorf("expected body %v, got %v", tt.expectedBody, rr.Body.String())
|
||||
}
|
||||
if tt.expectStack && !strings.Contains(rr.Body.String(), tt.expectedBody) {
|
||||
t.Errorf("expected body to contain %v, got %v", tt.expectedBody, rr.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBrokenPipeError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "broken pipe error",
|
||||
err: &os.SyscallError{Err: errors.New("broken pipe")},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "connection reset by peer error",
|
||||
err: &os.SyscallError{Err: errors.New("connection reset by peer")},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
err: errors.New("other error"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isBrokenPipeError(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
129
internal/app/api/core/middleware/recovery/options.go
Normal file
129
internal/app/api/core/middleware/recovery/options.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package recovery
|
||||
|
||||
import "net/http"
|
||||
|
||||
// options is a struct that contains options for the recovery middleware.
|
||||
// It uses the functional options pattern for flexible configuration.
|
||||
type options struct {
|
||||
logger Logger
|
||||
useSlog bool
|
||||
|
||||
errCallbackOverride bool
|
||||
errCallback func(err error, stack []byte, w http.ResponseWriter, r *http.Request)
|
||||
brokenPipeCallbackOverride bool
|
||||
brokenPipeCallback func(err error, stack []byte, w http.ResponseWriter, r *http.Request)
|
||||
|
||||
exposeStackTrace bool
|
||||
defaultLogPrefix string
|
||||
logCallbackOverride bool
|
||||
logCallback func(err error, stack []byte, brokenPipe bool)
|
||||
}
|
||||
|
||||
// Option is a type that is used to set options for the recovery middleware.
|
||||
// It implements the functional options pattern.
|
||||
type Option func(*options)
|
||||
|
||||
// WithErrCallback sets the error callback function for the recovery middleware.
|
||||
// The error callback function is called when a panic is recovered by the middleware.
|
||||
// This function completely overrides the default behavior of the middleware. It is the
|
||||
// responsibility of the user to handle the error and write a response to the client.
|
||||
//
|
||||
// Ensure that this function does not panic, as it will be called in a deferred function!
|
||||
func WithErrCallback(fn func(err error, stack []byte, w http.ResponseWriter, r *http.Request)) Option {
|
||||
return func(o *options) {
|
||||
o.errCallback = fn
|
||||
o.errCallbackOverride = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithBrokenPipeCallback sets the broken pipe callback function for the recovery middleware.
|
||||
// The broken pipe callback function is called when a broken pipe error is recovered by the middleware.
|
||||
// This function completely overrides the default behavior of the middleware. It is the responsibility
|
||||
// of the user to handle the error and write a response to the client.
|
||||
//
|
||||
// Ensure that this function does not panic, as it will be called in a deferred function!
|
||||
func WithBrokenPipeCallback(fn func(err error, stack []byte, w http.ResponseWriter, r *http.Request)) Option {
|
||||
return func(o *options) {
|
||||
o.brokenPipeCallback = fn
|
||||
o.brokenPipeCallbackOverride = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogCallback sets the log callback function for the recovery middleware.
|
||||
// The log callback function is called when a panic is recovered by the middleware.
|
||||
// This function allows the user to log the error and stack trace. The default behavior is to log
|
||||
// the error and stack trace in Error level.
|
||||
// This function completely overrides the default behavior of the middleware.
|
||||
//
|
||||
// Ensure that this function does not panic, as it will be called in a deferred function!
|
||||
func WithLogCallback(fn func(err error, stack []byte, brokenPipe bool)) Option {
|
||||
return func(o *options) {
|
||||
o.logCallback = fn
|
||||
o.logCallbackOverride = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger is a method that sets the logger for the logging middleware.
|
||||
// If a logger is set, the logging middleware will use this logger to log messages.
|
||||
// The default logger is the structured slog logger, see WithSlog.
|
||||
func WithLogger(logger Logger) Option {
|
||||
return func(o *options) {
|
||||
o.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// WithSlog is a method that sets whether the recovery middleware should use the structured slog logger.
|
||||
// If set to true, the middleware will use the structured slog logger. If set to false, the middleware
|
||||
// will not use any logger unless one is explicitly set with the WithLogger option.
|
||||
// The default value is true.
|
||||
func WithSlog(useSlog bool) Option {
|
||||
return func(o *options) {
|
||||
o.useSlog = useSlog
|
||||
}
|
||||
}
|
||||
|
||||
// WithDefaultLogPrefix is a method that sets the default log prefix for the recovery middleware.
|
||||
// If a default log prefix is set and the default log callback is used, the prefix will be prepended
|
||||
// to each log message. A space will be added between the prefix and the log message.
|
||||
// The default value is an empty string.
|
||||
func WithDefaultLogPrefix(defaultLogPrefix string) Option {
|
||||
return func(o *options) {
|
||||
o.defaultLogPrefix = defaultLogPrefix
|
||||
}
|
||||
}
|
||||
|
||||
// WithExposeStackTrace is a method that sets whether the stack trace should be exposed in the response.
|
||||
// If set to true, the stack trace will be included in the response body. If set to false, the stack trace
|
||||
// will not be included in the response body. This only applies to the default error callback.
|
||||
// The default value is false.
|
||||
func WithExposeStackTrace(exposeStackTrace bool) Option {
|
||||
return func(o *options) {
|
||||
o.exposeStackTrace = exposeStackTrace
|
||||
}
|
||||
}
|
||||
|
||||
// newOptions is a function that returns a new options struct with sane default values.
|
||||
func newOptions(opts ...Option) options {
|
||||
o := options{
|
||||
logger: nil,
|
||||
useSlog: true,
|
||||
errCallback: nil,
|
||||
brokenPipeCallback: nil, // by default, ignore broken pipe errors
|
||||
exposeStackTrace: false,
|
||||
defaultLogPrefix: "",
|
||||
logCallback: nil,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&o)
|
||||
}
|
||||
|
||||
if o.errCallback == nil && !o.errCallbackOverride {
|
||||
o.errCallback = getDefaultErrCallback(o)
|
||||
}
|
||||
if o.logCallback == nil && !o.logCallbackOverride {
|
||||
o.logCallback = getDefaultLogCallback(o)
|
||||
}
|
||||
|
||||
return o
|
||||
}
|
100
internal/app/api/core/middleware/recovery/options_test.go
Normal file
100
internal/app/api/core/middleware/recovery/options_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package recovery
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWithErrCallback(t *testing.T) {
|
||||
callback := func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {}
|
||||
opt := WithErrCallback(callback)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.errCallback == nil {
|
||||
t.Errorf("expected errCallback to be set, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithBrokenPipeCallback(t *testing.T) {
|
||||
callback := func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {}
|
||||
opt := WithBrokenPipeCallback(callback)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.brokenPipeCallback == nil {
|
||||
t.Errorf("expected brokenPipeCallback to be set, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithLogCallback(t *testing.T) {
|
||||
callback := func(err error, stack []byte, brokenPipe bool) {}
|
||||
opt := WithLogCallback(callback)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.logCallback == nil {
|
||||
t.Errorf("expected logCallback to be set, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithLogger(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
opt := WithLogger(logger)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.logger != logger {
|
||||
t.Errorf("expected logger to be %v, got %v", logger, o.logger)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithSlog(t *testing.T) {
|
||||
opt := WithSlog(false)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.useSlog != false {
|
||||
t.Errorf("expected useSlog to be false, got %v", o.useSlog)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithDefaultLogPrefix(t *testing.T) {
|
||||
prefix := "PREFIX"
|
||||
opt := WithDefaultLogPrefix(prefix)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.defaultLogPrefix != prefix {
|
||||
t.Errorf("expected defaultLogPrefix to be %v, got %v", prefix, o.defaultLogPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithExposeStackTrace(t *testing.T) {
|
||||
opt := WithExposeStackTrace(true)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.exposeStackTrace != true {
|
||||
t.Errorf("expected exposeStackTrace to be true, got %v", o.exposeStackTrace)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewOptionsDefaults(t *testing.T) {
|
||||
o := newOptions()
|
||||
|
||||
if o.logger != nil {
|
||||
t.Errorf("expected logger to be nil, got %v", o.logger)
|
||||
}
|
||||
if o.useSlog != true {
|
||||
t.Errorf("expected useSlog to be true, got %v", o.useSlog)
|
||||
}
|
||||
if o.errCallback == nil {
|
||||
t.Errorf("expected errCallback to be set, got nil")
|
||||
}
|
||||
if o.brokenPipeCallback != nil {
|
||||
t.Errorf("expected brokenPipeCallback to be nil, got %T", o.brokenPipeCallback)
|
||||
}
|
||||
if o.exposeStackTrace != false {
|
||||
t.Errorf("expected exposeStackTrace to be false, got %T", o.exposeStackTrace)
|
||||
}
|
||||
if o.defaultLogPrefix != "" {
|
||||
t.Errorf("expected defaultLogPrefix to be empty, got %T", o.defaultLogPrefix)
|
||||
}
|
||||
if o.logCallback == nil {
|
||||
t.Errorf("expected logCallback to be set, got nil")
|
||||
}
|
||||
}
|
69
internal/app/api/core/middleware/tracing/middleware.go
Normal file
69
internal/app/api/core/middleware/tracing/middleware.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Middleware is a type that creates a new tracing middleware. The tracing middleware
|
||||
// can be used to trace requests based on a request ID header or parameter.
|
||||
type Middleware struct {
|
||||
o options
|
||||
|
||||
seededRand *rand.Rand
|
||||
}
|
||||
|
||||
// New returns a new CORS middleware with the provided options.
|
||||
func New(opts ...Option) *Middleware {
|
||||
o := newOptions(opts...)
|
||||
|
||||
m := &Middleware{
|
||||
o: o,
|
||||
seededRand: rand.New(rand.NewSource(o.generateSeed)),
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Handler returns the tracing middleware handler.
|
||||
func (m *Middleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var reqId string
|
||||
|
||||
// read upstream header und re-use it
|
||||
if m.o.upstreamReqIdHeader != "" {
|
||||
reqId = r.Header.Get(m.o.upstreamReqIdHeader)
|
||||
}
|
||||
|
||||
// generate new id
|
||||
if reqId == "" && m.o.generateLength > 0 {
|
||||
reqId = m.generateRandomId()
|
||||
}
|
||||
|
||||
// set response header
|
||||
if m.o.headerIdentifier != "" {
|
||||
w.Header().Set(m.o.headerIdentifier, reqId)
|
||||
}
|
||||
|
||||
// set context value
|
||||
if m.o.contextIdentifier != "" {
|
||||
ctx := context.WithValue(r.Context(), m.o.contextIdentifier, reqId)
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r) // execute the next handler
|
||||
})
|
||||
}
|
||||
|
||||
// region internal-helpers
|
||||
|
||||
func (m *Middleware) generateRandomId() string {
|
||||
b := make([]byte, m.o.generateLength)
|
||||
for i := range b {
|
||||
b[i] = m.o.generateCharset[m.seededRand.Intn(len(m.o.generateCharset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// endregion internal-helpers
|
118
internal/app/api/core/middleware/tracing/middleware_test.go
Normal file
118
internal/app/api/core/middleware/tracing/middleware_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const defaultLength = 8
|
||||
const upstreamHeaderValue = "upstream-id"
|
||||
|
||||
func TestMiddleware_Handler_WithUpstreamHeader(t *testing.T) {
|
||||
m := New(WithUpstreamHeader("X-Upstream-Id"))
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqId := r.Header.Get("X-Upstream-Id")
|
||||
if reqId != upstreamHeaderValue {
|
||||
t.Errorf("expected upstream request id to be 'upstream-id', got %s", reqId)
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Upstream-Id", upstreamHeaderValue)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Header().Get("X-Request-Id") != upstreamHeaderValue {
|
||||
t.Errorf("expected X-Request-Id header to be set in the response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_GenerateNewId(t *testing.T) {
|
||||
idLen := 18
|
||||
m := New(WithIdLength(idLen))
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqId := w.Header().Get("X-Request-Id")
|
||||
if len(reqId) != 18 {
|
||||
t.Errorf("expected generated request id length to be %d, got %d", idLen, len(reqId))
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Header().Get("X-Request-Id") == "" || len(rr.Header().Get("X-Request-Id")) != idLen {
|
||||
t.Errorf("expected X-Request-Id header to be set in the response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_SetContextValue(t *testing.T) {
|
||||
m := New()
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqId := r.Context().Value("RequestId").(string)
|
||||
if reqId == "" || len(reqId) != defaultLength {
|
||||
t.Errorf("expected context request id to be set, got empty string")
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_SetCustomContextValue(t *testing.T) {
|
||||
m := New(WithContextIdentifier("Custom-Id"))
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqId := r.Context().Value("Custom-Id").(string)
|
||||
if reqId == "" || len(reqId) != defaultLength {
|
||||
t.Errorf("expected context request id to be set, got empty string")
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_NoIdGenerated(t *testing.T) {
|
||||
m := New(WithIdLength(0))
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqId := w.Header().Get("X-Request-Id")
|
||||
if reqId != "" {
|
||||
t.Errorf("expected no request id to be generated, got %s", reqId)
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_NoIdHeaderSet(t *testing.T) {
|
||||
m := New(WithHeaderIdentifier(""))
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqId := w.Header().Get("X-Request-Id")
|
||||
if reqId != "" {
|
||||
t.Errorf("expected no request id to be generated, got %s", reqId)
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_NoIdContextSet(t *testing.T) {
|
||||
m := New(WithHeaderIdentifier(""))
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqId := r.Context().Value("Request-Id")
|
||||
if reqId != nil {
|
||||
t.Errorf("expected no context request id to be set, got %v", reqId)
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
}
|
85
internal/app/api/core/middleware/tracing/options.go
Normal file
85
internal/app/api/core/middleware/tracing/options.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package tracing
|
||||
|
||||
import "time"
|
||||
|
||||
// options is a struct that contains options for the tracing middleware.
|
||||
// It uses the functional options pattern for flexible configuration.
|
||||
type options struct {
|
||||
upstreamReqIdHeader string
|
||||
headerIdentifier string
|
||||
contextIdentifier string
|
||||
generateLength int
|
||||
generateCharset string
|
||||
generateSeed int64
|
||||
}
|
||||
|
||||
// Option is a type that is used to set options for the tracing middleware.
|
||||
// It implements the functional options pattern.
|
||||
type Option func(*options)
|
||||
|
||||
// WithIdSeed sets the seed for the random request id.
|
||||
// If no seed is provided, the current timestamp is used.
|
||||
func WithIdSeed(seed int64) Option {
|
||||
return func(o *options) {
|
||||
o.generateSeed = seed
|
||||
}
|
||||
}
|
||||
|
||||
// WithIdCharset sets the charset that is used to generate a random request id.
|
||||
// By default, upper-case letters and numbers are used.
|
||||
func WithIdCharset(charset string) Option {
|
||||
return func(o *options) {
|
||||
o.generateCharset = charset
|
||||
}
|
||||
}
|
||||
|
||||
// WithIdLength specifies the length of generated random ids.
|
||||
// By default, a length of 8 is used. If the length is 0, no request id will be generated.
|
||||
func WithIdLength(len int) Option {
|
||||
return func(o *options) {
|
||||
o.generateLength = len
|
||||
}
|
||||
}
|
||||
|
||||
// WithHeaderIdentifier specifies the header name for the request id that is added to the response headers.
|
||||
// If the identifier is empty, the request id will not be added to the response headers.
|
||||
func WithHeaderIdentifier(identifier string) Option {
|
||||
return func(o *options) {
|
||||
o.headerIdentifier = identifier
|
||||
}
|
||||
}
|
||||
|
||||
// WithUpstreamHeader sets the upstream header name, that should be used to fetch the request id.
|
||||
// If no upstream header is found, a random id will be generated if the id-length parameter is set to a value > 0.
|
||||
func WithUpstreamHeader(header string) Option {
|
||||
return func(o *options) {
|
||||
o.upstreamReqIdHeader = header
|
||||
}
|
||||
}
|
||||
|
||||
// WithContextIdentifier specifies the value-key for the request id that is added to the request context.
|
||||
// If the identifier is empty, the request id will not be added to the context.
|
||||
// If the request id is added to the context, it can be retrieved with:
|
||||
// `id := r.Context().Value(THE-IDENTIFIER).(string)`
|
||||
func WithContextIdentifier(identifier string) Option {
|
||||
return func(o *options) {
|
||||
o.contextIdentifier = identifier
|
||||
}
|
||||
}
|
||||
|
||||
// newOptions is a function that returns a new options struct with sane default values.
|
||||
func newOptions(opts ...Option) options {
|
||||
o := options{
|
||||
headerIdentifier: "X-Request-Id",
|
||||
contextIdentifier: "RequestId",
|
||||
generateSeed: time.Now().UnixNano(),
|
||||
generateCharset: "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789",
|
||||
generateLength: 8,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&o)
|
||||
}
|
||||
|
||||
return o
|
||||
}
|
75
internal/app/api/core/middleware/tracing/options_test.go
Normal file
75
internal/app/api/core/middleware/tracing/options_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWithIdSeed(t *testing.T) {
|
||||
o := newOptions(WithIdSeed(12345))
|
||||
if o.generateSeed != 12345 {
|
||||
t.Errorf("expected generateSeed to be 12345, got %d", o.generateSeed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithIdCharset(t *testing.T) {
|
||||
o := newOptions(WithIdCharset("abc123"))
|
||||
if o.generateCharset != "abc123" {
|
||||
t.Errorf("expected generateCharset to be 'abc123', got %s", o.generateCharset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithIdLength(t *testing.T) {
|
||||
o := newOptions(WithIdLength(16))
|
||||
if o.generateLength != 16 {
|
||||
t.Errorf("expected generateLength to be 16, got %d", o.generateLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithHeaderIdentifier(t *testing.T) {
|
||||
o := newOptions(WithHeaderIdentifier("X-Custom-Id"))
|
||||
if o.headerIdentifier != "X-Custom-Id" {
|
||||
t.Errorf("expected headerIdentifier to be 'X-Custom-Id', got %s", o.headerIdentifier)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithUpstreamHeader(t *testing.T) {
|
||||
o := newOptions(WithUpstreamHeader("X-Upstream-Id"))
|
||||
if o.upstreamReqIdHeader != "X-Upstream-Id" {
|
||||
t.Errorf("expected upstreamReqIdHeader to be 'X-Upstream-Id', got %s", o.upstreamReqIdHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithContextIdentifier(t *testing.T) {
|
||||
o := newOptions(WithContextIdentifier("Request-Id"))
|
||||
if o.contextIdentifier != "Request-Id" {
|
||||
t.Errorf("expected contextIdentifier to be 'Request-Id', got %s", o.contextIdentifier)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaults(t *testing.T) {
|
||||
o := newOptions()
|
||||
|
||||
if o.generateLength != 8 {
|
||||
t.Errorf("expected generateLength to be 8, got %d", o.generateLength)
|
||||
}
|
||||
|
||||
if o.generateCharset != "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" {
|
||||
t.Errorf("expected generateCharset to be 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', got %s", o.generateCharset)
|
||||
}
|
||||
|
||||
if o.generateSeed == 0 {
|
||||
t.Errorf("expected generateSeed to be non-zero")
|
||||
}
|
||||
|
||||
if o.headerIdentifier != "X-Request-Id" {
|
||||
t.Errorf("expected headerIdentifier to be 'X-Request-Id', got %s", o.headerIdentifier)
|
||||
}
|
||||
|
||||
if o.upstreamReqIdHeader != "" {
|
||||
t.Errorf("expected upstreamReqIdHeader to be empty, got %s", o.upstreamReqIdHeader)
|
||||
}
|
||||
|
||||
if o.contextIdentifier != "RequestId" {
|
||||
t.Errorf("expected contextIdentifier to be 'RequestId', got %s", o.contextIdentifier)
|
||||
}
|
||||
}
|
259
internal/app/api/core/request/basic.go
Normal file
259
internal/app/api/core/request/basic.go
Normal file
@@ -0,0 +1,259 @@
|
||||
// Package request provides functions to extract parameters from the request.
|
||||
package request
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const CheckPrivateProxy = "PRIVATE"
|
||||
|
||||
// PathRaw returns the value of the named path parameter.
|
||||
func PathRaw(r *http.Request, name string) string {
|
||||
return r.PathValue(name)
|
||||
}
|
||||
|
||||
// Path returns the value of the named path parameter.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func Path(r *http.Request, name string) string {
|
||||
return strings.TrimSpace(PathRaw(r, name))
|
||||
}
|
||||
|
||||
// PathDefault returns the value of the named path parameter.
|
||||
// If the parameter is empty, it returns the default value.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func PathDefault(r *http.Request, name string, defaultValue string) string {
|
||||
value := r.PathValue(name)
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return Path(r, name)
|
||||
}
|
||||
|
||||
// QueryRaw returns the value of the named query parameter.
|
||||
func QueryRaw(r *http.Request, name string) string {
|
||||
return r.URL.Query().Get(name)
|
||||
}
|
||||
|
||||
// Query returns the value of the named query parameter.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func Query(r *http.Request, name string) string {
|
||||
return strings.TrimSpace(QueryRaw(r, name))
|
||||
}
|
||||
|
||||
// QueryDefault returns the value of the named query parameter.
|
||||
// If the parameter is empty, it returns the default value.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func QueryDefault(r *http.Request, name string, defaultValue string) string {
|
||||
if !r.URL.Query().Has(name) {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return Query(r, name)
|
||||
}
|
||||
|
||||
// QuerySlice returns the value of the named query parameter.
|
||||
// All slice values are trimmed of leading and trailing whitespace.
|
||||
func QuerySlice(r *http.Request, name string) []string {
|
||||
values, ok := r.URL.Query()[name]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]string, len(values))
|
||||
for i, value := range values {
|
||||
result[i] = strings.TrimSpace(value)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// QuerySliceDefault returns the value of the named query parameter.
|
||||
// If the parameter is empty, it returns the default value.
|
||||
// All slice values are trimmed of leading and trailing whitespace.
|
||||
func QuerySliceDefault(r *http.Request, name string, defaultValue []string) []string {
|
||||
if !r.URL.Query().Has(name) {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return QuerySlice(r, name)
|
||||
}
|
||||
|
||||
// FragmentRaw returns the value of the named fragment parameter.
|
||||
func FragmentRaw(r *http.Request) string {
|
||||
return r.URL.Fragment
|
||||
}
|
||||
|
||||
// Fragment returns the value of the named fragment parameter.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func Fragment(r *http.Request) string {
|
||||
return strings.TrimSpace(FragmentRaw(r))
|
||||
}
|
||||
|
||||
// FragmentDefault returns the value of the named fragment parameter.
|
||||
// If the parameter is empty, it returns the default value.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func FragmentDefault(r *http.Request, defaultValue string) string {
|
||||
if r.URL.Fragment == "" {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return Fragment(r)
|
||||
}
|
||||
|
||||
// FormRaw returns the value of the named form parameter.
|
||||
func FormRaw(r *http.Request, name string) string {
|
||||
return r.FormValue(name)
|
||||
}
|
||||
|
||||
// Form returns the value of the named form parameter.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func Form(r *http.Request, name string) string {
|
||||
return strings.TrimSpace(FormRaw(r, name))
|
||||
}
|
||||
|
||||
// DefaultForm returns the value of the named form parameter.
|
||||
// If the parameter is not set, it returns the default value.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func DefaultForm(r *http.Request, name, defaultValue string) string {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
if !r.Form.Has(name) {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return Form(r, name)
|
||||
}
|
||||
|
||||
// HeaderRaw returns the value of the named header.
|
||||
func HeaderRaw(r *http.Request, name string) string {
|
||||
return r.Header.Get(name)
|
||||
}
|
||||
|
||||
// Header returns the value of the named header.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func Header(r *http.Request, name string) string {
|
||||
return strings.TrimSpace(HeaderRaw(r, name))
|
||||
}
|
||||
|
||||
// HeaderDefault returns the value of the named header.
|
||||
// If the header is not set, it returns the default value.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func HeaderDefault(r *http.Request, name, defaultValue string) string {
|
||||
if _, ok := textproto.MIMEHeader(r.Header)[name]; !ok {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return Header(r, name)
|
||||
}
|
||||
|
||||
// Cookie returns the value of the named cookie.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func Cookie(r *http.Request, name string) string {
|
||||
cookie, err := r.Cookie(name)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.TrimSpace(cookie.Value)
|
||||
}
|
||||
|
||||
// CookieDefault returns the value of the named cookie.
|
||||
// If the cookie is not set, it returns the default value.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func CookieDefault(r *http.Request, name, defaultValue string) string {
|
||||
cookie, err := r.Cookie(name)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return strings.TrimSpace(cookie.Value)
|
||||
}
|
||||
|
||||
// ClientIp returns the client IP address.
|
||||
//
|
||||
// As the request may come from a proxy, the function checks the
|
||||
// X-Real-Ip and X-Forwarded-For headers to get the real client IP
|
||||
// if the request IP matches one of the allowed proxy IPs.
|
||||
// If the special proxy value CheckPrivateProxy ("PRIVATE") is passed, the function will
|
||||
// also check the header if the request IP is a private IP address.
|
||||
func ClientIp(r *http.Request, allowedProxyIp ...string) string {
|
||||
ipStr, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
|
||||
switch {
|
||||
case err != nil && strings.Contains(err.Error(), "missing port in address"):
|
||||
ipStr = strings.TrimSpace(r.RemoteAddr)
|
||||
case err != nil:
|
||||
ipStr = ""
|
||||
}
|
||||
IP := net.ParseIP(ipStr)
|
||||
if IP == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
isProxiedRequest := false
|
||||
if len(allowedProxyIp) > 0 {
|
||||
if slices.Contains(allowedProxyIp, IP.String()) {
|
||||
isProxiedRequest = true
|
||||
}
|
||||
if IP.IsPrivate() && slices.Contains(allowedProxyIp, CheckPrivateProxy) {
|
||||
isProxiedRequest = true
|
||||
}
|
||||
}
|
||||
|
||||
if isProxiedRequest {
|
||||
realClientIP := r.Header.Get("X-Real-Ip")
|
||||
if realClientIP == "" {
|
||||
realClientIP = r.Header.Get("X-Forwarded-For")
|
||||
}
|
||||
if realClientIP != "" {
|
||||
realIpStr, _, err := net.SplitHostPort(strings.TrimSpace(realClientIP))
|
||||
switch {
|
||||
case err != nil && strings.Contains(err.Error(), "missing port in address"):
|
||||
realIpStr = realClientIP
|
||||
case err != nil:
|
||||
realIpStr = ipStr
|
||||
}
|
||||
realIP := net.ParseIP(realIpStr)
|
||||
if realIP == nil {
|
||||
return IP.String()
|
||||
}
|
||||
return realIP.String()
|
||||
}
|
||||
}
|
||||
|
||||
return IP.String()
|
||||
}
|
||||
|
||||
// BodyJson decodes the JSON value from the request body into the target.
|
||||
// The target must be a pointer to a struct or slice.
|
||||
// The function returns an error if the JSON value could not be decoded.
|
||||
// The body reader is closed after reading.
|
||||
func BodyJson(r *http.Request, target any) error {
|
||||
defer func() {
|
||||
_ = r.Body.Close()
|
||||
}()
|
||||
return json.NewDecoder(r.Body).Decode(target)
|
||||
}
|
||||
|
||||
// BodyString returns the request body as a string.
|
||||
// The content is read and returned as is, without any processing.
|
||||
// The body is assumed to be UTF-8 encoded.
|
||||
func BodyString(r *http.Request) (string, error) {
|
||||
defer func() {
|
||||
_ = r.Body.Close()
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(bodyBytes), nil
|
||||
}
|
221
internal/app/api/core/request/basic_test.go
Normal file
221
internal/app/api/core/request/basic_test.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package request
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPath(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{Path: "/test/sample"}}
|
||||
r.SetPathValue("first", "test")
|
||||
if got := Path(r, "first"); got != "test" {
|
||||
t.Errorf("Path() = %v, want %v", got, "test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultPath(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{Path: "/"}}
|
||||
if got := PathDefault(r, "test", "default"); got != "default" {
|
||||
t.Errorf("PathDefault() = %v, want %v", got, "default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultPath_noDefault(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{Path: "/"}}
|
||||
r.SetPathValue("first", "test")
|
||||
if got := PathDefault(r, "first", "test"); got != "test" {
|
||||
t.Errorf("PathDefault() = %v, want %v", got, "test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuery(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{RawQuery: "name=value"}}
|
||||
if got := Query(r, "name"); got != "value" {
|
||||
t.Errorf("Query() = %v, want %v", got, "value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultQuery(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{RawQuery: ""}}
|
||||
if got := QueryDefault(r, "name", "default"); got != "default" {
|
||||
t.Errorf("QueryDefault() = %v, want %v", got, "default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuerySlice(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{RawQuery: "name=value1 &name=value2"}}
|
||||
expected := []string{"value1", "value2"}
|
||||
if got := QuerySlice(r, "name"); !slices.Equal(got, expected) {
|
||||
t.Errorf("QuerySlice() = %v, want %v", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuerySlice_empty(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{RawQuery: "name=value1&name=value2"}}
|
||||
if got := QuerySlice(r, "nix"); !slices.Equal(got, nil) {
|
||||
t.Errorf("QuerySlice() = %v, want %v", got, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultQuerySlice(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{RawQuery: ""}}
|
||||
defaultValue := []string{"default1", "default2"}
|
||||
if got := QuerySliceDefault(r, "name", defaultValue); !slices.Equal(got, defaultValue) {
|
||||
t.Errorf("QuerySliceDefault() = %v, want %v", got, defaultValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFragment(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{Fragment: "section"}}
|
||||
if got := Fragment(r); got != "section" {
|
||||
t.Errorf("Fragment() = %v, want %v", got, "section")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultFragment(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{Fragment: ""}}
|
||||
if got := FragmentDefault(r, "default"); got != "default" {
|
||||
t.Errorf("FragmentDefault() = %v, want %v", got, "default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForm(t *testing.T) {
|
||||
r := &http.Request{Form: url.Values{"name": {"value"}}}
|
||||
if got := Form(r, "name"); got != "value" {
|
||||
t.Errorf("Form() = %v, want %v", got, "value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultForm(t *testing.T) {
|
||||
r := &http.Request{Form: url.Values{}}
|
||||
if got := DefaultForm(r, "name", "default"); got != "default" {
|
||||
t.Errorf("DefaultForm() = %v, want %v", got, "default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeader(t *testing.T) {
|
||||
r := &http.Request{Header: http.Header{"X-Test-Header": {"value"}}}
|
||||
if got := Header(r, "X-Test-Header"); got != "value" {
|
||||
t.Errorf("Header() = %v, want %v", got, "value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultHeader(t *testing.T) {
|
||||
r := &http.Request{Header: http.Header{}}
|
||||
if got := HeaderDefault(r, "X-Test-Header", "default"); got != "default" {
|
||||
t.Errorf("HeaderDefault() = %v, want %v", got, "default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCookie(t *testing.T) {
|
||||
r := &http.Request{Header: http.Header{"Cookie": {"name=value"}}}
|
||||
if got := Cookie(r, "name"); got != "value" {
|
||||
t.Errorf("Cookie() = %v, want %v", got, "value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCookie(t *testing.T) {
|
||||
r := &http.Request{Header: http.Header{}}
|
||||
if got := CookieDefault(r, "name", "default"); got != "default" {
|
||||
t.Errorf("CookieDefault() = %v, want %v", got, "default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIp(t *testing.T) {
|
||||
r := &http.Request{RemoteAddr: "192.168.1.1:12345"}
|
||||
if got := ClientIp(r); got != "192.168.1.1" {
|
||||
t.Errorf("ClientIp() = %v, want %v", got, "192.168.1.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIp_invalid(t *testing.T) {
|
||||
r := &http.Request{RemoteAddr: "was_isn_des"}
|
||||
if got := ClientIp(r); got != "" {
|
||||
t.Errorf("ClientIp() = %v, want %v", got, "")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIp_ignore_header(t *testing.T) {
|
||||
r := &http.Request{RemoteAddr: "192.168.1.1:12345", Header: http.Header{"X-Forwarded-For": {"123.45.67.1"}}}
|
||||
if got := ClientIp(r); got != "192.168.1.1" {
|
||||
t.Errorf("ClientIp() = %v, want %v", got, "192.168.1.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIp_header1(t *testing.T) {
|
||||
r := &http.Request{RemoteAddr: "192.168.1.1:12345", Header: http.Header{"X-Forwarded-For": {"123.45.67.1"}}}
|
||||
if got := ClientIp(r, CheckPrivateProxy); got != "123.45.67.1" {
|
||||
t.Errorf("ClientIp() = %v, want %v", got, "123.45.67.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIp_header2(t *testing.T) {
|
||||
r := &http.Request{RemoteAddr: "192.168.1.1:12345", Header: http.Header{"X-Real-Ip": {"123.45.67.1"}}}
|
||||
if got := ClientIp(r, CheckPrivateProxy); got != "123.45.67.1" {
|
||||
t.Errorf("ClientIp() = %v, want %v", got, "123.45.67.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIp_header3(t *testing.T) {
|
||||
r := &http.Request{RemoteAddr: "1.1.1.1:12345", Header: http.Header{"X-Real-Ip": {"123.45.67.1"}}}
|
||||
if got := ClientIp(r, "1.1.1.1"); got != "123.45.67.1" {
|
||||
t.Errorf("ClientIp() = %v, want %v", got, "123.45.67.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIp_header4(t *testing.T) {
|
||||
r := &http.Request{RemoteAddr: "8.8.8.8:12345", Header: http.Header{"X-Real-Ip": {"123.45.67.1"}}}
|
||||
if got := ClientIp(r, "1.1.1.1"); got != "8.8.8.8" {
|
||||
t.Errorf("ClientIp() = %v, want %v", got, "8.8.8.8")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIp_header_invalid(t *testing.T) {
|
||||
r := &http.Request{RemoteAddr: "1.1.1.1:12345", Header: http.Header{"X-Real-Ip": {"so-sicher-nit"}}}
|
||||
if got := ClientIp(r, "1.1.1.1"); got != "1.1.1.1" {
|
||||
t.Errorf("ClientIp() = %v, want %v", got, "1.1.1.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBodyJson(t *testing.T) {
|
||||
type TestStruct struct {
|
||||
Name string `json:"name"`
|
||||
Value int `json:"value"`
|
||||
}
|
||||
|
||||
jsonStr := `{"name": "test", "value": 123}`
|
||||
r := &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader(jsonStr)),
|
||||
}
|
||||
|
||||
var result TestStruct
|
||||
err := BodyJson(r, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("BodyJson() error = %v", err)
|
||||
}
|
||||
|
||||
expected := TestStruct{Name: "test", Value: 123}
|
||||
if result != expected {
|
||||
t.Errorf("BodyJson() = %v, want %v", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBodyString(t *testing.T) {
|
||||
bodyStr := "test body content"
|
||||
r := &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader(bodyStr)),
|
||||
}
|
||||
|
||||
result, err := BodyString(r)
|
||||
if err != nil {
|
||||
t.Fatalf("BodyString() error = %v", err)
|
||||
}
|
||||
|
||||
if result != bodyStr {
|
||||
t.Errorf("BodyString() = %v, want %v", result, bodyStr)
|
||||
}
|
||||
}
|
100
internal/app/api/core/respond/basic.go
Normal file
100
internal/app/api/core/respond/basic.go
Normal file
@@ -0,0 +1,100 @@
|
||||
// Package respond provides a set of utility functions to help with the HTTP response handling.
|
||||
package respond
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Status writes a response with the given status code.
|
||||
// The response will not contain any data.
|
||||
func Status(w http.ResponseWriter, code int) {
|
||||
w.WriteHeader(code)
|
||||
}
|
||||
|
||||
// String writes a plain text response with the given status code and data.
|
||||
// The Content-Type header is set to text/plain with a charset of utf-8.
|
||||
func String(w http.ResponseWriter, code int, data string) {
|
||||
w.Header().Set("Content-Type", "text/plain;charset=utf-8")
|
||||
w.WriteHeader(code)
|
||||
|
||||
_, _ = w.Write([]byte(data))
|
||||
}
|
||||
|
||||
// JSON writes a JSON response with the given status code and data.
|
||||
// If data is nil, the response will null. The status code is set to the given code.
|
||||
// The Content-Type header is set to application/json.
|
||||
// If the given data is not JSON serializable, the response will not contain any data.
|
||||
// All encoding errors are silently ignored.
|
||||
func JSON(w http.ResponseWriter, code int, data any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// if no data was given, simply return null
|
||||
if data == nil {
|
||||
w.WriteHeader(code)
|
||||
_, _ = w.Write([]byte("null"))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
|
||||
_ = json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
// Data writes a response with the given status code, content type, and data.
|
||||
// If no content type is provided, it is detected from the data.
|
||||
func Data(w http.ResponseWriter, code int, contentType string, data []byte) {
|
||||
if contentType == "" {
|
||||
contentType = http.DetectContentType(data) // ensure content type is set
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
|
||||
w.WriteHeader(code)
|
||||
|
||||
_, _ = w.Write(data)
|
||||
}
|
||||
|
||||
// Reader writes a response with the given status code, content type, and data.
|
||||
// The content length is optional, it is only set if the given length is greater than 0.
|
||||
func Reader(w http.ResponseWriter, code int, contentType string, contentLength int, data io.Reader) {
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
if contentLength > 0 {
|
||||
w.Header().Set("Content-Length", strconv.Itoa(contentLength))
|
||||
}
|
||||
w.WriteHeader(code)
|
||||
|
||||
_, _ = io.Copy(w, data)
|
||||
}
|
||||
|
||||
// Attachment writes a response with the given status code, content type, filename, and data.
|
||||
// If no content type is provided, it is detected from the data.
|
||||
func Attachment(w http.ResponseWriter, code int, filename, contentType string, data []byte) {
|
||||
w.Header().Set("Content-Disposition", "attachment; filename="+filename)
|
||||
|
||||
Data(w, code, contentType, data)
|
||||
}
|
||||
|
||||
// AttachmentReader writes a response with the given status code, content type, filename, content length, and data.
|
||||
// The content length is optional, it is only set if the given length is greater than 0.
|
||||
func AttachmentReader(
|
||||
w http.ResponseWriter,
|
||||
code int,
|
||||
filename, contentType string,
|
||||
contentLength int,
|
||||
data io.Reader,
|
||||
) {
|
||||
w.Header().Set("Content-Disposition", "attachment; filename="+filename)
|
||||
|
||||
Reader(w, code, contentType, contentLength, data)
|
||||
}
|
||||
|
||||
// Redirect writes a response with the given status code and redirects to the given URL.
|
||||
// The redirect url will always be an absolute URL. If the given URL is relative,
|
||||
// the original request URL is used as the base.
|
||||
func Redirect(w http.ResponseWriter, r *http.Request, code int, url string) {
|
||||
http.Redirect(w, r, url, code)
|
||||
}
|
273
internal/app/api/core/respond/basic_test.go
Normal file
273
internal/app/api/core/respond/basic_test.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package respond
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStatus(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
Status(rec, http.StatusNoContent)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
t.Errorf("expected status %d, got %d", http.StatusNoContent, res.StatusCode)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if len(body) != 0 {
|
||||
t.Errorf("expected no body, got %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestString(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
String(rec, http.StatusOK, "Hello, World!")
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain;charset=utf-8" {
|
||||
t.Errorf("expected content type %s, got %s", "text/plain;charset=utf-8", contentType)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if string(body) != "Hello, World!" {
|
||||
t.Errorf("expected body %s, got %s", "Hello, World!", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSON(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
data := map[string]string{"hello": "world"}
|
||||
JSON(rec, http.StatusOK, data)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentType := res.Header.Get("Content-Type"); contentType != "application/json" {
|
||||
t.Errorf("expected content type %s, got %s", "application/json", contentType)
|
||||
}
|
||||
|
||||
var body map[string]string
|
||||
_ = json.NewDecoder(res.Body).Decode(&body)
|
||||
if body["hello"] != "world" {
|
||||
t.Errorf("expected body %v, got %v", data, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSON_empty(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
JSON(rec, http.StatusOK, nil)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentType := res.Header.Get("Content-Type"); contentType != "application/json" {
|
||||
t.Errorf("expected content type %s, got %s", "application/json", contentType)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if string(body) != "null" {
|
||||
t.Errorf("expected body %s, got %s", "null", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestData(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
data := []byte("Hello, World!")
|
||||
Data(rec, http.StatusOK, "text/plain", data)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" {
|
||||
t.Errorf("expected content type %s, got %s", "text/plain", contentType)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if !bytes.Equal(body, data) {
|
||||
t.Errorf("expected body %s, got %s", data, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestData_noContentType(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
data := []byte{0x1, 0x2, 0x3, 0x4, 0x5}
|
||||
Data(rec, http.StatusOK, "", data)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentType := res.Header.Get("Content-Type"); contentType != "application/octet-stream" {
|
||||
t.Errorf("expected content type %s, got %s", "application/octet-stream", contentType)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if !bytes.Equal(body, data) {
|
||||
t.Errorf("expected body %s, got %s", data, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
data := []byte("Hello, World!")
|
||||
reader := bytes.NewBufferString(string(data))
|
||||
Reader(rec, http.StatusOK, "text/plain", len(data), reader)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" {
|
||||
t.Errorf("expected content type %s, got %s", "text/plain", contentType)
|
||||
}
|
||||
|
||||
if contentLength := res.Header.Get("Content-Length"); contentLength != strconv.Itoa(len(data)) {
|
||||
t.Errorf("expected content length %d, got %s", len(data), contentLength)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if string(body) != "Hello, World!" {
|
||||
t.Errorf("expected body %s, got %s", "Hello, World!", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_unknownLength(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
data := bytes.NewBufferString("Hello, World!")
|
||||
Reader(rec, http.StatusOK, "text/plain", 0, data)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" {
|
||||
t.Errorf("expected content type %s, got %s", "text/plain", contentType)
|
||||
}
|
||||
|
||||
if contentLength := res.Header.Get("Content-Length"); contentLength != "" {
|
||||
t.Errorf("expected no content length, got %s", contentLength)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if string(body) != "Hello, World!" {
|
||||
t.Errorf("expected body %s, got %s", "Hello, World!", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttachment(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
data := []byte("Hello, World!")
|
||||
Attachment(rec, http.StatusOK, "example.txt", "text/plain", data)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentDisposition := res.Header.Get("Content-Disposition"); contentDisposition != "attachment; filename=example.txt" {
|
||||
t.Errorf("expected content disposition %s, got %s", "attachment; filename=example.txt", contentDisposition)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if !bytes.Equal(body, data) {
|
||||
t.Errorf("expected body %s, got %s", data, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttachmentReader(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
data := bytes.NewBufferString("Hello, World!")
|
||||
AttachmentReader(rec, http.StatusOK, "example.txt", "text/plain", data.Len(), data)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentDisposition := res.Header.Get("Content-Disposition"); contentDisposition != "attachment; filename=example.txt" {
|
||||
t.Errorf("expected content disposition %s, got %s", "attachment; filename=example.txt", contentDisposition)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if string(body) != "Hello, World!" {
|
||||
t.Errorf("expected body %s, got %s", "Hello, World!", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirect(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/old", nil)
|
||||
url := "http://example.com/new"
|
||||
|
||||
Redirect(rec, req, http.StatusMovedPermanently, url)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusMovedPermanently {
|
||||
t.Errorf("expected status %d, got %d", http.StatusMovedPermanently, res.StatusCode)
|
||||
}
|
||||
|
||||
if location := res.Header.Get("Location"); location != url {
|
||||
t.Errorf("expected location %s, got %s", url, location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirect_relative(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/old/dir", nil)
|
||||
url := "newlocation/sub"
|
||||
want := "/old/newlocation/sub"
|
||||
|
||||
Redirect(rec, req, http.StatusMovedPermanently, url)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusMovedPermanently {
|
||||
t.Errorf("expected status %d, got %d", http.StatusMovedPermanently, res.StatusCode)
|
||||
}
|
||||
|
||||
if location := res.Header.Get("Location"); location != want {
|
||||
t.Errorf("expected location %s, got %s", want, location)
|
||||
}
|
||||
}
|
46
internal/app/api/core/respond/template.go
Normal file
46
internal/app/api/core/respond/template.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package respond
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// TplData is a map of template data. This is a convenience type for passing data to templates.
|
||||
type TplData map[string]any
|
||||
|
||||
// TemplateInstance is an interface that wraps the ExecuteTemplate method.
|
||||
// It is implemented by the html/template and text/template packages.
|
||||
type TemplateInstance interface {
|
||||
// ExecuteTemplate executes a template with the given name and data.
|
||||
ExecuteTemplate(wr io.Writer, name string, data any) error
|
||||
}
|
||||
|
||||
// TemplateRenderer is a renderer that uses a template instance to render HTML or Text templates.
|
||||
type TemplateRenderer struct {
|
||||
t TemplateInstance
|
||||
}
|
||||
|
||||
// NewTemplateRenderer creates a new HTML or Text template renderer with the given template instance.
|
||||
func NewTemplateRenderer(t TemplateInstance) *TemplateRenderer {
|
||||
return &TemplateRenderer{t: t}
|
||||
}
|
||||
|
||||
// Render renders a template with the given name and data.
|
||||
// If rendering fails, it will panic with an error.
|
||||
func (r *TemplateRenderer) Render(w http.ResponseWriter, code int, name, contentType string, data any) {
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
w.WriteHeader(code)
|
||||
|
||||
err := r.t.ExecuteTemplate(w, name, data)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("error rendering template %s: %v", name, err))
|
||||
}
|
||||
}
|
||||
|
||||
// HTML renders a template with the given name and data. It is a convenience method for Render.
|
||||
// The content type is set to "text/html" and the encoding to "utf-8".
|
||||
// If rendering fails, it will panic with an error.
|
||||
func (r *TemplateRenderer) HTML(w http.ResponseWriter, code int, name string, data any) {
|
||||
r.Render(w, code, name, "text/html;charset=utf-8", data)
|
||||
}
|
67
internal/app/api/core/respond/template_test.go
Normal file
67
internal/app/api/core/respond/template_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package respond
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockTemplate struct {
|
||||
tmpl *template.Template
|
||||
}
|
||||
|
||||
func (m *mockTemplate) ExecuteTemplate(wr io.Writer, name string, data any) error {
|
||||
return m.tmpl.ExecuteTemplate(wr, name, data)
|
||||
}
|
||||
|
||||
func TestTemplateRenderer_Render(t *testing.T) {
|
||||
tmpl := template.Must(template.New("test").Parse(`{{define "test"}}Hello, {{.}}!{{end}}`))
|
||||
renderer := NewTemplateRenderer(&mockTemplate{tmpl: tmpl})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
renderer.Render(rec, http.StatusOK, "test", "text/plain", "World")
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" {
|
||||
t.Errorf("expected content type %s, got %s", "text/plain", contentType)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
expectedBody := "Hello, World!"
|
||||
if string(body) != expectedBody {
|
||||
t.Errorf("expected body %s, got %s", expectedBody, string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateRenderer_HTML(t *testing.T) {
|
||||
tmpl := template.Must(template.New("test").Parse(`{{define "test"}}<p>Hello, {{.}}!</p>{{end}}`))
|
||||
renderer := NewTemplateRenderer(&mockTemplate{tmpl: tmpl})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
renderer.HTML(rec, http.StatusOK, "test", "World")
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentType := res.Header.Get("Content-Type"); contentType != "text/html;charset=utf-8" {
|
||||
t.Errorf("expected content type %s, got %s", "text/html;charset=utf-8", contentType)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
expectedBody := "<p>Hello, World!</p>"
|
||||
if string(body) != expectedBody {
|
||||
t.Errorf("expected body %s, got %s", expectedBody, string(body))
|
||||
}
|
||||
}
|
@@ -2,27 +2,25 @@ package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-pkgz/routegroup"
|
||||
|
||||
"github.com/h44z/wg-portal/internal"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/middleware/cors"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/middleware/logging"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/middleware/recovery"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/middleware/tracing"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/respond"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
)
|
||||
|
||||
var (
|
||||
random = rand.New(rand.NewSource(time.Now().UTC().UnixNano()))
|
||||
)
|
||||
|
||||
const (
|
||||
RequestIDKey = "X-Request-ID"
|
||||
)
|
||||
@@ -30,19 +28,21 @@ const (
|
||||
type ApiVersion string
|
||||
type HandlerName string
|
||||
|
||||
type GroupSetupFn func(group *gin.RouterGroup)
|
||||
type GroupSetupFn func(group *routegroup.Bundle)
|
||||
|
||||
type ApiEndpointSetupFunc func() (ApiVersion, GroupSetupFn)
|
||||
|
||||
type Server struct {
|
||||
cfg *config.Config
|
||||
server *gin.Engine
|
||||
versions map[ApiVersion]*gin.RouterGroup
|
||||
server *routegroup.Bundle
|
||||
tpl *respond.TemplateRenderer
|
||||
versions map[ApiVersion]*routegroup.Bundle
|
||||
}
|
||||
|
||||
func NewServer(cfg *config.Config, endpoints ...ApiEndpointSetupFunc) (*Server, error) {
|
||||
s := &Server{
|
||||
cfg: cfg,
|
||||
cfg: cfg,
|
||||
server: routegroup.New(http.NewServeMux()),
|
||||
}
|
||||
|
||||
hostname, err := os.Hostname()
|
||||
@@ -51,69 +51,39 @@ func NewServer(cfg *config.Config, endpoints ...ApiEndpointSetupFunc) (*Server,
|
||||
}
|
||||
hostname += ", version " + internal.Version
|
||||
|
||||
// Setup http server
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
gin.DefaultWriter = io.Discard
|
||||
s.server = gin.New()
|
||||
|
||||
s.server.Use(recovery.New().Handler)
|
||||
if cfg.Web.RequestLogging {
|
||||
if cfg.Advanced.LogLevel == "trace" {
|
||||
gin.SetMode(gin.DebugMode)
|
||||
}
|
||||
s.server.Use(func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
raw := c.Request.URL.RawQuery
|
||||
s.server.Use(logging.New(logging.WithLevel(logging.LogLevelDebug)).Handler)
|
||||
|
||||
c.Next()
|
||||
|
||||
if raw != "" {
|
||||
path = path + "?" + raw
|
||||
}
|
||||
|
||||
latency := time.Since(start)
|
||||
status := c.Writer.Status()
|
||||
clientIP := c.ClientIP()
|
||||
method := c.Request.Method
|
||||
errorMsg := c.Errors.ByType(gin.ErrorTypePrivate).String()
|
||||
|
||||
slog.Debug("HTTP Request",
|
||||
"status", status,
|
||||
"latency", latency,
|
||||
"client", clientIP,
|
||||
"method", method,
|
||||
"path", path,
|
||||
"error", errorMsg,
|
||||
)
|
||||
}
|
||||
s.server.Use(cors.New().Handler)
|
||||
s.server.Use(tracing.New(
|
||||
tracing.WithContextIdentifier(RequestIDKey),
|
||||
tracing.WithHeaderIdentifier(RequestIDKey),
|
||||
).Handler)
|
||||
if cfg.Web.ExposeHostInfo {
|
||||
s.server.Use(func(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Served-By", hostname)
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
s.server.Use(gin.Recovery()).Use(func(c *gin.Context) {
|
||||
c.Writer.Header().Set("X-Served-By", hostname)
|
||||
c.Next()
|
||||
}).Use(func(c *gin.Context) {
|
||||
xRequestID := uuid(16)
|
||||
|
||||
c.Request.Header.Set(RequestIDKey, xRequestID)
|
||||
c.Set(RequestIDKey, xRequestID)
|
||||
c.Next()
|
||||
})
|
||||
|
||||
// Setup templates
|
||||
templates := template.Must(template.New("").Funcs(s.server.FuncMap).ParseFS(apiTemplates, "assets/tpl/*.gohtml"))
|
||||
s.server.SetHTMLTemplate(templates)
|
||||
s.tpl = respond.NewTemplateRenderer(
|
||||
template.Must(template.New("").ParseFS(apiTemplates, "assets/tpl/*.gohtml")),
|
||||
)
|
||||
|
||||
// Serve static files
|
||||
imgFs := http.FS(fsMust(fs.Sub(apiStatics, "assets/img")))
|
||||
s.server.StaticFS("/css", http.FS(fsMust(fs.Sub(apiStatics, "assets/css"))))
|
||||
s.server.StaticFS("/js", http.FS(fsMust(fs.Sub(apiStatics, "assets/js"))))
|
||||
s.server.StaticFS("/img", imgFs)
|
||||
s.server.StaticFS("/fonts", http.FS(fsMust(fs.Sub(apiStatics, "assets/fonts"))))
|
||||
s.server.StaticFS("/doc", http.FS(fsMust(fs.Sub(apiStatics, "assets/doc"))))
|
||||
s.server.HandleFiles("/css", http.FS(fsMust(fs.Sub(apiStatics, "assets/css"))))
|
||||
s.server.HandleFiles("/js", http.FS(fsMust(fs.Sub(apiStatics, "assets/js"))))
|
||||
s.server.HandleFiles("/img", imgFs)
|
||||
s.server.HandleFiles("/fonts", http.FS(fsMust(fs.Sub(apiStatics, "assets/fonts"))))
|
||||
s.server.HandleFiles("/doc", http.FS(fsMust(fs.Sub(apiStatics, "assets/doc"))))
|
||||
|
||||
// Setup routes
|
||||
s.server.UseRawPath = true
|
||||
s.server.UnescapePathValues = true
|
||||
s.setupRoutes(endpoints...)
|
||||
s.setupFrontendRoutes()
|
||||
|
||||
@@ -136,9 +106,7 @@ func (s *Server) Run(ctx context.Context, listenAddress string) {
|
||||
err = srv.ListenAndServe()
|
||||
}
|
||||
if err != nil {
|
||||
slog.Info("web service exited",
|
||||
"address", listenAddress,
|
||||
"error", err)
|
||||
slog.Info("web service exited", "address", listenAddress, "error", err)
|
||||
cancelFn()
|
||||
}
|
||||
}()
|
||||
@@ -157,18 +125,18 @@ func (s *Server) Run(ctx context.Context, listenAddress string) {
|
||||
}
|
||||
|
||||
func (s *Server) setupRoutes(endpoints ...ApiEndpointSetupFunc) {
|
||||
s.server.GET("/api", s.landingPage)
|
||||
s.versions = make(map[ApiVersion]*gin.RouterGroup)
|
||||
s.server.HandleFunc("GET /api", s.landingPage)
|
||||
s.versions = make(map[ApiVersion]*routegroup.Bundle)
|
||||
|
||||
for _, setupFunc := range endpoints {
|
||||
version, groupSetupFn := setupFunc()
|
||||
|
||||
if _, ok := s.versions[version]; !ok {
|
||||
s.versions[version] = s.server.Group(fmt.Sprintf("/api/%s", version))
|
||||
s.versions[version] = s.server.Mount(fmt.Sprintf("/api/%s", version))
|
||||
|
||||
// OpenAPI documentation (via RapiDoc)
|
||||
s.versions[version].GET("/swagger/index.html", s.rapiDocHandler(version)) // Deprecated: old link
|
||||
s.versions[version].GET("/doc.html", s.rapiDocHandler(version))
|
||||
s.versions[version].HandleFunc("GET /swagger/index.html", s.rapiDocHandler(version)) // Deprecated: old link
|
||||
s.versions[version].HandleFunc("GET /doc.html", s.rapiDocHandler(version))
|
||||
|
||||
groupSetupFn(s.versions[version])
|
||||
}
|
||||
@@ -177,25 +145,27 @@ func (s *Server) setupRoutes(endpoints ...ApiEndpointSetupFunc) {
|
||||
|
||||
func (s *Server) setupFrontendRoutes() {
|
||||
// Serve static files
|
||||
s.server.GET("/", func(c *gin.Context) {
|
||||
c.Redirect(http.StatusMovedPermanently, "/app")
|
||||
s.server.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
respond.Redirect(w, r, http.StatusMovedPermanently, "/app")
|
||||
})
|
||||
s.server.GET("/favicon.ico", func(c *gin.Context) {
|
||||
c.Redirect(http.StatusMovedPermanently, "/app/favicon.ico")
|
||||
|
||||
s.server.HandleFunc("/favicon.ico", func(w http.ResponseWriter, r *http.Request) {
|
||||
respond.Redirect(w, r, http.StatusMovedPermanently, "/app/favicon.ico")
|
||||
})
|
||||
s.server.StaticFS("/app", http.FS(fsMust(fs.Sub(frontendStatics, "frontend-dist"))))
|
||||
|
||||
s.server.HandleFiles("/app", http.FS(fsMust(fs.Sub(frontendStatics, "frontend-dist"))))
|
||||
}
|
||||
|
||||
func (s *Server) landingPage(c *gin.Context) {
|
||||
c.HTML(http.StatusOK, "index.gohtml", gin.H{
|
||||
func (s *Server) landingPage(w http.ResponseWriter, _ *http.Request) {
|
||||
s.tpl.HTML(w, http.StatusOK, "index.gohtml", respond.TplData{
|
||||
"Version": internal.Version,
|
||||
"Year": time.Now().Year(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) rapiDocHandler(version ApiVersion) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.HTML(http.StatusOK, "rapidoc.gohtml", gin.H{
|
||||
func (s *Server) rapiDocHandler(version ApiVersion) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
s.tpl.HTML(w, http.StatusOK, "rapidoc.gohtml", respond.TplData{
|
||||
"RapiDocSource": "/js/rapidoc-min.js",
|
||||
"ApiSpecUrl": fmt.Sprintf("/doc/%s_swagger.yaml", version),
|
||||
"Version": internal.Version,
|
||||
@@ -210,9 +180,3 @@ func fsMust(f fs.FS, err error) fs.FS {
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func uuid(len int) string {
|
||||
bytes := make([]byte, len)
|
||||
random.Read(bytes)
|
||||
return base64.StdEncoding.EncodeToString(bytes)[:len]
|
||||
}
|
||||
|
Reference in New Issue
Block a user