chore: replace gin with standard lib net/http

This commit is contained in:
Christoph Haas 2025-03-09 21:16:42 +01:00
parent 7473132932
commit 0206952182
58 changed files with 5302 additions and 1390 deletions

View File

@ -7,6 +7,7 @@ import (
"syscall"
"time"
"github.com/go-playground/validator/v10"
evbus "github.com/vardius/message-bus"
"github.com/h44z/wg-portal/internal"
@ -101,21 +102,48 @@ func main() {
err = backend.Startup(ctx)
internal.AssertNoError(err)
apiFrontend := handlersV0.NewRestApi(cfg, backend)
validatorManager := validator.New()
// region API v0 (SPA frontend)
apiV0Session := handlersV0.NewSessionWrapper(cfg)
apiV0Auth := handlersV0.NewAuthenticationHandler(authenticator, apiV0Session)
apiV0EndpointAuth := handlersV0.NewAuthEndpoint(backend, apiV0Auth, apiV0Session, validatorManager)
apiV0EndpointUsers := handlersV0.NewUserEndpoint(backend, apiV0Auth, validatorManager)
apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(backend, apiV0Auth, validatorManager)
apiV0EndpointPeers := handlersV0.NewPeerEndpoint(backend, apiV0Auth, validatorManager)
apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth)
apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth)
apiFrontend := handlersV0.NewRestApi(apiV0Session,
apiV0EndpointAuth,
apiV0EndpointUsers,
apiV0EndpointInterfaces,
apiV0EndpointPeers,
apiV0EndpointConfig,
apiV0EndpointTest,
)
// endregion API v0 (SPA frontend)
// region API v1 (User REST API)
apiV1Auth := handlersV1.NewAuthenticationHandler(userManager)
apiV1BackendUsers := backendV1.NewUserService(cfg, userManager)
apiV1BackendPeers := backendV1.NewPeerService(cfg, wireGuardManager, userManager)
apiV1BackendInterfaces := backendV1.NewInterfaceService(cfg, wireGuardManager)
apiV1BackendProvisioning := backendV1.NewProvisioningService(cfg, userManager, wireGuardManager, cfgFileManager)
apiV1BackendMetrics := backendV1.NewMetricsService(cfg, database, userManager, wireGuardManager)
apiV1EndpointUsers := handlersV1.NewUserEndpoint(apiV1BackendUsers)
apiV1EndpointPeers := handlersV1.NewPeerEndpoint(apiV1BackendPeers)
apiV1EndpointInterfaces := handlersV1.NewInterfaceEndpoint(apiV1BackendInterfaces)
apiV1EndpointProvisioning := handlersV1.NewProvisioningEndpoint(apiV1BackendProvisioning)
apiV1EndpointMetrics := handlersV1.NewMetricsEndpoint(apiV1BackendMetrics)
apiV1EndpointUsers := handlersV1.NewUserEndpoint(apiV1Auth, validatorManager, apiV1BackendUsers)
apiV1EndpointPeers := handlersV1.NewPeerEndpoint(apiV1Auth, validatorManager, apiV1BackendPeers)
apiV1EndpointInterfaces := handlersV1.NewInterfaceEndpoint(apiV1Auth, validatorManager, apiV1BackendInterfaces)
apiV1EndpointProvisioning := handlersV1.NewProvisioningEndpoint(apiV1Auth, validatorManager,
apiV1BackendProvisioning)
apiV1EndpointMetrics := handlersV1.NewMetricsEndpoint(apiV1Auth, validatorManager, apiV1BackendMetrics)
apiV1 := handlersV1.NewRestApi(
userManager,
apiV1EndpointUsers,
apiV1EndpointPeers,
apiV1EndpointInterfaces,
@ -123,6 +151,8 @@ func main() {
apiV1EndpointMetrics,
)
// endregion API v1 (User REST API)
webSrv, err := core.NewServer(cfg, apiFrontend, apiV1)
internal.AssertNoError(err)

View File

@ -4,6 +4,7 @@ import LoginView from '../views/LoginView.vue'
import InterfaceView from '../views/InterfaceView.vue'
import {authStore} from '@/stores/auth'
import {securityStore} from '@/stores/security'
import {notify} from "@kyvg/vue3-notification";
const router = createRouter({
@ -63,6 +64,7 @@ const router = createRouter({
router.beforeEach(async (to) => {
const auth = authStore()
const sec = securityStore()
// check if the request was a successful oauth login
if ('wgLoginState' in to.query && !auth.IsAuthenticated) {
@ -112,6 +114,10 @@ router.beforeEach(async (to) => {
auth.SetReturnUrl(to.fullPath) // store original destination before starting the auth process
return '/login'
}
if (publicPages.includes(to.path)) {
await sec.LoadSecurityProperties() // make sure we have a valid csrf token
}
})
export default router

44
go.mod
View File

@ -4,26 +4,25 @@ go 1.24.0
require (
github.com/a8m/envsubst v1.4.3
github.com/alexedwards/scs/v2 v2.8.0
github.com/coreos/go-oidc/v3 v3.12.0
github.com/gin-contrib/cors v1.7.3
github.com/gin-contrib/sessions v1.0.2
github.com/gin-gonic/gin v1.10.0
github.com/glebarez/sqlite v1.11.0
github.com/go-ldap/ldap/v3 v3.4.10
github.com/go-pkgz/routegroup v1.3.1
github.com/go-playground/validator/v10 v10.25.0
github.com/google/uuid v1.6.0
github.com/prometheus-community/pro-bing v0.6.1
github.com/prometheus/client_golang v1.21.0
github.com/prometheus/client_golang v1.21.1
github.com/stretchr/testify v1.10.0
github.com/swaggo/swag v1.16.4
github.com/utrack/gin-csrf v0.0.0-20190424104817-40fb8d2c8fca
github.com/vardius/message-bus v1.1.5
github.com/vishvananda/netlink v1.3.0
github.com/xhit/go-simple-mail/v2 v2.16.0
github.com/yeqown/go-qrcode/v2 v2.2.5
github.com/yeqown/go-qrcode/writer/compressed v1.0.1
golang.org/x/crypto v0.35.0
golang.org/x/oauth2 v0.27.0
golang.org/x/sys v0.30.0
golang.org/x/crypto v0.36.0
golang.org/x/oauth2 v0.28.0
golang.org/x/sys v0.31.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.5.7
@ -37,15 +36,10 @@ require (
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
github.com/KyleBanks/depth v1.2.1 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/bytedance/sonic v1.12.9 // indirect
github.com/bytedance/sonic/loader v0.2.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cloudwego/base64x v0.1.5 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dchest/uniuri v1.2.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
github.com/gin-contrib/sse v1.0.0 // indirect
github.com/glebarez/go-sqlite v1.22.0 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.7 // indirect
github.com/go-jose/go-jose/v4 v4.0.5 // indirect
@ -55,16 +49,11 @@ require (
github.com/go-openapi/swag v0.23.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.25.0 // indirect
github.com/go-sql-driver/mysql v1.9.0 // indirect
github.com/go-test/deep v1.1.1 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/gorilla/context v1.1.2 // indirect
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/gorilla/sessions v1.4.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.7.2 // indirect
@ -73,9 +62,7 @@ require (
github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mailru/easyjson v0.9.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
@ -83,28 +70,21 @@ require (
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
github.com/microsoft/go-mssqldb v1.8.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/toorop/go-dkim v0.0.0-20250226130143-9025cce95817 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
github.com/yeqown/reedsolomon v1.0.0 // indirect
golang.org/x/arch v0.14.0 // indirect
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
golang.org/x/net v0.35.0 // indirect
golang.org/x/sync v0.11.0 // indirect
golang.org/x/text v0.22.0 // indirect
golang.org/x/tools v0.30.0 // indirect
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
golang.org/x/net v0.37.0 // indirect
golang.org/x/sync v0.12.0 // indirect
golang.org/x/text v0.23.0 // indirect
golang.org/x/tools v0.31.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 // indirect
google.golang.org/protobuf v1.36.5 // indirect
modernc.org/libc v1.61.13 // indirect

130
go.sum
View File

@ -29,52 +29,27 @@ github.com/a8m/envsubst v1.4.3 h1:kDF7paGK8QACWYaQo6KtyYBozY2jhQrTuNNuUxQkhJY=
github.com/a8m/envsubst v1.4.3/go.mod h1:4jjHWQlZoaXPoLQUb7H2qT4iLkZDdmEQiOUogdUmqVU=
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI=
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
github.com/alexedwards/scs/v2 v2.8.0 h1:h31yUYoycPuL0zt14c0gd+oqxfRwIj6SOjHdKRZxhEw=
github.com/alexedwards/scs/v2 v2.8.0/go.mod h1:ToaROZxyKukJKT/xLcVQAChi5k6+Pn1Gvmdl7h3RRj8=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff/go.mod h1:+RTT1BOk5P97fT2CiHkbFQwkK3mjsFAP6zCYV2aXtjw=
github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60=
github.com/bradleypeabody/gorilla-sessions-memcache v0.0.0-20181103040241-659414f458e1/go.mod h1:dkChI7Tbtx7H1Tj7TqGSZMOeGpMP5gLHtjroHd4agiI=
github.com/bytedance/sonic v1.12.9 h1:Od1BvK55NnewtGaJsTDeAOSnLVO2BTSLOe0+ooKokmQ=
github.com/bytedance/sonic v1.12.9/go.mod h1:uVvFidNmlt9+wa31S1urfwwthTWteBgG0hWuoKAXTx8=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/bytedance/sonic/loader v0.2.3 h1:yctD0Q3v2NOGfSWPLPvG2ggA2kV6TS6s4wioyEqssH0=
github.com/bytedance/sonic/loader v0.2.3/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4=
github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/coreos/go-oidc/v3 v3.12.0 h1:sJk+8G2qq94rDI6ehZ71Bol3oUHy63qNYmkiSjrc/Jo=
github.com/coreos/go-oidc/v3 v3.12.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dchest/uniuri v0.0.0-20160212164326-8902c56451e9/go.mod h1:GgB8SF9nRG+GqaDtLcwJZsQFhcogVCJ79j4EdT0c2V4=
github.com/dchest/uniuri v1.2.0 h1:koIcOUdrTIivZgSLhHQvKgqdWZq5d7KdMEWF1Ud6+5g=
github.com/dchest/uniuri v1.2.0/go.mod h1:fSzm4SLHzNZvWLvWJew423PhAzkpNQYq+uNLq4kxhkY=
github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko=
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY=
github.com/gin-contrib/cors v1.7.3 h1:hV+a5xp8hwJoTw7OY+a70FsL8JkVVFTXw9EcfrYUdns=
github.com/gin-contrib/cors v1.7.3/go.mod h1:M3bcKZhxzsvI+rlRSkkxHyljJt1ESd93COUvemZ79j4=
github.com/gin-contrib/sessions v0.0.0-20190101140330-dc5246754963/go.mod h1:4lkInX8nHSR62NSmhXM3xtPeMSyfiR58NaEz+om1lHM=
github.com/gin-contrib/sessions v1.0.2 h1:UaIjUvTH1cMeOdj3in6dl+Xb6It8RiKRF9Z1anbUyCA=
github.com/gin-contrib/sessions v1.0.2/go.mod h1:KxKxWqWP5LJVDCInulOl4WbLzK2KSPlLesfZ66wRvMs=
github.com/gin-contrib/sse v0.0.0-20170109093832-22d885f9ecc7/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s=
github.com/gin-contrib/sse v1.0.0 h1:y3bT1mUWUxDpW4JLQg/HnTqV4rozuW4tC9eFKTxYI9E=
github.com/gin-contrib/sse v1.0.0/go.mod h1:zNuFdwarAygJBht0NTKiSi3jRf6RbqeILZ9Sp6Slhe0=
github.com/gin-gonic/gin v1.3.0/go.mod h1:7cKuhb5qV2ggCFctp2fJQ+ErvciLZrIeoOSOm6mUr7Y=
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ=
github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc=
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q=
github.com/go-asn1-ber/asn1-ber v1.5.7 h1:DTX+lbVTWaTw1hQ+PbZPlnDZPEIs0SS/GCZAl535dDk=
github.com/go-asn1-ber/asn1-ber v1.5.7/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
@ -89,6 +64,8 @@ github.com/go-openapi/spec v0.21.0 h1:LTVzPc3p/RzRnkQqLRndbAzjY0d0BCL72A6j3CdL9Z
github.com/go-openapi/spec v0.21.0/go.mod h1:78u6VdPw81XU44qEWGhtr982gJ5BWg2c0I5XwVMotYk=
github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE=
github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ=
github.com/go-pkgz/routegroup v1.3.1 h1:XAVWskX8Iup6HoQD9zv+gJx4DOJC2DSkKBHCMeeW8/s=
github.com/go-pkgz/routegroup v1.3.1/go.mod h1:kDDPDRLRiRY1vnENrZJw1jQAzQX7fvsbsHGRQFNQfKc=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@ -102,8 +79,6 @@ github.com/go-sql-driver/mysql v1.9.0 h1:Y0zIbQXhQKmQgTp44Y1dp3wTXcn804QoTptLZT1
github.com/go-sql-driver/mysql v1.9.0/go.mod h1:pDetrLJeA3oMujJuvXc8RJoasr589B6A9fwzD3QMrqw=
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
@ -112,32 +87,18 @@ github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0kt
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o=
github.com/gorilla/context v1.1.2/go.mod h1:KDPwT9i/MeWHiLl90fuTgrt4/wPcv75vFAZLaOOcbxM=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.1.1/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w=
github.com/gorilla/sessions v1.1.3/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w=
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik=
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
@ -169,21 +130,10 @@ github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8Hm
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/json-iterator/go v1.1.5/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/kidstuff/mongostore v0.0.0-20181113001930-e650cd85ee4b/go.mod h1:g2nVr8KZVXJSS97Jo8pJ0jgq29P6H7dG0oplUA86MQw=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE=
github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
@ -192,7 +142,6 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
@ -201,26 +150,17 @@ github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/memcachier/mc v2.0.1+incompatible/go.mod h1:7bkvFE61leUBvXz+yxsOnGBQSZpBSPIMUQSmmSHvuXc=
github.com/microsoft/go-mssqldb v1.7.2/go.mod h1:kOvZKUdrhhFQmxLZqbwUV0rHkNkZpthMITIb2Ko1IoA=
github.com/microsoft/go-mssqldb v1.8.0 h1:7cyZ/AT7ycDsEoWPIXibd+aVKFtteUNhDGf3aobP+tw=
github.com/microsoft/go-mssqldb v1.8.0/go.mod h1:6znkekS3T2vp0waiMhen4GPU1BiAsrP+iXHcE7a7rFo=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8=
github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
@ -228,17 +168,14 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus-community/pro-bing v0.6.1 h1:EQukUOma9YFZRPe4DGSscxUf9LH07rpqwisNWjSZrgU=
github.com/prometheus-community/pro-bing v0.6.1/go.mod h1:jNCOI3D7pmTCeaoF41cNS6uaxeFY/Gmc3ffwbuJVzAQ=
github.com/prometheus/client_golang v1.21.0 h1:DIsaGmiaBkSangBgMtWdNfxbMNdku5IK6iNhrEqWvdA=
github.com/prometheus/client_golang v1.21.0/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk=
github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/quasoft/memstore v0.0.0-20180925164028-84a050167438/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg=
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b h1:aUNXCGgukb4gtY99imuIeoh8Vr0GSwAlYxPAhqZrpFc=
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
@ -246,8 +183,6 @@ github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUz
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@ -262,13 +197,6 @@ github.com/swaggo/swag v1.16.4/go.mod h1:VBsHJRsDvfYvqoiMKnsdwhNV9LEMHgEDZcyVYX0
github.com/toorop/go-dkim v0.0.0-20201103131630-e1cd1a0a5208/go.mod h1:BzWtXXrXzZUvMacR0oF/fbDDgUPO8L36tDMmRAf14ns=
github.com/toorop/go-dkim v0.0.0-20250226130143-9025cce95817 h1:q0hKh5a5FRkhuTb5JNfgjzpzvYLHjH0QOgPZPYnRWGA=
github.com/toorop/go-dkim v0.0.0-20250226130143-9025cce95817/go.mod h1:BzWtXXrXzZUvMacR0oF/fbDDgUPO8L36tDMmRAf14ns=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v0.0.0-20181209151446-772ced7fd4c2/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/utrack/gin-csrf v0.0.0-20190424104817-40fb8d2c8fca h1:lpvAjPK+PcxnbcB8H7axIb4fMNwjX9bE4DzwPjGg8aE=
github.com/utrack/gin-csrf v0.0.0-20190424104817-40fb8d2c8fca/go.mod h1:XXKxNbpoLihvvT7orUZbs/iZayg1n4ip7iJakJPAwA8=
github.com/vardius/message-bus v1.1.5 h1:YSAC2WB4HRlwc4neFPTmT88kzzoiQ+9WRRbej/E/LZc=
github.com/vardius/message-bus v1.1.5/go.mod h1:6xladCV2lMkUAE4bzzS85qKOiB5miV7aBVRafiTJGqw=
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
@ -285,8 +213,6 @@ github.com/yeqown/go-qrcode/writer/compressed v1.0.1/go.mod h1:BJScsGUIKM+eg0CCL
github.com/yeqown/reedsolomon v1.0.0 h1:x1h/Ej/uJnNu8jaX7GLHBWmZKCAWjEJTetkqaabr4B0=
github.com/yeqown/reedsolomon v1.0.0/go.mod h1:P76zpcn2TCuL0ul1Fso373qHRc69LKwAw/Iy6g1WiiM=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/arch v0.14.0 h1:z9JUEZWr8x4rR0OU6c4/4t6E6jOZ8/QBS2bBYBm4tx4=
golang.org/x/arch v0.14.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
@ -300,18 +226,17 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 h1:nDVHiLt8aIbd/VzvPWN6kSOPE7+F/fNFDSXLVYkE/Iw=
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394/go.mod h1:sIifuuw/Yco/y6yb6+bDNfyeQ/MdPUy/hKEMYQV17cM=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
@ -329,11 +254,10 @@ golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/oauth2 v0.27.0 h1:da9Vo7/tDv5RH/7nZDz1eMGS/q1Vv1N/7FCrBhI9I3M=
golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc=
golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -341,9 +265,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20181228144115-9a3f9b0469bb/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@ -364,8 +287,8 @@ golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@ -393,16 +316,16 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
golang.org/x/tools v0.31.0 h1:0EedkvKDbh+qistFTd0Bcwe/YLh4vHwWEkiI0toFIBU=
golang.org/x/tools v0.31.0/go.mod h1:naFTU+Cev749tSJRXJlna0T3WxKvb1kWEx15xA4SdmQ=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
@ -411,12 +334,8 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE=
gopkg.in/go-playground/validator.v8 v8.18.2/go.mod h1:RX2a/7Ha8BgOhfk7j780h4/u/RRjR0eouCJSH80/M2Y=
gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
@ -458,6 +377,5 @@ modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E=
sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY=

View File

@ -0,0 +1,214 @@
package cors
import (
"net/http"
"slices"
"strconv"
"strings"
)
// Middleware is a type that creates a new CORS middleware. The CORS middleware
// adds Cross-Origin Resource Sharing headers to the response. This middleware should
// be used to allow cross-origin requests to your server.
type Middleware struct {
o options
varyHeaders string // precomputed Vary header
allOrigins bool // all origins are allowed
}
// New returns a new CORS middleware with the provided options.
func New(opts ...Option) *Middleware {
o := newOptions(opts...)
m := &Middleware{
o: o,
}
// set vary headers
if m.o.allowPrivateNetworks {
m.varyHeaders = "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network"
} else {
m.varyHeaders = "Origin, Access-Control-Request-Method, Access-Control-Request-Headers"
}
if len(m.o.allowedOrigins) == 1 && m.o.allowedOrigins[0] == "*" {
m.allOrigins = true
}
return m
}
// Handler returns the CORS middleware handler.
func (m *Middleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Handle preflight requests and stop the chain as some other
// middleware may not handle OPTIONS requests correctly.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#preflighted_requests
if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
m.handlePreflight(w, r)
w.WriteHeader(http.StatusNoContent) // always return 204 No Content
return
}
// handle normal CORS requests
m.handleNormal(w, r)
next.ServeHTTP(w, r) // execute the next handler
})
}
// region internal-helpers
// handlePreflight handles preflight requests. If the request was successful, this function will
// write the CORS headers and return. If the request was not successful, this function will
// not add any CORS headers and return - thus the CORS request is considered invalid.
func (m *Middleware) handlePreflight(w http.ResponseWriter, r *http.Request) {
// Always set Vary headers
// see https://github.com/rs/cors/issues/10,
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
w.Header().Add("Vary", m.varyHeaders)
// check origin
origin := r.Header.Get("Origin")
if origin == "" {
return // not a valid CORS request
}
if !m.originAllowed(origin) {
return
}
// check method
reqMethod := r.Header.Get("Access-Control-Request-Method")
if !m.methodAllowed(reqMethod) {
return
}
// check headers
reqHeaders := r.Header.Get("Access-Control-Request-Headers")
if !m.headersAllowed(reqHeaders) {
return
}
// set CORS headers for the successful preflight request
if m.allOrigins {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else {
w.Header().Set("Access-Control-Allow-Origin", origin) // return original origin
}
w.Header().Set("Access-Control-Allow-Methods", reqMethod)
if reqHeaders != "" {
// Spec says: Since the list of headers can be unbounded, simply returning supported headers
// from Access-Control-Request-Headers can be enough
w.Header().Set("Access-Control-Allow-Headers", reqHeaders)
}
if m.o.allowCredentials {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
if m.o.allowPrivateNetworks && r.Header.Get("Access-Control-Request-Private-Network") == "true" {
w.Header().Set("Access-Control-Allow-Private-Network", "true")
}
if m.o.maxAge > 0 {
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(m.o.maxAge))
}
}
// handleNormal handles normal CORS requests. If the request was successful, this function will
// write the CORS headers to the response. If the request was not successful, this function will
// not add any CORS headers to the response. In this case, the CORS request is considered invalid.
func (m *Middleware) handleNormal(w http.ResponseWriter, r *http.Request) {
// Always set Vary headers
// see https://github.com/rs/cors/issues/10,
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
w.Header().Add("Vary", "Origin")
// check origin
origin := r.Header.Get("Origin")
if origin == "" {
return // not a valid CORS request
}
if !m.originAllowed(origin) {
return
}
// check method
if !m.methodAllowed(r.Method) {
return
}
// set CORS headers for the successful CORS request
if m.allOrigins {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else {
w.Header().Set("Access-Control-Allow-Origin", origin) // return original origin
}
if len(m.o.exposedHeaders) > 0 {
w.Header().Set("Access-Control-Expose-Headers", strings.Join(m.o.exposedHeaders, ", "))
}
if m.o.allowCredentials {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
}
func (m *Middleware) originAllowed(origin string) bool {
if len(m.o.allowedOrigins) == 1 && m.o.allowedOrigins[0] == "*" {
return true // everything is allowed
}
// check simple origins
if slices.Contains(m.o.allowedOrigins, origin) {
return true
}
// check wildcard origins
for _, allowedOrigin := range m.o.allowedOriginPatterns {
if allowedOrigin.match(origin) {
return true
}
}
return false
}
func (m *Middleware) methodAllowed(method string) bool {
if method == http.MethodOptions {
return true // preflight request is always allowed
}
if len(m.o.allowedMethods) == 1 && m.o.allowedMethods[0] == "*" {
return true // everything is allowed
}
if slices.Contains(m.o.allowedMethods, method) {
return true
}
return false
}
func (m *Middleware) headersAllowed(headers string) bool {
if headers == "" {
return true // no headers are requested
}
if len(m.o.allowedHeaders) == 0 {
return false // no headers are allowed
}
if _, ok := m.o.allowedHeaders["*"]; ok {
return true // everything is allowed
}
// split headers by comma (according to definition, the headers are sorted and in lowercase)
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Request-Headers
for header := range strings.SplitSeq(headers, ",") {
if _, ok := m.o.allowedHeaders[strings.TrimSpace(header)]; !ok {
return false
}
}
return true
}
// endregion internal-helpers

View File

@ -0,0 +1,101 @@
package cors
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestMiddleware_New(t *testing.T) {
m := New(WithAllowedOrigins("*"))
if len(m.varyHeaders) == 0 {
t.Errorf("expected vary headers to be populated, got %v", m.varyHeaders)
}
if !m.allOrigins {
t.Errorf("expected allOrigins to be true, got %v", m.allOrigins)
}
}
func TestMiddleware_Handler_normal(t *testing.T) {
m := New(WithAllowedOrigins("http://example.com"))
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
req.Header.Set("Origin", "http://example.com")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Result().StatusCode != http.StatusOK {
t.Errorf("expected status code 200, got %d", w.Result().StatusCode)
}
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
t.Errorf("expected Access-Control-Allow-Origin to be 'http://example.com', got %s",
w.Header().Get("Access-Control-Allow-Origin"))
}
}
func TestMiddleware_Handler_preflight(t *testing.T) {
m := New(WithAllowedOrigins("http://example.com"))
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodOptions, "http://example.com", nil)
req.Header.Set("Origin", "http://example.com")
req.Header.Set("Access-Control-Request-Method", http.MethodGet)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Result().StatusCode != http.StatusNoContent {
t.Errorf("expected status code 204, got %d", w.Result().StatusCode)
}
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
t.Errorf("expected Access-Control-Allow-Origin to be 'http://example.com', got %s",
w.Header().Get("Access-Control-Allow-Origin"))
}
}
func TestMiddleware_originAllowed(t *testing.T) {
m := New(WithAllowedOrigins("http://example.com"))
if !m.originAllowed("http://example.com") {
t.Errorf("expected origin 'http://example.com' to be allowed")
}
if m.originAllowed("http://notallowed.com") {
t.Errorf("expected origin 'http://notallowed.com' to be not allowed")
}
}
func TestMiddleware_methodAllowed(t *testing.T) {
m := New(WithAllowedMethods(http.MethodGet, http.MethodPost))
if !m.methodAllowed(http.MethodGet) {
t.Errorf("expected method 'GET' to be allowed")
}
if m.methodAllowed(http.MethodDelete) {
t.Errorf("expected method 'DELETE' to be not allowed")
}
}
func TestMiddleware_headersAllowed(t *testing.T) {
m := New(WithAllowedHeaders("Content-Type", "Authorization"))
if !m.headersAllowed("content-type, authorization") {
t.Errorf("expected headers 'Content-Type, Authorization' to be allowed")
}
if m.headersAllowed("x-custom-header") {
t.Errorf("expected header 'X-Custom-Header' to be not allowed")
}
}

View File

@ -0,0 +1,133 @@
package cors
import (
"net/http"
"strings"
)
type void struct{}
// options is a struct that contains options for the CORS middleware.
// It uses the functional options pattern for flexible configuration.
type options struct {
allowedOrigins []string // origins without wildcards
allowedOriginPatterns []wildcard // origins with wildcards
allowedMethods []string
allowedHeaders map[string]void
exposedHeaders []string // these are in addition to the CORS-safelisted response headers
allowCredentials bool
allowPrivateNetworks bool
maxAge int
}
// Option is a type that is used to set options for the CORS middleware.
// It implements the functional options pattern.
type Option func(*options)
// WithAllowedOrigins sets the allowed origins for the CORS middleware.
// If the special "*" value is present in the list, all origins will be allowed.
// An origin may contain a wildcard (*) to replace 0 or more characters
// (i.e.: http://*.domain.com). Usage of wildcards implies a small performance penalty.
// Only one wildcard can be used per origin.
// By default, all origins are allowed (*).
func WithAllowedOrigins(origins ...string) Option {
return func(o *options) {
o.allowedOrigins = nil
o.allowedOriginPatterns = nil
for _, origin := range origins {
if len(origin) > 1 && strings.Contains(origin, "*") {
o.allowedOriginPatterns = append(
o.allowedOriginPatterns,
newWildcard(origin),
)
} else {
o.allowedOrigins = append(o.allowedOrigins, origin)
}
}
}
}
// WithAllowedMethods sets the allowed methods for the CORS middleware.
// By default, all methods are allowed (*).
func WithAllowedMethods(methods ...string) Option {
return func(o *options) {
o.allowedMethods = methods
}
}
// WithAllowedHeaders sets the allowed headers for the CORS middleware.
// By default, all headers are allowed (*).
func WithAllowedHeaders(headers ...string) Option {
return func(o *options) {
o.allowedHeaders = make(map[string]void)
for _, header := range headers {
// allowed headers are always checked in lowercase
o.allowedHeaders[strings.ToLower(header)] = void{}
}
}
}
// WithExposedHeaders sets the exposed headers for the CORS middleware.
// By default, no headers are exposed.
func WithExposedHeaders(headers ...string) Option {
return func(o *options) {
o.exposedHeaders = nil
for _, header := range headers {
o.exposedHeaders = append(o.exposedHeaders, http.CanonicalHeaderKey(header))
}
}
}
// WithAllowCredentials sets the allow credentials option for the CORS middleware.
// This setting indicates whether the request can include user credentials like
// cookies, HTTP authentication or client side SSL certificates.
// By default, credentials are not allowed.
func WithAllowCredentials(allow bool) Option {
return func(o *options) {
o.allowCredentials = allow
}
}
// WithAllowPrivateNetworks sets the allow private networks option for the CORS middleware.
// This setting indicates whether to accept cross-origin requests over a private network.
func WithAllowPrivateNetworks(allow bool) Option {
return func(o *options) {
o.allowPrivateNetworks = allow
}
}
// WithMaxAge sets the max age (in seconds) for the CORS middleware.
// The maximum age indicates how long (in seconds) the results of a preflight request
// can be cached. A value of 0 means that no Access-Control-Max-Age header is sent back,
// resulting in browsers using their default value (5s by spec).
// If you need to force a 0 max-age, set it to a negative value (ie: -1).
// By default, the max age is 7200 seconds.
func WithMaxAge(age int) Option {
return func(o *options) {
o.maxAge = age
}
}
// newOptions is a function that returns a new options struct with sane default values.
func newOptions(opts ...Option) options {
o := options{
allowedOrigins: []string{"*"},
allowedMethods: []string{
http.MethodHead, http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete,
},
allowedHeaders: map[string]void{"*": {}},
exposedHeaders: nil,
allowCredentials: false,
allowPrivateNetworks: false,
maxAge: 0,
}
for _, opt := range opts {
opt(&o)
}
return o
}

View File

@ -0,0 +1,96 @@
package cors
import (
"maps"
"net/http"
"slices"
"testing"
)
func TestWithAllowedOrigins(t *testing.T) {
tests := []struct {
name string
origins []string
wantNormal []string
wantWildcard []wildcard
}{
{
name: "No origins",
origins: []string{},
wantNormal: nil,
wantWildcard: nil,
},
{
name: "Single origin",
origins: []string{"http://example.com"},
wantNormal: []string{"http://example.com"},
wantWildcard: nil,
},
{
name: "Wildcard origin",
origins: []string{"http://*.example.com"},
wantNormal: nil,
wantWildcard: []wildcard{newWildcard("http://*.example.com")},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := newOptions(WithAllowedOrigins(tt.origins...))
if !slices.Equal(o.allowedOrigins, tt.wantNormal) {
t.Errorf("got %v, want %v", o, tt.wantNormal)
}
if !slices.Equal(o.allowedOriginPatterns, tt.wantWildcard) {
t.Errorf("got %v, want %v", o, tt.wantWildcard)
}
})
}
}
func TestWithAllowedMethods(t *testing.T) {
methods := []string{http.MethodGet, http.MethodPost}
o := newOptions(WithAllowedMethods(methods...))
if !slices.Equal(o.allowedMethods, methods) {
t.Errorf("got %v, want %v", o.allowedMethods, methods)
}
}
func TestWithAllowedHeaders(t *testing.T) {
headers := []string{"Content-Type", "Authorization"}
o := newOptions(WithAllowedHeaders(headers...))
expectedHeaders := map[string]void{"content-type": {}, "authorization": {}}
if !maps.Equal(o.allowedHeaders, expectedHeaders) {
t.Errorf("got %v, want %v", o.allowedHeaders, expectedHeaders)
}
}
func TestWithExposedHeaders(t *testing.T) {
headers := []string{"X-Custom-Header"}
o := newOptions(WithExposedHeaders(headers...))
expectedHeaders := []string{http.CanonicalHeaderKey("X-Custom-Header")}
if !slices.Equal(o.exposedHeaders, expectedHeaders) {
t.Errorf("got %v, want %v", o.exposedHeaders, expectedHeaders)
}
}
func TestWithAllowCredentials(t *testing.T) {
o := newOptions(WithAllowCredentials(true))
if !o.allowCredentials {
t.Errorf("got %v, want %v", o.allowCredentials, true)
}
}
func TestWithAllowPrivateNetworks(t *testing.T) {
o := newOptions(WithAllowPrivateNetworks(true))
if !o.allowPrivateNetworks {
t.Errorf("got %v, want %v", o.allowPrivateNetworks, true)
}
}
func TestWithMaxAge(t *testing.T) {
maxAge := 3600
o := newOptions(WithMaxAge(maxAge))
if o.maxAge != maxAge {
t.Errorf("got %v, want %v", o.maxAge, maxAge)
}
}

View File

@ -0,0 +1,33 @@
package cors
import "strings"
// wildcard is a type that represents a wildcard string.
// This type allows faster matching of strings with a wildcard
// in comparison to using regex.
type wildcard struct {
prefix string
suffix string
}
// match returns true if the string s has the prefix and suffix of the wildcard.
func (w wildcard) match(s string) bool {
return len(s) >= len(w.prefix)+len(w.suffix) &&
strings.HasPrefix(s, w.prefix) &&
strings.HasSuffix(s, w.suffix)
}
func newWildcard(s string) wildcard {
if i := strings.IndexByte(s, '*'); i >= 0 {
return wildcard{
prefix: s[:i],
suffix: s[i+1:],
}
}
// fallback, usually this case should not happen
return wildcard{
prefix: s,
suffix: "",
}
}

View File

@ -0,0 +1,94 @@
package cors
import "testing"
func TestWildcardMatch(t *testing.T) {
tests := []struct {
name string
wildcard wildcard
input string
expected bool
}{
{
name: "Match with prefix and suffix",
wildcard: newWildcard("http://*.example.com"),
input: "http://sub.example.com",
expected: true,
},
{
name: "No match with different prefix",
wildcard: newWildcard("http://*.example.com"),
input: "https://sub.example.com",
expected: false,
},
{
name: "No match with different suffix",
wildcard: newWildcard("http://*.example.com"),
input: "http://sub.example.org",
expected: false,
},
{
name: "Match with empty suffix",
wildcard: newWildcard("http://*"),
input: "http://example.com",
expected: true,
},
{
name: "Match with empty prefix",
wildcard: newWildcard("*.example.com"),
input: "sub.example.com",
expected: true,
},
{
name: "No match with empty prefix and different suffix",
wildcard: newWildcard("*.example.com"),
input: "sub.example.org",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.wildcard.match(tt.input); got != tt.expected {
t.Errorf("wildcard.match(%s) = %v, want %v", tt.input, got, tt.expected)
}
})
}
}
func TestNewWildcard(t *testing.T) {
tests := []struct {
name string
input string
expected wildcard
}{
{
name: "Wildcard with prefix and suffix",
input: "http://*.example.com",
expected: wildcard{prefix: "http://", suffix: ".example.com"},
},
{
name: "Wildcard with empty suffix",
input: "http://*",
expected: wildcard{prefix: "http://", suffix: ""},
},
{
name: "Wildcard with empty prefix",
input: "*.example.com",
expected: wildcard{prefix: "", suffix: ".example.com"},
},
{
name: "No wildcard character",
input: "http://example.com",
expected: wildcard{prefix: "http://example.com", suffix: ""},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := newWildcard(tt.input); got != tt.expected {
t.Errorf("newWildcard(%s) = %v, want %v", tt.input, got, tt.expected)
}
})
}
}

View File

@ -0,0 +1,137 @@
package csrf
import (
"context"
"net/http"
"slices"
)
// ContextValueIdentifier is the context value identifier for the CSRF token.
// The token is only stored in the context if the RefreshToken function was called before.
const ContextValueIdentifier = "_csrf_token"
// Middleware is a type that creates a new CSRF middleware. The CSRF middleware
// can be used to mitigate Cross-Site Request Forgery attacks.
type Middleware struct {
o options
}
// New returns a new CSRF middleware with the provided options.
func New(sessionReader SessionReader, sessionWriter SessionWriter, opts ...Option) *Middleware {
opts = append(opts, withSessionReader(sessionReader), withSessionWriter(sessionWriter))
o := newOptions(opts...)
m := &Middleware{
o: o,
}
checkForPRNG()
return m
}
// Handler returns the CSRF middleware handler. This middleware validates the CSRF token and calls the specified
// error handler if an invalid CSRF token was found.
func (m *Middleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if slices.Contains(m.o.ignoreMethods, r.Method) {
next.ServeHTTP(w, r) // skip CSRF check for ignored methods
return
}
// get the token from the request
token := m.o.tokenGetter(r)
storedToken := m.o.sessionGetter(r)
if !tokenEqual(token, storedToken) {
m.o.errCallback(w, r)
return
}
next.ServeHTTP(w, r) // execute the next handler
})
}
// RefreshToken generates a new CSRF Token and stores it in the session. The token is also passed to subsequent handlers
// via the context value ContextValueIdentifier.
func (m *Middleware) RefreshToken(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if GetToken(r.Context()) != "" {
// token already generated higher up in the chain
next.ServeHTTP(w, r)
return
}
// generate a new token
token := generateToken(m.o.tokenLength)
key := generateToken(m.o.tokenLength)
// mask the token
maskedToken := maskToken(token, key)
// store the encoded token in the session
encodedToken := encodeToken(maskedToken)
m.o.sessionWriter(r, encodedToken)
// pass the token down the chain via the context
r = r.WithContext(setToken(r.Context(), encodedToken))
next.ServeHTTP(w, r)
})
}
// region token-access
// GetToken retrieves the CSRF token from the given context. Ensure that the RefreshToken function was called before,
// otherwise, no token is populated in the context.
func GetToken(ctx context.Context) string {
token, ok := ctx.Value(ContextValueIdentifier).(string)
if !ok {
return ""
}
return token
}
// endregion token-access
// region internal-helpers
func setToken(ctx context.Context, token string) context.Context {
return context.WithValue(ctx, ContextValueIdentifier, token)
}
// defaultTokenGetter is the default token getter function for the CSRF middleware.
// It checks the request form values, URL query parameters, and headers for the CSRF token.
// The order of precedence is:
// 1. Header "X-CSRF-TOKEN"
// 2. Header "X-XSRF-TOKEN"
// 3. URL query parameter "_csrf"
// 4. Form value "_csrf"
func defaultTokenGetter(r *http.Request) string {
if t := r.Header.Get("X-CSRF-TOKEN"); len(t) > 0 {
return t
}
if t := r.Header.Get("X-XSRF-TOKEN"); len(t) > 0 {
return t
}
if t := r.URL.Query().Get("_csrf"); len(t) > 0 {
return t
}
if t := r.FormValue("_csrf"); len(t) > 0 {
return t
}
return ""
}
// defaultErrorHandler is the default error handler function for the CSRF middleware.
// It writes a 403 Forbidden response.
func defaultErrorHandler(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "CSRF token mismatch", http.StatusForbidden)
}
// endregion internal-helpers

View File

@ -0,0 +1,251 @@
package csrf
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/h44z/wg-portal/internal/app/api/core/request"
)
func TestMiddleware_Handler(t *testing.T) {
sessionToken := "stored-token"
sessionReader := func(r *http.Request) string {
return sessionToken
}
sessionWriter := func(r *http.Request, token string) {
sessionToken = token
}
m := New(sessionReader, sessionWriter)
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
tests := []struct {
name string
method string
token string
wantStatus int
}{
{"ValidToken", "POST", "stored-token", http.StatusOK},
{"ValidToken2", "PUT", "stored-token", http.StatusOK},
{"ValidToken3", "GET", "stored-token", http.StatusOK},
{"InvalidToken", "POST", "invalid-token", http.StatusForbidden},
{"IgnoredMethod", "GET", "", http.StatusOK},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, "/", nil)
req.Header.Set("X-CSRF-TOKEN", tt.token)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if status := rr.Code; status != tt.wantStatus {
t.Errorf("Handler() status = %d, want %d", status, tt.wantStatus)
}
})
}
}
func TestMiddleware_RefreshToken(t *testing.T) {
sessionToken := ""
sessionReader := func(r *http.Request) string {
return sessionToken
}
sessionWriter := func(r *http.Request, token string) {
sessionToken = token
}
m := New(sessionReader, sessionWriter)
handler := m.RefreshToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := GetToken(r.Context())
if token == "" {
t.Errorf("RefreshToken() did not set token in context")
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("POST", "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusOK {
t.Errorf("RefreshToken() status = %d, want %d", status, http.StatusOK)
}
if sessionToken == "" {
t.Errorf("RefreshToken() did not set token in session")
}
}
func TestMiddleware_RefreshToken_chained(t *testing.T) {
sessionToken := ""
tokenWrites := 0
sessionReader := func(r *http.Request) string {
return sessionToken
}
sessionWriter := func(r *http.Request, token string) {
sessionToken = token
tokenWrites++
}
m := New(sessionReader, sessionWriter)
handler := m.RefreshToken(m.RefreshToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := GetToken(r.Context())
if token == "" {
t.Errorf("RefreshToken() did not set token in context")
}
w.WriteHeader(http.StatusOK)
})))
req := httptest.NewRequest("POST", "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusOK {
t.Errorf("RefreshToken() status = %d, want %d", status, http.StatusOK)
}
if sessionToken == "" {
t.Errorf("RefreshToken() did not set token in session")
}
if tokenWrites != 1 {
t.Errorf("RefreshToken() wrote token to session more than once: %d", tokenWrites)
}
}
func TestMiddleware_RefreshToken_Handler(t *testing.T) {
sessionToken := ""
sessionReader := func(r *http.Request) string {
return sessionToken
}
sessionWriter := func(r *http.Request, token string) {
sessionToken = token
}
m := New(sessionReader, sessionWriter)
// simulate two requests: first one GET request with the RefreshToken handler, the next one is a PUT request with
// the token from the first request added as X-CSRF-TOKEN header
// first request
retrievedToken := ""
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
handler := m.RefreshToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
retrievedToken = GetToken(r.Context())
if retrievedToken == "" {
t.Errorf("RefreshToken() did not set token in context")
}
w.WriteHeader(http.StatusAccepted)
}))
handler.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusAccepted {
t.Errorf("Handler() status = %d, want %d", status, http.StatusAccepted)
}
if retrievedToken == "" {
t.Errorf("no token retrieved")
}
if retrievedToken != sessionToken {
t.Errorf("token in context does not match token in session")
}
// second request
req = httptest.NewRequest("PUT", "/", nil)
req.Header.Set("X-CSRF-TOKEN", retrievedToken)
rr = httptest.NewRecorder()
handler = m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
handler.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusOK {
t.Errorf("Handler() status = %d, want %d", status, http.StatusOK)
}
}
func TestMiddleware_Handler_FormBody(t *testing.T) {
sessionToken := "stored-token"
sessionReader := func(r *http.Request) string {
return sessionToken
}
sessionWriter := func(r *http.Request, token string) {
sessionToken = token
}
m := New(sessionReader, sessionWriter)
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bodyData, err := request.BodyString(r)
if err != nil {
t.Errorf("Handler() error = %v, want nil", err)
}
// ensure that the body is empty - ParseForm() should have been called before by the CSRF middleware
if bodyData != "" {
t.Errorf("Handler() bodyData = %s, want empty", bodyData)
}
if r.FormValue("_csrf") != "stored-token" {
t.Errorf("Handler() _csrf = %s, want %s", r.FormValue("_csrf"), "stored-token")
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("POST", "/", nil)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Form = make(map[string][]string)
req.Form.Add("_csrf", "stored-token")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusOK {
t.Errorf("Handler() status = %d, want %d", status, http.StatusOK)
}
}
func TestMiddleware_Handler_FormBodyAvailable(t *testing.T) {
sessionToken := "stored-token"
sessionReader := func(r *http.Request) string {
return sessionToken
}
sessionWriter := func(r *http.Request, token string) {
sessionToken = token
}
m := New(sessionReader, sessionWriter)
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bodyData, err := request.BodyString(r)
if err != nil {
t.Errorf("Handler() error = %v, want nil", err)
}
// ensure that the body is not empty, as the CSRF middleware should not have read the body
if bodyData != "the original body" {
t.Errorf("Handler() bodyData = %s, want %s", bodyData, "the original body")
}
// check if the token is available in the form values (from query parameters)
if r.FormValue("_csrf") != "stored-token" {
t.Errorf("Handler() _csrf = %s, want %s", r.FormValue("_csrf"), "stored-token")
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("POST", "/?_csrf=stored-token", nil)
req.Header.Set("Content-Type", "text/plain")
req.Body = io.NopCloser(strings.NewReader("the original body"))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusOK {
t.Errorf("Handler() status = %d, want %d", status, http.StatusOK)
}
}

View File

@ -0,0 +1,88 @@
package csrf
import "net/http"
type SessionReader func(r *http.Request) string
type SessionWriter func(r *http.Request, token string)
// options is a struct that contains options for the CSRF middleware.
// It uses the functional options pattern for flexible configuration.
type options struct {
tokenLength int
ignoreMethods []string
errCallbackOverride bool
errCallback func(w http.ResponseWriter, r *http.Request)
tokenGetterOverride bool
tokenGetter func(r *http.Request) string
sessionGetter SessionReader
sessionWriter SessionWriter
}
// Option is a type that is used to set options for the CSRF middleware.
// It implements the functional options pattern.
type Option func(*options)
// WithTokenLength is a method that sets the token length for the CSRF middleware.
// The default value is 32.
func WithTokenLength(length int) Option {
return func(o *options) {
o.tokenLength = length
}
}
// WithErrorCallback is a method that sets the error callback function for the CSRF middleware.
// The error callback function is called when the CSRF token is invalid.
// The default behavior is to write a 403 Forbidden response.
func WithErrorCallback(fn func(w http.ResponseWriter, r *http.Request)) Option {
return func(o *options) {
o.errCallback = fn
o.errCallbackOverride = true
}
}
// WithTokenGetter is a method that sets the token getter function for the CSRF middleware.
// The token getter function is called to get the CSRF token from the request.
// The default behavior is to get the token from the "X-CSRF-Token" header.
func WithTokenGetter(fn func(r *http.Request) string) Option {
return func(o *options) {
o.tokenGetter = fn
o.tokenGetterOverride = true
}
}
// withSessionReader is a method that sets the session reader function for the CSRF middleware.
// The session reader function is called to get the CSRF token from the session.
func withSessionReader(fn SessionReader) Option {
return func(o *options) {
o.sessionGetter = fn
}
}
// withSessionWriter is a method that sets the session writer function for the CSRF middleware.
// The session writer function is called to write the CSRF token to the session.
func withSessionWriter(fn SessionWriter) Option {
return func(o *options) {
o.sessionWriter = fn
}
}
// newOptions is a function that returns a new options struct with sane default values.
func newOptions(opts ...Option) options {
o := options{
tokenLength: 32,
ignoreMethods: []string{"GET", "HEAD", "OPTIONS"},
errCallbackOverride: false,
errCallback: defaultErrorHandler,
tokenGetterOverride: false,
tokenGetter: defaultTokenGetter,
}
for _, opt := range opts {
opt(&o)
}
return o
}

View File

@ -0,0 +1,75 @@
package csrf
import (
"net/http"
"testing"
)
func TestWithTokenLength(t *testing.T) {
o := newOptions(WithTokenLength(64))
if o.tokenLength != 64 {
t.Errorf("WithTokenLength() = %d, want %d", o.tokenLength, 64)
}
}
func TestWithErrorCallback(t *testing.T) {
callback := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
}
o := newOptions(WithErrorCallback(callback))
if !o.errCallbackOverride {
t.Errorf("WithErrorCallback() did not set errCallbackOverride to true")
}
if o.errCallback == nil {
t.Errorf("WithErrorCallback() did not set errCallback")
}
}
func TestWithTokenGetter(t *testing.T) {
getter := func(r *http.Request) string {
return "test-token"
}
o := newOptions(WithTokenGetter(getter))
if !o.tokenGetterOverride {
t.Errorf("WithTokenGetter() did not set tokenGetterOverride to true")
}
if o.tokenGetter == nil {
t.Errorf("WithTokenGetter() did not set tokenGetter")
}
}
func TestWithSessionReader(t *testing.T) {
reader := func(r *http.Request) string {
return "session-token"
}
o := newOptions(withSessionReader(reader))
if o.sessionGetter == nil {
t.Errorf("withSessionReader() did not set sessionGetter")
}
}
func TestWithSessionWriter(t *testing.T) {
writer := func(r *http.Request, token string) {
// do nothing
}
o := newOptions(withSessionWriter(writer))
if o.sessionWriter == nil {
t.Errorf("withSessionWriter() did not set sessionWriter")
}
}
func TestNewOptionsDefaults(t *testing.T) {
o := newOptions()
if o.tokenLength != 32 {
t.Errorf("newOptions() default tokenLength = %d, want %d", o.tokenLength, 32)
}
if len(o.ignoreMethods) != 3 {
t.Errorf("newOptions() default ignoreMethods length = %d, want %d", len(o.ignoreMethods), 3)
}
if o.errCallback == nil {
t.Errorf("newOptions() default errCallback is nil")
}
if o.tokenGetter == nil {
t.Errorf("newOptions() default tokenGetter is nil")
}
}

View File

@ -0,0 +1,90 @@
package csrf
import (
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"slices"
)
// checkForPRNG is a function that checks if a cryptographically secure PRNG is available.
// If it is not available, the function panics.
func checkForPRNG() {
buf := make([]byte, 1)
_, err := io.ReadFull(rand.Reader, buf)
if err != nil {
panic(fmt.Sprintf("crypto/rand is unavailable: Read() failed with %#v", err))
}
}
// generateToken is a function that generates a secure random CSRF token.
func generateToken(length int) []byte {
bytes := make([]byte, length)
if _, err := io.ReadFull(rand.Reader, bytes); err != nil {
panic(err)
}
return bytes
}
// encodeToken is a function that encodes a token to a base64 string.
func encodeToken(token []byte) string {
return base64.URLEncoding.EncodeToString(token)
}
// decodeToken is a function that decodes a base64 string to a token.
func decodeToken(token string) ([]byte, error) {
return base64.URLEncoding.DecodeString(token)
}
// maskToken is a function that masks a token with a given key.
// The returned byte slice contains the key + the masked token.
// The key needs to have the same length as the token, otherwise the function panics.
// So the resulting slice has a length of len(token) * 2.
func maskToken(token, key []byte) []byte {
if len(token) != len(key) {
panic("token and key must have the same length")
}
// masked contains the key in the first half and the XOR masked token in the second half
tokenLength := len(token)
masked := make([]byte, tokenLength*2)
for i := 0; i < len(token); i++ {
masked[i] = key[i]
masked[i+tokenLength] = token[i] ^ key[i] // XOR mask
}
return masked
}
// unmaskToken is a function that unmask a token which contains the key in the first half.
// The returned byte slice contains the unmasked token, it has exactly half the length of the input slice.
func unmaskToken(masked []byte) []byte {
tokenLength := len(masked) / 2
token := make([]byte, tokenLength)
for i := 0; i < tokenLength; i++ {
token[i] = masked[i] ^ masked[i+tokenLength] // XOR unmask
}
return token
}
// tokenEqual is a function that compares two tokens for equality.
func tokenEqual(a, b string) bool {
decodedA, err := decodeToken(a)
if err != nil {
return false
}
decodedB, err := decodeToken(b)
if err != nil {
return false
}
unmaskedA := unmaskToken(decodedA)
unmaskedB := unmaskToken(decodedB)
return slices.Equal(unmaskedA, unmaskedB)
}

View File

@ -0,0 +1,81 @@
package csrf
import (
"encoding/base64"
"testing"
)
func TestCheckForPRNG(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Errorf("checkForPRNG() panicked: %v", r)
}
}()
checkForPRNG()
}
func TestGenerateToken(t *testing.T) {
length := 32
token := generateToken(length)
if len(token) != length {
t.Errorf("generateToken() returned token of length %d, expected %d", len(token), length)
}
}
func TestEncodeToken(t *testing.T) {
token := []byte("testtoken")
encoded := encodeToken(token)
expected := base64.URLEncoding.EncodeToString(token)
if encoded != expected {
t.Errorf("encodeToken() = %v, want %v", encoded, expected)
}
}
func TestDecodeToken(t *testing.T) {
token := "dGVzdHRva2Vu"
expected := []byte("testtoken")
decoded, err := decodeToken(token)
if err != nil {
t.Errorf("decodeToken() error = %v", err)
}
if string(decoded) != string(expected) {
t.Errorf("decodeToken() = %v, want %v", decoded, expected)
}
}
func TestMaskToken(t *testing.T) {
token := []byte("testtoken")
key := []byte("keykeykey")
masked := maskToken(token, key)
if len(masked) != len(token)*2 {
t.Errorf("maskToken() returned masked token of length %d, expected %d", len(masked), len(token)*2)
}
}
func TestUnmaskToken(t *testing.T) {
token := []byte("testtoken")
key := []byte("keykeykey")
masked := maskToken(token, key)
unmasked := unmaskToken(masked)
if string(unmasked) != string(token) {
t.Errorf("unmaskToken() = %v, want %v", unmasked, token)
}
}
func TestTokenEqual(t *testing.T) {
tokenA := encodeToken(maskToken([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}))
tokenB := encodeToken(maskToken([]byte{0x01, 0x02, 0x03}, []byte{0x04, 0x05, 0x06}))
if !tokenEqual(tokenA, tokenB) {
t.Errorf("tokenEqual() = false, want true")
}
tokenC := encodeToken(maskToken([]byte{0x01, 0x02, 0x03}, []byte{0x07, 0x08, 0x09}))
if !tokenEqual(tokenA, tokenC) {
t.Errorf("tokenEqual() = false, want true")
}
tokenD := encodeToken(maskToken([]byte{0x09, 0x02, 0x03}, []byte{0x04, 0x05, 0x06}))
if tokenEqual(tokenA, tokenD) {
t.Errorf("tokenEqual() = true, want false")
}
}

View File

@ -0,0 +1,199 @@
package logging
import (
"fmt"
"log/slog"
"net/http"
"strings"
"time"
)
// LogLevel is an enumeration of the different log levels.
type LogLevel int
const (
LogLevelDebug LogLevel = iota
LogLevelInfo
LogLevelWarn
LogLevelError
)
// Logger is an interface that defines the methods that a logger must implement.
// This allows the logging middleware to be used with different logging libraries.
type Logger interface {
// Debugf logs a message at debug level.
Debugf(format string, args ...any)
// Infof logs a message at info level.
Infof(format string, args ...any)
// Warnf logs a message at warn level.
Warnf(format string, args ...any)
// Errorf logs a message at error level.
Errorf(format string, args ...any)
}
// Middleware is a type that creates a new logging middleware. The logging middleware
// logs information about each request.
type Middleware struct {
o options
}
// New returns a new logging middleware with the provided options.
func New(opts ...Option) *Middleware {
o := newOptions(opts...)
m := &Middleware{
o: o,
}
return m
}
// Handler returns the logging middleware handler.
func (m *Middleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ww := newWriterWrapper(w)
start := time.Now()
defer func() {
info := m.extractInfoMap(r, start, ww)
if m.o.logger == nil {
msg, args := m.buildSlogMessageAndArguments(info)
m.logMsg(msg, args...)
} else {
msg := m.buildNormalLogMessage(info)
m.logMsg(msg)
}
}()
next.ServeHTTP(ww, r)
})
}
func (m *Middleware) extractInfoMap(r *http.Request, start time.Time, ww *writerWrapper) map[string]any {
info := make(map[string]any)
info["method"] = r.Method
info["path"] = r.URL.Path
info["protocol"] = r.Proto
info["clientIP"] = r.Header.Get("X-Forwarded-For")
if info["clientIP"] == "" {
// If the X-Forwarded-For header is not set, use the remote address without the port number.
lastColonIndex := strings.LastIndex(r.RemoteAddr, ":")
switch lastColonIndex {
case -1:
info["clientIP"] = r.RemoteAddr
default:
info["clientIP"] = r.RemoteAddr[:lastColonIndex]
}
}
info["userAgent"] = r.UserAgent()
info["referer"] = r.Header.Get("Referer")
info["duration"] = time.Since(start).String()
info["status"] = ww.StatusCode
info["dataLength"] = ww.WrittenBytes
if m.o.headerRequestIdKey != "" {
info["headerRequestId"] = r.Header.Get(m.o.headerRequestIdKey)
}
if m.o.contextRequestIdKey != "" {
info["contextRequestId"], _ = r.Context().Value(m.o.contextRequestIdKey).(string)
}
return info
}
func (m *Middleware) buildNormalLogMessage(info map[string]any) string {
switch {
case info["headerRequestId"] != nil && info["contextRequestId"] != nil:
return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s - rid=%s ctx=%s",
info["method"], info["path"], info["protocol"],
info["status"], info["dataLength"],
info["duration"],
info["clientIP"], info["userAgent"], info["referer"],
info["headerRequestId"], info["contextRequestId"])
case info["headerRequestId"] != nil:
return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s - rid=%s",
info["method"], info["path"], info["protocol"],
info["status"], info["dataLength"],
info["duration"],
info["clientIP"], info["userAgent"], info["referer"],
info["headerRequestId"])
case info["contextRequestId"] != nil:
return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s - ctx=%s",
info["method"], info["path"], info["protocol"],
info["status"], info["dataLength"],
info["duration"],
info["clientIP"], info["userAgent"], info["referer"],
info["contextRequestId"])
default:
return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s",
info["method"], info["path"], info["protocol"],
info["status"], info["dataLength"],
info["duration"],
info["clientIP"], info["userAgent"], info["referer"])
}
}
func (m *Middleware) buildSlogMessageAndArguments(info map[string]any) (message string, args []any) {
message = fmt.Sprintf("%s %s", info["method"], info["path"])
// Use a fixed order for the keys, so that the message is always the same.
// Skip method and path as they are already in the message.
keys := []string{
"protocol",
"status",
"dataLength",
"duration",
"clientIP",
"userAgent",
"referer",
"headerRequestId",
"contextRequestId",
}
for _, k := range keys {
if v, ok := info[k]; ok {
args = append(args, k, v) // only add key, value if it exists
}
}
return
}
func (m *Middleware) addPrefix(message string) string {
if m.o.prefix != "" {
return m.o.prefix + " " + message
}
return message
}
func (m *Middleware) logMsg(message string, args ...any) {
message = m.addPrefix(message)
if m.o.logger != nil {
switch m.o.logLevel {
case LogLevelDebug:
m.o.logger.Debugf(message, args...)
case LogLevelInfo:
m.o.logger.Infof(message, args...)
case LogLevelWarn:
m.o.logger.Warnf(message, args...)
case LogLevelError:
m.o.logger.Errorf(message, args...)
default:
m.o.logger.Infof(message, args...)
}
} else {
switch m.o.logLevel {
case LogLevelDebug:
slog.Debug(message, args...)
case LogLevelInfo:
slog.Info(message, args...)
case LogLevelWarn:
slog.Warn(message, args...)
case LogLevelError:
slog.Error(message, args...)
default:
slog.Info(message, args...)
}
}
}

View File

@ -0,0 +1,148 @@
package logging
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
type mockLogger struct {
messages []string
}
func (m *mockLogger) Debugf(format string, _ ...any) {
m.messages = append(m.messages, "DEBUG: "+format)
}
func (m *mockLogger) Infof(format string, _ ...any) {
m.messages = append(m.messages, "INFO: "+format)
}
func (m *mockLogger) Warnf(format string, _ ...any) {
m.messages = append(m.messages, "WARN: "+format)
}
func (m *mockLogger) Errorf(format string, _ ...any) {
m.messages = append(m.messages, "ERROR: "+format)
}
func TestMiddleware_Normal(t *testing.T) {
logger := &mockLogger{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte("Hello, World!"))
})
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
rr := httptest.NewRecorder()
middleware.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusTeapot {
t.Errorf("expected status code to be %v, got %v", http.StatusTeapot, status)
}
expected := "Hello, World!"
if rr.Body.String() != expected {
t.Errorf("expected response body to be %v, got %v", expected, rr.Body.String())
}
if len(logger.messages) == 0 {
t.Errorf("expected log messages, got none")
}
if len(logger.messages) != 0 && !strings.Contains(logger.messages[0], "ERROR: GET /foo") {
t.Errorf("expected log message to contain request info, got %v", logger.messages[0])
}
}
func TestMiddleware_Extended(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte("Hello, World!"))
})
middleware := New(WithContextRequestIdKey("requestId"), WithHeaderRequestIdKey("X-Request-Id")).
Handler(handler)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
rr := httptest.NewRecorder()
middleware.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusTeapot {
t.Errorf("expected status code to be %v, got %v", http.StatusTeapot, status)
}
expected := "Hello, World!"
if rr.Body.String() != expected {
t.Errorf("expected response body to be %v, got %v", expected, rr.Body.String())
}
}
func TestMiddleware_Logger_remoteAddr(t *testing.T) {
logger := &mockLogger{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte("Hello, World!"))
})
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
req.RemoteAddr = "xhamster.com:1234"
rr := httptest.NewRecorder()
middleware.ServeHTTP(rr, req)
}
func TestMiddleware_Logger_remoteAddrNoPort(t *testing.T) {
logger := &mockLogger{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte("Hello, World!"))
})
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
req.RemoteAddr = "xhamster.com"
rr := httptest.NewRecorder()
middleware.ServeHTTP(rr, req)
}
func TestMiddleware_Logger_remoteAddrV6(t *testing.T) {
logger := &mockLogger{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte("Hello, World!"))
})
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
req.RemoteAddr = "[::1]:4711"
rr := httptest.NewRecorder()
middleware.ServeHTTP(rr, req)
}
func TestMiddleware_Logger_remoteAddrV6NoPort(t *testing.T) {
logger := &mockLogger{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte("Hello, World!"))
})
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
req.RemoteAddr = "[::1]"
rr := httptest.NewRecorder()
middleware.ServeHTTP(rr, req)
}

View File

@ -0,0 +1,80 @@
package logging
// options is a struct that contains options for the logging middleware.
// It uses the functional options pattern for flexible configuration.
type options struct {
logLevel LogLevel
logger Logger
prefix string
contextRequestIdKey string
headerRequestIdKey string
}
// Option is a type that is used to set options for the logging middleware.
// It implements the functional options pattern.
type Option func(*options)
// WithLevel is a method that sets the log level for the logging middleware.
// Possible values are LogLevelDebug, LogLevelInfo, LogLevelWarn, and LogLevelError.
// The default value is LogLevelInfo.
func WithLevel(level LogLevel) Option {
return func(o *options) {
o.logLevel = level
}
}
// WithPrefix is a method that sets the prefix for the logging middleware.
// If a prefix is set, it will be prepended to each log message. A space will
// be added between the prefix and the log message.
// The default value is an empty string.
func WithPrefix(prefix string) Option {
return func(o *options) {
o.prefix = prefix
}
}
// WithContextRequestIdKey is a method that sets the key for the request ID in the
// request context. If a key is set, the logging middleware will use this key to
// retrieve the request ID from the request context.
// The default value is an empty string, meaning the request ID will not be logged.
func WithContextRequestIdKey(key string) Option {
return func(o *options) {
o.contextRequestIdKey = key
}
}
// WithHeaderRequestIdKey is a method that sets the key for the request ID in the
// request headers. If a key is set, the logging middleware will use this key to
// retrieve the request ID from the request headers.
// The default value is an empty string, meaning the request ID will not be logged.
func WithHeaderRequestIdKey(key string) Option {
return func(o *options) {
o.headerRequestIdKey = key
}
}
// WithLogger is a method that sets the logger for the logging middleware.
// If a logger is set, the logging middleware will use this logger to log messages.
// The default logger is the structured slog logger.
func WithLogger(logger Logger) Option {
return func(o *options) {
o.logger = logger
}
}
// newOptions is a function that returns a new options struct with sane default values.
func newOptions(opts ...Option) options {
o := options{
logLevel: LogLevelInfo,
logger: nil,
prefix: "",
contextRequestIdKey: "",
}
for _, opt := range opts {
opt(&o)
}
return o
}

View File

@ -0,0 +1,88 @@
package logging
import (
"testing"
)
func TestWithLevel(t *testing.T) {
// table test to check all possible log levels
levels := []LogLevel{
LogLevelDebug,
LogLevelInfo,
LogLevelWarn,
LogLevelError,
}
for _, level := range levels {
opt := WithLevel(level)
o := newOptions(opt)
if o.logLevel != level {
t.Errorf("expected log level to be %v, got %v", level, o.logLevel)
}
}
}
func TestWithPrefix(t *testing.T) {
prefix := "TEST"
opt := WithPrefix(prefix)
o := newOptions(opt)
if o.prefix != prefix {
t.Errorf("expected prefix to be %v, got %v", prefix, o.prefix)
}
}
func TestWithContextRequestIdKey(t *testing.T) {
key := "contextKey"
opt := WithContextRequestIdKey(key)
o := newOptions(opt)
if o.contextRequestIdKey != key {
t.Errorf("expected contextRequestIdKey to be %v, got %v", key, o.contextRequestIdKey)
}
}
func TestWithHeaderRequestIdKey(t *testing.T) {
key := "headerKey"
opt := WithHeaderRequestIdKey(key)
o := newOptions(opt)
if o.headerRequestIdKey != key {
t.Errorf("expected headerRequestIdKey to be %v, got %v", key, o.headerRequestIdKey)
}
}
func TestWithLogger(t *testing.T) {
logger := &mockLogger{}
opt := WithLogger(logger)
o := newOptions(opt)
if o.logger != logger {
t.Errorf("expected logger to be %v, got %v", logger, o.logger)
}
}
func TestDefaults(t *testing.T) {
o := newOptions()
if o.logLevel != LogLevelInfo {
t.Errorf("expected log level to be %v, got %v", LogLevelInfo, o.logLevel)
}
if o.logger != nil {
t.Errorf("expected logger to be nil, got %v", o.logger)
}
if o.prefix != "" {
t.Errorf("expected prefix to be empty, got %v", o.prefix)
}
if o.contextRequestIdKey != "" {
t.Errorf("expected contextRequestIdKey to be empty, got %v", o.contextRequestIdKey)
}
if o.headerRequestIdKey != "" {
t.Errorf("expected headerRequestIdKey to be empty, got %v", o.headerRequestIdKey)
}
}

View File

@ -0,0 +1,45 @@
package logging
import (
"net/http"
)
// writerWrapper wraps a http.ResponseWriter and tracks the number of bytes written to it.
// It also tracks the http response code passed to the WriteHeader func of
// the ResponseWriter.
type writerWrapper struct {
http.ResponseWriter
// StatusCode is the last http response code passed to the WriteHeader func of
// the ResponseWriter. If no such call is made, a default code of http.StatusOK
// is assumed instead.
StatusCode int
// WrittenBytes is the number of bytes successfully written by the Write or
// ReadFrom function of the ResponseWriter. ResponseWriters may also write
// data to their underlaying connection directly (e.g. headers), but those
// are not tracked. Therefor the number of Written bytes will usually match
// the size of the response body.
WrittenBytes int64
}
// WriteHeader wraps the WriteHeader method of the ResponseWriter and tracks the
// http response code passed to it.
func (w *writerWrapper) WriteHeader(code int) {
w.StatusCode = code
w.ResponseWriter.WriteHeader(code)
}
// Write wraps the Write method of the ResponseWriter and tracks the number of bytes
// written to it.
func (w *writerWrapper) Write(data []byte) (int, error) {
n, err := w.ResponseWriter.Write(data)
w.WrittenBytes += int64(n)
return n, err
}
// newWriterWrapper returns a new writerWrapper that wraps the given http.ResponseWriter.
// It initializes the StatusCode to http.StatusOK.
func newWriterWrapper(w http.ResponseWriter) *writerWrapper {
return &writerWrapper{ResponseWriter: w, StatusCode: http.StatusOK}
}

View File

@ -0,0 +1,85 @@
package logging
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestWriterWrapper_WriteHeader(t *testing.T) {
rr := httptest.NewRecorder()
ww := newWriterWrapper(rr)
ww.WriteHeader(http.StatusNotFound)
if ww.StatusCode != http.StatusNotFound {
t.Errorf("expected status code to be %v, got %v", http.StatusNotFound, ww.StatusCode)
}
if rr.Code != http.StatusNotFound {
t.Errorf("expected recorder status code to be %v, got %v", http.StatusNotFound, rr.Code)
}
}
func TestWriterWrapper_Write(t *testing.T) {
rr := httptest.NewRecorder()
ww := newWriterWrapper(rr)
data := []byte("Hello, World!")
n, err := ww.Write(data)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if n != len(data) {
t.Errorf("expected written bytes to be %v, got %v", len(data), n)
}
if ww.WrittenBytes != int64(len(data)) {
t.Errorf("expected WrittenBytes to be %v, got %v", len(data), ww.WrittenBytes)
}
if rr.Body.String() != string(data) {
t.Errorf("expected response body to be %v, got %v", string(data), rr.Body.String())
}
}
func TestWriterWrapper_WriteWithHeaders(t *testing.T) {
rr := httptest.NewRecorder()
ww := newWriterWrapper(rr)
data := []byte("Hello, World!")
n, err := ww.Write(data)
ww.Header().Set("Content-Type", "text/plain")
ww.Header().Set("X-Some-Header", "some-value")
ww.WriteHeader(http.StatusTeapot)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if n != len(data) {
t.Errorf("expected written bytes to be %v, got %v", len(data), n)
}
if ww.WrittenBytes != int64(len(data)) {
t.Errorf("expected WrittenBytes to be %v, got %v", len(data), ww.WrittenBytes)
}
if rr.Body.String() != string(data) {
t.Errorf("expected response body to be %v, got %v", string(data), rr.Body.String())
}
if ww.StatusCode != http.StatusTeapot {
t.Errorf("expected status code to be %v, got %v", http.StatusTeapot, ww.StatusCode)
}
}
func TestNewWriterWrapper(t *testing.T) {
rr := httptest.NewRecorder()
ww := newWriterWrapper(rr)
if ww.StatusCode != http.StatusOK {
t.Errorf("expected initial status code to be %v, got %v", http.StatusOK, ww.StatusCode)
}
if ww.WrittenBytes != 0 {
t.Errorf("expected initial WrittenBytes to be %v, got %v", 0, ww.WrittenBytes)
}
if ww.ResponseWriter != rr {
t.Errorf("expected ResponseWriter to be %v, got %v", rr, ww.ResponseWriter)
}
}

View File

@ -0,0 +1,133 @@
package recovery
import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"os"
"runtime/debug"
"strings"
)
// Logger is an interface that defines the methods that a logger must implement.
// This allows the logging middleware to be used with different logging libraries.
type Logger interface {
// Errorf logs a message at error level.
Errorf(format string, args ...any)
}
// Middleware is a type that creates a new recovery middleware. The recovery middleware
// recovers from panics and returns an Internal Server Error response. This middleware should
// be the first middleware in the middleware chain, so that it can recover from panics in other
// middlewares.
type Middleware struct {
o options
}
// New returns a new recovery middleware with the provided options.
func New(opts ...Option) *Middleware {
o := newOptions(opts...)
m := &Middleware{
o: o,
}
return m
}
// Handler returns the recovery middleware handler.
func (m *Middleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
stack := debug.Stack()
realErr, ok := err.(error)
if !ok {
realErr = fmt.Errorf("%v", err)
}
// Check for a broken connection, as it is not really a
// condition that warrants a panic stack trace.
brokenPipe := isBrokenPipeError(realErr)
// Log the error and stack trace
if m.o.logCallback != nil {
m.o.logCallback(realErr, stack, brokenPipe)
}
switch {
case brokenPipe && m.o.brokenPipeCallback != nil:
m.o.brokenPipeCallback(realErr, stack, w, r)
case !brokenPipe && m.o.errCallback != nil:
m.o.errCallback(realErr, stack, w, r)
default:
// no callback set, simply recover and do nothing...
}
}
}()
next.ServeHTTP(w, r)
})
}
func addPrefix(o options, message string) string {
if o.defaultLogPrefix != "" {
return o.defaultLogPrefix + " " + message
}
return message
}
// defaultErrCallback is the default error callback function for the recovery middleware.
// It writes a JSON response with an Internal Server Error status code. If the exposeStackTrace option is
// enabled, the stack trace is included in the response.
func getDefaultErrCallback(o options) func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {
return func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {
responseBody := map[string]interface{}{
"error": "Internal Server Error",
}
if o.exposeStackTrace && len(stack) > 0 {
responseBody["stack"] = string(stack)
}
jsonBody, _ := json.Marshal(responseBody)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write(jsonBody)
}
}
// getDefaultLogCallback is the default log callback function for the recovery middleware.
// It logs the error and stack trace using the structured slog logger or the provided logger in Error level.
func getDefaultLogCallback(o options) func(error, []byte, bool) {
return func(err error, stack []byte, brokenPipe bool) {
if brokenPipe {
return // by default, ignore broken pipe errors
}
switch {
case o.useSlog:
slog.Error(addPrefix(o, err.Error()), "stack", string(stack))
case o.logger != nil:
o.logger.Errorf(fmt.Sprintf("%s; stacktrace=%s", addPrefix(o, err.Error()), string(stack)))
default:
// no logger set, do nothing...
}
}
}
func isBrokenPipeError(err error) bool {
var syscallErr *os.SyscallError
if errors.As(err, &syscallErr) {
errMsg := strings.ToLower(syscallErr.Err.Error())
if strings.Contains(errMsg, "broken pipe") ||
strings.Contains(errMsg, "connection reset by peer") {
return true
}
}
return false
}

View File

@ -0,0 +1,149 @@
package recovery
import (
"errors"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
)
type mockLogger struct{}
func (m *mockLogger) Errorf(_ string, _ ...any) {}
func TestMiddleware(t *testing.T) {
tests := []struct {
name string
options []Option
panicSimulator func()
expectedStatus int
expectedBody string
expectStack bool
}{
{
name: "default behavior",
options: []Option{},
panicSimulator: func() {
panic(errors.New("test panic"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: `{"error":"Internal Server Error"}`,
},
{
name: "custom error callback",
options: []Option{
WithErrCallback(func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
w.Write([]byte("custom error"))
}),
},
panicSimulator: func() {
panic(errors.New("test panic"))
},
expectedStatus: http.StatusTeapot,
expectedBody: "custom error",
},
{
name: "broken pipe error",
options: []Option{
WithBrokenPipeCallback(func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte("broken pipe"))
}),
},
panicSimulator: func() {
panic(&os.SyscallError{Err: errors.New("broken pipe")})
},
expectedStatus: http.StatusServiceUnavailable,
expectedBody: "broken pipe",
},
{
name: "default callback broken pipe error",
options: nil,
panicSimulator: func() {
panic(&os.SyscallError{Err: errors.New("broken pipe")})
},
expectedStatus: http.StatusOK,
expectedBody: "",
},
{
name: "default callback normal error",
options: nil,
panicSimulator: func() {
panic("something went wrong")
},
expectedStatus: http.StatusInternalServerError,
expectedBody: "{\"error\":\"Internal Server Error\"}",
},
{
name: "default callback with stack trace",
options: []Option{
WithExposeStackTrace(true),
},
panicSimulator: func() {
panic("something went wrong")
},
expectedStatus: http.StatusInternalServerError,
expectedBody: "\"stack\":",
expectStack: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := New(tt.options...).Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tt.panicSimulator()
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != tt.expectedStatus {
t.Errorf("expected status %v, got %v", tt.expectedStatus, rr.Code)
}
if !tt.expectStack && rr.Body.String() != tt.expectedBody {
t.Errorf("expected body %v, got %v", tt.expectedBody, rr.Body.String())
}
if tt.expectStack && !strings.Contains(rr.Body.String(), tt.expectedBody) {
t.Errorf("expected body to contain %v, got %v", tt.expectedBody, rr.Body.String())
}
})
}
}
func TestIsBrokenPipeError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "broken pipe error",
err: &os.SyscallError{Err: errors.New("broken pipe")},
expected: true,
},
{
name: "connection reset by peer error",
err: &os.SyscallError{Err: errors.New("connection reset by peer")},
expected: true,
},
{
name: "other error",
err: errors.New("other error"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isBrokenPipeError(tt.err)
if result != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, result)
}
})
}
}

View File

@ -0,0 +1,129 @@
package recovery
import "net/http"
// options is a struct that contains options for the recovery middleware.
// It uses the functional options pattern for flexible configuration.
type options struct {
logger Logger
useSlog bool
errCallbackOverride bool
errCallback func(err error, stack []byte, w http.ResponseWriter, r *http.Request)
brokenPipeCallbackOverride bool
brokenPipeCallback func(err error, stack []byte, w http.ResponseWriter, r *http.Request)
exposeStackTrace bool
defaultLogPrefix string
logCallbackOverride bool
logCallback func(err error, stack []byte, brokenPipe bool)
}
// Option is a type that is used to set options for the recovery middleware.
// It implements the functional options pattern.
type Option func(*options)
// WithErrCallback sets the error callback function for the recovery middleware.
// The error callback function is called when a panic is recovered by the middleware.
// This function completely overrides the default behavior of the middleware. It is the
// responsibility of the user to handle the error and write a response to the client.
//
// Ensure that this function does not panic, as it will be called in a deferred function!
func WithErrCallback(fn func(err error, stack []byte, w http.ResponseWriter, r *http.Request)) Option {
return func(o *options) {
o.errCallback = fn
o.errCallbackOverride = true
}
}
// WithBrokenPipeCallback sets the broken pipe callback function for the recovery middleware.
// The broken pipe callback function is called when a broken pipe error is recovered by the middleware.
// This function completely overrides the default behavior of the middleware. It is the responsibility
// of the user to handle the error and write a response to the client.
//
// Ensure that this function does not panic, as it will be called in a deferred function!
func WithBrokenPipeCallback(fn func(err error, stack []byte, w http.ResponseWriter, r *http.Request)) Option {
return func(o *options) {
o.brokenPipeCallback = fn
o.brokenPipeCallbackOverride = true
}
}
// WithLogCallback sets the log callback function for the recovery middleware.
// The log callback function is called when a panic is recovered by the middleware.
// This function allows the user to log the error and stack trace. The default behavior is to log
// the error and stack trace in Error level.
// This function completely overrides the default behavior of the middleware.
//
// Ensure that this function does not panic, as it will be called in a deferred function!
func WithLogCallback(fn func(err error, stack []byte, brokenPipe bool)) Option {
return func(o *options) {
o.logCallback = fn
o.logCallbackOverride = true
}
}
// WithLogger is a method that sets the logger for the logging middleware.
// If a logger is set, the logging middleware will use this logger to log messages.
// The default logger is the structured slog logger, see WithSlog.
func WithLogger(logger Logger) Option {
return func(o *options) {
o.logger = logger
}
}
// WithSlog is a method that sets whether the recovery middleware should use the structured slog logger.
// If set to true, the middleware will use the structured slog logger. If set to false, the middleware
// will not use any logger unless one is explicitly set with the WithLogger option.
// The default value is true.
func WithSlog(useSlog bool) Option {
return func(o *options) {
o.useSlog = useSlog
}
}
// WithDefaultLogPrefix is a method that sets the default log prefix for the recovery middleware.
// If a default log prefix is set and the default log callback is used, the prefix will be prepended
// to each log message. A space will be added between the prefix and the log message.
// The default value is an empty string.
func WithDefaultLogPrefix(defaultLogPrefix string) Option {
return func(o *options) {
o.defaultLogPrefix = defaultLogPrefix
}
}
// WithExposeStackTrace is a method that sets whether the stack trace should be exposed in the response.
// If set to true, the stack trace will be included in the response body. If set to false, the stack trace
// will not be included in the response body. This only applies to the default error callback.
// The default value is false.
func WithExposeStackTrace(exposeStackTrace bool) Option {
return func(o *options) {
o.exposeStackTrace = exposeStackTrace
}
}
// newOptions is a function that returns a new options struct with sane default values.
func newOptions(opts ...Option) options {
o := options{
logger: nil,
useSlog: true,
errCallback: nil,
brokenPipeCallback: nil, // by default, ignore broken pipe errors
exposeStackTrace: false,
defaultLogPrefix: "",
logCallback: nil,
}
for _, opt := range opts {
opt(&o)
}
if o.errCallback == nil && !o.errCallbackOverride {
o.errCallback = getDefaultErrCallback(o)
}
if o.logCallback == nil && !o.logCallbackOverride {
o.logCallback = getDefaultLogCallback(o)
}
return o
}

View File

@ -0,0 +1,100 @@
package recovery
import (
"net/http"
"testing"
)
func TestWithErrCallback(t *testing.T) {
callback := func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {}
opt := WithErrCallback(callback)
o := newOptions(opt)
if o.errCallback == nil {
t.Errorf("expected errCallback to be set, got nil")
}
}
func TestWithBrokenPipeCallback(t *testing.T) {
callback := func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {}
opt := WithBrokenPipeCallback(callback)
o := newOptions(opt)
if o.brokenPipeCallback == nil {
t.Errorf("expected brokenPipeCallback to be set, got nil")
}
}
func TestWithLogCallback(t *testing.T) {
callback := func(err error, stack []byte, brokenPipe bool) {}
opt := WithLogCallback(callback)
o := newOptions(opt)
if o.logCallback == nil {
t.Errorf("expected logCallback to be set, got nil")
}
}
func TestWithLogger(t *testing.T) {
logger := &mockLogger{}
opt := WithLogger(logger)
o := newOptions(opt)
if o.logger != logger {
t.Errorf("expected logger to be %v, got %v", logger, o.logger)
}
}
func TestWithSlog(t *testing.T) {
opt := WithSlog(false)
o := newOptions(opt)
if o.useSlog != false {
t.Errorf("expected useSlog to be false, got %v", o.useSlog)
}
}
func TestWithDefaultLogPrefix(t *testing.T) {
prefix := "PREFIX"
opt := WithDefaultLogPrefix(prefix)
o := newOptions(opt)
if o.defaultLogPrefix != prefix {
t.Errorf("expected defaultLogPrefix to be %v, got %v", prefix, o.defaultLogPrefix)
}
}
func TestWithExposeStackTrace(t *testing.T) {
opt := WithExposeStackTrace(true)
o := newOptions(opt)
if o.exposeStackTrace != true {
t.Errorf("expected exposeStackTrace to be true, got %v", o.exposeStackTrace)
}
}
func TestNewOptionsDefaults(t *testing.T) {
o := newOptions()
if o.logger != nil {
t.Errorf("expected logger to be nil, got %v", o.logger)
}
if o.useSlog != true {
t.Errorf("expected useSlog to be true, got %v", o.useSlog)
}
if o.errCallback == nil {
t.Errorf("expected errCallback to be set, got nil")
}
if o.brokenPipeCallback != nil {
t.Errorf("expected brokenPipeCallback to be nil, got %T", o.brokenPipeCallback)
}
if o.exposeStackTrace != false {
t.Errorf("expected exposeStackTrace to be false, got %T", o.exposeStackTrace)
}
if o.defaultLogPrefix != "" {
t.Errorf("expected defaultLogPrefix to be empty, got %T", o.defaultLogPrefix)
}
if o.logCallback == nil {
t.Errorf("expected logCallback to be set, got nil")
}
}

View File

@ -0,0 +1,69 @@
package tracing
import (
"context"
"math/rand"
"net/http"
)
// Middleware is a type that creates a new tracing middleware. The tracing middleware
// can be used to trace requests based on a request ID header or parameter.
type Middleware struct {
o options
seededRand *rand.Rand
}
// New returns a new CORS middleware with the provided options.
func New(opts ...Option) *Middleware {
o := newOptions(opts...)
m := &Middleware{
o: o,
seededRand: rand.New(rand.NewSource(o.generateSeed)),
}
return m
}
// Handler returns the tracing middleware handler.
func (m *Middleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var reqId string
// read upstream header und re-use it
if m.o.upstreamReqIdHeader != "" {
reqId = r.Header.Get(m.o.upstreamReqIdHeader)
}
// generate new id
if reqId == "" && m.o.generateLength > 0 {
reqId = m.generateRandomId()
}
// set response header
if m.o.headerIdentifier != "" {
w.Header().Set(m.o.headerIdentifier, reqId)
}
// set context value
if m.o.contextIdentifier != "" {
ctx := context.WithValue(r.Context(), m.o.contextIdentifier, reqId)
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r) // execute the next handler
})
}
// region internal-helpers
func (m *Middleware) generateRandomId() string {
b := make([]byte, m.o.generateLength)
for i := range b {
b[i] = m.o.generateCharset[m.seededRand.Intn(len(m.o.generateCharset))]
}
return string(b)
}
// endregion internal-helpers

View File

@ -0,0 +1,118 @@
package tracing
import (
"net/http"
"net/http/httptest"
"testing"
)
const defaultLength = 8
const upstreamHeaderValue = "upstream-id"
func TestMiddleware_Handler_WithUpstreamHeader(t *testing.T) {
m := New(WithUpstreamHeader("X-Upstream-Id"))
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqId := r.Header.Get("X-Upstream-Id")
if reqId != upstreamHeaderValue {
t.Errorf("expected upstream request id to be 'upstream-id', got %s", reqId)
}
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Upstream-Id", upstreamHeaderValue)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Header().Get("X-Request-Id") != upstreamHeaderValue {
t.Errorf("expected X-Request-Id header to be set in the response")
}
}
func TestMiddleware_Handler_GenerateNewId(t *testing.T) {
idLen := 18
m := New(WithIdLength(idLen))
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqId := w.Header().Get("X-Request-Id")
if len(reqId) != 18 {
t.Errorf("expected generated request id length to be %d, got %d", idLen, len(reqId))
}
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Header().Get("X-Request-Id") == "" || len(rr.Header().Get("X-Request-Id")) != idLen {
t.Errorf("expected X-Request-Id header to be set in the response")
}
}
func TestMiddleware_Handler_SetContextValue(t *testing.T) {
m := New()
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqId := r.Context().Value("RequestId").(string)
if reqId == "" || len(reqId) != defaultLength {
t.Errorf("expected context request id to be set, got empty string")
}
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
}
func TestMiddleware_Handler_SetCustomContextValue(t *testing.T) {
m := New(WithContextIdentifier("Custom-Id"))
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqId := r.Context().Value("Custom-Id").(string)
if reqId == "" || len(reqId) != defaultLength {
t.Errorf("expected context request id to be set, got empty string")
}
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
}
func TestMiddleware_Handler_NoIdGenerated(t *testing.T) {
m := New(WithIdLength(0))
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqId := w.Header().Get("X-Request-Id")
if reqId != "" {
t.Errorf("expected no request id to be generated, got %s", reqId)
}
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
}
func TestMiddleware_Handler_NoIdHeaderSet(t *testing.T) {
m := New(WithHeaderIdentifier(""))
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqId := w.Header().Get("X-Request-Id")
if reqId != "" {
t.Errorf("expected no request id to be generated, got %s", reqId)
}
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
}
func TestMiddleware_Handler_NoIdContextSet(t *testing.T) {
m := New(WithHeaderIdentifier(""))
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqId := r.Context().Value("Request-Id")
if reqId != nil {
t.Errorf("expected no context request id to be set, got %v", reqId)
}
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
}

View File

@ -0,0 +1,85 @@
package tracing
import "time"
// options is a struct that contains options for the tracing middleware.
// It uses the functional options pattern for flexible configuration.
type options struct {
upstreamReqIdHeader string
headerIdentifier string
contextIdentifier string
generateLength int
generateCharset string
generateSeed int64
}
// Option is a type that is used to set options for the tracing middleware.
// It implements the functional options pattern.
type Option func(*options)
// WithIdSeed sets the seed for the random request id.
// If no seed is provided, the current timestamp is used.
func WithIdSeed(seed int64) Option {
return func(o *options) {
o.generateSeed = seed
}
}
// WithIdCharset sets the charset that is used to generate a random request id.
// By default, upper-case letters and numbers are used.
func WithIdCharset(charset string) Option {
return func(o *options) {
o.generateCharset = charset
}
}
// WithIdLength specifies the length of generated random ids.
// By default, a length of 8 is used. If the length is 0, no request id will be generated.
func WithIdLength(len int) Option {
return func(o *options) {
o.generateLength = len
}
}
// WithHeaderIdentifier specifies the header name for the request id that is added to the response headers.
// If the identifier is empty, the request id will not be added to the response headers.
func WithHeaderIdentifier(identifier string) Option {
return func(o *options) {
o.headerIdentifier = identifier
}
}
// WithUpstreamHeader sets the upstream header name, that should be used to fetch the request id.
// If no upstream header is found, a random id will be generated if the id-length parameter is set to a value > 0.
func WithUpstreamHeader(header string) Option {
return func(o *options) {
o.upstreamReqIdHeader = header
}
}
// WithContextIdentifier specifies the value-key for the request id that is added to the request context.
// If the identifier is empty, the request id will not be added to the context.
// If the request id is added to the context, it can be retrieved with:
// `id := r.Context().Value(THE-IDENTIFIER).(string)`
func WithContextIdentifier(identifier string) Option {
return func(o *options) {
o.contextIdentifier = identifier
}
}
// newOptions is a function that returns a new options struct with sane default values.
func newOptions(opts ...Option) options {
o := options{
headerIdentifier: "X-Request-Id",
contextIdentifier: "RequestId",
generateSeed: time.Now().UnixNano(),
generateCharset: "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789",
generateLength: 8,
}
for _, opt := range opts {
opt(&o)
}
return o
}

View File

@ -0,0 +1,75 @@
package tracing
import (
"testing"
)
func TestWithIdSeed(t *testing.T) {
o := newOptions(WithIdSeed(12345))
if o.generateSeed != 12345 {
t.Errorf("expected generateSeed to be 12345, got %d", o.generateSeed)
}
}
func TestWithIdCharset(t *testing.T) {
o := newOptions(WithIdCharset("abc123"))
if o.generateCharset != "abc123" {
t.Errorf("expected generateCharset to be 'abc123', got %s", o.generateCharset)
}
}
func TestWithIdLength(t *testing.T) {
o := newOptions(WithIdLength(16))
if o.generateLength != 16 {
t.Errorf("expected generateLength to be 16, got %d", o.generateLength)
}
}
func TestWithHeaderIdentifier(t *testing.T) {
o := newOptions(WithHeaderIdentifier("X-Custom-Id"))
if o.headerIdentifier != "X-Custom-Id" {
t.Errorf("expected headerIdentifier to be 'X-Custom-Id', got %s", o.headerIdentifier)
}
}
func TestWithUpstreamHeader(t *testing.T) {
o := newOptions(WithUpstreamHeader("X-Upstream-Id"))
if o.upstreamReqIdHeader != "X-Upstream-Id" {
t.Errorf("expected upstreamReqIdHeader to be 'X-Upstream-Id', got %s", o.upstreamReqIdHeader)
}
}
func TestWithContextIdentifier(t *testing.T) {
o := newOptions(WithContextIdentifier("Request-Id"))
if o.contextIdentifier != "Request-Id" {
t.Errorf("expected contextIdentifier to be 'Request-Id', got %s", o.contextIdentifier)
}
}
func TestDefaults(t *testing.T) {
o := newOptions()
if o.generateLength != 8 {
t.Errorf("expected generateLength to be 8, got %d", o.generateLength)
}
if o.generateCharset != "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" {
t.Errorf("expected generateCharset to be 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', got %s", o.generateCharset)
}
if o.generateSeed == 0 {
t.Errorf("expected generateSeed to be non-zero")
}
if o.headerIdentifier != "X-Request-Id" {
t.Errorf("expected headerIdentifier to be 'X-Request-Id', got %s", o.headerIdentifier)
}
if o.upstreamReqIdHeader != "" {
t.Errorf("expected upstreamReqIdHeader to be empty, got %s", o.upstreamReqIdHeader)
}
if o.contextIdentifier != "RequestId" {
t.Errorf("expected contextIdentifier to be 'RequestId', got %s", o.contextIdentifier)
}
}

View File

@ -0,0 +1,259 @@
// Package request provides functions to extract parameters from the request.
package request
import (
"encoding/json"
"io"
"net"
"net/http"
"net/textproto"
"slices"
"strings"
)
const CheckPrivateProxy = "PRIVATE"
// PathRaw returns the value of the named path parameter.
func PathRaw(r *http.Request, name string) string {
return r.PathValue(name)
}
// Path returns the value of the named path parameter.
// The return value is trimmed of leading and trailing whitespace.
func Path(r *http.Request, name string) string {
return strings.TrimSpace(PathRaw(r, name))
}
// PathDefault returns the value of the named path parameter.
// If the parameter is empty, it returns the default value.
// The return value is trimmed of leading and trailing whitespace.
func PathDefault(r *http.Request, name string, defaultValue string) string {
value := r.PathValue(name)
if value == "" {
return defaultValue
}
return Path(r, name)
}
// QueryRaw returns the value of the named query parameter.
func QueryRaw(r *http.Request, name string) string {
return r.URL.Query().Get(name)
}
// Query returns the value of the named query parameter.
// The return value is trimmed of leading and trailing whitespace.
func Query(r *http.Request, name string) string {
return strings.TrimSpace(QueryRaw(r, name))
}
// QueryDefault returns the value of the named query parameter.
// If the parameter is empty, it returns the default value.
// The return value is trimmed of leading and trailing whitespace.
func QueryDefault(r *http.Request, name string, defaultValue string) string {
if !r.URL.Query().Has(name) {
return defaultValue
}
return Query(r, name)
}
// QuerySlice returns the value of the named query parameter.
// All slice values are trimmed of leading and trailing whitespace.
func QuerySlice(r *http.Request, name string) []string {
values, ok := r.URL.Query()[name]
if !ok {
return nil
}
result := make([]string, len(values))
for i, value := range values {
result[i] = strings.TrimSpace(value)
}
return result
}
// QuerySliceDefault returns the value of the named query parameter.
// If the parameter is empty, it returns the default value.
// All slice values are trimmed of leading and trailing whitespace.
func QuerySliceDefault(r *http.Request, name string, defaultValue []string) []string {
if !r.URL.Query().Has(name) {
return defaultValue
}
return QuerySlice(r, name)
}
// FragmentRaw returns the value of the named fragment parameter.
func FragmentRaw(r *http.Request) string {
return r.URL.Fragment
}
// Fragment returns the value of the named fragment parameter.
// The return value is trimmed of leading and trailing whitespace.
func Fragment(r *http.Request) string {
return strings.TrimSpace(FragmentRaw(r))
}
// FragmentDefault returns the value of the named fragment parameter.
// If the parameter is empty, it returns the default value.
// The return value is trimmed of leading and trailing whitespace.
func FragmentDefault(r *http.Request, defaultValue string) string {
if r.URL.Fragment == "" {
return defaultValue
}
return Fragment(r)
}
// FormRaw returns the value of the named form parameter.
func FormRaw(r *http.Request, name string) string {
return r.FormValue(name)
}
// Form returns the value of the named form parameter.
// The return value is trimmed of leading and trailing whitespace.
func Form(r *http.Request, name string) string {
return strings.TrimSpace(FormRaw(r, name))
}
// DefaultForm returns the value of the named form parameter.
// If the parameter is not set, it returns the default value.
// The return value is trimmed of leading and trailing whitespace.
func DefaultForm(r *http.Request, name, defaultValue string) string {
err := r.ParseForm()
if err != nil {
return defaultValue
}
if !r.Form.Has(name) {
return defaultValue
}
return Form(r, name)
}
// HeaderRaw returns the value of the named header.
func HeaderRaw(r *http.Request, name string) string {
return r.Header.Get(name)
}
// Header returns the value of the named header.
// The return value is trimmed of leading and trailing whitespace.
func Header(r *http.Request, name string) string {
return strings.TrimSpace(HeaderRaw(r, name))
}
// HeaderDefault returns the value of the named header.
// If the header is not set, it returns the default value.
// The return value is trimmed of leading and trailing whitespace.
func HeaderDefault(r *http.Request, name, defaultValue string) string {
if _, ok := textproto.MIMEHeader(r.Header)[name]; !ok {
return defaultValue
}
return Header(r, name)
}
// Cookie returns the value of the named cookie.
// The return value is trimmed of leading and trailing whitespace.
func Cookie(r *http.Request, name string) string {
cookie, err := r.Cookie(name)
if err != nil {
return ""
}
return strings.TrimSpace(cookie.Value)
}
// CookieDefault returns the value of the named cookie.
// If the cookie is not set, it returns the default value.
// The return value is trimmed of leading and trailing whitespace.
func CookieDefault(r *http.Request, name, defaultValue string) string {
cookie, err := r.Cookie(name)
if err != nil {
return defaultValue
}
return strings.TrimSpace(cookie.Value)
}
// ClientIp returns the client IP address.
//
// As the request may come from a proxy, the function checks the
// X-Real-Ip and X-Forwarded-For headers to get the real client IP
// if the request IP matches one of the allowed proxy IPs.
// If the special proxy value CheckPrivateProxy ("PRIVATE") is passed, the function will
// also check the header if the request IP is a private IP address.
func ClientIp(r *http.Request, allowedProxyIp ...string) string {
ipStr, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
switch {
case err != nil && strings.Contains(err.Error(), "missing port in address"):
ipStr = strings.TrimSpace(r.RemoteAddr)
case err != nil:
ipStr = ""
}
IP := net.ParseIP(ipStr)
if IP == nil {
return ""
}
isProxiedRequest := false
if len(allowedProxyIp) > 0 {
if slices.Contains(allowedProxyIp, IP.String()) {
isProxiedRequest = true
}
if IP.IsPrivate() && slices.Contains(allowedProxyIp, CheckPrivateProxy) {
isProxiedRequest = true
}
}
if isProxiedRequest {
realClientIP := r.Header.Get("X-Real-Ip")
if realClientIP == "" {
realClientIP = r.Header.Get("X-Forwarded-For")
}
if realClientIP != "" {
realIpStr, _, err := net.SplitHostPort(strings.TrimSpace(realClientIP))
switch {
case err != nil && strings.Contains(err.Error(), "missing port in address"):
realIpStr = realClientIP
case err != nil:
realIpStr = ipStr
}
realIP := net.ParseIP(realIpStr)
if realIP == nil {
return IP.String()
}
return realIP.String()
}
}
return IP.String()
}
// BodyJson decodes the JSON value from the request body into the target.
// The target must be a pointer to a struct or slice.
// The function returns an error if the JSON value could not be decoded.
// The body reader is closed after reading.
func BodyJson(r *http.Request, target any) error {
defer func() {
_ = r.Body.Close()
}()
return json.NewDecoder(r.Body).Decode(target)
}
// BodyString returns the request body as a string.
// The content is read and returned as is, without any processing.
// The body is assumed to be UTF-8 encoded.
func BodyString(r *http.Request) (string, error) {
defer func() {
_ = r.Body.Close()
}()
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
return "", err
}
return string(bodyBytes), nil
}

View File

@ -0,0 +1,221 @@
package request
import (
"io"
"net/http"
"net/url"
"slices"
"strings"
"testing"
)
func TestPath(t *testing.T) {
r := &http.Request{URL: &url.URL{Path: "/test/sample"}}
r.SetPathValue("first", "test")
if got := Path(r, "first"); got != "test" {
t.Errorf("Path() = %v, want %v", got, "test")
}
}
func TestDefaultPath(t *testing.T) {
r := &http.Request{URL: &url.URL{Path: "/"}}
if got := PathDefault(r, "test", "default"); got != "default" {
t.Errorf("PathDefault() = %v, want %v", got, "default")
}
}
func TestDefaultPath_noDefault(t *testing.T) {
r := &http.Request{URL: &url.URL{Path: "/"}}
r.SetPathValue("first", "test")
if got := PathDefault(r, "first", "test"); got != "test" {
t.Errorf("PathDefault() = %v, want %v", got, "test")
}
}
func TestQuery(t *testing.T) {
r := &http.Request{URL: &url.URL{RawQuery: "name=value"}}
if got := Query(r, "name"); got != "value" {
t.Errorf("Query() = %v, want %v", got, "value")
}
}
func TestDefaultQuery(t *testing.T) {
r := &http.Request{URL: &url.URL{RawQuery: ""}}
if got := QueryDefault(r, "name", "default"); got != "default" {
t.Errorf("QueryDefault() = %v, want %v", got, "default")
}
}
func TestQuerySlice(t *testing.T) {
r := &http.Request{URL: &url.URL{RawQuery: "name=value1 &name=value2"}}
expected := []string{"value1", "value2"}
if got := QuerySlice(r, "name"); !slices.Equal(got, expected) {
t.Errorf("QuerySlice() = %v, want %v", got, expected)
}
}
func TestQuerySlice_empty(t *testing.T) {
r := &http.Request{URL: &url.URL{RawQuery: "name=value1&name=value2"}}
if got := QuerySlice(r, "nix"); !slices.Equal(got, nil) {
t.Errorf("QuerySlice() = %v, want %v", got, nil)
}
}
func TestDefaultQuerySlice(t *testing.T) {
r := &http.Request{URL: &url.URL{RawQuery: ""}}
defaultValue := []string{"default1", "default2"}
if got := QuerySliceDefault(r, "name", defaultValue); !slices.Equal(got, defaultValue) {
t.Errorf("QuerySliceDefault() = %v, want %v", got, defaultValue)
}
}
func TestFragment(t *testing.T) {
r := &http.Request{URL: &url.URL{Fragment: "section"}}
if got := Fragment(r); got != "section" {
t.Errorf("Fragment() = %v, want %v", got, "section")
}
}
func TestDefaultFragment(t *testing.T) {
r := &http.Request{URL: &url.URL{Fragment: ""}}
if got := FragmentDefault(r, "default"); got != "default" {
t.Errorf("FragmentDefault() = %v, want %v", got, "default")
}
}
func TestForm(t *testing.T) {
r := &http.Request{Form: url.Values{"name": {"value"}}}
if got := Form(r, "name"); got != "value" {
t.Errorf("Form() = %v, want %v", got, "value")
}
}
func TestDefaultForm(t *testing.T) {
r := &http.Request{Form: url.Values{}}
if got := DefaultForm(r, "name", "default"); got != "default" {
t.Errorf("DefaultForm() = %v, want %v", got, "default")
}
}
func TestHeader(t *testing.T) {
r := &http.Request{Header: http.Header{"X-Test-Header": {"value"}}}
if got := Header(r, "X-Test-Header"); got != "value" {
t.Errorf("Header() = %v, want %v", got, "value")
}
}
func TestDefaultHeader(t *testing.T) {
r := &http.Request{Header: http.Header{}}
if got := HeaderDefault(r, "X-Test-Header", "default"); got != "default" {
t.Errorf("HeaderDefault() = %v, want %v", got, "default")
}
}
func TestCookie(t *testing.T) {
r := &http.Request{Header: http.Header{"Cookie": {"name=value"}}}
if got := Cookie(r, "name"); got != "value" {
t.Errorf("Cookie() = %v, want %v", got, "value")
}
}
func TestDefaultCookie(t *testing.T) {
r := &http.Request{Header: http.Header{}}
if got := CookieDefault(r, "name", "default"); got != "default" {
t.Errorf("CookieDefault() = %v, want %v", got, "default")
}
}
func TestClientIp(t *testing.T) {
r := &http.Request{RemoteAddr: "192.168.1.1:12345"}
if got := ClientIp(r); got != "192.168.1.1" {
t.Errorf("ClientIp() = %v, want %v", got, "192.168.1.1")
}
}
func TestClientIp_invalid(t *testing.T) {
r := &http.Request{RemoteAddr: "was_isn_des"}
if got := ClientIp(r); got != "" {
t.Errorf("ClientIp() = %v, want %v", got, "")
}
}
func TestClientIp_ignore_header(t *testing.T) {
r := &http.Request{RemoteAddr: "192.168.1.1:12345", Header: http.Header{"X-Forwarded-For": {"123.45.67.1"}}}
if got := ClientIp(r); got != "192.168.1.1" {
t.Errorf("ClientIp() = %v, want %v", got, "192.168.1.1")
}
}
func TestClientIp_header1(t *testing.T) {
r := &http.Request{RemoteAddr: "192.168.1.1:12345", Header: http.Header{"X-Forwarded-For": {"123.45.67.1"}}}
if got := ClientIp(r, CheckPrivateProxy); got != "123.45.67.1" {
t.Errorf("ClientIp() = %v, want %v", got, "123.45.67.1")
}
}
func TestClientIp_header2(t *testing.T) {
r := &http.Request{RemoteAddr: "192.168.1.1:12345", Header: http.Header{"X-Real-Ip": {"123.45.67.1"}}}
if got := ClientIp(r, CheckPrivateProxy); got != "123.45.67.1" {
t.Errorf("ClientIp() = %v, want %v", got, "123.45.67.1")
}
}
func TestClientIp_header3(t *testing.T) {
r := &http.Request{RemoteAddr: "1.1.1.1:12345", Header: http.Header{"X-Real-Ip": {"123.45.67.1"}}}
if got := ClientIp(r, "1.1.1.1"); got != "123.45.67.1" {
t.Errorf("ClientIp() = %v, want %v", got, "123.45.67.1")
}
}
func TestClientIp_header4(t *testing.T) {
r := &http.Request{RemoteAddr: "8.8.8.8:12345", Header: http.Header{"X-Real-Ip": {"123.45.67.1"}}}
if got := ClientIp(r, "1.1.1.1"); got != "8.8.8.8" {
t.Errorf("ClientIp() = %v, want %v", got, "8.8.8.8")
}
}
func TestClientIp_header_invalid(t *testing.T) {
r := &http.Request{RemoteAddr: "1.1.1.1:12345", Header: http.Header{"X-Real-Ip": {"so-sicher-nit"}}}
if got := ClientIp(r, "1.1.1.1"); got != "1.1.1.1" {
t.Errorf("ClientIp() = %v, want %v", got, "1.1.1.1")
}
}
func TestBodyJson(t *testing.T) {
type TestStruct struct {
Name string `json:"name"`
Value int `json:"value"`
}
jsonStr := `{"name": "test", "value": 123}`
r := &http.Request{
Body: io.NopCloser(strings.NewReader(jsonStr)),
}
var result TestStruct
err := BodyJson(r, &result)
if err != nil {
t.Fatalf("BodyJson() error = %v", err)
}
expected := TestStruct{Name: "test", Value: 123}
if result != expected {
t.Errorf("BodyJson() = %v, want %v", result, expected)
}
}
func TestBodyString(t *testing.T) {
bodyStr := "test body content"
r := &http.Request{
Body: io.NopCloser(strings.NewReader(bodyStr)),
}
result, err := BodyString(r)
if err != nil {
t.Fatalf("BodyString() error = %v", err)
}
if result != bodyStr {
t.Errorf("BodyString() = %v, want %v", result, bodyStr)
}
}

View File

@ -0,0 +1,100 @@
// Package respond provides a set of utility functions to help with the HTTP response handling.
package respond
import (
"encoding/json"
"io"
"net/http"
"strconv"
)
// Status writes a response with the given status code.
// The response will not contain any data.
func Status(w http.ResponseWriter, code int) {
w.WriteHeader(code)
}
// String writes a plain text response with the given status code and data.
// The Content-Type header is set to text/plain with a charset of utf-8.
func String(w http.ResponseWriter, code int, data string) {
w.Header().Set("Content-Type", "text/plain;charset=utf-8")
w.WriteHeader(code)
_, _ = w.Write([]byte(data))
}
// JSON writes a JSON response with the given status code and data.
// If data is nil, the response will null. The status code is set to the given code.
// The Content-Type header is set to application/json.
// If the given data is not JSON serializable, the response will not contain any data.
// All encoding errors are silently ignored.
func JSON(w http.ResponseWriter, code int, data any) {
w.Header().Set("Content-Type", "application/json")
// if no data was given, simply return null
if data == nil {
w.WriteHeader(code)
_, _ = w.Write([]byte("null"))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
_ = json.NewEncoder(w).Encode(data)
}
// Data writes a response with the given status code, content type, and data.
// If no content type is provided, it is detected from the data.
func Data(w http.ResponseWriter, code int, contentType string, data []byte) {
if contentType == "" {
contentType = http.DetectContentType(data) // ensure content type is set
}
w.Header().Set("Content-Type", contentType)
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
w.WriteHeader(code)
_, _ = w.Write(data)
}
// Reader writes a response with the given status code, content type, and data.
// The content length is optional, it is only set if the given length is greater than 0.
func Reader(w http.ResponseWriter, code int, contentType string, contentLength int, data io.Reader) {
w.Header().Set("Content-Type", contentType)
if contentLength > 0 {
w.Header().Set("Content-Length", strconv.Itoa(contentLength))
}
w.WriteHeader(code)
_, _ = io.Copy(w, data)
}
// Attachment writes a response with the given status code, content type, filename, and data.
// If no content type is provided, it is detected from the data.
func Attachment(w http.ResponseWriter, code int, filename, contentType string, data []byte) {
w.Header().Set("Content-Disposition", "attachment; filename="+filename)
Data(w, code, contentType, data)
}
// AttachmentReader writes a response with the given status code, content type, filename, content length, and data.
// The content length is optional, it is only set if the given length is greater than 0.
func AttachmentReader(
w http.ResponseWriter,
code int,
filename, contentType string,
contentLength int,
data io.Reader,
) {
w.Header().Set("Content-Disposition", "attachment; filename="+filename)
Reader(w, code, contentType, contentLength, data)
}
// Redirect writes a response with the given status code and redirects to the given URL.
// The redirect url will always be an absolute URL. If the given URL is relative,
// the original request URL is used as the base.
func Redirect(w http.ResponseWriter, r *http.Request, code int, url string) {
http.Redirect(w, r, url, code)
}

View File

@ -0,0 +1,273 @@
package respond
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strconv"
"testing"
)
func TestStatus(t *testing.T) {
rec := httptest.NewRecorder()
Status(rec, http.StatusNoContent)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
t.Errorf("expected status %d, got %d", http.StatusNoContent, res.StatusCode)
}
body, _ := io.ReadAll(res.Body)
if len(body) != 0 {
t.Errorf("expected no body, got %s", body)
}
}
func TestString(t *testing.T) {
rec := httptest.NewRecorder()
String(rec, http.StatusOK, "Hello, World!")
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain;charset=utf-8" {
t.Errorf("expected content type %s, got %s", "text/plain;charset=utf-8", contentType)
}
body, _ := io.ReadAll(res.Body)
if string(body) != "Hello, World!" {
t.Errorf("expected body %s, got %s", "Hello, World!", string(body))
}
}
func TestJSON(t *testing.T) {
rec := httptest.NewRecorder()
data := map[string]string{"hello": "world"}
JSON(rec, http.StatusOK, data)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentType := res.Header.Get("Content-Type"); contentType != "application/json" {
t.Errorf("expected content type %s, got %s", "application/json", contentType)
}
var body map[string]string
_ = json.NewDecoder(res.Body).Decode(&body)
if body["hello"] != "world" {
t.Errorf("expected body %v, got %v", data, body)
}
}
func TestJSON_empty(t *testing.T) {
rec := httptest.NewRecorder()
JSON(rec, http.StatusOK, nil)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentType := res.Header.Get("Content-Type"); contentType != "application/json" {
t.Errorf("expected content type %s, got %s", "application/json", contentType)
}
body, _ := io.ReadAll(res.Body)
if string(body) != "null" {
t.Errorf("expected body %s, got %s", "null", body)
}
}
func TestData(t *testing.T) {
rec := httptest.NewRecorder()
data := []byte("Hello, World!")
Data(rec, http.StatusOK, "text/plain", data)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" {
t.Errorf("expected content type %s, got %s", "text/plain", contentType)
}
body, _ := io.ReadAll(res.Body)
if !bytes.Equal(body, data) {
t.Errorf("expected body %s, got %s", data, body)
}
}
func TestData_noContentType(t *testing.T) {
rec := httptest.NewRecorder()
data := []byte{0x1, 0x2, 0x3, 0x4, 0x5}
Data(rec, http.StatusOK, "", data)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentType := res.Header.Get("Content-Type"); contentType != "application/octet-stream" {
t.Errorf("expected content type %s, got %s", "application/octet-stream", contentType)
}
body, _ := io.ReadAll(res.Body)
if !bytes.Equal(body, data) {
t.Errorf("expected body %s, got %s", data, body)
}
}
func TestReader(t *testing.T) {
rec := httptest.NewRecorder()
data := []byte("Hello, World!")
reader := bytes.NewBufferString(string(data))
Reader(rec, http.StatusOK, "text/plain", len(data), reader)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" {
t.Errorf("expected content type %s, got %s", "text/plain", contentType)
}
if contentLength := res.Header.Get("Content-Length"); contentLength != strconv.Itoa(len(data)) {
t.Errorf("expected content length %d, got %s", len(data), contentLength)
}
body, _ := io.ReadAll(res.Body)
if string(body) != "Hello, World!" {
t.Errorf("expected body %s, got %s", "Hello, World!", string(body))
}
}
func TestReader_unknownLength(t *testing.T) {
rec := httptest.NewRecorder()
data := bytes.NewBufferString("Hello, World!")
Reader(rec, http.StatusOK, "text/plain", 0, data)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" {
t.Errorf("expected content type %s, got %s", "text/plain", contentType)
}
if contentLength := res.Header.Get("Content-Length"); contentLength != "" {
t.Errorf("expected no content length, got %s", contentLength)
}
body, _ := io.ReadAll(res.Body)
if string(body) != "Hello, World!" {
t.Errorf("expected body %s, got %s", "Hello, World!", string(body))
}
}
func TestAttachment(t *testing.T) {
rec := httptest.NewRecorder()
data := []byte("Hello, World!")
Attachment(rec, http.StatusOK, "example.txt", "text/plain", data)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentDisposition := res.Header.Get("Content-Disposition"); contentDisposition != "attachment; filename=example.txt" {
t.Errorf("expected content disposition %s, got %s", "attachment; filename=example.txt", contentDisposition)
}
body, _ := io.ReadAll(res.Body)
if !bytes.Equal(body, data) {
t.Errorf("expected body %s, got %s", data, body)
}
}
func TestAttachmentReader(t *testing.T) {
rec := httptest.NewRecorder()
data := bytes.NewBufferString("Hello, World!")
AttachmentReader(rec, http.StatusOK, "example.txt", "text/plain", data.Len(), data)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentDisposition := res.Header.Get("Content-Disposition"); contentDisposition != "attachment; filename=example.txt" {
t.Errorf("expected content disposition %s, got %s", "attachment; filename=example.txt", contentDisposition)
}
body, _ := io.ReadAll(res.Body)
if string(body) != "Hello, World!" {
t.Errorf("expected body %s, got %s", "Hello, World!", string(body))
}
}
func TestRedirect(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/old", nil)
url := "http://example.com/new"
Redirect(rec, req, http.StatusMovedPermanently, url)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusMovedPermanently {
t.Errorf("expected status %d, got %d", http.StatusMovedPermanently, res.StatusCode)
}
if location := res.Header.Get("Location"); location != url {
t.Errorf("expected location %s, got %s", url, location)
}
}
func TestRedirect_relative(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/old/dir", nil)
url := "newlocation/sub"
want := "/old/newlocation/sub"
Redirect(rec, req, http.StatusMovedPermanently, url)
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusMovedPermanently {
t.Errorf("expected status %d, got %d", http.StatusMovedPermanently, res.StatusCode)
}
if location := res.Header.Get("Location"); location != want {
t.Errorf("expected location %s, got %s", want, location)
}
}

View File

@ -0,0 +1,46 @@
package respond
import (
"fmt"
"io"
"net/http"
)
// TplData is a map of template data. This is a convenience type for passing data to templates.
type TplData map[string]any
// TemplateInstance is an interface that wraps the ExecuteTemplate method.
// It is implemented by the html/template and text/template packages.
type TemplateInstance interface {
// ExecuteTemplate executes a template with the given name and data.
ExecuteTemplate(wr io.Writer, name string, data any) error
}
// TemplateRenderer is a renderer that uses a template instance to render HTML or Text templates.
type TemplateRenderer struct {
t TemplateInstance
}
// NewTemplateRenderer creates a new HTML or Text template renderer with the given template instance.
func NewTemplateRenderer(t TemplateInstance) *TemplateRenderer {
return &TemplateRenderer{t: t}
}
// Render renders a template with the given name and data.
// If rendering fails, it will panic with an error.
func (r *TemplateRenderer) Render(w http.ResponseWriter, code int, name, contentType string, data any) {
w.Header().Set("Content-Type", contentType)
w.WriteHeader(code)
err := r.t.ExecuteTemplate(w, name, data)
if err != nil {
panic(fmt.Errorf("error rendering template %s: %v", name, err))
}
}
// HTML renders a template with the given name and data. It is a convenience method for Render.
// The content type is set to "text/html" and the encoding to "utf-8".
// If rendering fails, it will panic with an error.
func (r *TemplateRenderer) HTML(w http.ResponseWriter, code int, name string, data any) {
r.Render(w, code, name, "text/html;charset=utf-8", data)
}

View File

@ -0,0 +1,67 @@
package respond
import (
"html/template"
"io"
"net/http"
"net/http/httptest"
"testing"
)
type mockTemplate struct {
tmpl *template.Template
}
func (m *mockTemplate) ExecuteTemplate(wr io.Writer, name string, data any) error {
return m.tmpl.ExecuteTemplate(wr, name, data)
}
func TestTemplateRenderer_Render(t *testing.T) {
tmpl := template.Must(template.New("test").Parse(`{{define "test"}}Hello, {{.}}!{{end}}`))
renderer := NewTemplateRenderer(&mockTemplate{tmpl: tmpl})
rec := httptest.NewRecorder()
renderer.Render(rec, http.StatusOK, "test", "text/plain", "World")
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" {
t.Errorf("expected content type %s, got %s", "text/plain", contentType)
}
body, _ := io.ReadAll(res.Body)
expectedBody := "Hello, World!"
if string(body) != expectedBody {
t.Errorf("expected body %s, got %s", expectedBody, string(body))
}
}
func TestTemplateRenderer_HTML(t *testing.T) {
tmpl := template.Must(template.New("test").Parse(`{{define "test"}}<p>Hello, {{.}}!</p>{{end}}`))
renderer := NewTemplateRenderer(&mockTemplate{tmpl: tmpl})
rec := httptest.NewRecorder()
renderer.HTML(rec, http.StatusOK, "test", "World")
res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
}
if contentType := res.Header.Get("Content-Type"); contentType != "text/html;charset=utf-8" {
t.Errorf("expected content type %s, got %s", "text/html;charset=utf-8", contentType)
}
body, _ := io.ReadAll(res.Body)
expectedBody := "<p>Hello, World!</p>"
if string(body) != expectedBody {
t.Errorf("expected body %s, got %s", expectedBody, string(body))
}
}

View File

@ -2,27 +2,25 @@ package core
import (
"context"
"encoding/base64"
"fmt"
"html/template"
"io"
"io/fs"
"log/slog"
"math/rand"
"net/http"
"os"
"time"
"github.com/gin-gonic/gin"
"github.com/go-pkgz/routegroup"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/app/api/core/middleware/cors"
"github.com/h44z/wg-portal/internal/app/api/core/middleware/logging"
"github.com/h44z/wg-portal/internal/app/api/core/middleware/recovery"
"github.com/h44z/wg-portal/internal/app/api/core/middleware/tracing"
"github.com/h44z/wg-portal/internal/app/api/core/respond"
"github.com/h44z/wg-portal/internal/config"
)
var (
random = rand.New(rand.NewSource(time.Now().UTC().UnixNano()))
)
const (
RequestIDKey = "X-Request-ID"
)
@ -30,19 +28,21 @@ const (
type ApiVersion string
type HandlerName string
type GroupSetupFn func(group *gin.RouterGroup)
type GroupSetupFn func(group *routegroup.Bundle)
type ApiEndpointSetupFunc func() (ApiVersion, GroupSetupFn)
type Server struct {
cfg *config.Config
server *gin.Engine
versions map[ApiVersion]*gin.RouterGroup
server *routegroup.Bundle
tpl *respond.TemplateRenderer
versions map[ApiVersion]*routegroup.Bundle
}
func NewServer(cfg *config.Config, endpoints ...ApiEndpointSetupFunc) (*Server, error) {
s := &Server{
cfg: cfg,
cfg: cfg,
server: routegroup.New(http.NewServeMux()),
}
hostname, err := os.Hostname()
@ -51,69 +51,39 @@ func NewServer(cfg *config.Config, endpoints ...ApiEndpointSetupFunc) (*Server,
}
hostname += ", version " + internal.Version
// Setup http server
gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard
s.server = gin.New()
s.server.Use(recovery.New().Handler)
if cfg.Web.RequestLogging {
if cfg.Advanced.LogLevel == "trace" {
gin.SetMode(gin.DebugMode)
}
s.server.Use(func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
raw := c.Request.URL.RawQuery
s.server.Use(logging.New(logging.WithLevel(logging.LogLevelDebug)).Handler)
c.Next()
if raw != "" {
path = path + "?" + raw
}
latency := time.Since(start)
status := c.Writer.Status()
clientIP := c.ClientIP()
method := c.Request.Method
errorMsg := c.Errors.ByType(gin.ErrorTypePrivate).String()
slog.Debug("HTTP Request",
"status", status,
"latency", latency,
"client", clientIP,
"method", method,
"path", path,
"error", errorMsg,
)
}
s.server.Use(cors.New().Handler)
s.server.Use(tracing.New(
tracing.WithContextIdentifier(RequestIDKey),
tracing.WithHeaderIdentifier(RequestIDKey),
).Handler)
if cfg.Web.ExposeHostInfo {
s.server.Use(func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Served-By", hostname)
handler.ServeHTTP(w, r)
})
})
}
s.server.Use(gin.Recovery()).Use(func(c *gin.Context) {
c.Writer.Header().Set("X-Served-By", hostname)
c.Next()
}).Use(func(c *gin.Context) {
xRequestID := uuid(16)
c.Request.Header.Set(RequestIDKey, xRequestID)
c.Set(RequestIDKey, xRequestID)
c.Next()
})
// Setup templates
templates := template.Must(template.New("").Funcs(s.server.FuncMap).ParseFS(apiTemplates, "assets/tpl/*.gohtml"))
s.server.SetHTMLTemplate(templates)
s.tpl = respond.NewTemplateRenderer(
template.Must(template.New("").ParseFS(apiTemplates, "assets/tpl/*.gohtml")),
)
// Serve static files
imgFs := http.FS(fsMust(fs.Sub(apiStatics, "assets/img")))
s.server.StaticFS("/css", http.FS(fsMust(fs.Sub(apiStatics, "assets/css"))))
s.server.StaticFS("/js", http.FS(fsMust(fs.Sub(apiStatics, "assets/js"))))
s.server.StaticFS("/img", imgFs)
s.server.StaticFS("/fonts", http.FS(fsMust(fs.Sub(apiStatics, "assets/fonts"))))
s.server.StaticFS("/doc", http.FS(fsMust(fs.Sub(apiStatics, "assets/doc"))))
s.server.HandleFiles("/css", http.FS(fsMust(fs.Sub(apiStatics, "assets/css"))))
s.server.HandleFiles("/js", http.FS(fsMust(fs.Sub(apiStatics, "assets/js"))))
s.server.HandleFiles("/img", imgFs)
s.server.HandleFiles("/fonts", http.FS(fsMust(fs.Sub(apiStatics, "assets/fonts"))))
s.server.HandleFiles("/doc", http.FS(fsMust(fs.Sub(apiStatics, "assets/doc"))))
// Setup routes
s.server.UseRawPath = true
s.server.UnescapePathValues = true
s.setupRoutes(endpoints...)
s.setupFrontendRoutes()
@ -136,9 +106,7 @@ func (s *Server) Run(ctx context.Context, listenAddress string) {
err = srv.ListenAndServe()
}
if err != nil {
slog.Info("web service exited",
"address", listenAddress,
"error", err)
slog.Info("web service exited", "address", listenAddress, "error", err)
cancelFn()
}
}()
@ -157,18 +125,18 @@ func (s *Server) Run(ctx context.Context, listenAddress string) {
}
func (s *Server) setupRoutes(endpoints ...ApiEndpointSetupFunc) {
s.server.GET("/api", s.landingPage)
s.versions = make(map[ApiVersion]*gin.RouterGroup)
s.server.HandleFunc("GET /api", s.landingPage)
s.versions = make(map[ApiVersion]*routegroup.Bundle)
for _, setupFunc := range endpoints {
version, groupSetupFn := setupFunc()
if _, ok := s.versions[version]; !ok {
s.versions[version] = s.server.Group(fmt.Sprintf("/api/%s", version))
s.versions[version] = s.server.Mount(fmt.Sprintf("/api/%s", version))
// OpenAPI documentation (via RapiDoc)
s.versions[version].GET("/swagger/index.html", s.rapiDocHandler(version)) // Deprecated: old link
s.versions[version].GET("/doc.html", s.rapiDocHandler(version))
s.versions[version].HandleFunc("GET /swagger/index.html", s.rapiDocHandler(version)) // Deprecated: old link
s.versions[version].HandleFunc("GET /doc.html", s.rapiDocHandler(version))
groupSetupFn(s.versions[version])
}
@ -177,25 +145,27 @@ func (s *Server) setupRoutes(endpoints ...ApiEndpointSetupFunc) {
func (s *Server) setupFrontendRoutes() {
// Serve static files
s.server.GET("/", func(c *gin.Context) {
c.Redirect(http.StatusMovedPermanently, "/app")
s.server.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
respond.Redirect(w, r, http.StatusMovedPermanently, "/app")
})
s.server.GET("/favicon.ico", func(c *gin.Context) {
c.Redirect(http.StatusMovedPermanently, "/app/favicon.ico")
s.server.HandleFunc("/favicon.ico", func(w http.ResponseWriter, r *http.Request) {
respond.Redirect(w, r, http.StatusMovedPermanently, "/app/favicon.ico")
})
s.server.StaticFS("/app", http.FS(fsMust(fs.Sub(frontendStatics, "frontend-dist"))))
s.server.HandleFiles("/app", http.FS(fsMust(fs.Sub(frontendStatics, "frontend-dist"))))
}
func (s *Server) landingPage(c *gin.Context) {
c.HTML(http.StatusOK, "index.gohtml", gin.H{
func (s *Server) landingPage(w http.ResponseWriter, _ *http.Request) {
s.tpl.HTML(w, http.StatusOK, "index.gohtml", respond.TplData{
"Version": internal.Version,
"Year": time.Now().Year(),
})
}
func (s *Server) rapiDocHandler(version ApiVersion) gin.HandlerFunc {
return func(c *gin.Context) {
c.HTML(http.StatusOK, "rapidoc.gohtml", gin.H{
func (s *Server) rapiDocHandler(version ApiVersion) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
s.tpl.HTML(w, http.StatusOK, "rapidoc.gohtml", respond.TplData{
"RapiDocSource": "/js/rapidoc-min.js",
"ApiSpecUrl": fmt.Sprintf("/doc/%s_swagger.yaml", version),
"Version": internal.Version,
@ -210,9 +180,3 @@ func fsMust(f fs.FS, err error) fs.FS {
}
return f
}
func uuid(len int) string {
bytes := make([]byte, len)
random.Read(bytes)
return base64.StdEncoding.EncodeToString(bytes)[:len]
}

View File

@ -1,24 +1,46 @@
package handlers
import (
"context"
"net/http"
"strings"
"github.com/gin-contrib/cors"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/memstore"
"github.com/gin-gonic/gin"
csrf "github.com/utrack/gin-csrf"
"github.com/go-pkgz/routegroup"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/app/api/core"
"github.com/h44z/wg-portal/internal/app/api/v0/model"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/app/api/core/middleware/cors"
"github.com/h44z/wg-portal/internal/app/api/core/middleware/csrf"
"github.com/h44z/wg-portal/internal/app/api/core/respond"
)
type handler interface {
type SessionMiddleware interface {
// SetData sets the session data for the given context.
SetData(ctx context.Context, val SessionData)
// GetData returns the session data for the given context. If no data is found, the default session data is returned.
GetData(ctx context.Context) SessionData
// DestroyData destroys the session data for the given context.
DestroyData(ctx context.Context)
// GetString returns the string value for the given key. If no value is found, an empty string is returned.
GetString(ctx context.Context, key string) string
// Put sets the value for the given key.
Put(ctx context.Context, key string, value any)
// LoadAndSave is a middleware that loads the session data for the given request and saves it after the request is
// finished.
LoadAndSave(next http.Handler) http.Handler
}
type Handler interface {
// GetName returns the name of the handler.
GetName() string
RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler)
// RegisterRoutes registers the routes for the handler. The session manager is passed to the handler.
RegisterRoutes(g *routegroup.Bundle)
}
type Authenticator interface {
// LoggedIn checks if a user is logged in. If scopes are given, they are validated as well.
LoggedIn(scopes ...Scope) func(next http.Handler) http.Handler
// UserIdMatch checks if the user id in the session matches the user id in the request. If not, the request is aborted.
UserIdMatch(idParameter string) func(next http.Handler) http.Handler
}
// To compile the API documentation use the
@ -35,54 +57,33 @@ type handler interface {
// @BasePath /api/v0
// @query.collection.format multi
func NewRestApi(cfg *config.Config, app *app.App) core.ApiEndpointSetupFunc {
authenticator := &authenticationHandler{
app: app,
Session: GinSessionStore{sessionIdentifier: cfg.Web.SessionIdentifier},
}
handlers := make([]handler, 0, 1)
handlers = append(handlers, testEndpoint{})
handlers = append(handlers, userEndpoint{app: app, authenticator: authenticator})
handlers = append(handlers, newConfigEndpoint(app, authenticator))
handlers = append(handlers, authEndpoint{app: app, authenticator: authenticator})
handlers = append(handlers, interfaceEndpoint{app: app, authenticator: authenticator})
handlers = append(handlers, peerEndpoint{app: app, authenticator: authenticator})
func NewRestApi(
session SessionMiddleware,
handlers ...Handler,
) core.ApiEndpointSetupFunc {
return func() (core.ApiVersion, core.GroupSetupFn) {
return "v0", func(group *gin.RouterGroup) {
cookieStore := memstore.NewStore([]byte(cfg.Web.SessionSecret))
cookieStore.Options(sessions.Options{
Path: "/",
MaxAge: 86400, // auth session is valid for 1 day
Secure: strings.HasPrefix(cfg.Web.ExternalUrl, "https"),
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
return "v0", func(group *routegroup.Bundle) {
csrfMiddleware := csrf.New(func(r *http.Request) string {
return session.GetString(r.Context(), "csrf_token")
}, func(r *http.Request, token string) {
session.Put(r.Context(), "csrf_token", token)
})
group.Use(sessions.Sessions(cfg.Web.SessionIdentifier, cookieStore))
group.Use(cors.Default())
group.Use(csrf.Middleware(csrf.Options{
Secret: cfg.Web.CsrfSecret,
ErrorFunc: func(c *gin.Context) {
c.JSON(http.StatusBadRequest, model.Error{
Code: http.StatusBadRequest,
Message: "CSRF token mismatch",
})
c.Abort()
},
}))
group.GET("/csrf", handleCsrfGet())
group.Use(session.LoadAndSave)
group.Use(csrfMiddleware.Handler)
group.Use(cors.New().Handler)
group.With(csrfMiddleware.RefreshToken).HandleFunc("GET /csrf", handleCsrfGet())
// Handler functions
for _, h := range handlers {
h.RegisterRoutes(group, authenticator)
h.RegisterRoutes(group)
}
}
}
}
// handleCsrfGet returns a gorm handler function.
// handleCsrfGet returns a gorm Handler function.
//
// @ID base_handleCsrfGet
// @Tags Security
@ -90,8 +91,12 @@ func NewRestApi(cfg *config.Config, app *app.App) core.ApiEndpointSetupFunc {
// @Produce json
// @Success 200 {object} string
// @Router /csrf [get]
func handleCsrfGet() gin.HandlerFunc {
return func(c *gin.Context) {
c.JSON(http.StatusOK, csrf.GetToken(c))
func handleCsrfGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
respond.JSON(w, http.StatusOK, csrf.GetToken(r.Context()))
}
}
// region session wrapper
// endregion session wrapper

View File

@ -8,36 +8,62 @@ import (
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/go-pkgz/routegroup"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/app/api/core/request"
"github.com/h44z/wg-portal/internal/app/api/core/respond"
"github.com/h44z/wg-portal/internal/app/api/v0/model"
"github.com/h44z/wg-portal/internal/domain"
)
type authEndpoint struct {
app *app.App
authenticator *authenticationHandler
type Session interface {
// SetData sets the session data for the given context.
SetData(ctx context.Context, val SessionData)
// GetData returns the session data for the given context. If no data is found, the default session data is returned.
GetData(ctx context.Context) SessionData
// DestroyData destroys the session data for the given context.
DestroyData(ctx context.Context)
}
func (e authEndpoint) GetName() string {
type Validator interface {
Struct(s interface{}) error
}
type AuthEndpoint struct {
app *app.App
authenticator Authenticator
session Session
validate Validator
}
func NewAuthEndpoint(app *app.App, authenticator Authenticator, session Session, validator Validator) AuthEndpoint {
return AuthEndpoint{
app: app,
authenticator: authenticator,
session: session,
validate: validator,
}
}
func (e AuthEndpoint) GetName() string {
return "AuthEndpoint"
}
func (e authEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
apiGroup := g.Group("/auth")
func (e AuthEndpoint) RegisterRoutes(g *routegroup.Bundle) {
apiGroup := g.Mount("/auth")
apiGroup.GET("/providers", e.handleExternalLoginProvidersGet())
apiGroup.GET("/session", e.handleSessionInfoGet())
apiGroup.HandleFunc("GET /providers", e.handleExternalLoginProvidersGet())
apiGroup.HandleFunc("GET /session", e.handleSessionInfoGet())
apiGroup.GET("/login/:provider/init", e.handleOauthInitiateGet())
apiGroup.GET("/login/:provider/callback", e.handleOauthCallbackGet())
apiGroup.HandleFunc("GET /login/{provider}/init", e.handleOauthInitiateGet())
apiGroup.HandleFunc("GET /login/{provider}/callback", e.handleOauthCallbackGet())
apiGroup.POST("/login", e.handleLoginPost())
apiGroup.POST("/logout", authenticator.LoggedIn(), e.handleLogoutPost())
apiGroup.HandleFunc("POST /login", e.handleLoginPost())
apiGroup.With(e.authenticator.LoggedIn()).HandleFunc("POST /logout", e.handleLogoutPost())
}
// handleExternalLoginProvidersGet returns a gorm handler function.
// handleExternalLoginProvidersGet returns a gorm Handler function.
//
// @ID auth_handleExternalLoginProvidersGet
// @Tags Authentication
@ -45,16 +71,15 @@ func (e authEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenti
// @Produce json
// @Success 200 {object} []model.LoginProviderInfo
// @Router /auth/providers [get]
func (e authEndpoint) handleExternalLoginProvidersGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
providers := e.app.Authenticator.GetExternalLoginProviders(ctx)
func (e AuthEndpoint) handleExternalLoginProvidersGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
providers := e.app.Authenticator.GetExternalLoginProviders(r.Context())
c.JSON(http.StatusOK, model.NewLoginProviderInfos(providers))
respond.JSON(w, http.StatusOK, model.NewLoginProviderInfos(providers))
}
}
// handleSessionInfoGet returns a gorm handler function.
// handleSessionInfoGet returns a gorm Handler function.
//
// @ID auth_handleSessionInfoGet
// @Tags Authentication
@ -63,9 +88,9 @@ func (e authEndpoint) handleExternalLoginProvidersGet() gin.HandlerFunc {
// @Success 200 {object} []model.SessionInfo
// @Failure 500 {object} model.Error
// @Router /auth/session [get]
func (e authEndpoint) handleSessionInfoGet() gin.HandlerFunc {
return func(c *gin.Context) {
currentSession := e.authenticator.Session.GetData(c)
func (e AuthEndpoint) handleSessionInfoGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
currentSession := e.session.GetData(r.Context())
var loggedInUid *string
var firstname *string
@ -83,7 +108,7 @@ func (e authEndpoint) handleSessionInfoGet() gin.HandlerFunc {
email = &e
}
c.JSON(http.StatusOK, model.SessionInfo{
respond.JSON(w, http.StatusOK, model.SessionInfo{
LoggedIn: currentSession.LoggedIn,
IsAdmin: currentSession.IsAdmin,
UserIdentifier: loggedInUid,
@ -94,7 +119,7 @@ func (e authEndpoint) handleSessionInfoGet() gin.HandlerFunc {
}
}
// handleOauthInitiateGet returns a gorm handler function.
// handleOauthInitiateGet returns a gorm Handler function.
//
// @ID auth_handleOauthInitiateGet
// @Tags Authentication
@ -102,23 +127,24 @@ func (e authEndpoint) handleSessionInfoGet() gin.HandlerFunc {
// @Produce json
// @Success 200 {object} []model.LoginProviderInfo
// @Router /auth/{provider}/init [get]
func (e authEndpoint) handleOauthInitiateGet() gin.HandlerFunc {
return func(c *gin.Context) {
currentSession := e.authenticator.Session.GetData(c)
func (e AuthEndpoint) handleOauthInitiateGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
currentSession := e.session.GetData(r.Context())
autoRedirect, _ := strconv.ParseBool(c.DefaultQuery("redirect", "false"))
returnTo := c.Query("return")
provider := c.Param("provider")
autoRedirect, _ := strconv.ParseBool(request.QueryDefault(r, "redirect", "false"))
returnTo := request.Query(r, "return")
provider := request.Path(r, "provider")
var returnUrl *url.URL
var returnParams string
redirectToReturn := func() {
c.Redirect(http.StatusFound, returnUrl.String()+"?"+returnParams)
respond.Redirect(w, r, http.StatusFound, returnUrl.String()+"?"+returnParams)
}
if returnTo != "" {
if !e.isValidReturnUrl(returnTo) {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "invalid return URL"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "invalid return URL"})
return
}
if u, err := url.Parse(returnTo); err == nil {
@ -137,34 +163,34 @@ func (e authEndpoint) handleOauthInitiateGet() gin.HandlerFunc {
returnParams = queryParams.Encode()
redirectToReturn()
} else {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "already logged in"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "already logged in"})
}
return
}
ctx := domain.SetUserInfoFromGin(c)
authCodeUrl, state, nonce, err := e.app.Authenticator.OauthLoginStep1(ctx, provider)
authCodeUrl, state, nonce, err := e.app.Authenticator.OauthLoginStep1(context.Background(), provider)
if err != nil {
if autoRedirect && e.isValidReturnUrl(returnTo) {
redirectToReturn()
} else {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
}
return
}
authSession := e.authenticator.Session.DefaultSessionData()
authSession := e.session.GetData(r.Context())
authSession.OauthState = state
authSession.OauthNonce = nonce
authSession.OauthProvider = provider
authSession.OauthReturnTo = returnTo
e.authenticator.Session.SetData(c, authSession)
e.session.SetData(r.Context(), authSession)
if autoRedirect {
c.Redirect(http.StatusFound, authCodeUrl)
respond.Redirect(w, r, http.StatusFound, authCodeUrl)
} else {
c.JSON(http.StatusOK, model.OauthInitiationResponse{
respond.JSON(w, http.StatusOK, model.OauthInitiationResponse{
RedirectUrl: authCodeUrl,
State: state,
})
@ -172,7 +198,7 @@ func (e authEndpoint) handleOauthInitiateGet() gin.HandlerFunc {
}
}
// handleOauthCallbackGet returns a gorm handler function.
// handleOauthCallbackGet returns a gorm Handler function.
//
// @ID auth_handleOauthCallbackGet
// @Tags Authentication
@ -180,14 +206,14 @@ func (e authEndpoint) handleOauthInitiateGet() gin.HandlerFunc {
// @Produce json
// @Success 200 {object} []model.LoginProviderInfo
// @Router /auth/{provider}/callback [get]
func (e authEndpoint) handleOauthCallbackGet() gin.HandlerFunc {
return func(c *gin.Context) {
currentSession := e.authenticator.Session.GetData(c)
func (e AuthEndpoint) handleOauthCallbackGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
currentSession := e.session.GetData(r.Context())
var returnUrl *url.URL
var returnParams string
redirectToReturn := func() {
c.Redirect(http.StatusFound, returnUrl.String()+"?"+returnParams)
respond.Redirect(w, r, http.StatusFound, returnUrl.String()+"?"+returnParams)
}
if currentSession.OauthReturnTo != "" {
@ -207,20 +233,20 @@ func (e authEndpoint) handleOauthCallbackGet() gin.HandlerFunc {
returnParams = queryParams.Encode()
redirectToReturn()
} else {
c.JSON(http.StatusBadRequest, model.Error{Message: "already logged in"})
respond.JSON(w, http.StatusBadRequest, model.Error{Message: "already logged in"})
}
return
}
provider := c.Param("provider")
oauthCode := c.Query("code")
oauthState := c.Query("state")
provider := request.Path(r, "provider")
oauthCode := request.Query(r, "code")
oauthState := request.Query(r, "state")
if provider != currentSession.OauthProvider {
if returnUrl != nil && e.isValidReturnUrl(returnUrl.String()) {
redirectToReturn()
} else {
c.JSON(http.StatusBadRequest,
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "invalid oauth provider"})
}
return
@ -229,7 +255,8 @@ func (e authEndpoint) handleOauthCallbackGet() gin.HandlerFunc {
if returnUrl != nil && e.isValidReturnUrl(returnUrl.String()) {
redirectToReturn()
} else {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "invalid oauth state"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "invalid oauth state"})
}
return
}
@ -241,12 +268,13 @@ func (e authEndpoint) handleOauthCallbackGet() gin.HandlerFunc {
if returnUrl != nil && e.isValidReturnUrl(returnUrl.String()) {
redirectToReturn()
} else {
c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: err.Error()})
respond.JSON(w, http.StatusUnauthorized,
model.Error{Code: http.StatusUnauthorized, Message: err.Error()})
}
return
}
e.setAuthenticatedUser(c, user)
e.setAuthenticatedUser(r, user)
if returnUrl != nil && e.isValidReturnUrl(returnUrl.String()) {
queryParams := returnUrl.Query()
@ -254,13 +282,13 @@ func (e authEndpoint) handleOauthCallbackGet() gin.HandlerFunc {
returnParams = queryParams.Encode()
redirectToReturn()
} else {
c.JSON(http.StatusOK, user)
respond.JSON(w, http.StatusOK, user)
}
}
}
func (e authEndpoint) setAuthenticatedUser(c *gin.Context, user *domain.User) {
currentSession := e.authenticator.Session.GetData(c)
func (e AuthEndpoint) setAuthenticatedUser(r *http.Request, user *domain.User) {
currentSession := e.session.GetData(r.Context())
currentSession.LoggedIn = true
currentSession.IsAdmin = user.IsAdmin
@ -274,10 +302,10 @@ func (e authEndpoint) setAuthenticatedUser(c *gin.Context, user *domain.User) {
currentSession.OauthProvider = ""
currentSession.OauthReturnTo = ""
e.authenticator.Session.SetData(c, currentSession)
e.session.SetData(r.Context(), currentSession)
}
// handleLoginPost returns a gorm handler function.
// handleLoginPost returns a gorm Handler function.
//
// @ID auth_handleLoginPost
// @Tags Authentication
@ -285,11 +313,11 @@ func (e authEndpoint) setAuthenticatedUser(c *gin.Context, user *domain.User) {
// @Produce json
// @Success 200 {object} []model.LoginProviderInfo
// @Router /auth/login [post]
func (e authEndpoint) handleLoginPost() gin.HandlerFunc {
return func(c *gin.Context) {
currentSession := e.authenticator.Session.GetData(c)
func (e AuthEndpoint) handleLoginPost() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
currentSession := e.session.GetData(r.Context())
if currentSession.LoggedIn {
c.JSON(http.StatusOK, model.Error{Code: http.StatusOK, Message: "already logged in"})
respond.JSON(w, http.StatusOK, model.Error{Code: http.StatusOK, Message: "already logged in"})
return
}
@ -298,25 +326,29 @@ func (e authEndpoint) handleLoginPost() gin.HandlerFunc {
Password string `json:"password" binding:"required,min=4"`
}
if err := c.ShouldBindJSON(&loginData); err != nil {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
if err := request.BodyJson(r, &loginData); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if err := e.validate.Struct(loginData); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
ctx := domain.SetUserInfoFromGin(c)
user, err := e.app.Authenticator.PlainLogin(ctx, loginData.Username, loginData.Password)
user, err := e.app.Authenticator.PlainLogin(context.Background(), loginData.Username, loginData.Password)
if err != nil {
c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "login failed"})
respond.JSON(w, http.StatusUnauthorized,
model.Error{Code: http.StatusUnauthorized, Message: "login failed"})
return
}
e.setAuthenticatedUser(c, user)
e.setAuthenticatedUser(r, user)
c.JSON(http.StatusOK, user)
respond.JSON(w, http.StatusOK, user)
}
}
// handleLogoutPost returns a gorm handler function.
// handleLogoutPost returns a gorm Handler function.
//
// @ID auth_handleLogoutGet
// @Tags Authentication
@ -324,22 +356,22 @@ func (e authEndpoint) handleLoginPost() gin.HandlerFunc {
// @Produce json
// @Success 200 {object} []model.LoginProviderInfo
// @Router /auth/logout [get]
func (e authEndpoint) handleLogoutPost() gin.HandlerFunc {
return func(c *gin.Context) {
currentSession := e.authenticator.Session.GetData(c)
func (e AuthEndpoint) handleLogoutPost() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
currentSession := e.session.GetData(r.Context())
if !currentSession.LoggedIn { // Not logged in
c.JSON(http.StatusOK, model.Error{Code: http.StatusOK, Message: "not logged in"})
respond.JSON(w, http.StatusOK, model.Error{Code: http.StatusOK, Message: "not logged in"})
return
}
e.authenticator.Session.DestroyData(c)
c.JSON(http.StatusOK, model.Error{Code: http.StatusOK, Message: "logout ok"})
e.session.DestroyData(r.Context())
respond.JSON(w, http.StatusOK, model.Error{Code: http.StatusOK, Message: "logout ok"})
}
}
// isValidReturnUrl checks if the given return URL matches the configured external URL of the application.
func (e authEndpoint) isValidReturnUrl(returnUrl string) bool {
func (e AuthEndpoint) isValidReturnUrl(returnUrl string) bool {
if !strings.HasPrefix(returnUrl, e.app.Config.Web.ExternalUrl) {
return false
}

View File

@ -1,7 +1,6 @@
package handlers
import (
"bytes"
"embed"
"fmt"
"html/template"
@ -9,57 +8,61 @@ import (
"net/http"
"net/url"
"github.com/gin-gonic/gin"
"github.com/go-pkgz/routegroup"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/app/api/core/request"
"github.com/h44z/wg-portal/internal/app/api/core/respond"
"github.com/h44z/wg-portal/internal/app/api/v0/model"
"github.com/h44z/wg-portal/internal/config"
)
//go:embed frontend_config.js.gotpl
var frontendJs embed.FS
type configEndpoint struct {
app *app.App
authenticator *authenticationHandler
type ConfigEndpoint struct {
cfg *config.Config
authenticator Authenticator
tpl *template.Template
tpl *respond.TemplateRenderer
}
func newConfigEndpoint(app *app.App, authenticator *authenticationHandler) configEndpoint {
ep := configEndpoint{
app: app,
func NewConfigEndpoint(cfg *config.Config, authenticator Authenticator) ConfigEndpoint {
ep := ConfigEndpoint{
cfg: cfg,
authenticator: authenticator,
tpl: template.Must(template.ParseFS(frontendJs, "frontend_config.js.gotpl")),
tpl: respond.NewTemplateRenderer(template.Must(template.ParseFS(frontendJs,
"frontend_config.js.gotpl"))),
}
return ep
}
func (e configEndpoint) GetName() string {
func (e ConfigEndpoint) GetName() string {
return "ConfigEndpoint"
}
func (e configEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandler) {
apiGroup := g.Group("/config")
func (e ConfigEndpoint) RegisterRoutes(g *routegroup.Bundle) {
apiGroup := g.Mount("/config")
apiGroup.GET("/frontend.js", e.handleConfigJsGet())
apiGroup.GET("/settings", e.authenticator.LoggedIn(), e.handleSettingsGet())
apiGroup.HandleFunc("GET /frontend.js", e.handleConfigJsGet())
apiGroup.With(e.authenticator.LoggedIn()).HandleFunc("GET /settings", e.handleSettingsGet())
}
// handleConfigJsGet returns a gorm handler function.
// handleConfigJsGet returns a gorm Handler function.
//
// @ID config_handleConfigJsGet
// @Tags Configuration
// @Summary Get the dynamic frontend configuration javascript.
// @Produce text/javascript
// @Success 200 string javascript "The JavaScript contents"
// @Failure 500
// @Router /config/frontend.js [get]
func (e configEndpoint) handleConfigJsGet() gin.HandlerFunc {
return func(c *gin.Context) {
backendUrl := fmt.Sprintf("%s/api/v0", e.app.Config.Web.ExternalUrl)
if c.GetHeader("x-wg-dev") != "" {
referer := c.Request.Header.Get("Referer")
func (e ConfigEndpoint) handleConfigJsGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
backendUrl := fmt.Sprintf("%s/api/v0", e.cfg.Web.ExternalUrl)
if request.Header(r, "x-wg-dev") != "" {
referer := request.Header(r, "Referer")
host := "localhost"
port := "5000"
parsedReferer, err := url.Parse(referer)
@ -69,23 +72,17 @@ func (e configEndpoint) handleConfigJsGet() gin.HandlerFunc {
backendUrl = fmt.Sprintf("http://%s:%s/api/v0", host,
port) // override if request comes from frontend started with npm run dev
}
buf := &bytes.Buffer{}
err := e.tpl.ExecuteTemplate(buf, "frontend_config.js.gotpl", gin.H{
e.tpl.Render(w, http.StatusOK, "frontend_config.js.gotpl", "text/javascript", map[string]any{
"BackendUrl": backendUrl,
"Version": internal.Version,
"SiteTitle": e.app.Config.Web.SiteTitle,
"SiteCompanyName": e.app.Config.Web.SiteCompanyName,
"SiteTitle": e.cfg.Web.SiteTitle,
"SiteCompanyName": e.cfg.Web.SiteCompanyName,
})
if err != nil {
c.Status(http.StatusInternalServerError)
return
}
c.Data(http.StatusOK, "application/javascript", buf.Bytes())
}
}
// handleSettingsGet returns a gorm handler function.
// handleSettingsGet returns a gorm Handler function.
//
// @ID config_handleSettingsGet
// @Tags Configuration
@ -94,13 +91,13 @@ func (e configEndpoint) handleConfigJsGet() gin.HandlerFunc {
// @Success 200 {object} model.Settings
// @Success 200 string javascript "The JavaScript contents"
// @Router /config/settings [get]
func (e configEndpoint) handleSettingsGet() gin.HandlerFunc {
return func(c *gin.Context) {
c.JSON(http.StatusOK, model.Settings{
MailLinkOnly: e.app.Config.Mail.LinkOnly,
PersistentConfigSupported: e.app.Config.Advanced.ConfigStoragePath != "",
SelfProvisioning: e.app.Config.Core.SelfProvisioningAllowed,
ApiAdminOnly: e.app.Config.Advanced.ApiAdminOnly,
func (e ConfigEndpoint) handleSettingsGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
respond.JSON(w, http.StatusOK, model.Settings{
MailLinkOnly: e.cfg.Mail.LinkOnly,
PersistentConfigSupported: e.cfg.Advanced.ConfigStoragePath != "",
SelfProvisioning: e.cfg.Core.SelfProvisioningAllowed,
ApiAdminOnly: e.cfg.Advanced.ApiAdminOnly,
})
}
}

View File

@ -4,39 +4,51 @@ import (
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/go-pkgz/routegroup"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/app/api/core/request"
"github.com/h44z/wg-portal/internal/app/api/core/respond"
"github.com/h44z/wg-portal/internal/app/api/v0/model"
"github.com/h44z/wg-portal/internal/domain"
)
type interfaceEndpoint struct {
type InterfaceEndpoint struct {
app *app.App
authenticator *authenticationHandler
authenticator Authenticator
validator Validator
}
func (e interfaceEndpoint) GetName() string {
func NewInterfaceEndpoint(app *app.App, authenticator Authenticator, validator Validator) InterfaceEndpoint {
return InterfaceEndpoint{
app: app,
authenticator: authenticator,
validator: validator,
}
}
func (e InterfaceEndpoint) GetName() string {
return "InterfaceEndpoint"
}
func (e interfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandler) {
apiGroup := g.Group("/interface", e.authenticator.LoggedIn(ScopeAdmin))
func (e InterfaceEndpoint) RegisterRoutes(g *routegroup.Bundle) {
apiGroup := g.Mount("/interface")
apiGroup.Use(e.authenticator.LoggedIn(ScopeAdmin))
apiGroup.GET("/prepare", e.handlePrepareGet())
apiGroup.GET("/all", e.handleAllGet())
apiGroup.GET("/get/:id", e.handleSingleGet())
apiGroup.PUT("/:id", e.handleUpdatePut())
apiGroup.DELETE("/:id", e.handleDelete())
apiGroup.POST("/new", e.handleCreatePost())
apiGroup.GET("/config/:id", e.handleConfigGet())
apiGroup.POST("/:id/save-config", e.handleSaveConfigPost())
apiGroup.POST("/:id/apply-peer-defaults", e.handleApplyPeerDefaultsPost())
apiGroup.HandleFunc("GET /prepare", e.handlePrepareGet())
apiGroup.HandleFunc("GET /all", e.handleAllGet())
apiGroup.HandleFunc("GET /get/{id}", e.handleSingleGet())
apiGroup.HandleFunc("PUT /{id}", e.handleUpdatePut())
apiGroup.HandleFunc("DELETE /{id}", e.handleDelete())
apiGroup.HandleFunc("POST /new", e.handleCreatePost())
apiGroup.HandleFunc("GET /config/{id}", e.handleConfigGet())
apiGroup.HandleFunc("POST /{id}/save-config", e.handleSaveConfigPost())
apiGroup.HandleFunc("POST /{id}/apply-peer-defaults", e.handleApplyPeerDefaultsPost())
apiGroup.GET("/peers/:id", e.handlePeersGet())
apiGroup.HandleFunc("GET /peers/{id}", e.handlePeersGet())
}
// handlePrepareGet returns a gorm handler function.
// handlePrepareGet returns a gorm Handler function.
//
// @ID interfaces_handlePrepareGet
// @Tags Interface
@ -45,22 +57,21 @@ func (e interfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationH
// @Success 200 {object} model.Interface
// @Failure 500 {object} model.Error
// @Router /interface/prepare [get]
func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
in, err := e.app.PrepareInterface(ctx)
func (e InterfaceEndpoint) handlePrepareGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
in, err := e.app.PrepareInterface(r.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
})
return
}
c.JSON(http.StatusOK, model.NewInterface(in, nil))
respond.JSON(w, http.StatusOK, model.NewInterface(in, nil))
}
}
// handleAllGet returns a gorm handler function.
// handleAllGet returns a gorm Handler function.
//
// @ID interfaces_handleAllGet
// @Tags Interface
@ -69,22 +80,21 @@ func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc {
// @Success 200 {object} []model.Interface
// @Failure 500 {object} model.Error
// @Router /interface/all [get]
func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
interfaces, peers, err := e.app.GetAllInterfacesAndPeers(ctx)
func (e InterfaceEndpoint) handleAllGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
interfaces, peers, err := e.app.GetAllInterfacesAndPeers(r.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
})
return
}
c.JSON(http.StatusOK, model.NewInterfaces(interfaces, peers))
respond.JSON(w, http.StatusOK, model.NewInterfaces(interfaces, peers))
}
}
// handleSingleGet returns a gorm handler function.
// handleSingleGet returns a gorm Handler function.
//
// @ID interfaces_handleSingleGet
// @Tags Interface
@ -94,30 +104,29 @@ func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /interface/get/{id} [get]
func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := Base64UrlDecode(c.Param("id"))
func (e InterfaceEndpoint) handleSingleGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := Base64UrlDecode(request.Path(r, "id"))
if id == "" {
c.JSON(http.StatusBadRequest, model.Error{
respond.JSON(w, http.StatusBadRequest, model.Error{
Code: http.StatusInternalServerError, Message: "missing id parameter",
})
return
}
iface, peers, err := e.app.GetInterfaceAndPeers(ctx, domain.InterfaceIdentifier(id))
iface, peers, err := e.app.GetInterfaceAndPeers(r.Context(), domain.InterfaceIdentifier(id))
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
})
return
}
c.JSON(http.StatusOK, model.NewInterface(iface, peers))
respond.JSON(w, http.StatusOK, model.NewInterface(iface, peers))
}
}
// handleConfigGet returns a gorm handler function.
// handleConfigGet returns a gorm Handler function.
//
// @ID interfaces_handleConfigGet
// @Tags Interface
@ -127,20 +136,19 @@ func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /interface/config/{id} [get]
func (e interfaceEndpoint) handleConfigGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := Base64UrlDecode(c.Param("id"))
func (e InterfaceEndpoint) handleConfigGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := Base64UrlDecode(request.Path(r, "id"))
if id == "" {
c.JSON(http.StatusBadRequest, model.Error{
respond.JSON(w, http.StatusBadRequest, model.Error{
Code: http.StatusInternalServerError, Message: "missing id parameter",
})
return
}
config, err := e.app.GetInterfaceConfig(ctx, domain.InterfaceIdentifier(id))
config, err := e.app.GetInterfaceConfig(r.Context(), domain.InterfaceIdentifier(id))
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
})
return
@ -148,17 +156,17 @@ func (e interfaceEndpoint) handleConfigGet() gin.HandlerFunc {
configString, err := io.ReadAll(config)
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
})
return
}
c.JSON(http.StatusOK, string(configString))
respond.JSON(w, http.StatusOK, string(configString))
}
}
// handleUpdatePut returns a gorm handler function.
// handleUpdatePut returns a gorm Handler function.
//
// @ID interfaces_handleUpdatePut
// @Tags Interface
@ -170,41 +178,44 @@ func (e interfaceEndpoint) handleConfigGet() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /interface/{id} [put]
func (e interfaceEndpoint) handleUpdatePut() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := Base64UrlDecode(c.Param("id"))
func (e InterfaceEndpoint) handleUpdatePut() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := Base64UrlDecode(request.Path(r, "id"))
if id == "" {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
return
}
var in model.Interface
err := c.BindJSON(&in)
if err != nil {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
if err := request.BodyJson(r, &in); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if err := e.validator.Struct(in); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if id != in.Identifier {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"})
return
}
updatedInterface, peers, err := e.app.UpdateInterface(ctx, model.NewDomainInterface(&in))
updatedInterface, peers, err := e.app.UpdateInterface(r.Context(), model.NewDomainInterface(&in))
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
})
return
}
c.JSON(http.StatusOK, model.NewInterface(updatedInterface, peers))
respond.JSON(w, http.StatusOK, model.NewInterface(updatedInterface, peers))
}
}
// handleCreatePost returns a gorm handler function.
// handleCreatePost returns a gorm Handler function.
//
// @ID interfaces_handleCreatePost
// @Tags Interface
@ -215,30 +226,31 @@ func (e interfaceEndpoint) handleUpdatePut() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /interface/new [post]
func (e interfaceEndpoint) handleCreatePost() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
func (e InterfaceEndpoint) handleCreatePost() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var in model.Interface
err := c.BindJSON(&in)
if err != nil {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
if err := request.BodyJson(r, &in); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if err := e.validator.Struct(in); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
newInterface, err := e.app.CreateInterface(ctx, model.NewDomainInterface(&in))
newInterface, err := e.app.CreateInterface(r.Context(), model.NewDomainInterface(&in))
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
})
return
}
c.JSON(http.StatusOK, model.NewInterface(newInterface, nil))
respond.JSON(w, http.StatusOK, model.NewInterface(newInterface, nil))
}
}
// handlePeersGet returns a gorm handler function.
// handlePeersGet returns a gorm Handler function.
//
// @ID interfaces_handlePeersGet
// @Tags Interface
@ -247,31 +259,29 @@ func (e interfaceEndpoint) handleCreatePost() gin.HandlerFunc {
// @Success 200 {object} []model.Peer
// @Failure 500 {object} model.Error
// @Router /interface/peers/{id} [get]
func (e interfaceEndpoint) handlePeersGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := Base64UrlDecode(c.Param("id"))
func (e InterfaceEndpoint) handlePeersGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := Base64UrlDecode(request.Path(r, "id"))
if id == "" {
c.JSON(http.StatusBadRequest, model.Error{
respond.JSON(w, http.StatusBadRequest, model.Error{
Code: http.StatusInternalServerError, Message: "missing id parameter",
})
return
}
_, peers, err := e.app.GetInterfaceAndPeers(ctx, domain.InterfaceIdentifier(id))
_, peers, err := e.app.GetInterfaceAndPeers(r.Context(), domain.InterfaceIdentifier(id))
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
})
return
}
c.JSON(http.StatusOK, model.NewPeers(peers))
respond.JSON(w, http.StatusOK, model.NewPeers(peers))
}
}
// handleDelete returns a gorm handler function.
// handleDelete returns a gorm Handler function.
//
// @ID interfaces_handleDelete
// @Tags Interface
@ -282,29 +292,28 @@ func (e interfaceEndpoint) handlePeersGet() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /interface/{id} [delete]
func (e interfaceEndpoint) handleDelete() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := Base64UrlDecode(c.Param("id"))
func (e InterfaceEndpoint) handleDelete() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := Base64UrlDecode(request.Path(r, "id"))
if id == "" {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
return
}
err := e.app.DeleteInterface(ctx, domain.InterfaceIdentifier(id))
err := e.app.DeleteInterface(r.Context(), domain.InterfaceIdentifier(id))
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
})
return
}
c.Status(http.StatusNoContent)
respond.Status(w, http.StatusNoContent)
}
}
// handleSaveConfigPost returns a gorm handler function.
// handleSaveConfigPost returns a gorm Handler function.
//
// @ID interfaces_handleSaveConfigPost
// @Tags Interface
@ -315,29 +324,28 @@ func (e interfaceEndpoint) handleDelete() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /interface/{id}/save-config [post]
func (e interfaceEndpoint) handleSaveConfigPost() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := Base64UrlDecode(c.Param("id"))
func (e InterfaceEndpoint) handleSaveConfigPost() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := Base64UrlDecode(request.Path(r, "id"))
if id == "" {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
return
}
err := e.app.PersistInterfaceConfig(ctx, domain.InterfaceIdentifier(id))
err := e.app.PersistInterfaceConfig(r.Context(), domain.InterfaceIdentifier(id))
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
})
return
}
c.Status(http.StatusNoContent)
respond.Status(w, http.StatusNoContent)
}
}
// handleApplyPeerDefaultsPost returns a gorm handler function.
// handleApplyPeerDefaultsPost returns a gorm Handler function.
//
// @ID interfaces_handleApplyPeerDefaultsPost
// @Tags Interface
@ -349,36 +357,38 @@ func (e interfaceEndpoint) handleSaveConfigPost() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /interface/{id}/apply-peer-defaults [post]
func (e interfaceEndpoint) handleApplyPeerDefaultsPost() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := Base64UrlDecode(c.Param("id"))
func (e InterfaceEndpoint) handleApplyPeerDefaultsPost() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := Base64UrlDecode(request.Path(r, "id"))
if id == "" {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
return
}
var in model.Interface
err := c.BindJSON(&in)
if err != nil {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
if err := request.BodyJson(r, &in); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if err := e.validator.Struct(in); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if id != in.Identifier {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"})
return
}
err = e.app.ApplyPeerDefaults(ctx, model.NewDomainInterface(&in))
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
if err := e.app.ApplyPeerDefaults(r.Context(), model.NewDomainInterface(&in)); err != nil {
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
})
return
}
c.Status(http.StatusNoContent)
respond.Status(w, http.StatusNoContent)
}
}

View File

@ -4,39 +4,52 @@ import (
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/go-pkgz/routegroup"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/app/api/core/request"
"github.com/h44z/wg-portal/internal/app/api/core/respond"
"github.com/h44z/wg-portal/internal/app/api/v0/model"
"github.com/h44z/wg-portal/internal/domain"
)
type peerEndpoint struct {
type PeerEndpoint struct {
app *app.App
authenticator *authenticationHandler
authenticator Authenticator
validator Validator
}
func (e peerEndpoint) GetName() string {
func NewPeerEndpoint(app *app.App, authenticator Authenticator, validator Validator) PeerEndpoint {
return PeerEndpoint{
app: app,
authenticator: authenticator,
validator: validator,
}
}
func (e PeerEndpoint) GetName() string {
return "PeerEndpoint"
}
func (e peerEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandler) {
apiGroup := g.Group("/peer", e.authenticator.LoggedIn())
func (e PeerEndpoint) RegisterRoutes(g *routegroup.Bundle) {
apiGroup := g.Mount("/peer")
apiGroup.Use(e.authenticator.LoggedIn())
apiGroup.GET("/iface/:iface/all", e.authenticator.LoggedIn(ScopeAdmin), e.handleAllGet())
apiGroup.GET("/iface/:iface/stats", e.authenticator.LoggedIn(ScopeAdmin), e.handleStatsGet())
apiGroup.GET("/iface/:iface/prepare", e.authenticator.LoggedIn(), e.handlePrepareGet())
apiGroup.POST("/iface/:iface/new", e.authenticator.LoggedIn(), e.handleCreatePost())
apiGroup.POST("/iface/:iface/multiplenew", e.authenticator.LoggedIn(ScopeAdmin), e.handleCreateMultiplePost())
apiGroup.GET("/config-qr/:id", e.handleQrCodeGet())
apiGroup.POST("/config-mail", e.handleEmailPost())
apiGroup.GET("/config/:id", e.handleConfigGet())
apiGroup.GET("/:id", e.handleSingleGet())
apiGroup.PUT("/:id", e.handleUpdatePut())
apiGroup.DELETE("/:id", e.handleDelete())
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /iface/{iface}/all", e.handleAllGet())
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /iface/{iface}/stats", e.handleStatsGet())
apiGroup.HandleFunc("GET /iface/{iface}/prepare", e.handlePrepareGet())
apiGroup.HandleFunc("POST /iface/{iface}/new", e.handleCreatePost())
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("POST /iface/{iface}/multiplenew",
e.handleCreateMultiplePost())
apiGroup.HandleFunc("GET /config-qr/{id}", e.handleQrCodeGet())
apiGroup.HandleFunc("POST /config-mail", e.handleEmailPost())
apiGroup.HandleFunc("GET /config/{id}", e.handleConfigGet())
apiGroup.HandleFunc("GET /{id}", e.handleSingleGet())
apiGroup.HandleFunc("PUT /{id}", e.handleUpdatePut())
apiGroup.HandleFunc("DELETE /{id}", e.handleDelete())
}
// handleAllGet returns a gorm handler function.
// handleAllGet returns a gorm Handler function.
//
// @ID peers_handleAllGet
// @Tags Peer
@ -47,28 +60,27 @@ func (e peerEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandle
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /peer/iface/{iface}/all [get]
func (e peerEndpoint) handleAllGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
interfaceId := Base64UrlDecode(c.Param("iface"))
func (e PeerEndpoint) handleAllGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
interfaceId := Base64UrlDecode(request.Path(r, "iface"))
if interfaceId == "" {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
return
}
_, peers, err := e.app.GetInterfaceAndPeers(ctx, domain.InterfaceIdentifier(interfaceId))
_, peers, err := e.app.GetInterfaceAndPeers(r.Context(), domain.InterfaceIdentifier(interfaceId))
if err != nil {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.JSON(http.StatusOK, model.NewPeers(peers))
respond.JSON(w, http.StatusOK, model.NewPeers(peers))
}
}
// handleSingleGet returns a gorm handler function.
// handleSingleGet returns a gorm Handler function.
//
// @ID peers_handleSingleGet
// @Tags Peer
@ -79,28 +91,27 @@ func (e peerEndpoint) handleAllGet() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /peer/{id} [get]
func (e peerEndpoint) handleSingleGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
peerId := Base64UrlDecode(c.Param("id"))
func (e PeerEndpoint) handleSingleGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
peerId := Base64UrlDecode(request.Path(r, "id"))
if peerId == "" {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing id parameter"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "missing id parameter"})
return
}
peer, err := e.app.GetPeer(ctx, domain.PeerIdentifier(peerId))
peer, err := e.app.GetPeer(r.Context(), domain.PeerIdentifier(peerId))
if err != nil {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.JSON(http.StatusOK, model.NewPeer(peer))
respond.JSON(w, http.StatusOK, model.NewPeer(peer))
}
}
// handlePrepareGet returns a gorm handler function.
// handlePrepareGet returns a gorm Handler function.
//
// @ID peers_handlePrepareGet
// @Tags Peer
@ -111,28 +122,27 @@ func (e peerEndpoint) handleSingleGet() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /peer/iface/{iface}/prepare [get]
func (e peerEndpoint) handlePrepareGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
interfaceId := Base64UrlDecode(c.Param("iface"))
func (e PeerEndpoint) handlePrepareGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
interfaceId := Base64UrlDecode(request.Path(r, "iface"))
if interfaceId == "" {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
return
}
peer, err := e.app.PreparePeer(ctx, domain.InterfaceIdentifier(interfaceId))
peer, err := e.app.PreparePeer(r.Context(), domain.InterfaceIdentifier(interfaceId))
if err != nil {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.JSON(http.StatusOK, model.NewPeer(peer))
respond.JSON(w, http.StatusOK, model.NewPeer(peer))
}
}
// handleCreatePost returns a gorm handler function.
// handleCreatePost returns a gorm Handler function.
//
// @ID peers_handleCreatePost
// @Tags Peer
@ -144,40 +154,43 @@ func (e peerEndpoint) handlePrepareGet() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /peer/iface/{iface}/new [post]
func (e peerEndpoint) handleCreatePost() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
interfaceId := Base64UrlDecode(c.Param("iface"))
func (e PeerEndpoint) handleCreatePost() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
interfaceId := Base64UrlDecode(request.Path(r, "iface"))
if interfaceId == "" {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
return
}
var p model.Peer
err := c.BindJSON(&p)
if err != nil {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
if err := request.BodyJson(r, &p); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if err := e.validator.Struct(p); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if p.InterfaceIdentifier != interfaceId {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"})
return
}
newPeer, err := e.app.CreatePeer(ctx, model.NewDomainPeer(&p))
newPeer, err := e.app.CreatePeer(r.Context(), model.NewDomainPeer(&p))
if err != nil {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.JSON(http.StatusOK, model.NewPeer(newPeer))
respond.JSON(w, http.StatusOK, model.NewPeer(newPeer))
}
}
// handleCreateMultiplePost returns a gorm handler function.
// handleCreateMultiplePost returns a gorm Handler function.
//
// @ID peers_handleCreateMultiplePost
// @Tags Peer
@ -189,36 +202,38 @@ func (e peerEndpoint) handleCreatePost() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /peer/iface/{iface}/multiplenew [post]
func (e peerEndpoint) handleCreateMultiplePost() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
interfaceId := Base64UrlDecode(c.Param("iface"))
func (e PeerEndpoint) handleCreateMultiplePost() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
interfaceId := Base64UrlDecode(request.Path(r, "iface"))
if interfaceId == "" {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
return
}
var req model.MultiPeerRequest
err := c.BindJSON(&req)
if err != nil {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
if err := request.BodyJson(r, &req); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if err := e.validator.Struct(req); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
newPeers, err := e.app.CreateMultiplePeers(ctx, domain.InterfaceIdentifier(interfaceId),
newPeers, err := e.app.CreateMultiplePeers(r.Context(), domain.InterfaceIdentifier(interfaceId),
model.NewDomainPeerCreationRequest(&req))
if err != nil {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.JSON(http.StatusOK, model.NewPeers(newPeers))
respond.JSON(w, http.StatusOK, model.NewPeers(newPeers))
}
}
// handleUpdatePut returns a gorm handler function.
// handleUpdatePut returns a gorm Handler function.
//
// @ID peers_handleUpdatePut
// @Tags Peer
@ -230,40 +245,43 @@ func (e peerEndpoint) handleCreateMultiplePost() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /peer/{id} [put]
func (e peerEndpoint) handleUpdatePut() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
peerId := Base64UrlDecode(c.Param("id"))
func (e PeerEndpoint) handleUpdatePut() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
peerId := Base64UrlDecode(request.Path(r, "id"))
if peerId == "" {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing id parameter"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "missing id parameter"})
return
}
var p model.Peer
err := c.BindJSON(&p)
if err != nil {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
if err := request.BodyJson(r, &p); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if err := e.validator.Struct(p); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if p.Identifier != peerId {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "peer id mismatch"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "peer id mismatch"})
return
}
updatedPeer, err := e.app.UpdatePeer(ctx, model.NewDomainPeer(&p))
updatedPeer, err := e.app.UpdatePeer(r.Context(), model.NewDomainPeer(&p))
if err != nil {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.JSON(http.StatusOK, model.NewPeer(updatedPeer))
respond.JSON(w, http.StatusOK, model.NewPeer(updatedPeer))
}
}
// handleDelete returns a gorm handler function.
// handleDelete returns a gorm Handler function.
//
// @ID peers_handleDelete
// @Tags Peer
@ -274,28 +292,26 @@ func (e peerEndpoint) handleUpdatePut() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /peer/{id} [delete]
func (e peerEndpoint) handleDelete() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := Base64UrlDecode(c.Param("id"))
func (e PeerEndpoint) handleDelete() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := Base64UrlDecode(request.Path(r, "id"))
if id == "" {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
return
}
err := e.app.DeletePeer(ctx, domain.PeerIdentifier(id))
err := e.app.DeletePeer(r.Context(), domain.PeerIdentifier(id))
if err != nil {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.Status(http.StatusNoContent)
respond.Status(w, http.StatusNoContent)
}
}
// handleConfigGet returns a gorm handler function.
// handleConfigGet returns a gorm Handler function.
//
// @ID peers_handleConfigGet
// @Tags Peer
@ -306,21 +322,19 @@ func (e peerEndpoint) handleDelete() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /peer/config/{id} [get]
func (e peerEndpoint) handleConfigGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := Base64UrlDecode(c.Param("id"))
func (e PeerEndpoint) handleConfigGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := Base64UrlDecode(request.Path(r, "id"))
if id == "" {
c.JSON(http.StatusBadRequest, model.Error{
respond.JSON(w, http.StatusBadRequest, model.Error{
Code: http.StatusInternalServerError, Message: "missing id parameter",
})
return
}
config, err := e.app.GetPeerConfig(ctx, domain.PeerIdentifier(id))
config, err := e.app.GetPeerConfig(r.Context(), domain.PeerIdentifier(id))
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
})
return
@ -328,17 +342,17 @@ func (e peerEndpoint) handleConfigGet() gin.HandlerFunc {
configString, err := io.ReadAll(config)
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
})
return
}
c.JSON(http.StatusOK, string(configString))
respond.JSON(w, http.StatusOK, string(configString))
}
}
// handleQrCodeGet returns a gorm handler function.
// handleQrCodeGet returns a gorm Handler function.
//
// @ID peers_handleQrCodeGet
// @Tags Peer
@ -350,20 +364,19 @@ func (e peerEndpoint) handleConfigGet() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /peer/config-qr/{id} [get]
func (e peerEndpoint) handleQrCodeGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := Base64UrlDecode(c.Param("id"))
func (e PeerEndpoint) handleQrCodeGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := Base64UrlDecode(request.Path(r, "id"))
if id == "" {
c.JSON(http.StatusBadRequest, model.Error{
respond.JSON(w, http.StatusBadRequest, model.Error{
Code: http.StatusInternalServerError, Message: "missing id parameter",
})
return
}
config, err := e.app.GetPeerConfigQrCode(ctx, domain.PeerIdentifier(id))
config, err := e.app.GetPeerConfigQrCode(r.Context(), domain.PeerIdentifier(id))
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
})
return
@ -371,17 +384,17 @@ func (e peerEndpoint) handleQrCodeGet() gin.HandlerFunc {
configData, err := io.ReadAll(config)
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
})
return
}
c.Data(http.StatusOK, "image/png", configData)
respond.Data(w, http.StatusOK, "image/png", configData)
}
}
// handleEmailPost returns a gorm handler function.
// handleEmailPost returns a gorm Handler function.
//
// @ID peers_handleEmailPost
// @Tags Peer
@ -392,38 +405,39 @@ func (e peerEndpoint) handleQrCodeGet() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /peer/config-mail [post]
func (e peerEndpoint) handleEmailPost() gin.HandlerFunc {
return func(c *gin.Context) {
func (e PeerEndpoint) handleEmailPost() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req model.PeerMailRequest
err := c.BindJSON(&req)
if err != nil {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
if err := request.BodyJson(r, &req); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if err := e.validator.Struct(req); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if len(req.Identifiers) == 0 {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing peer identifiers"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "missing peer identifiers"})
return
}
ctx := domain.SetUserInfoFromGin(c)
peerIds := make([]domain.PeerIdentifier, len(req.Identifiers))
for i := range req.Identifiers {
peerIds[i] = domain.PeerIdentifier(req.Identifiers[i])
}
err = e.app.SendPeerEmail(ctx, req.LinkOnly, peerIds...)
if err != nil {
c.JSON(http.StatusInternalServerError,
if err := e.app.SendPeerEmail(r.Context(), req.LinkOnly, peerIds...); err != nil {
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.Status(http.StatusNoContent)
respond.Status(w, http.StatusNoContent)
}
}
// handleStatsGet returns a gorm handler function.
// handleStatsGet returns a gorm Handler function.
//
// @ID peers_handleStatsGet
// @Tags Peer
@ -434,23 +448,22 @@ func (e peerEndpoint) handleEmailPost() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /peer/iface/{iface}/stats [get]
func (e peerEndpoint) handleStatsGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
interfaceId := Base64UrlDecode(c.Param("iface"))
func (e PeerEndpoint) handleStatsGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
interfaceId := Base64UrlDecode(request.Path(r, "iface"))
if interfaceId == "" {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
return
}
stats, err := e.app.GetPeerStats(ctx, domain.InterfaceIdentifier(interfaceId))
stats, err := e.app.GetPeerStats(r.Context(), domain.InterfaceIdentifier(interfaceId))
if err != nil {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.JSON(http.StatusOK, model.NewPeerStats(e.app.Config.Statistics.CollectPeerData, stats))
respond.JSON(w, http.StatusOK, model.NewPeerStats(e.app.Config.Statistics.CollectPeerData, stats))
}
}

View File

@ -5,20 +5,29 @@ import (
"os"
"time"
"github.com/gin-gonic/gin"
"github.com/go-pkgz/routegroup"
"github.com/h44z/wg-portal/internal/app/api/core/respond"
"github.com/h44z/wg-portal/internal/app/api/v0/model"
)
type testEndpoint struct{}
type TestEndpoint struct {
authenticator Authenticator
}
func (e testEndpoint) GetName() string {
func NewTestEndpoint(authenticator Authenticator) TestEndpoint {
return TestEndpoint{
authenticator: authenticator,
}
}
func (e TestEndpoint) GetName() string {
return "TestEndpoint"
}
func (e testEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandler) {
g.GET("/now", e.handleCurrentTimeGet())
g.GET("/hostname", e.handleHostnameGet())
func (e TestEndpoint) RegisterRoutes(g *routegroup.Bundle) {
g.HandleFunc("GET /now", e.handleCurrentTimeGet())
g.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /hostname", e.handleHostnameGet())
}
// handleCurrentTimeGet represents the GET endpoint that responds the current time
@ -31,15 +40,15 @@ func (e testEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandle
// @Success 200 {object} string
// @Failure 500 {object} model.Error
// @Router /now [get]
func (e testEndpoint) handleCurrentTimeGet() gin.HandlerFunc {
return func(c *gin.Context) {
func (e TestEndpoint) handleCurrentTimeGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if time.Now().Second() == 0 {
c.JSON(http.StatusInternalServerError, model.Error{
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError,
Message: "invalid time",
})
}
c.JSON(http.StatusOK, time.Now().String())
respond.JSON(w, http.StatusOK, time.Now().String())
}
}
@ -53,15 +62,15 @@ func (e testEndpoint) handleCurrentTimeGet() gin.HandlerFunc {
// @Success 200 {object} string
// @Failure 500 {object} model.Error
// @Router /hostname [get]
func (e testEndpoint) handleHostnameGet() gin.HandlerFunc {
return func(c *gin.Context) {
func (e TestEndpoint) handleHostnameGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
hostname, err := os.Hostname()
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError,
Message: err.Error(),
})
}
c.JSON(http.StatusOK, hostname)
respond.JSON(w, http.StatusOK, hostname)
}
}

View File

@ -3,38 +3,50 @@ package handlers
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/go-pkgz/routegroup"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/app/api/core/request"
"github.com/h44z/wg-portal/internal/app/api/core/respond"
"github.com/h44z/wg-portal/internal/app/api/v0/model"
"github.com/h44z/wg-portal/internal/domain"
)
type userEndpoint struct {
type UserEndpoint struct {
app *app.App
authenticator *authenticationHandler
authenticator Authenticator
validator Validator
}
func (e userEndpoint) GetName() string {
func NewUserEndpoint(app *app.App, authenticator Authenticator, validator Validator) UserEndpoint {
return UserEndpoint{
app: app,
authenticator: authenticator,
validator: validator,
}
}
func (e UserEndpoint) GetName() string {
return "UserEndpoint"
}
func (e userEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandler) {
apiGroup := g.Group("/user", e.authenticator.LoggedIn())
func (e UserEndpoint) RegisterRoutes(g *routegroup.Bundle) {
apiGroup := g.Mount("/user")
apiGroup.Use(e.authenticator.LoggedIn())
apiGroup.GET("/all", e.authenticator.LoggedIn(ScopeAdmin), e.handleAllGet())
apiGroup.GET("/:id", e.authenticator.UserIdMatch("id"), e.handleSingleGet())
apiGroup.PUT("/:id", e.authenticator.UserIdMatch("id"), e.handleUpdatePut())
apiGroup.DELETE("/:id", e.authenticator.UserIdMatch("id"), e.handleDelete())
apiGroup.POST("/new", e.authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost())
apiGroup.GET("/:id/peers", e.authenticator.UserIdMatch("id"), e.handlePeersGet())
apiGroup.GET("/:id/stats", e.authenticator.UserIdMatch("id"), e.handleStatsGet())
apiGroup.GET("/:id/interfaces", e.authenticator.UserIdMatch("id"), e.handleInterfacesGet())
apiGroup.POST("/:id/api/enable", e.authenticator.UserIdMatch("id"), e.handleApiEnablePost())
apiGroup.POST("/:id/api/disable", e.authenticator.UserIdMatch("id"), e.handleApiDisablePost())
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /all", e.handleAllGet())
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("GET /{id}", e.handleSingleGet())
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("PUT /{id}", e.handleUpdatePut())
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("DELETE /{id}", e.handleDelete())
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("POST /new", e.handleCreatePost())
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("GET /{id}/peers", e.handlePeersGet())
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("GET /{id}/stats", e.handleStatsGet())
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("GET /{id}/interfaces", e.handleInterfacesGet())
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("POST /{id}/api/enable", e.handleApiEnablePost())
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("POST /{id}/api/disable", e.handleApiDisablePost())
}
// handleAllGet returns a gorm handler function.
// handleAllGet returns a gorm Handler function.
//
// @ID users_handleAllGet
// @Tags Users
@ -43,22 +55,20 @@ func (e userEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandle
// @Success 200 {object} []model.User
// @Failure 500 {object} model.Error
// @Router /user/all [get]
func (e userEndpoint) handleAllGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
users, err := e.app.GetAllUsers(ctx)
func (e UserEndpoint) handleAllGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
users, err := e.app.GetAllUsers(r.Context())
if err != nil {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.JSON(http.StatusOK, model.NewUsers(users))
respond.JSON(w, http.StatusOK, model.NewUsers(users))
}
}
// handleSingleGet returns a gorm handler function.
// handleSingleGet returns a gorm Handler function.
//
// @ID users_handleSingleGet
// @Tags Users
@ -68,28 +78,26 @@ func (e userEndpoint) handleAllGet() gin.HandlerFunc {
// @Success 200 {object} model.User
// @Failure 500 {object} model.Error
// @Router /user/{id} [get]
func (e userEndpoint) handleSingleGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := Base64UrlDecode(c.Param("id"))
func (e UserEndpoint) handleSingleGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := Base64UrlDecode(request.Path(r, "id"))
if id == "" {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing user id"})
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing user id"})
return
}
user, err := e.app.GetUser(ctx, domain.UserIdentifier(id))
user, err := e.app.GetUser(r.Context(), domain.UserIdentifier(id))
if err != nil {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.JSON(http.StatusOK, model.NewUser(user, true))
respond.JSON(w, http.StatusOK, model.NewUser(user, true))
}
}
// handleUpdatePut returns a gorm handler function.
// handleUpdatePut returns a gorm Handler function.
//
// @ID users_handleUpdatePut
// @Tags Users
@ -101,40 +109,42 @@ func (e userEndpoint) handleSingleGet() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /user/{id} [put]
func (e userEndpoint) handleUpdatePut() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := Base64UrlDecode(c.Param("id"))
func (e UserEndpoint) handleUpdatePut() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := Base64UrlDecode(request.Path(r, "id"))
if id == "" {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing user id"})
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing user id"})
return
}
var user model.User
err := c.BindJSON(&user)
if err != nil {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
if err := request.BodyJson(r, &user); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if err := e.validator.Struct(user); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if id != user.Identifier {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "user id mismatch"})
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "user id mismatch"})
return
}
updateUser, err := e.app.UpdateUser(ctx, model.NewDomainUser(&user))
updateUser, err := e.app.UpdateUser(r.Context(), model.NewDomainUser(&user))
if err != nil {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.JSON(http.StatusOK, model.NewUser(updateUser, false))
respond.JSON(w, http.StatusOK, model.NewUser(updateUser, false))
}
}
// handleCreatePost returns a gorm handler function.
// handleCreatePost returns a gorm Handler function.
//
// @ID users_handleCreatePost
// @Tags Users
@ -145,29 +155,30 @@ func (e userEndpoint) handleUpdatePut() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /user/new [post]
func (e userEndpoint) handleCreatePost() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
func (e UserEndpoint) handleCreatePost() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var user model.User
err := c.BindJSON(&user)
if err != nil {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
if err := request.BodyJson(r, &user); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if err := e.validator.Struct(user); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
newUser, err := e.app.CreateUser(ctx, model.NewDomainUser(&user))
newUser, err := e.app.CreateUser(r.Context(), model.NewDomainUser(&user))
if err != nil {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.JSON(http.StatusOK, model.NewUser(newUser, false))
respond.JSON(w, http.StatusOK, model.NewUser(newUser, false))
}
}
// handlePeersGet returns a gorm handler function.
// handlePeersGet returns a gorm Handler function.
//
// @ID users_handlePeersGet
// @Tags Users
@ -178,29 +189,27 @@ func (e userEndpoint) handleCreatePost() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /user/{id}/peers [get]
func (e userEndpoint) handlePeersGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
userId := Base64UrlDecode(c.Param("id"))
func (e UserEndpoint) handlePeersGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userId := Base64UrlDecode(request.Path(r, "id"))
if userId == "" {
c.JSON(http.StatusBadRequest,
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusInternalServerError, Message: "missing id parameter"})
return
}
peers, err := e.app.GetUserPeers(ctx, domain.UserIdentifier(userId))
peers, err := e.app.GetUserPeers(r.Context(), domain.UserIdentifier(userId))
if err != nil {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.JSON(http.StatusOK, model.NewPeers(peers))
respond.JSON(w, http.StatusOK, model.NewPeers(peers))
}
}
// handleStatsGet returns a gorm handler function.
// handleStatsGet returns a gorm Handler function.
//
// @ID users_handleStatsGet
// @Tags Users
@ -211,29 +220,27 @@ func (e userEndpoint) handlePeersGet() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /user/{id}/stats [get]
func (e userEndpoint) handleStatsGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
userId := Base64UrlDecode(c.Param("id"))
func (e UserEndpoint) handleStatsGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userId := Base64UrlDecode(request.Path(r, "id"))
if userId == "" {
c.JSON(http.StatusBadRequest,
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusInternalServerError, Message: "missing id parameter"})
return
}
stats, err := e.app.GetUserPeerStats(ctx, domain.UserIdentifier(userId))
stats, err := e.app.GetUserPeerStats(r.Context(), domain.UserIdentifier(userId))
if err != nil {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.JSON(http.StatusOK, model.NewPeerStats(e.app.Config.Statistics.CollectPeerData, stats))
respond.JSON(w, http.StatusOK, model.NewPeerStats(e.app.Config.Statistics.CollectPeerData, stats))
}
}
// handleInterfacesGet returns a gorm handler function.
// handleInterfacesGet returns a gorm Handler function.
//
// @ID users_handleInterfacesGet
// @Tags Users
@ -244,29 +251,27 @@ func (e userEndpoint) handleStatsGet() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /user/{id}/interfaces [get]
func (e userEndpoint) handleInterfacesGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
userId := Base64UrlDecode(c.Param("id"))
func (e UserEndpoint) handleInterfacesGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userId := Base64UrlDecode(request.Path(r, "id"))
if userId == "" {
c.JSON(http.StatusBadRequest,
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusInternalServerError, Message: "missing id parameter"})
return
}
peers, err := e.app.GetUserInterfaces(ctx, domain.UserIdentifier(userId))
peers, err := e.app.GetUserInterfaces(r.Context(), domain.UserIdentifier(userId))
if err != nil {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.JSON(http.StatusOK, model.NewInterfaces(peers, nil))
respond.JSON(w, http.StatusOK, model.NewInterfaces(peers, nil))
}
}
// handleDelete returns a gorm handler function.
// handleDelete returns a gorm Handler function.
//
// @ID users_handleDelete
// @Tags Users
@ -277,28 +282,26 @@ func (e userEndpoint) handleInterfacesGet() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /user/{id} [delete]
func (e userEndpoint) handleDelete() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := Base64UrlDecode(c.Param("id"))
func (e UserEndpoint) handleDelete() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := Base64UrlDecode(request.Path(r, "id"))
if id == "" {
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing user id"})
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing user id"})
return
}
err := e.app.DeleteUser(ctx, domain.UserIdentifier(id))
err := e.app.DeleteUser(r.Context(), domain.UserIdentifier(id))
if err != nil {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.Status(http.StatusNoContent)
respond.Status(w, http.StatusNoContent)
}
}
// handleApiEnablePost returns a gorm handler function.
// handleApiEnablePost returns a gorm Handler function.
//
// @ID users_handleApiEnablePost
// @Tags Users
@ -308,29 +311,27 @@ func (e userEndpoint) handleDelete() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /user/{id}/api/enable [post]
func (e userEndpoint) handleApiEnablePost() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
userId := Base64UrlDecode(c.Param("id"))
func (e UserEndpoint) handleApiEnablePost() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userId := Base64UrlDecode(request.Path(r, "id"))
if userId == "" {
c.JSON(http.StatusBadRequest,
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusInternalServerError, Message: "missing id parameter"})
return
}
user, err := e.app.ActivateApi(ctx, domain.UserIdentifier(userId))
user, err := e.app.ActivateApi(r.Context(), domain.UserIdentifier(userId))
if err != nil {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.JSON(http.StatusOK, model.NewUser(user, true))
respond.JSON(w, http.StatusOK, model.NewUser(user, true))
}
}
// handleApiDisablePost returns a gorm handler function.
// handleApiDisablePost returns a gorm Handler function.
//
// @ID users_handleApiDisablePost
// @Tags Users
@ -340,24 +341,22 @@ func (e userEndpoint) handleApiEnablePost() gin.HandlerFunc {
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /user/{id}/api/disable [post]
func (e userEndpoint) handleApiDisablePost() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
userId := Base64UrlDecode(c.Param("id"))
func (e UserEndpoint) handleApiDisablePost() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userId := Base64UrlDecode(request.Path(r, "id"))
if userId == "" {
c.JSON(http.StatusBadRequest,
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusInternalServerError, Message: "missing id parameter"})
return
}
user, err := e.app.DeactivateApi(ctx, domain.UserIdentifier(userId))
user, err := e.app.DeactivateApi(r.Context(), domain.UserIdentifier(userId))
if err != nil {
c.JSON(http.StatusInternalServerError,
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
c.JSON(http.StatusOK, model.NewUser(user, false))
respond.JSON(w, http.StatusOK, model.NewUser(user, false))
}
}

View File

@ -1,111 +0,0 @@
package handlers
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/app/api/v0/model"
"github.com/h44z/wg-portal/internal/domain"
)
type Scope string
const (
ScopeAdmin Scope = "ADMIN" // Admin scope contains all other scopes
)
type authenticationHandler struct {
app *app.App
Session SessionStore
}
// LoggedIn checks if a user is logged in. If scopes are given, they are validated as well.
func (h authenticationHandler) LoggedIn(scopes ...Scope) gin.HandlerFunc {
return func(c *gin.Context) {
session := h.Session.GetData(c)
if !session.LoggedIn {
// Abort the request with the appropriate error code
c.Abort()
c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "not logged in"})
return
}
if !UserHasScopes(session, scopes...) {
// Abort the request with the appropriate error code
c.Abort()
c.JSON(http.StatusForbidden, model.Error{Code: http.StatusForbidden, Message: "not enough permissions"})
return
}
// Check if logged-in user is still valid
if !h.app.Authenticator.IsUserValid(c.Request.Context(), domain.UserIdentifier(session.UserIdentifier)) {
h.Session.DestroyData(c)
c.Abort()
c.JSON(http.StatusUnauthorized,
model.Error{Code: http.StatusUnauthorized, Message: "session no longer available"})
return
}
c.Set(domain.CtxUserInfo, &domain.ContextUserInfo{
Id: domain.UserIdentifier(session.UserIdentifier),
IsAdmin: session.IsAdmin,
})
// Continue down the chain to handler etc
c.Next()
}
}
// UserIdMatch checks if the user id in the session matches the user id in the request. If not, the request is aborted.
func (h authenticationHandler) UserIdMatch(idParameter string) gin.HandlerFunc {
return func(c *gin.Context) {
session := h.Session.GetData(c)
if session.IsAdmin {
c.Next() // Admins can do everything
return
}
sessionUserId := domain.UserIdentifier(session.UserIdentifier)
requestUserId := domain.UserIdentifier(Base64UrlDecode(c.Param(idParameter)))
if sessionUserId != requestUserId {
// Abort the request with the appropriate error code
c.Abort()
c.JSON(http.StatusForbidden, model.Error{Code: http.StatusForbidden, Message: "not enough permissions"})
return
}
// Continue down the chain to handler etc
c.Next()
}
}
func UserHasScopes(session SessionData, scopes ...Scope) bool {
// No scopes give, so the check should succeed
if len(scopes) == 0 {
return true
}
// check if user has admin scope
if session.IsAdmin {
return true
}
// Check if admin scope is required
for _, scope := range scopes {
if scope == ScopeAdmin {
return false
}
}
// For all other scopes, a logged-in user is sufficient (for now)
if session.LoggedIn {
return true
}
return false
}

View File

@ -1,92 +0,0 @@
package handlers
import (
"encoding/gob"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
func init() {
gob.Register(SessionData{})
}
type SessionData struct {
LoggedIn bool
IsAdmin bool
UserIdentifier string
Firstname string
Lastname string
Email string
OauthState string
OauthNonce string
OauthProvider string
OauthReturnTo string
}
type SessionStore interface {
DefaultSessionData() SessionData
GetData(c *gin.Context) SessionData
SetData(c *gin.Context, data SessionData)
DestroyData(c *gin.Context)
}
type GinSessionStore struct {
sessionIdentifier string
}
func (g GinSessionStore) GetData(c *gin.Context) SessionData {
session := sessions.Default(c)
rawSessionData := session.Get(g.sessionIdentifier)
var sessionData SessionData
if rawSessionData != nil {
sessionData = rawSessionData.(SessionData)
} else {
// init a new default session
sessionData = g.DefaultSessionData()
session.Set(g.sessionIdentifier, sessionData)
if err := session.Save(); err != nil {
panic(fmt.Sprintf("failed to store session: %v", err))
}
}
return sessionData
}
func (g GinSessionStore) DefaultSessionData() SessionData {
return SessionData{
LoggedIn: false,
IsAdmin: false,
UserIdentifier: "",
Firstname: "",
Lastname: "",
Email: "",
OauthState: "",
OauthNonce: "",
OauthProvider: "",
OauthReturnTo: "",
}
}
func (g GinSessionStore) SetData(c *gin.Context, data SessionData) {
session := sessions.Default(c)
session.Set(g.sessionIdentifier, data)
if err := session.Save(); err != nil {
panic(fmt.Sprintf("failed to store session: %v", err))
}
}
func (g GinSessionStore) DestroyData(c *gin.Context) {
session := sessions.Default(c)
session.Delete(g.sessionIdentifier)
if err := session.Save(); err != nil {
panic(fmt.Sprintf("failed to store session: %v", err))
}
}

View File

@ -0,0 +1,126 @@
package handlers
import (
"context"
"net/http"
"github.com/h44z/wg-portal/internal/app/api/core/request"
"github.com/h44z/wg-portal/internal/app/api/core/respond"
"github.com/h44z/wg-portal/internal/app/api/v0/model"
"github.com/h44z/wg-portal/internal/domain"
)
type Scope string
const (
ScopeAdmin Scope = "ADMIN" // Admin scope contains all other scopes
)
type UserAuthenticator interface {
IsUserValid(ctx context.Context, id domain.UserIdentifier) bool
}
type AuthenticationHandler struct {
authenticator UserAuthenticator
session Session
}
func NewAuthenticationHandler(authenticator UserAuthenticator, session Session) AuthenticationHandler {
return AuthenticationHandler{
authenticator: authenticator,
session: session,
}
}
// LoggedIn checks if a user is logged in. If scopes are given, they are validated as well.
func (h AuthenticationHandler) LoggedIn(scopes ...Scope) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session := h.session.GetData(r.Context())
if !session.LoggedIn {
// Abort the request with the appropriate error code
respond.JSON(w, http.StatusUnauthorized,
model.Error{Code: http.StatusUnauthorized, Message: "not logged in"})
return
}
if !UserHasScopes(session, scopes...) {
// Abort the request with the appropriate error code
respond.JSON(w, http.StatusForbidden,
model.Error{Code: http.StatusForbidden, Message: "not enough permissions"})
return
}
// Check if logged-in user is still valid
if !h.authenticator.IsUserValid(r.Context(), domain.UserIdentifier(session.UserIdentifier)) {
h.session.DestroyData(r.Context())
respond.JSON(w, http.StatusUnauthorized,
model.Error{Code: http.StatusUnauthorized, Message: "session no longer available"})
return
}
ctx := context.WithValue(r.Context(), domain.CtxUserInfo, &domain.ContextUserInfo{
Id: domain.UserIdentifier(session.UserIdentifier),
IsAdmin: session.IsAdmin,
})
r = r.WithContext(ctx)
// Continue down the chain to Handler etc
next.ServeHTTP(w, r)
})
}
}
// UserIdMatch checks if the user id in the session matches the user id in the request. If not, the request is aborted.
func (h AuthenticationHandler) UserIdMatch(idParameter string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session := h.session.GetData(r.Context())
if session.IsAdmin {
next.ServeHTTP(w, r) // Admins can do everything
return
}
sessionUserId := domain.UserIdentifier(session.UserIdentifier)
requestUserId := domain.UserIdentifier(Base64UrlDecode(request.Path(r, idParameter)))
if sessionUserId != requestUserId {
// Abort the request with the appropriate error code
respond.JSON(w, http.StatusForbidden,
model.Error{Code: http.StatusForbidden, Message: "not enough permissions"})
return
}
// Continue down the chain to Handler etc
next.ServeHTTP(w, r)
})
}
}
func UserHasScopes(session SessionData, scopes ...Scope) bool {
// No scopes give, so the check should succeed
if len(scopes) == 0 {
return true
}
// check if user has admin scope
if session.IsAdmin {
return true
}
// Check if admin scope is required
for _, scope := range scopes {
if scope == ScopeAdmin {
return false
}
}
// For all other scopes, a logged-in user is sufficient (for now)
if session.LoggedIn {
return true
}
return false
}

View File

@ -0,0 +1,88 @@
package handlers
import (
"context"
"encoding/gob"
"net/http"
"strings"
"time"
"github.com/alexedwards/scs/v2"
"github.com/h44z/wg-portal/internal/config"
)
func init() {
gob.Register(SessionData{})
}
type SessionData struct {
LoggedIn bool
IsAdmin bool
UserIdentifier string
Firstname string
Lastname string
Email string
OauthState string
OauthNonce string
OauthProvider string
OauthReturnTo string
CsrfToken string
}
const sessionApiV0Key = "session_api_v0"
type SessionWrapper struct {
*scs.SessionManager
}
func NewSessionWrapper(cfg *config.Config) *SessionWrapper {
sessionManager := scs.New()
sessionManager.Lifetime = 24 * time.Hour
sessionManager.IdleTimeout = 1 * time.Hour
sessionManager.Cookie.Name = cfg.Web.SessionIdentifier
sessionManager.Cookie.Secure = strings.HasPrefix(cfg.Web.ExternalUrl, "https")
sessionManager.Cookie.HttpOnly = true
sessionManager.Cookie.SameSite = http.SameSiteLaxMode
sessionManager.Cookie.Path = "/"
sessionManager.Cookie.Persist = false
wrappedSessionManager := &SessionWrapper{sessionManager}
return wrappedSessionManager
}
func (s *SessionWrapper) SetData(ctx context.Context, value SessionData) {
s.SessionManager.Put(ctx, sessionApiV0Key, value)
}
func (s *SessionWrapper) GetData(ctx context.Context) SessionData {
sessionData, ok := s.SessionManager.Get(ctx, sessionApiV0Key).(SessionData)
if !ok {
return s.defaultSessionData()
}
return sessionData
}
func (s *SessionWrapper) DestroyData(ctx context.Context) {
_ = s.SessionManager.Destroy(ctx)
}
func (s *SessionWrapper) defaultSessionData() SessionData {
return SessionData{
LoggedIn: false,
IsAdmin: false,
UserIdentifier: "",
Firstname: "",
Lastname: "",
Email: "",
OauthState: "",
OauthNonce: "",
OauthProvider: "",
OauthReturnTo: "",
}
}

View File

@ -4,17 +4,19 @@ import (
"errors"
"net/http"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/go-pkgz/routegroup"
"github.com/h44z/wg-portal/internal/app/api/core"
"github.com/h44z/wg-portal/internal/app/api/core/middleware/cors"
"github.com/h44z/wg-portal/internal/app/api/v1/models"
"github.com/h44z/wg-portal/internal/domain"
)
type Handler interface {
// GetName returns the name of the handler.
GetName() string
RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler)
// RegisterRoutes registers the routes for the handler. The session manager is passed to the handler.
RegisterRoutes(g *routegroup.Bundle)
}
// To compile the API documentation use the
@ -38,18 +40,14 @@ type Handler interface {
// @BasePath /api/v1
// @query.collection.format multi
func NewRestApi(userSource UserSource, handlers ...Handler) core.ApiEndpointSetupFunc {
authenticator := &authenticationHandler{
userSource: userSource,
}
func NewRestApi(handlers ...Handler) core.ApiEndpointSetupFunc {
return func() (core.ApiVersion, core.GroupSetupFn) {
return "v1", func(group *gin.RouterGroup) {
group.Use(cors.Default())
return "v1", func(group *routegroup.Bundle) {
group.Use(cors.New().Handler)
// Handler functions
for _, h := range handlers {
h.RegisterRoutes(group, authenticator)
h.RegisterRoutes(group)
}
}
}
@ -80,3 +78,12 @@ func ParseServiceError(err error) (int, models.Error) {
Message: err.Error(),
}
}
type Authenticator interface {
// LoggedIn checks if a user is logged in. If scopes are given, they are validated as well.
LoggedIn(scopes ...Scope) func(next http.Handler) http.Handler
}
type Validator interface {
Struct(s interface{}) error
}

View File

@ -4,8 +4,10 @@ import (
"context"
"net/http"
"github.com/gin-gonic/gin"
"github.com/go-pkgz/routegroup"
"github.com/h44z/wg-portal/internal/app/api/core/request"
"github.com/h44z/wg-portal/internal/app/api/core/respond"
"github.com/h44z/wg-portal/internal/app/api/v1/models"
"github.com/h44z/wg-portal/internal/domain"
)
@ -19,12 +21,20 @@ type InterfaceEndpointInterfaceService interface {
}
type InterfaceEndpoint struct {
interfaces InterfaceEndpointInterfaceService
interfaces InterfaceEndpointInterfaceService
authenticator Authenticator
validator Validator
}
func NewInterfaceEndpoint(interfaceService InterfaceEndpointInterfaceService) *InterfaceEndpoint {
func NewInterfaceEndpoint(
authenticator Authenticator,
validator Validator,
interfaceService InterfaceEndpointInterfaceService,
) *InterfaceEndpoint {
return &InterfaceEndpoint{
interfaces: interfaceService,
authenticator: authenticator,
validator: validator,
interfaces: interfaceService,
}
}
@ -32,15 +42,16 @@ func (e InterfaceEndpoint) GetName() string {
return "InterfaceEndpoint"
}
func (e InterfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
apiGroup := g.Group("/interface", authenticator.LoggedIn())
func (e InterfaceEndpoint) RegisterRoutes(g *routegroup.Bundle) {
apiGroup := g.Mount("/interface")
apiGroup.Use(e.authenticator.LoggedIn(ScopeAdmin))
apiGroup.GET("/all", authenticator.LoggedIn(ScopeAdmin), e.handleAllGet())
apiGroup.GET("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleByIdGet())
apiGroup.HandleFunc("GET /all", e.handleAllGet())
apiGroup.HandleFunc("GET /by-id/{id}", e.handleByIdGet())
apiGroup.POST("/new", authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost())
apiGroup.PUT("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleUpdatePut())
apiGroup.DELETE("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleDelete())
apiGroup.HandleFunc("POST /new", e.handleCreatePost())
apiGroup.HandleFunc("PUT /by-id/{id}", e.handleUpdatePut())
apiGroup.HandleFunc("DELETE /by-id/{id}", e.handleDelete())
}
// handleAllGet returns a gorm Handler function.
@ -54,17 +65,16 @@ func (e InterfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *aut
// @Failure 500 {object} models.Error
// @Router /interface/all [get]
// @Security BasicAuth
func (e InterfaceEndpoint) handleAllGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
allInterfaces, allPeersPerInterface, err := e.interfaces.GetAll(ctx)
func (e InterfaceEndpoint) handleAllGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
allInterfaces, allPeersPerInterface, err := e.interfaces.GetAll(r.Context())
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.JSON(http.StatusOK, models.NewInterfaces(allInterfaces, allPeersPerInterface))
respond.JSON(w, http.StatusOK, models.NewInterfaces(allInterfaces, allPeersPerInterface))
}
}
@ -82,23 +92,23 @@ func (e InterfaceEndpoint) handleAllGet() gin.HandlerFunc {
// @Failure 500 {object} models.Error
// @Router /interface/by-id/{id} [get]
// @Security BasicAuth
func (e InterfaceEndpoint) handleByIdGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := c.Param("id")
func (e InterfaceEndpoint) handleByIdGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := request.Path(r, "id")
if id == "" {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
respond.JSON(w, http.StatusBadRequest,
models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
return
}
iface, interfacePeers, err := e.interfaces.GetById(ctx, domain.InterfaceIdentifier(id))
iface, interfacePeers, err := e.interfaces.GetById(r.Context(), domain.InterfaceIdentifier(id))
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.JSON(http.StatusOK, models.NewInterface(iface, interfacePeers))
respond.JSON(w, http.StatusOK, models.NewInterface(iface, interfacePeers))
}
}
@ -117,24 +127,26 @@ func (e InterfaceEndpoint) handleByIdGet() gin.HandlerFunc {
// @Failure 500 {object} models.Error
// @Router /interface/new [post]
// @Security BasicAuth
func (e InterfaceEndpoint) handleCreatePost() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
func (e InterfaceEndpoint) handleCreatePost() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var iface models.Interface
err := c.BindJSON(&iface)
if err != nil {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
if err := request.BodyJson(r, &iface); err != nil {
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if err := e.validator.Struct(iface); err != nil {
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
newInterface, err := e.interfaces.Create(ctx, models.NewDomainInterface(&iface))
newInterface, err := e.interfaces.Create(r.Context(), models.NewDomainInterface(&iface))
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.JSON(http.StatusOK, models.NewInterface(newInterface, nil))
respond.JSON(w, http.StatusOK, models.NewInterface(newInterface, nil))
}
}
@ -154,34 +166,43 @@ func (e InterfaceEndpoint) handleCreatePost() gin.HandlerFunc {
// @Failure 500 {object} models.Error
// @Router /interface/by-id/{id} [put]
// @Security BasicAuth
func (e InterfaceEndpoint) handleUpdatePut() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := c.Param("id")
func (e InterfaceEndpoint) handleUpdatePut() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := request.Path(r, "id")
if id == "" {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
respond.JSON(w, http.StatusBadRequest,
models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
return
}
var iface models.Interface
err := c.BindJSON(&iface)
if err != nil {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
if err := request.BodyJson(r, &iface); err != nil {
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if err := e.validator.Struct(iface); err != nil {
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if id != iface.Identifier {
respond.JSON(w, http.StatusBadRequest,
models.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"})
return
}
updatedInterface, updatedInterfacePeers, err := e.interfaces.Update(
ctx,
r.Context(),
domain.InterfaceIdentifier(id),
models.NewDomainInterface(&iface),
)
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.JSON(http.StatusOK, models.NewInterface(updatedInterface, updatedInterfacePeers))
respond.JSON(w, http.StatusOK, models.NewInterface(updatedInterface, updatedInterfacePeers))
}
}
@ -200,22 +221,22 @@ func (e InterfaceEndpoint) handleUpdatePut() gin.HandlerFunc {
// @Failure 500 {object} models.Error
// @Router /interface/by-id/{id} [delete]
// @Security BasicAuth
func (e InterfaceEndpoint) handleDelete() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := c.Param("id")
func (e InterfaceEndpoint) handleDelete() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := request.Path(r, "id")
if id == "" {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
respond.JSON(w, http.StatusBadRequest,
models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
return
}
err := e.interfaces.Delete(ctx, domain.InterfaceIdentifier(id))
err := e.interfaces.Delete(r.Context(), domain.InterfaceIdentifier(id))
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.Status(http.StatusNoContent)
respond.Status(w, http.StatusNoContent)
}
}

View File

@ -4,8 +4,10 @@ import (
"context"
"net/http"
"github.com/gin-gonic/gin"
"github.com/go-pkgz/routegroup"
"github.com/h44z/wg-portal/internal/app/api/core/request"
"github.com/h44z/wg-portal/internal/app/api/core/respond"
"github.com/h44z/wg-portal/internal/app/api/v1/models"
"github.com/h44z/wg-portal/internal/domain"
)
@ -17,12 +19,20 @@ type MetricsEndpointStatisticsService interface {
}
type MetricsEndpoint struct {
metrics MetricsEndpointStatisticsService
metrics MetricsEndpointStatisticsService
authenticator Authenticator
validator Validator
}
func NewMetricsEndpoint(metrics MetricsEndpointStatisticsService) *MetricsEndpoint {
func NewMetricsEndpoint(
authenticator Authenticator,
validator Validator,
metrics MetricsEndpointStatisticsService,
) *MetricsEndpoint {
return &MetricsEndpoint{
metrics: metrics,
authenticator: authenticator,
validator: validator,
metrics: metrics,
}
}
@ -30,12 +40,14 @@ func (e MetricsEndpoint) GetName() string {
return "MetricsEndpoint"
}
func (e MetricsEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
apiGroup := g.Group("/metrics", authenticator.LoggedIn())
func (e MetricsEndpoint) RegisterRoutes(g *routegroup.Bundle) {
apiGroup := g.Mount("/metrics")
apiGroup.Use(e.authenticator.LoggedIn())
apiGroup.GET("/by-interface/:id", authenticator.LoggedIn(ScopeAdmin), e.handleMetricsForInterfaceGet())
apiGroup.GET("/by-user/:id", authenticator.LoggedIn(), e.handleMetricsForUserGet())
apiGroup.GET("/by-peer/:id", authenticator.LoggedIn(), e.handleMetricsForPeerGet())
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /by-interface/{id}",
e.handleMetricsForInterfaceGet())
apiGroup.HandleFunc("GET /by-user/{id}", e.handleMetricsForUserGet())
apiGroup.HandleFunc("GET /by-peer/{id}", e.handleMetricsForPeerGet())
}
// handleMetricsForInterfaceGet returns a gorm Handler function.
@ -52,23 +64,23 @@ func (e MetricsEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authe
// @Failure 500 {object} models.Error
// @Router /metrics/by-interface/{id} [get]
// @Security BasicAuth
func (e MetricsEndpoint) handleMetricsForInterfaceGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := c.Param("id")
func (e MetricsEndpoint) handleMetricsForInterfaceGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := request.Path(r, "id")
if id == "" {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
respond.JSON(w, http.StatusBadRequest,
models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
return
}
interfaceMetrics, err := e.metrics.GetForInterface(ctx, domain.InterfaceIdentifier(id))
interfaceMetrics, err := e.metrics.GetForInterface(r.Context(), domain.InterfaceIdentifier(id))
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.JSON(http.StatusOK, models.NewInterfaceMetrics(interfaceMetrics))
respond.JSON(w, http.StatusOK, models.NewInterfaceMetrics(interfaceMetrics))
}
}
@ -86,23 +98,23 @@ func (e MetricsEndpoint) handleMetricsForInterfaceGet() gin.HandlerFunc {
// @Failure 500 {object} models.Error
// @Router /metrics/by-user/{id} [get]
// @Security BasicAuth
func (e MetricsEndpoint) handleMetricsForUserGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := c.Param("id")
func (e MetricsEndpoint) handleMetricsForUserGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := request.Path(r, "id")
if id == "" {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
respond.JSON(w, http.StatusBadRequest,
models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
return
}
user, userMetrics, err := e.metrics.GetForUser(ctx, domain.UserIdentifier(id))
user, userMetrics, err := e.metrics.GetForUser(r.Context(), domain.UserIdentifier(id))
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.JSON(http.StatusOK, models.NewUserMetrics(user, userMetrics))
respond.JSON(w, http.StatusOK, models.NewUserMetrics(user, userMetrics))
}
}
@ -120,22 +132,22 @@ func (e MetricsEndpoint) handleMetricsForUserGet() gin.HandlerFunc {
// @Failure 500 {object} models.Error
// @Router /metrics/by-peer/{id} [get]
// @Security BasicAuth
func (e MetricsEndpoint) handleMetricsForPeerGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := c.Param("id")
func (e MetricsEndpoint) handleMetricsForPeerGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := request.Path(r, "id")
if id == "" {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
respond.JSON(w, http.StatusBadRequest,
models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
return
}
peerMetrics, err := e.metrics.GetForPeer(ctx, domain.PeerIdentifier(id))
peerMetrics, err := e.metrics.GetForPeer(r.Context(), domain.PeerIdentifier(id))
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.JSON(http.StatusOK, models.NewPeerMetrics(peerMetrics))
respond.JSON(w, http.StatusOK, models.NewPeerMetrics(peerMetrics))
}
}

View File

@ -4,8 +4,10 @@ import (
"context"
"net/http"
"github.com/gin-gonic/gin"
"github.com/go-pkgz/routegroup"
"github.com/h44z/wg-portal/internal/app/api/core/request"
"github.com/h44z/wg-portal/internal/app/api/core/respond"
"github.com/h44z/wg-portal/internal/app/api/v1/models"
"github.com/h44z/wg-portal/internal/domain"
)
@ -20,12 +22,19 @@ type PeerService interface {
}
type PeerEndpoint struct {
peers PeerService
peers PeerService
authenticator Authenticator
validator Validator
}
func NewPeerEndpoint(peerService PeerService) *PeerEndpoint {
func NewPeerEndpoint(
authenticator Authenticator,
validator Validator, peerService PeerService,
) *PeerEndpoint {
return &PeerEndpoint{
peers: peerService,
authenticator: authenticator,
validator: validator,
peers: peerService,
}
}
@ -33,16 +42,18 @@ func (e PeerEndpoint) GetName() string {
return "PeerEndpoint"
}
func (e PeerEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
apiGroup := g.Group("/peer", authenticator.LoggedIn())
func (e PeerEndpoint) RegisterRoutes(g *routegroup.Bundle) {
apiGroup := g.Mount("/peer")
apiGroup.Use(e.authenticator.LoggedIn())
apiGroup.GET("/by-interface/:id", authenticator.LoggedIn(ScopeAdmin), e.handleAllForInterfaceGet())
apiGroup.GET("/by-user/:id", authenticator.LoggedIn(), e.handleAllForUserGet())
apiGroup.GET("/by-id/:id", authenticator.LoggedIn(), e.handleByIdGet())
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /by-interface/{id}",
e.handleAllForInterfaceGet())
apiGroup.HandleFunc("GET /by-user/{id}", e.handleAllForUserGet())
apiGroup.HandleFunc("GET /by-id/{id}", e.handleByIdGet())
apiGroup.POST("/new", authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost())
apiGroup.PUT("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleUpdatePut())
apiGroup.DELETE("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleDelete())
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("POST /new", e.handleCreatePost())
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("PUT /by-id/{id}", e.handleUpdatePut())
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("DELETE /by-id/{id}", e.handleDelete())
}
// handleAllForInterfaceGet returns a gorm Handler function.
@ -57,23 +68,23 @@ func (e PeerEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenti
// @Failure 500 {object} models.Error
// @Router /peer/by-interface/{id} [get]
// @Security BasicAuth
func (e PeerEndpoint) handleAllForInterfaceGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := c.Param("id")
func (e PeerEndpoint) handleAllForInterfaceGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := request.Path(r, "id")
if id == "" {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
respond.JSON(w, http.StatusBadRequest,
models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
return
}
interfacePeers, err := e.peers.GetForInterface(ctx, domain.InterfaceIdentifier(id))
interfacePeers, err := e.peers.GetForInterface(r.Context(), domain.InterfaceIdentifier(id))
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.JSON(http.StatusOK, models.NewPeers(interfacePeers))
respond.JSON(w, http.StatusOK, models.NewPeers(interfacePeers))
}
}
@ -90,23 +101,23 @@ func (e PeerEndpoint) handleAllForInterfaceGet() gin.HandlerFunc {
// @Failure 500 {object} models.Error
// @Router /peer/by-user/{id} [get]
// @Security BasicAuth
func (e PeerEndpoint) handleAllForUserGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := c.Param("id")
func (e PeerEndpoint) handleAllForUserGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := request.Path(r, "id")
if id == "" {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing user id"})
respond.JSON(w, http.StatusBadRequest,
models.Error{Code: http.StatusBadRequest, Message: "missing user id"})
return
}
interfacePeers, err := e.peers.GetForUser(ctx, domain.UserIdentifier(id))
interfacePeers, err := e.peers.GetForUser(r.Context(), domain.UserIdentifier(id))
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.JSON(http.StatusOK, models.NewPeers(interfacePeers))
respond.JSON(w, http.StatusOK, models.NewPeers(interfacePeers))
}
}
@ -125,23 +136,23 @@ func (e PeerEndpoint) handleAllForUserGet() gin.HandlerFunc {
// @Failure 500 {object} models.Error
// @Router /peer/by-id/{id} [get]
// @Security BasicAuth
func (e PeerEndpoint) handleByIdGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := c.Param("id")
func (e PeerEndpoint) handleByIdGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := request.Path(r, "id")
if id == "" {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
respond.JSON(w, http.StatusBadRequest,
models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
return
}
peer, err := e.peers.GetById(ctx, domain.PeerIdentifier(id))
peer, err := e.peers.GetById(r.Context(), domain.PeerIdentifier(id))
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.JSON(http.StatusOK, models.NewPeer(peer))
respond.JSON(w, http.StatusOK, models.NewPeer(peer))
}
}
@ -161,24 +172,26 @@ func (e PeerEndpoint) handleByIdGet() gin.HandlerFunc {
// @Failure 500 {object} models.Error
// @Router /peer/new [post]
// @Security BasicAuth
func (e PeerEndpoint) handleCreatePost() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
func (e PeerEndpoint) handleCreatePost() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var peer models.Peer
err := c.BindJSON(&peer)
if err != nil {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
if err := request.BodyJson(r, &peer); err != nil {
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if err := e.validator.Struct(peer); err != nil {
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
newPeer, err := e.peers.Create(ctx, models.NewDomainPeer(&peer))
newPeer, err := e.peers.Create(r.Context(), models.NewDomainPeer(&peer))
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.JSON(http.StatusOK, models.NewPeer(newPeer))
respond.JSON(w, http.StatusOK, models.NewPeer(newPeer))
}
}
@ -199,30 +212,33 @@ func (e PeerEndpoint) handleCreatePost() gin.HandlerFunc {
// @Failure 500 {object} models.Error
// @Router /peer/by-id/{id} [put]
// @Security BasicAuth
func (e PeerEndpoint) handleUpdatePut() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := c.Param("id")
func (e PeerEndpoint) handleUpdatePut() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := request.Path(r, "id")
if id == "" {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
respond.JSON(w, http.StatusBadRequest,
models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
return
}
var peer models.Peer
err := c.BindJSON(&peer)
if err != nil {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
if err := request.BodyJson(r, &peer); err != nil {
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if err := e.validator.Struct(peer); err != nil {
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
updatedPeer, err := e.peers.Update(ctx, domain.PeerIdentifier(id), models.NewDomainPeer(&peer))
updatedPeer, err := e.peers.Update(r.Context(), domain.PeerIdentifier(id), models.NewDomainPeer(&peer))
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.JSON(http.StatusOK, models.NewPeer(updatedPeer))
respond.JSON(w, http.StatusOK, models.NewPeer(updatedPeer))
}
}
@ -241,22 +257,22 @@ func (e PeerEndpoint) handleUpdatePut() gin.HandlerFunc {
// @Failure 500 {object} models.Error
// @Router /peer/by-id/{id} [delete]
// @Security BasicAuth
func (e PeerEndpoint) handleDelete() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := c.Param("id")
func (e PeerEndpoint) handleDelete() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := request.Path(r, "id")
if id == "" {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
respond.JSON(w, http.StatusBadRequest,
models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
return
}
err := e.peers.Delete(ctx, domain.PeerIdentifier(id))
err := e.peers.Delete(r.Context(), domain.PeerIdentifier(id))
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.Status(http.StatusNoContent)
respond.Status(w, http.StatusNoContent)
}
}

View File

@ -5,8 +5,10 @@ import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/go-pkgz/routegroup"
"github.com/h44z/wg-portal/internal/app/api/core/request"
"github.com/h44z/wg-portal/internal/app/api/core/respond"
"github.com/h44z/wg-portal/internal/app/api/v1/models"
"github.com/h44z/wg-portal/internal/domain"
)
@ -23,12 +25,20 @@ type ProvisioningEndpointProvisioningService interface {
}
type ProvisioningEndpoint struct {
provisioning ProvisioningEndpointProvisioningService
provisioning ProvisioningEndpointProvisioningService
authenticator Authenticator
validator Validator
}
func NewProvisioningEndpoint(provisioning ProvisioningEndpointProvisioningService) *ProvisioningEndpoint {
func NewProvisioningEndpoint(
authenticator Authenticator,
validator Validator,
provisioning ProvisioningEndpointProvisioningService,
) *ProvisioningEndpoint {
return &ProvisioningEndpoint{
provisioning: provisioning,
authenticator: authenticator,
validator: validator,
provisioning: provisioning,
}
}
@ -36,14 +46,15 @@ func (e ProvisioningEndpoint) GetName() string {
return "ProvisioningEndpoint"
}
func (e ProvisioningEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
apiGroup := g.Group("/provisioning", authenticator.LoggedIn())
func (e ProvisioningEndpoint) RegisterRoutes(g *routegroup.Bundle) {
apiGroup := g.Mount("/provisioning")
apiGroup.Use(e.authenticator.LoggedIn())
apiGroup.GET("/data/user-info", authenticator.LoggedIn(), e.handleUserInfoGet())
apiGroup.GET("/data/peer-config", authenticator.LoggedIn(), e.handlePeerConfigGet())
apiGroup.GET("/data/peer-qr", authenticator.LoggedIn(), e.handlePeerQrGet())
apiGroup.HandleFunc("GET /data/user-info", e.handleUserInfoGet())
apiGroup.HandleFunc("GET /data/peer-config", e.handlePeerConfigGet())
apiGroup.HandleFunc("GET /data/peer-qr", e.handlePeerQrGet())
apiGroup.POST("/new-peer", authenticator.LoggedIn(), e.handleNewPeerPost())
apiGroup.HandleFunc("POST /new-peer", e.handleNewPeerPost())
}
// handleUserInfoGet returns a gorm Handler function.
@ -63,24 +74,23 @@ func (e ProvisioningEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *
// @Failure 500 {object} models.Error
// @Router /provisioning/data/user-info [get]
// @Security BasicAuth
func (e ProvisioningEndpoint) handleUserInfoGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := strings.TrimSpace(c.Query("UserId"))
email := strings.TrimSpace(c.Query("Email"))
func (e ProvisioningEndpoint) handleUserInfoGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := strings.TrimSpace(request.Query(r, "UserId"))
email := strings.TrimSpace(request.Query(r, "Email"))
if id == "" && email == "" {
id = string(domain.GetUserInfo(ctx).Id)
id = string(domain.GetUserInfo(r.Context()).Id)
}
user, peers, err := e.provisioning.GetUserAndPeers(ctx, domain.UserIdentifier(id), email)
user, peers, err := e.provisioning.GetUserAndPeers(r.Context(), domain.UserIdentifier(id), email)
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.JSON(http.StatusOK, models.NewUserInformation(user, peers))
respond.JSON(w, http.StatusOK, models.NewUserInformation(user, peers))
}
}
@ -101,23 +111,23 @@ func (e ProvisioningEndpoint) handleUserInfoGet() gin.HandlerFunc {
// @Failure 500 {object} models.Error
// @Router /provisioning/data/peer-config [get]
// @Security BasicAuth
func (e ProvisioningEndpoint) handlePeerConfigGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := strings.TrimSpace(c.Query("PeerId"))
func (e ProvisioningEndpoint) handlePeerConfigGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := strings.TrimSpace(request.Query(r, "PeerId"))
if id == "" {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
respond.JSON(w, http.StatusBadRequest,
models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
return
}
peerConfig, err := e.provisioning.GetPeerConfig(ctx, domain.PeerIdentifier(id))
peerConfig, err := e.provisioning.GetPeerConfig(r.Context(), domain.PeerIdentifier(id))
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.Data(http.StatusOK, "text/plain", peerConfig)
respond.Data(w, http.StatusOK, "text/plain", peerConfig)
}
}
@ -138,23 +148,23 @@ func (e ProvisioningEndpoint) handlePeerConfigGet() gin.HandlerFunc {
// @Failure 500 {object} models.Error
// @Router /provisioning/data/peer-qr [get]
// @Security BasicAuth
func (e ProvisioningEndpoint) handlePeerQrGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := strings.TrimSpace(c.Query("PeerId"))
func (e ProvisioningEndpoint) handlePeerQrGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := strings.TrimSpace(request.Query(r, "PeerId"))
if id == "" {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
respond.JSON(w, http.StatusBadRequest,
models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
return
}
peerConfigQrCode, err := e.provisioning.GetPeerQrPng(ctx, domain.PeerIdentifier(id))
peerConfigQrCode, err := e.provisioning.GetPeerQrPng(r.Context(), domain.PeerIdentifier(id))
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.Data(http.StatusOK, "image/png", peerConfigQrCode)
respond.Data(w, http.StatusOK, "image/png", peerConfigQrCode)
}
}
@ -174,23 +184,25 @@ func (e ProvisioningEndpoint) handlePeerQrGet() gin.HandlerFunc {
// @Failure 500 {object} models.Error
// @Router /provisioning/new-peer [post]
// @Security BasicAuth
func (e ProvisioningEndpoint) handleNewPeerPost() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
func (e ProvisioningEndpoint) handleNewPeerPost() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req models.ProvisioningRequest
err := c.BindJSON(&req)
if err != nil {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
if err := request.BodyJson(r, &req); err != nil {
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if err := e.validator.Struct(req); err != nil {
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
peer, err := e.provisioning.NewPeer(ctx, req)
peer, err := e.provisioning.NewPeer(r.Context(), req)
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.JSON(http.StatusOK, models.NewPeer(peer))
respond.JSON(w, http.StatusOK, models.NewPeer(peer))
}
}

View File

@ -4,8 +4,10 @@ import (
"context"
"net/http"
"github.com/gin-gonic/gin"
"github.com/go-pkgz/routegroup"
"github.com/h44z/wg-portal/internal/app/api/core/request"
"github.com/h44z/wg-portal/internal/app/api/core/respond"
"github.com/h44z/wg-portal/internal/app/api/v1/models"
"github.com/h44z/wg-portal/internal/domain"
)
@ -19,12 +21,20 @@ type UserService interface {
}
type UserEndpoint struct {
users UserService
users UserService
authenticator Authenticator
validator Validator
}
func NewUserEndpoint(userService UserService) *UserEndpoint {
func NewUserEndpoint(
authenticator Authenticator,
validator Validator,
userService UserService,
) *UserEndpoint {
return &UserEndpoint{
users: userService,
authenticator: authenticator,
validator: validator,
users: userService,
}
}
@ -32,14 +42,15 @@ func (e UserEndpoint) GetName() string {
return "UserEndpoint"
}
func (e UserEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
apiGroup := g.Group("/user", authenticator.LoggedIn())
func (e UserEndpoint) RegisterRoutes(g *routegroup.Bundle) {
apiGroup := g.Mount("/user")
apiGroup.Use(e.authenticator.LoggedIn())
apiGroup.GET("/all", authenticator.LoggedIn(ScopeAdmin), e.handleAllGet())
apiGroup.GET("/by-id/:id", authenticator.LoggedIn(), e.handleByIdGet())
apiGroup.POST("/new", authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost())
apiGroup.PUT("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleUpdatePut())
apiGroup.DELETE("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleDelete())
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /all", e.handleAllGet())
apiGroup.HandleFunc("GET /by-id/{id}", e.handleByIdGet())
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("POST /new", e.handleCreatePost())
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("PUT /by-id/{id}", e.handleUpdatePut())
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("DELETE /by-id/{id}", e.handleDelete())
}
// handleAllGet returns a gorm Handler function.
@ -53,17 +64,16 @@ func (e UserEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenti
// @Failure 500 {object} models.Error
// @Router /user/all [get]
// @Security BasicAuth
func (e UserEndpoint) handleAllGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
users, err := e.users.GetAll(ctx)
func (e UserEndpoint) handleAllGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
users, err := e.users.GetAll(r.Context())
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.JSON(http.StatusOK, models.NewUsers(users))
respond.JSON(w, http.StatusOK, models.NewUsers(users))
}
}
@ -82,23 +92,23 @@ func (e UserEndpoint) handleAllGet() gin.HandlerFunc {
// @Failure 500 {object} models.Error
// @Router /user/by-id/{id} [get]
// @Security BasicAuth
func (e UserEndpoint) handleByIdGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := c.Param("id")
func (e UserEndpoint) handleByIdGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := request.Path(r, "id")
if id == "" {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing user id"})
respond.JSON(w, http.StatusBadRequest,
models.Error{Code: http.StatusBadRequest, Message: "missing user id"})
return
}
user, err := e.users.GetById(ctx, domain.UserIdentifier(id))
user, err := e.users.GetById(r.Context(), domain.UserIdentifier(id))
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.JSON(http.StatusOK, models.NewUser(user, true))
respond.JSON(w, http.StatusOK, models.NewUser(user, true))
}
}
@ -118,24 +128,26 @@ func (e UserEndpoint) handleByIdGet() gin.HandlerFunc {
// @Failure 500 {object} models.Error
// @Router /user/new [post]
// @Security BasicAuth
func (e UserEndpoint) handleCreatePost() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
func (e UserEndpoint) handleCreatePost() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var user models.User
err := c.BindJSON(&user)
if err != nil {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
if err := request.BodyJson(r, &user); err != nil {
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if err := e.validator.Struct(user); err != nil {
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
newUser, err := e.users.Create(ctx, models.NewDomainUser(&user))
newUser, err := e.users.Create(r.Context(), models.NewDomainUser(&user))
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.JSON(http.StatusOK, models.NewUser(newUser, true))
respond.JSON(w, http.StatusOK, models.NewUser(newUser, true))
}
}
@ -156,30 +168,33 @@ func (e UserEndpoint) handleCreatePost() gin.HandlerFunc {
// @Failure 500 {object} models.Error
// @Router /user/by-id/{id} [put]
// @Security BasicAuth
func (e UserEndpoint) handleUpdatePut() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := c.Param("id")
func (e UserEndpoint) handleUpdatePut() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := request.Path(r, "id")
if id == "" {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing user id"})
respond.JSON(w, http.StatusBadRequest,
models.Error{Code: http.StatusBadRequest, Message: "missing user id"})
return
}
var user models.User
err := c.BindJSON(&user)
if err != nil {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
if err := request.BodyJson(r, &user); err != nil {
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if err := e.validator.Struct(user); err != nil {
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
updateUser, err := e.users.Update(ctx, domain.UserIdentifier(id), models.NewDomainUser(&user))
updateUser, err := e.users.Update(r.Context(), domain.UserIdentifier(id), models.NewDomainUser(&user))
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.JSON(http.StatusOK, models.NewUser(updateUser, true))
respond.JSON(w, http.StatusOK, models.NewUser(updateUser, true))
}
}
@ -198,22 +213,22 @@ func (e UserEndpoint) handleUpdatePut() gin.HandlerFunc {
// @Failure 500 {object} models.Error
// @Router /user/by-id/{id} [delete]
// @Security BasicAuth
func (e UserEndpoint) handleDelete() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := c.Param("id")
func (e UserEndpoint) handleDelete() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := request.Path(r, "id")
if id == "" {
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing user id"})
respond.JSON(w, http.StatusBadRequest,
models.Error{Code: http.StatusBadRequest, Message: "missing user id"})
return
}
err := e.users.Delete(ctx, domain.UserIdentifier(id))
err := e.users.Delete(r.Context(), domain.UserIdentifier(id))
if err != nil {
c.JSON(ParseServiceError(err))
status, model := ParseServiceError(err)
respond.JSON(w, status, model)
return
}
c.Status(http.StatusNoContent)
respond.Status(w, http.StatusNoContent)
}
}

View File

@ -1,93 +0,0 @@
package handlers
import (
"context"
"net/http"
"github.com/gin-gonic/gin"
"github.com/h44z/wg-portal/internal/app/api/v0/model"
"github.com/h44z/wg-portal/internal/domain"
)
type Scope string
const (
ScopeAdmin Scope = "ADMIN" // Admin scope contains all other scopes
)
type UserSource interface {
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
}
type authenticationHandler struct {
userSource UserSource
}
// LoggedIn checks if a user is logged in. If scopes are given, they are validated as well.
func (h authenticationHandler) LoggedIn(scopes ...Scope) gin.HandlerFunc {
return func(c *gin.Context) {
username, password, ok := c.Request.BasicAuth()
if !ok || username == "" || password == "" {
// Abort the request with the appropriate error code
c.Abort()
c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "missing credentials"})
return
}
// check if user exists in DB
ctx := domain.SetUserInfo(c.Request.Context(), domain.SystemAdminContextUserInfo())
user, err := h.userSource.GetUser(ctx, domain.UserIdentifier(username))
if err != nil {
// Abort the request with the appropriate error code
c.Abort()
c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "invalid credentials"})
return
}
// validate API token
if err := user.CheckApiToken(password); err != nil {
// Abort the request with the appropriate error code
c.Abort()
c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "invalid credentials"})
return
}
if !UserHasScopes(user, scopes...) {
// Abort the request with the appropriate error code
c.Abort()
c.JSON(http.StatusForbidden, model.Error{Code: http.StatusForbidden, Message: "not enough permissions"})
return
}
c.Set(domain.CtxUserInfo, &domain.ContextUserInfo{
Id: user.Identifier,
IsAdmin: user.IsAdmin,
})
// Continue down the chain to Handler etc
c.Next()
}
}
func UserHasScopes(user *domain.User, scopes ...Scope) bool {
// No scopes give, so the check should succeed
if len(scopes) == 0 {
return true
}
// check if user has admin scope
if user.IsAdmin {
return true
}
// Check if admin scope is required
for _, scope := range scopes {
if scope == ScopeAdmin {
return false
}
}
return true
}

View File

@ -0,0 +1,101 @@
package handlers
import (
"context"
"net/http"
"github.com/h44z/wg-portal/internal/app/api/core/respond"
"github.com/h44z/wg-portal/internal/app/api/v0/model"
"github.com/h44z/wg-portal/internal/domain"
)
type Scope string
const (
ScopeAdmin Scope = "ADMIN" // Admin scope contains all other scopes
)
type UserAuthenticator interface {
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
}
type AuthenticationHandler struct {
authenticator UserAuthenticator
}
func NewAuthenticationHandler(authenticator UserAuthenticator) AuthenticationHandler {
return AuthenticationHandler{
authenticator: authenticator,
}
}
// LoggedIn checks if a user is logged in. If scopes are given, they are validated as well.
func (h AuthenticationHandler) LoggedIn(scopes ...Scope) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if !ok || username == "" || password == "" {
// Abort the request with the appropriate error code
respond.JSON(w, http.StatusUnauthorized,
model.Error{Code: http.StatusUnauthorized, Message: "missing credentials"})
return
}
// check if user exists in DB
ctx := domain.SetUserInfo(r.Context(), domain.SystemAdminContextUserInfo())
user, err := h.authenticator.GetUser(ctx, domain.UserIdentifier(username))
if err != nil {
// Abort the request with the appropriate error code
respond.JSON(w, http.StatusUnauthorized,
model.Error{Code: http.StatusUnauthorized, Message: "invalid credentials"})
return
}
// validate API token
if err := user.CheckApiToken(password); err != nil {
// Abort the request with the appropriate error code
respond.JSON(w, http.StatusUnauthorized,
model.Error{Code: http.StatusUnauthorized, Message: "invalid credentials"})
return
}
if !UserHasScopes(user, scopes...) {
// Abort the request with the appropriate error code
respond.JSON(w, http.StatusForbidden,
model.Error{Code: http.StatusForbidden, Message: "not enough permissions"})
return
}
ctx = context.WithValue(r.Context(), domain.CtxUserInfo, &domain.ContextUserInfo{
Id: user.Identifier,
IsAdmin: user.IsAdmin,
})
r = r.WithContext(ctx)
// Continue down the chain to Handler etc
next.ServeHTTP(w, r)
})
}
}
func UserHasScopes(user *domain.User, scopes ...Scope) bool {
// No scopes give, so the check should succeed
if len(scopes) == 0 {
return true
}
// check if user has admin scope
if user.IsAdmin {
return true
}
// Check if admin scope is required
for _, scope := range scopes {
if scope == ScopeAdmin {
return false
}
}
return true
}

View File

@ -3,6 +3,8 @@ package config
type WebConfig struct {
// RequestLogging enables logging of all HTTP requests.
RequestLogging bool `yaml:"request_logging"`
// ExposeHostInfo sets whether the host information should be exposed in a response header.
ExposeHostInfo bool `yaml:"expose_host_info"`
// ExternalUrl is the URL where a client can access WireGuard Portal.
// This is used for the callback URL of the OAuth providers.
ExternalUrl string `yaml:"external_url"`

View File

@ -4,8 +4,6 @@ import (
"context"
"fmt"
"log/slog"
"github.com/gin-gonic/gin"
)
const CtxUserInfo = "userInfo"
@ -47,21 +45,6 @@ func SystemAdminContextUserInfo() *ContextUserInfo {
}
}
// SetUserInfoFromGin sets the user info from the gin context to the request context.
func SetUserInfoFromGin(c *gin.Context) context.Context {
ginUserInfo, exists := c.Get(CtxUserInfo)
info := DefaultContextUserInfo()
if exists {
if ginInfo, ok := ginUserInfo.(*ContextUserInfo); ok {
info = ginInfo
}
}
ctx := SetUserInfo(c.Request.Context(), info)
return ctx
}
// SetUserInfo sets the user info in the context.
func SetUserInfo(ctx context.Context, info *ContextUserInfo) context.Context {
ctx = context.WithValue(ctx, CtxUserInfo, info)