mirror of
https://github.com/h44z/wg-portal.git
synced 2025-04-19 00:45:17 +00:00
chore: replace gin with standard lib net/http
This commit is contained in:
parent
7473132932
commit
0206952182
@ -7,6 +7,7 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
evbus "github.com/vardius/message-bus"
|
||||
|
||||
"github.com/h44z/wg-portal/internal"
|
||||
@ -101,21 +102,48 @@ func main() {
|
||||
err = backend.Startup(ctx)
|
||||
internal.AssertNoError(err)
|
||||
|
||||
apiFrontend := handlersV0.NewRestApi(cfg, backend)
|
||||
validatorManager := validator.New()
|
||||
|
||||
// region API v0 (SPA frontend)
|
||||
|
||||
apiV0Session := handlersV0.NewSessionWrapper(cfg)
|
||||
apiV0Auth := handlersV0.NewAuthenticationHandler(authenticator, apiV0Session)
|
||||
|
||||
apiV0EndpointAuth := handlersV0.NewAuthEndpoint(backend, apiV0Auth, apiV0Session, validatorManager)
|
||||
apiV0EndpointUsers := handlersV0.NewUserEndpoint(backend, apiV0Auth, validatorManager)
|
||||
apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(backend, apiV0Auth, validatorManager)
|
||||
apiV0EndpointPeers := handlersV0.NewPeerEndpoint(backend, apiV0Auth, validatorManager)
|
||||
apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth)
|
||||
apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth)
|
||||
|
||||
apiFrontend := handlersV0.NewRestApi(apiV0Session,
|
||||
apiV0EndpointAuth,
|
||||
apiV0EndpointUsers,
|
||||
apiV0EndpointInterfaces,
|
||||
apiV0EndpointPeers,
|
||||
apiV0EndpointConfig,
|
||||
apiV0EndpointTest,
|
||||
)
|
||||
|
||||
// endregion API v0 (SPA frontend)
|
||||
|
||||
// region API v1 (User REST API)
|
||||
|
||||
apiV1Auth := handlersV1.NewAuthenticationHandler(userManager)
|
||||
apiV1BackendUsers := backendV1.NewUserService(cfg, userManager)
|
||||
apiV1BackendPeers := backendV1.NewPeerService(cfg, wireGuardManager, userManager)
|
||||
apiV1BackendInterfaces := backendV1.NewInterfaceService(cfg, wireGuardManager)
|
||||
apiV1BackendProvisioning := backendV1.NewProvisioningService(cfg, userManager, wireGuardManager, cfgFileManager)
|
||||
apiV1BackendMetrics := backendV1.NewMetricsService(cfg, database, userManager, wireGuardManager)
|
||||
apiV1EndpointUsers := handlersV1.NewUserEndpoint(apiV1BackendUsers)
|
||||
apiV1EndpointPeers := handlersV1.NewPeerEndpoint(apiV1BackendPeers)
|
||||
apiV1EndpointInterfaces := handlersV1.NewInterfaceEndpoint(apiV1BackendInterfaces)
|
||||
apiV1EndpointProvisioning := handlersV1.NewProvisioningEndpoint(apiV1BackendProvisioning)
|
||||
apiV1EndpointMetrics := handlersV1.NewMetricsEndpoint(apiV1BackendMetrics)
|
||||
|
||||
apiV1EndpointUsers := handlersV1.NewUserEndpoint(apiV1Auth, validatorManager, apiV1BackendUsers)
|
||||
apiV1EndpointPeers := handlersV1.NewPeerEndpoint(apiV1Auth, validatorManager, apiV1BackendPeers)
|
||||
apiV1EndpointInterfaces := handlersV1.NewInterfaceEndpoint(apiV1Auth, validatorManager, apiV1BackendInterfaces)
|
||||
apiV1EndpointProvisioning := handlersV1.NewProvisioningEndpoint(apiV1Auth, validatorManager,
|
||||
apiV1BackendProvisioning)
|
||||
apiV1EndpointMetrics := handlersV1.NewMetricsEndpoint(apiV1Auth, validatorManager, apiV1BackendMetrics)
|
||||
|
||||
apiV1 := handlersV1.NewRestApi(
|
||||
userManager,
|
||||
apiV1EndpointUsers,
|
||||
apiV1EndpointPeers,
|
||||
apiV1EndpointInterfaces,
|
||||
@ -123,6 +151,8 @@ func main() {
|
||||
apiV1EndpointMetrics,
|
||||
)
|
||||
|
||||
// endregion API v1 (User REST API)
|
||||
|
||||
webSrv, err := core.NewServer(cfg, apiFrontend, apiV1)
|
||||
internal.AssertNoError(err)
|
||||
|
||||
|
@ -4,6 +4,7 @@ import LoginView from '../views/LoginView.vue'
|
||||
import InterfaceView from '../views/InterfaceView.vue'
|
||||
|
||||
import {authStore} from '@/stores/auth'
|
||||
import {securityStore} from '@/stores/security'
|
||||
import {notify} from "@kyvg/vue3-notification";
|
||||
|
||||
const router = createRouter({
|
||||
@ -63,6 +64,7 @@ const router = createRouter({
|
||||
|
||||
router.beforeEach(async (to) => {
|
||||
const auth = authStore()
|
||||
const sec = securityStore()
|
||||
|
||||
// check if the request was a successful oauth login
|
||||
if ('wgLoginState' in to.query && !auth.IsAuthenticated) {
|
||||
@ -112,6 +114,10 @@ router.beforeEach(async (to) => {
|
||||
auth.SetReturnUrl(to.fullPath) // store original destination before starting the auth process
|
||||
return '/login'
|
||||
}
|
||||
|
||||
if (publicPages.includes(to.path)) {
|
||||
await sec.LoadSecurityProperties() // make sure we have a valid csrf token
|
||||
}
|
||||
})
|
||||
|
||||
export default router
|
||||
|
44
go.mod
44
go.mod
@ -4,26 +4,25 @@ go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/a8m/envsubst v1.4.3
|
||||
github.com/alexedwards/scs/v2 v2.8.0
|
||||
github.com/coreos/go-oidc/v3 v3.12.0
|
||||
github.com/gin-contrib/cors v1.7.3
|
||||
github.com/gin-contrib/sessions v1.0.2
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/glebarez/sqlite v1.11.0
|
||||
github.com/go-ldap/ldap/v3 v3.4.10
|
||||
github.com/go-pkgz/routegroup v1.3.1
|
||||
github.com/go-playground/validator/v10 v10.25.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/prometheus-community/pro-bing v0.6.1
|
||||
github.com/prometheus/client_golang v1.21.0
|
||||
github.com/prometheus/client_golang v1.21.1
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/swaggo/swag v1.16.4
|
||||
github.com/utrack/gin-csrf v0.0.0-20190424104817-40fb8d2c8fca
|
||||
github.com/vardius/message-bus v1.1.5
|
||||
github.com/vishvananda/netlink v1.3.0
|
||||
github.com/xhit/go-simple-mail/v2 v2.16.0
|
||||
github.com/yeqown/go-qrcode/v2 v2.2.5
|
||||
github.com/yeqown/go-qrcode/writer/compressed v1.0.1
|
||||
golang.org/x/crypto v0.35.0
|
||||
golang.org/x/oauth2 v0.27.0
|
||||
golang.org/x/sys v0.30.0
|
||||
golang.org/x/crypto v0.36.0
|
||||
golang.org/x/oauth2 v0.28.0
|
||||
golang.org/x/sys v0.31.0
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/mysql v1.5.7
|
||||
@ -37,15 +36,10 @@ require (
|
||||
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
|
||||
github.com/KyleBanks/depth v1.2.1 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/bytedance/sonic v1.12.9 // indirect
|
||||
github.com/bytedance/sonic/loader v0.2.3 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.5 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dchest/uniuri v1.2.0 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
|
||||
github.com/gin-contrib/sse v1.0.0 // indirect
|
||||
github.com/glebarez/go-sqlite v1.22.0 // indirect
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.7 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.0.5 // indirect
|
||||
@ -55,16 +49,11 @@ require (
|
||||
github.com/go-openapi/swag v0.23.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.25.0 // indirect
|
||||
github.com/go-sql-driver/mysql v1.9.0 // indirect
|
||||
github.com/go-test/deep v1.1.1 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/gorilla/context v1.1.2 // indirect
|
||||
github.com/gorilla/securecookie v1.1.2 // indirect
|
||||
github.com/gorilla/sessions v1.4.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/pgx/v5 v5.7.2 // indirect
|
||||
@ -73,9 +62,7 @@ require (
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/josharian/native v1.1.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mailru/easyjson v0.9.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
@ -83,28 +70,21 @@ require (
|
||||
github.com/mdlayher/netlink v1.7.2 // indirect
|
||||
github.com/mdlayher/socket v0.5.1 // indirect
|
||||
github.com/microsoft/go-mssqldb v1.8.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.62.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/toorop/go-dkim v0.0.0-20250226130143-9025cce95817 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
github.com/yeqown/reedsolomon v1.0.0 // indirect
|
||||
golang.org/x/arch v0.14.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||
golang.org/x/net v0.35.0 // indirect
|
||||
golang.org/x/sync v0.11.0 // indirect
|
||||
golang.org/x/text v0.22.0 // indirect
|
||||
golang.org/x/tools v0.30.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
|
||||
golang.org/x/net v0.37.0 // indirect
|
||||
golang.org/x/sync v0.12.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
golang.org/x/tools v0.31.0 // indirect
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 // indirect
|
||||
google.golang.org/protobuf v1.36.5 // indirect
|
||||
modernc.org/libc v1.61.13 // indirect
|
||||
|
130
go.sum
130
go.sum
@ -29,52 +29,27 @@ github.com/a8m/envsubst v1.4.3 h1:kDF7paGK8QACWYaQo6KtyYBozY2jhQrTuNNuUxQkhJY=
|
||||
github.com/a8m/envsubst v1.4.3/go.mod h1:4jjHWQlZoaXPoLQUb7H2qT4iLkZDdmEQiOUogdUmqVU=
|
||||
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI=
|
||||
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
|
||||
github.com/alexedwards/scs/v2 v2.8.0 h1:h31yUYoycPuL0zt14c0gd+oqxfRwIj6SOjHdKRZxhEw=
|
||||
github.com/alexedwards/scs/v2 v2.8.0/go.mod h1:ToaROZxyKukJKT/xLcVQAChi5k6+Pn1Gvmdl7h3RRj8=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff/go.mod h1:+RTT1BOk5P97fT2CiHkbFQwkK3mjsFAP6zCYV2aXtjw=
|
||||
github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60=
|
||||
github.com/bradleypeabody/gorilla-sessions-memcache v0.0.0-20181103040241-659414f458e1/go.mod h1:dkChI7Tbtx7H1Tj7TqGSZMOeGpMP5gLHtjroHd4agiI=
|
||||
github.com/bytedance/sonic v1.12.9 h1:Od1BvK55NnewtGaJsTDeAOSnLVO2BTSLOe0+ooKokmQ=
|
||||
github.com/bytedance/sonic v1.12.9/go.mod h1:uVvFidNmlt9+wa31S1urfwwthTWteBgG0hWuoKAXTx8=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/bytedance/sonic/loader v0.2.3 h1:yctD0Q3v2NOGfSWPLPvG2ggA2kV6TS6s4wioyEqssH0=
|
||||
github.com/bytedance/sonic/loader v0.2.3/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4=
|
||||
github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
|
||||
github.com/coreos/go-oidc/v3 v3.12.0 h1:sJk+8G2qq94rDI6ehZ71Bol3oUHy63qNYmkiSjrc/Jo=
|
||||
github.com/coreos/go-oidc/v3 v3.12.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dchest/uniuri v0.0.0-20160212164326-8902c56451e9/go.mod h1:GgB8SF9nRG+GqaDtLcwJZsQFhcogVCJ79j4EdT0c2V4=
|
||||
github.com/dchest/uniuri v1.2.0 h1:koIcOUdrTIivZgSLhHQvKgqdWZq5d7KdMEWF1Ud6+5g=
|
||||
github.com/dchest/uniuri v1.2.0/go.mod h1:fSzm4SLHzNZvWLvWJew423PhAzkpNQYq+uNLq4kxhkY=
|
||||
github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko=
|
||||
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
|
||||
github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY=
|
||||
github.com/gin-contrib/cors v1.7.3 h1:hV+a5xp8hwJoTw7OY+a70FsL8JkVVFTXw9EcfrYUdns=
|
||||
github.com/gin-contrib/cors v1.7.3/go.mod h1:M3bcKZhxzsvI+rlRSkkxHyljJt1ESd93COUvemZ79j4=
|
||||
github.com/gin-contrib/sessions v0.0.0-20190101140330-dc5246754963/go.mod h1:4lkInX8nHSR62NSmhXM3xtPeMSyfiR58NaEz+om1lHM=
|
||||
github.com/gin-contrib/sessions v1.0.2 h1:UaIjUvTH1cMeOdj3in6dl+Xb6It8RiKRF9Z1anbUyCA=
|
||||
github.com/gin-contrib/sessions v1.0.2/go.mod h1:KxKxWqWP5LJVDCInulOl4WbLzK2KSPlLesfZ66wRvMs=
|
||||
github.com/gin-contrib/sse v0.0.0-20170109093832-22d885f9ecc7/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s=
|
||||
github.com/gin-contrib/sse v1.0.0 h1:y3bT1mUWUxDpW4JLQg/HnTqV4rozuW4tC9eFKTxYI9E=
|
||||
github.com/gin-contrib/sse v1.0.0/go.mod h1:zNuFdwarAygJBht0NTKiSi3jRf6RbqeILZ9Sp6Slhe0=
|
||||
github.com/gin-gonic/gin v1.3.0/go.mod h1:7cKuhb5qV2ggCFctp2fJQ+ErvciLZrIeoOSOm6mUr7Y=
|
||||
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
|
||||
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
|
||||
github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ=
|
||||
github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc=
|
||||
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
|
||||
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
|
||||
github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q=
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.7 h1:DTX+lbVTWaTw1hQ+PbZPlnDZPEIs0SS/GCZAl535dDk=
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.7/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
|
||||
github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
|
||||
@ -89,6 +64,8 @@ github.com/go-openapi/spec v0.21.0 h1:LTVzPc3p/RzRnkQqLRndbAzjY0d0BCL72A6j3CdL9Z
|
||||
github.com/go-openapi/spec v0.21.0/go.mod h1:78u6VdPw81XU44qEWGhtr982gJ5BWg2c0I5XwVMotYk=
|
||||
github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE=
|
||||
github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ=
|
||||
github.com/go-pkgz/routegroup v1.3.1 h1:XAVWskX8Iup6HoQD9zv+gJx4DOJC2DSkKBHCMeeW8/s=
|
||||
github.com/go-pkgz/routegroup v1.3.1/go.mod h1:kDDPDRLRiRY1vnENrZJw1jQAzQX7fvsbsHGRQFNQfKc=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
@ -102,8 +79,6 @@ github.com/go-sql-driver/mysql v1.9.0 h1:Y0zIbQXhQKmQgTp44Y1dp3wTXcn804QoTptLZT1
|
||||
github.com/go-sql-driver/mysql v1.9.0/go.mod h1:pDetrLJeA3oMujJuvXc8RJoasr589B6A9fwzD3QMrqw=
|
||||
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
|
||||
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
|
||||
@ -112,32 +87,18 @@ github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0kt
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
|
||||
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
|
||||
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
|
||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
|
||||
github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o=
|
||||
github.com/gorilla/context v1.1.2/go.mod h1:KDPwT9i/MeWHiLl90fuTgrt4/wPcv75vFAZLaOOcbxM=
|
||||
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
||||
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
|
||||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||
github.com/gorilla/sessions v1.1.1/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w=
|
||||
github.com/gorilla/sessions v1.1.3/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w=
|
||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
||||
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
|
||||
github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik=
|
||||
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
|
||||
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
@ -169,21 +130,10 @@ github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8Hm
|
||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
|
||||
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/json-iterator/go v1.1.5/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/kidstuff/mongostore v0.0.0-20181113001930-e650cd85ee4b/go.mod h1:g2nVr8KZVXJSS97Jo8pJ0jgq29P6H7dG0oplUA86MQw=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE=
|
||||
github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
@ -192,7 +142,6 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
|
||||
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||
github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
|
||||
@ -201,26 +150,17 @@ github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/
|
||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||
github.com/memcachier/mc v2.0.1+incompatible/go.mod h1:7bkvFE61leUBvXz+yxsOnGBQSZpBSPIMUQSmmSHvuXc=
|
||||
github.com/microsoft/go-mssqldb v1.7.2/go.mod h1:kOvZKUdrhhFQmxLZqbwUV0rHkNkZpthMITIb2Ko1IoA=
|
||||
github.com/microsoft/go-mssqldb v1.8.0 h1:7cyZ/AT7ycDsEoWPIXibd+aVKFtteUNhDGf3aobP+tw=
|
||||
github.com/microsoft/go-mssqldb v1.8.0/go.mod h1:6znkekS3T2vp0waiMhen4GPU1BiAsrP+iXHcE7a7rFo=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8=
|
||||
github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
|
||||
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
|
||||
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
@ -228,17 +168,14 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus-community/pro-bing v0.6.1 h1:EQukUOma9YFZRPe4DGSscxUf9LH07rpqwisNWjSZrgU=
|
||||
github.com/prometheus-community/pro-bing v0.6.1/go.mod h1:jNCOI3D7pmTCeaoF41cNS6uaxeFY/Gmc3ffwbuJVzAQ=
|
||||
github.com/prometheus/client_golang v1.21.0 h1:DIsaGmiaBkSangBgMtWdNfxbMNdku5IK6iNhrEqWvdA=
|
||||
github.com/prometheus/client_golang v1.21.0/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
|
||||
github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk=
|
||||
github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
|
||||
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
|
||||
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
|
||||
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
|
||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/quasoft/memstore v0.0.0-20180925164028-84a050167438/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg=
|
||||
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b h1:aUNXCGgukb4gtY99imuIeoh8Vr0GSwAlYxPAhqZrpFc=
|
||||
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
@ -246,8 +183,6 @@ github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUz
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
@ -262,13 +197,6 @@ github.com/swaggo/swag v1.16.4/go.mod h1:VBsHJRsDvfYvqoiMKnsdwhNV9LEMHgEDZcyVYX0
|
||||
github.com/toorop/go-dkim v0.0.0-20201103131630-e1cd1a0a5208/go.mod h1:BzWtXXrXzZUvMacR0oF/fbDDgUPO8L36tDMmRAf14ns=
|
||||
github.com/toorop/go-dkim v0.0.0-20250226130143-9025cce95817 h1:q0hKh5a5FRkhuTb5JNfgjzpzvYLHjH0QOgPZPYnRWGA=
|
||||
github.com/toorop/go-dkim v0.0.0-20250226130143-9025cce95817/go.mod h1:BzWtXXrXzZUvMacR0oF/fbDDgUPO8L36tDMmRAf14ns=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v0.0.0-20181209151446-772ced7fd4c2/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/utrack/gin-csrf v0.0.0-20190424104817-40fb8d2c8fca h1:lpvAjPK+PcxnbcB8H7axIb4fMNwjX9bE4DzwPjGg8aE=
|
||||
github.com/utrack/gin-csrf v0.0.0-20190424104817-40fb8d2c8fca/go.mod h1:XXKxNbpoLihvvT7orUZbs/iZayg1n4ip7iJakJPAwA8=
|
||||
github.com/vardius/message-bus v1.1.5 h1:YSAC2WB4HRlwc4neFPTmT88kzzoiQ+9WRRbej/E/LZc=
|
||||
github.com/vardius/message-bus v1.1.5/go.mod h1:6xladCV2lMkUAE4bzzS85qKOiB5miV7aBVRafiTJGqw=
|
||||
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
|
||||
@ -285,8 +213,6 @@ github.com/yeqown/go-qrcode/writer/compressed v1.0.1/go.mod h1:BJScsGUIKM+eg0CCL
|
||||
github.com/yeqown/reedsolomon v1.0.0 h1:x1h/Ej/uJnNu8jaX7GLHBWmZKCAWjEJTetkqaabr4B0=
|
||||
github.com/yeqown/reedsolomon v1.0.0/go.mod h1:P76zpcn2TCuL0ul1Fso373qHRc69LKwAw/Iy6g1WiiM=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
golang.org/x/arch v0.14.0 h1:z9JUEZWr8x4rR0OU6c4/4t6E6jOZ8/QBS2bBYBm4tx4=
|
||||
golang.org/x/arch v0.14.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
|
||||
@ -300,18 +226,17 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
|
||||
golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 h1:nDVHiLt8aIbd/VzvPWN6kSOPE7+F/fNFDSXLVYkE/Iw=
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394/go.mod h1:sIifuuw/Yco/y6yb6+bDNfyeQ/MdPUy/hKEMYQV17cM=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
|
||||
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||
golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
|
||||
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
@ -329,11 +254,10 @@ golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/oauth2 v0.27.0 h1:da9Vo7/tDv5RH/7nZDz1eMGS/q1Vv1N/7FCrBhI9I3M=
|
||||
golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
|
||||
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc=
|
||||
golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@ -341,9 +265,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20181228144115-9a3f9b0469bb/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@ -364,8 +287,8 @@ golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
@ -393,16 +316,16 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
||||
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
||||
golang.org/x/tools v0.31.0 h1:0EedkvKDbh+qistFTd0Bcwe/YLh4vHwWEkiI0toFIBU=
|
||||
golang.org/x/tools v0.31.0/go.mod h1:naFTU+Cev749tSJRXJlna0T3WxKvb1kWEx15xA4SdmQ=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
@ -411,12 +334,8 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T
|
||||
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
|
||||
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE=
|
||||
gopkg.in/go-playground/validator.v8 v8.18.2/go.mod h1:RX2a/7Ha8BgOhfk7j780h4/u/RRjR0eouCJSH80/M2Y=
|
||||
gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
@ -458,6 +377,5 @@ modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
|
||||
sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E=
|
||||
sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY=
|
||||
|
214
internal/app/api/core/middleware/cors/middleware.go
Normal file
214
internal/app/api/core/middleware/cors/middleware.go
Normal file
@ -0,0 +1,214 @@
|
||||
package cors
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Middleware is a type that creates a new CORS middleware. The CORS middleware
|
||||
// adds Cross-Origin Resource Sharing headers to the response. This middleware should
|
||||
// be used to allow cross-origin requests to your server.
|
||||
type Middleware struct {
|
||||
o options
|
||||
|
||||
varyHeaders string // precomputed Vary header
|
||||
allOrigins bool // all origins are allowed
|
||||
}
|
||||
|
||||
// New returns a new CORS middleware with the provided options.
|
||||
func New(opts ...Option) *Middleware {
|
||||
o := newOptions(opts...)
|
||||
|
||||
m := &Middleware{
|
||||
o: o,
|
||||
}
|
||||
|
||||
// set vary headers
|
||||
if m.o.allowPrivateNetworks {
|
||||
m.varyHeaders = "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network"
|
||||
} else {
|
||||
m.varyHeaders = "Origin, Access-Control-Request-Method, Access-Control-Request-Headers"
|
||||
}
|
||||
|
||||
if len(m.o.allowedOrigins) == 1 && m.o.allowedOrigins[0] == "*" {
|
||||
m.allOrigins = true
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Handler returns the CORS middleware handler.
|
||||
func (m *Middleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Handle preflight requests and stop the chain as some other
|
||||
// middleware may not handle OPTIONS requests correctly.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#preflighted_requests
|
||||
if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
|
||||
m.handlePreflight(w, r)
|
||||
w.WriteHeader(http.StatusNoContent) // always return 204 No Content
|
||||
return
|
||||
}
|
||||
|
||||
// handle normal CORS requests
|
||||
m.handleNormal(w, r)
|
||||
next.ServeHTTP(w, r) // execute the next handler
|
||||
})
|
||||
}
|
||||
|
||||
// region internal-helpers
|
||||
|
||||
// handlePreflight handles preflight requests. If the request was successful, this function will
|
||||
// write the CORS headers and return. If the request was not successful, this function will
|
||||
// not add any CORS headers and return - thus the CORS request is considered invalid.
|
||||
func (m *Middleware) handlePreflight(w http.ResponseWriter, r *http.Request) {
|
||||
// Always set Vary headers
|
||||
// see https://github.com/rs/cors/issues/10,
|
||||
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
|
||||
w.Header().Add("Vary", m.varyHeaders)
|
||||
|
||||
// check origin
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return // not a valid CORS request
|
||||
}
|
||||
|
||||
if !m.originAllowed(origin) {
|
||||
return
|
||||
}
|
||||
|
||||
// check method
|
||||
reqMethod := r.Header.Get("Access-Control-Request-Method")
|
||||
if !m.methodAllowed(reqMethod) {
|
||||
return
|
||||
}
|
||||
|
||||
// check headers
|
||||
reqHeaders := r.Header.Get("Access-Control-Request-Headers")
|
||||
if !m.headersAllowed(reqHeaders) {
|
||||
return
|
||||
}
|
||||
|
||||
// set CORS headers for the successful preflight request
|
||||
if m.allOrigins {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
} else {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin) // return original origin
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Methods", reqMethod)
|
||||
if reqHeaders != "" {
|
||||
// Spec says: Since the list of headers can be unbounded, simply returning supported headers
|
||||
// from Access-Control-Request-Headers can be enough
|
||||
w.Header().Set("Access-Control-Allow-Headers", reqHeaders)
|
||||
}
|
||||
if m.o.allowCredentials {
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
if m.o.allowPrivateNetworks && r.Header.Get("Access-Control-Request-Private-Network") == "true" {
|
||||
w.Header().Set("Access-Control-Allow-Private-Network", "true")
|
||||
}
|
||||
if m.o.maxAge > 0 {
|
||||
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(m.o.maxAge))
|
||||
}
|
||||
}
|
||||
|
||||
// handleNormal handles normal CORS requests. If the request was successful, this function will
|
||||
// write the CORS headers to the response. If the request was not successful, this function will
|
||||
// not add any CORS headers to the response. In this case, the CORS request is considered invalid.
|
||||
func (m *Middleware) handleNormal(w http.ResponseWriter, r *http.Request) {
|
||||
// Always set Vary headers
|
||||
// see https://github.com/rs/cors/issues/10,
|
||||
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
|
||||
w.Header().Add("Vary", "Origin")
|
||||
|
||||
// check origin
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return // not a valid CORS request
|
||||
}
|
||||
|
||||
if !m.originAllowed(origin) {
|
||||
return
|
||||
}
|
||||
|
||||
// check method
|
||||
if !m.methodAllowed(r.Method) {
|
||||
return
|
||||
}
|
||||
|
||||
// set CORS headers for the successful CORS request
|
||||
if m.allOrigins {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
} else {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin) // return original origin
|
||||
}
|
||||
if len(m.o.exposedHeaders) > 0 {
|
||||
w.Header().Set("Access-Control-Expose-Headers", strings.Join(m.o.exposedHeaders, ", "))
|
||||
}
|
||||
if m.o.allowCredentials {
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) originAllowed(origin string) bool {
|
||||
if len(m.o.allowedOrigins) == 1 && m.o.allowedOrigins[0] == "*" {
|
||||
return true // everything is allowed
|
||||
}
|
||||
|
||||
// check simple origins
|
||||
if slices.Contains(m.o.allowedOrigins, origin) {
|
||||
return true
|
||||
}
|
||||
|
||||
// check wildcard origins
|
||||
for _, allowedOrigin := range m.o.allowedOriginPatterns {
|
||||
if allowedOrigin.match(origin) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Middleware) methodAllowed(method string) bool {
|
||||
if method == http.MethodOptions {
|
||||
return true // preflight request is always allowed
|
||||
}
|
||||
|
||||
if len(m.o.allowedMethods) == 1 && m.o.allowedMethods[0] == "*" {
|
||||
return true // everything is allowed
|
||||
}
|
||||
|
||||
if slices.Contains(m.o.allowedMethods, method) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Middleware) headersAllowed(headers string) bool {
|
||||
if headers == "" {
|
||||
return true // no headers are requested
|
||||
}
|
||||
|
||||
if len(m.o.allowedHeaders) == 0 {
|
||||
return false // no headers are allowed
|
||||
}
|
||||
|
||||
if _, ok := m.o.allowedHeaders["*"]; ok {
|
||||
return true // everything is allowed
|
||||
}
|
||||
|
||||
// split headers by comma (according to definition, the headers are sorted and in lowercase)
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Request-Headers
|
||||
for header := range strings.SplitSeq(headers, ",") {
|
||||
if _, ok := m.o.allowedHeaders[strings.TrimSpace(header)]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// endregion internal-helpers
|
101
internal/app/api/core/middleware/cors/middleware_test.go
Normal file
101
internal/app/api/core/middleware/cors/middleware_test.go
Normal file
@ -0,0 +1,101 @@
|
||||
package cors
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMiddleware_New(t *testing.T) {
|
||||
m := New(WithAllowedOrigins("*"))
|
||||
|
||||
if len(m.varyHeaders) == 0 {
|
||||
t.Errorf("expected vary headers to be populated, got %v", m.varyHeaders)
|
||||
}
|
||||
if !m.allOrigins {
|
||||
t.Errorf("expected allOrigins to be true, got %v", m.allOrigins)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_normal(t *testing.T) {
|
||||
m := New(WithAllowedOrigins("http://example.com"))
|
||||
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Result().StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status code 200, got %d", w.Result().StatusCode)
|
||||
}
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
|
||||
t.Errorf("expected Access-Control-Allow-Origin to be 'http://example.com', got %s",
|
||||
w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_preflight(t *testing.T) {
|
||||
m := New(WithAllowedOrigins("http://example.com"))
|
||||
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "http://example.com", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
req.Header.Set("Access-Control-Request-Method", http.MethodGet)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Result().StatusCode != http.StatusNoContent {
|
||||
t.Errorf("expected status code 204, got %d", w.Result().StatusCode)
|
||||
}
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
|
||||
t.Errorf("expected Access-Control-Allow-Origin to be 'http://example.com', got %s",
|
||||
w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_originAllowed(t *testing.T) {
|
||||
m := New(WithAllowedOrigins("http://example.com"))
|
||||
|
||||
if !m.originAllowed("http://example.com") {
|
||||
t.Errorf("expected origin 'http://example.com' to be allowed")
|
||||
}
|
||||
|
||||
if m.originAllowed("http://notallowed.com") {
|
||||
t.Errorf("expected origin 'http://notallowed.com' to be not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_methodAllowed(t *testing.T) {
|
||||
m := New(WithAllowedMethods(http.MethodGet, http.MethodPost))
|
||||
|
||||
if !m.methodAllowed(http.MethodGet) {
|
||||
t.Errorf("expected method 'GET' to be allowed")
|
||||
}
|
||||
|
||||
if m.methodAllowed(http.MethodDelete) {
|
||||
t.Errorf("expected method 'DELETE' to be not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_headersAllowed(t *testing.T) {
|
||||
m := New(WithAllowedHeaders("Content-Type", "Authorization"))
|
||||
|
||||
if !m.headersAllowed("content-type, authorization") {
|
||||
t.Errorf("expected headers 'Content-Type, Authorization' to be allowed")
|
||||
}
|
||||
|
||||
if m.headersAllowed("x-custom-header") {
|
||||
t.Errorf("expected header 'X-Custom-Header' to be not allowed")
|
||||
}
|
||||
}
|
133
internal/app/api/core/middleware/cors/options.go
Normal file
133
internal/app/api/core/middleware/cors/options.go
Normal file
@ -0,0 +1,133 @@
|
||||
package cors
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type void struct{}
|
||||
|
||||
// options is a struct that contains options for the CORS middleware.
|
||||
// It uses the functional options pattern for flexible configuration.
|
||||
type options struct {
|
||||
allowedOrigins []string // origins without wildcards
|
||||
allowedOriginPatterns []wildcard // origins with wildcards
|
||||
allowedMethods []string
|
||||
allowedHeaders map[string]void
|
||||
exposedHeaders []string // these are in addition to the CORS-safelisted response headers
|
||||
allowCredentials bool
|
||||
allowPrivateNetworks bool
|
||||
maxAge int
|
||||
}
|
||||
|
||||
// Option is a type that is used to set options for the CORS middleware.
|
||||
// It implements the functional options pattern.
|
||||
type Option func(*options)
|
||||
|
||||
// WithAllowedOrigins sets the allowed origins for the CORS middleware.
|
||||
// If the special "*" value is present in the list, all origins will be allowed.
|
||||
// An origin may contain a wildcard (*) to replace 0 or more characters
|
||||
// (i.e.: http://*.domain.com). Usage of wildcards implies a small performance penalty.
|
||||
// Only one wildcard can be used per origin.
|
||||
// By default, all origins are allowed (*).
|
||||
func WithAllowedOrigins(origins ...string) Option {
|
||||
return func(o *options) {
|
||||
o.allowedOrigins = nil
|
||||
o.allowedOriginPatterns = nil
|
||||
|
||||
for _, origin := range origins {
|
||||
if len(origin) > 1 && strings.Contains(origin, "*") {
|
||||
o.allowedOriginPatterns = append(
|
||||
o.allowedOriginPatterns,
|
||||
newWildcard(origin),
|
||||
)
|
||||
} else {
|
||||
o.allowedOrigins = append(o.allowedOrigins, origin)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithAllowedMethods sets the allowed methods for the CORS middleware.
|
||||
// By default, all methods are allowed (*).
|
||||
func WithAllowedMethods(methods ...string) Option {
|
||||
return func(o *options) {
|
||||
o.allowedMethods = methods
|
||||
}
|
||||
}
|
||||
|
||||
// WithAllowedHeaders sets the allowed headers for the CORS middleware.
|
||||
// By default, all headers are allowed (*).
|
||||
func WithAllowedHeaders(headers ...string) Option {
|
||||
return func(o *options) {
|
||||
o.allowedHeaders = make(map[string]void)
|
||||
|
||||
for _, header := range headers {
|
||||
// allowed headers are always checked in lowercase
|
||||
o.allowedHeaders[strings.ToLower(header)] = void{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithExposedHeaders sets the exposed headers for the CORS middleware.
|
||||
// By default, no headers are exposed.
|
||||
func WithExposedHeaders(headers ...string) Option {
|
||||
return func(o *options) {
|
||||
o.exposedHeaders = nil
|
||||
|
||||
for _, header := range headers {
|
||||
o.exposedHeaders = append(o.exposedHeaders, http.CanonicalHeaderKey(header))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithAllowCredentials sets the allow credentials option for the CORS middleware.
|
||||
// This setting indicates whether the request can include user credentials like
|
||||
// cookies, HTTP authentication or client side SSL certificates.
|
||||
// By default, credentials are not allowed.
|
||||
func WithAllowCredentials(allow bool) Option {
|
||||
return func(o *options) {
|
||||
o.allowCredentials = allow
|
||||
}
|
||||
}
|
||||
|
||||
// WithAllowPrivateNetworks sets the allow private networks option for the CORS middleware.
|
||||
// This setting indicates whether to accept cross-origin requests over a private network.
|
||||
func WithAllowPrivateNetworks(allow bool) Option {
|
||||
return func(o *options) {
|
||||
o.allowPrivateNetworks = allow
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxAge sets the max age (in seconds) for the CORS middleware.
|
||||
// The maximum age indicates how long (in seconds) the results of a preflight request
|
||||
// can be cached. A value of 0 means that no Access-Control-Max-Age header is sent back,
|
||||
// resulting in browsers using their default value (5s by spec).
|
||||
// If you need to force a 0 max-age, set it to a negative value (ie: -1).
|
||||
// By default, the max age is 7200 seconds.
|
||||
func WithMaxAge(age int) Option {
|
||||
return func(o *options) {
|
||||
o.maxAge = age
|
||||
}
|
||||
}
|
||||
|
||||
// newOptions is a function that returns a new options struct with sane default values.
|
||||
func newOptions(opts ...Option) options {
|
||||
o := options{
|
||||
allowedOrigins: []string{"*"},
|
||||
allowedMethods: []string{
|
||||
http.MethodHead, http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete,
|
||||
},
|
||||
allowedHeaders: map[string]void{"*": {}},
|
||||
exposedHeaders: nil,
|
||||
allowCredentials: false,
|
||||
allowPrivateNetworks: false,
|
||||
maxAge: 0,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&o)
|
||||
}
|
||||
|
||||
return o
|
||||
}
|
96
internal/app/api/core/middleware/cors/options_test.go
Normal file
96
internal/app/api/core/middleware/cors/options_test.go
Normal file
@ -0,0 +1,96 @@
|
||||
package cors
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"net/http"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWithAllowedOrigins(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
origins []string
|
||||
wantNormal []string
|
||||
wantWildcard []wildcard
|
||||
}{
|
||||
{
|
||||
name: "No origins",
|
||||
origins: []string{},
|
||||
wantNormal: nil,
|
||||
wantWildcard: nil,
|
||||
},
|
||||
{
|
||||
name: "Single origin",
|
||||
origins: []string{"http://example.com"},
|
||||
wantNormal: []string{"http://example.com"},
|
||||
wantWildcard: nil,
|
||||
},
|
||||
{
|
||||
name: "Wildcard origin",
|
||||
origins: []string{"http://*.example.com"},
|
||||
wantNormal: nil,
|
||||
wantWildcard: []wildcard{newWildcard("http://*.example.com")},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
o := newOptions(WithAllowedOrigins(tt.origins...))
|
||||
if !slices.Equal(o.allowedOrigins, tt.wantNormal) {
|
||||
t.Errorf("got %v, want %v", o, tt.wantNormal)
|
||||
}
|
||||
if !slices.Equal(o.allowedOriginPatterns, tt.wantWildcard) {
|
||||
t.Errorf("got %v, want %v", o, tt.wantWildcard)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithAllowedMethods(t *testing.T) {
|
||||
methods := []string{http.MethodGet, http.MethodPost}
|
||||
o := newOptions(WithAllowedMethods(methods...))
|
||||
if !slices.Equal(o.allowedMethods, methods) {
|
||||
t.Errorf("got %v, want %v", o.allowedMethods, methods)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithAllowedHeaders(t *testing.T) {
|
||||
headers := []string{"Content-Type", "Authorization"}
|
||||
o := newOptions(WithAllowedHeaders(headers...))
|
||||
expectedHeaders := map[string]void{"content-type": {}, "authorization": {}}
|
||||
if !maps.Equal(o.allowedHeaders, expectedHeaders) {
|
||||
t.Errorf("got %v, want %v", o.allowedHeaders, expectedHeaders)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithExposedHeaders(t *testing.T) {
|
||||
headers := []string{"X-Custom-Header"}
|
||||
o := newOptions(WithExposedHeaders(headers...))
|
||||
expectedHeaders := []string{http.CanonicalHeaderKey("X-Custom-Header")}
|
||||
if !slices.Equal(o.exposedHeaders, expectedHeaders) {
|
||||
t.Errorf("got %v, want %v", o.exposedHeaders, expectedHeaders)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithAllowCredentials(t *testing.T) {
|
||||
o := newOptions(WithAllowCredentials(true))
|
||||
if !o.allowCredentials {
|
||||
t.Errorf("got %v, want %v", o.allowCredentials, true)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithAllowPrivateNetworks(t *testing.T) {
|
||||
o := newOptions(WithAllowPrivateNetworks(true))
|
||||
if !o.allowPrivateNetworks {
|
||||
t.Errorf("got %v, want %v", o.allowPrivateNetworks, true)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithMaxAge(t *testing.T) {
|
||||
maxAge := 3600
|
||||
o := newOptions(WithMaxAge(maxAge))
|
||||
if o.maxAge != maxAge {
|
||||
t.Errorf("got %v, want %v", o.maxAge, maxAge)
|
||||
}
|
||||
}
|
33
internal/app/api/core/middleware/cors/wildcard.go
Normal file
33
internal/app/api/core/middleware/cors/wildcard.go
Normal file
@ -0,0 +1,33 @@
|
||||
package cors
|
||||
|
||||
import "strings"
|
||||
|
||||
// wildcard is a type that represents a wildcard string.
|
||||
// This type allows faster matching of strings with a wildcard
|
||||
// in comparison to using regex.
|
||||
type wildcard struct {
|
||||
prefix string
|
||||
suffix string
|
||||
}
|
||||
|
||||
// match returns true if the string s has the prefix and suffix of the wildcard.
|
||||
func (w wildcard) match(s string) bool {
|
||||
return len(s) >= len(w.prefix)+len(w.suffix) &&
|
||||
strings.HasPrefix(s, w.prefix) &&
|
||||
strings.HasSuffix(s, w.suffix)
|
||||
}
|
||||
|
||||
func newWildcard(s string) wildcard {
|
||||
if i := strings.IndexByte(s, '*'); i >= 0 {
|
||||
return wildcard{
|
||||
prefix: s[:i],
|
||||
suffix: s[i+1:],
|
||||
}
|
||||
}
|
||||
|
||||
// fallback, usually this case should not happen
|
||||
return wildcard{
|
||||
prefix: s,
|
||||
suffix: "",
|
||||
}
|
||||
}
|
94
internal/app/api/core/middleware/cors/wildcard_test.go
Normal file
94
internal/app/api/core/middleware/cors/wildcard_test.go
Normal file
@ -0,0 +1,94 @@
|
||||
package cors
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestWildcardMatch(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wildcard wildcard
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Match with prefix and suffix",
|
||||
wildcard: newWildcard("http://*.example.com"),
|
||||
input: "http://sub.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "No match with different prefix",
|
||||
wildcard: newWildcard("http://*.example.com"),
|
||||
input: "https://sub.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "No match with different suffix",
|
||||
wildcard: newWildcard("http://*.example.com"),
|
||||
input: "http://sub.example.org",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Match with empty suffix",
|
||||
wildcard: newWildcard("http://*"),
|
||||
input: "http://example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Match with empty prefix",
|
||||
wildcard: newWildcard("*.example.com"),
|
||||
input: "sub.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "No match with empty prefix and different suffix",
|
||||
wildcard: newWildcard("*.example.com"),
|
||||
input: "sub.example.org",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.wildcard.match(tt.input); got != tt.expected {
|
||||
t.Errorf("wildcard.match(%s) = %v, want %v", tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWildcard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected wildcard
|
||||
}{
|
||||
{
|
||||
name: "Wildcard with prefix and suffix",
|
||||
input: "http://*.example.com",
|
||||
expected: wildcard{prefix: "http://", suffix: ".example.com"},
|
||||
},
|
||||
{
|
||||
name: "Wildcard with empty suffix",
|
||||
input: "http://*",
|
||||
expected: wildcard{prefix: "http://", suffix: ""},
|
||||
},
|
||||
{
|
||||
name: "Wildcard with empty prefix",
|
||||
input: "*.example.com",
|
||||
expected: wildcard{prefix: "", suffix: ".example.com"},
|
||||
},
|
||||
{
|
||||
name: "No wildcard character",
|
||||
input: "http://example.com",
|
||||
expected: wildcard{prefix: "http://example.com", suffix: ""},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := newWildcard(tt.input); got != tt.expected {
|
||||
t.Errorf("newWildcard(%s) = %v, want %v", tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
137
internal/app/api/core/middleware/csrf/middleware.go
Normal file
137
internal/app/api/core/middleware/csrf/middleware.go
Normal file
@ -0,0 +1,137 @@
|
||||
package csrf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"slices"
|
||||
)
|
||||
|
||||
// ContextValueIdentifier is the context value identifier for the CSRF token.
|
||||
// The token is only stored in the context if the RefreshToken function was called before.
|
||||
const ContextValueIdentifier = "_csrf_token"
|
||||
|
||||
// Middleware is a type that creates a new CSRF middleware. The CSRF middleware
|
||||
// can be used to mitigate Cross-Site Request Forgery attacks.
|
||||
type Middleware struct {
|
||||
o options
|
||||
}
|
||||
|
||||
// New returns a new CSRF middleware with the provided options.
|
||||
func New(sessionReader SessionReader, sessionWriter SessionWriter, opts ...Option) *Middleware {
|
||||
opts = append(opts, withSessionReader(sessionReader), withSessionWriter(sessionWriter))
|
||||
o := newOptions(opts...)
|
||||
|
||||
m := &Middleware{
|
||||
o: o,
|
||||
}
|
||||
|
||||
checkForPRNG()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Handler returns the CSRF middleware handler. This middleware validates the CSRF token and calls the specified
|
||||
// error handler if an invalid CSRF token was found.
|
||||
func (m *Middleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if slices.Contains(m.o.ignoreMethods, r.Method) {
|
||||
next.ServeHTTP(w, r) // skip CSRF check for ignored methods
|
||||
return
|
||||
}
|
||||
|
||||
// get the token from the request
|
||||
token := m.o.tokenGetter(r)
|
||||
storedToken := m.o.sessionGetter(r)
|
||||
|
||||
if !tokenEqual(token, storedToken) {
|
||||
m.o.errCallback(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r) // execute the next handler
|
||||
})
|
||||
}
|
||||
|
||||
// RefreshToken generates a new CSRF Token and stores it in the session. The token is also passed to subsequent handlers
|
||||
// via the context value ContextValueIdentifier.
|
||||
func (m *Middleware) RefreshToken(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if GetToken(r.Context()) != "" {
|
||||
// token already generated higher up in the chain
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// generate a new token
|
||||
token := generateToken(m.o.tokenLength)
|
||||
key := generateToken(m.o.tokenLength)
|
||||
|
||||
// mask the token
|
||||
maskedToken := maskToken(token, key)
|
||||
|
||||
// store the encoded token in the session
|
||||
encodedToken := encodeToken(maskedToken)
|
||||
m.o.sessionWriter(r, encodedToken)
|
||||
|
||||
// pass the token down the chain via the context
|
||||
r = r.WithContext(setToken(r.Context(), encodedToken))
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// region token-access
|
||||
|
||||
// GetToken retrieves the CSRF token from the given context. Ensure that the RefreshToken function was called before,
|
||||
// otherwise, no token is populated in the context.
|
||||
func GetToken(ctx context.Context) string {
|
||||
token, ok := ctx.Value(ContextValueIdentifier).(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
// endregion token-access
|
||||
|
||||
// region internal-helpers
|
||||
|
||||
func setToken(ctx context.Context, token string) context.Context {
|
||||
return context.WithValue(ctx, ContextValueIdentifier, token)
|
||||
}
|
||||
|
||||
// defaultTokenGetter is the default token getter function for the CSRF middleware.
|
||||
// It checks the request form values, URL query parameters, and headers for the CSRF token.
|
||||
// The order of precedence is:
|
||||
// 1. Header "X-CSRF-TOKEN"
|
||||
// 2. Header "X-XSRF-TOKEN"
|
||||
// 3. URL query parameter "_csrf"
|
||||
// 4. Form value "_csrf"
|
||||
func defaultTokenGetter(r *http.Request) string {
|
||||
if t := r.Header.Get("X-CSRF-TOKEN"); len(t) > 0 {
|
||||
return t
|
||||
}
|
||||
|
||||
if t := r.Header.Get("X-XSRF-TOKEN"); len(t) > 0 {
|
||||
return t
|
||||
}
|
||||
|
||||
if t := r.URL.Query().Get("_csrf"); len(t) > 0 {
|
||||
return t
|
||||
}
|
||||
|
||||
if t := r.FormValue("_csrf"); len(t) > 0 {
|
||||
return t
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// defaultErrorHandler is the default error handler function for the CSRF middleware.
|
||||
// It writes a 403 Forbidden response.
|
||||
func defaultErrorHandler(w http.ResponseWriter, _ *http.Request) {
|
||||
http.Error(w, "CSRF token mismatch", http.StatusForbidden)
|
||||
}
|
||||
|
||||
// endregion internal-helpers
|
251
internal/app/api/core/middleware/csrf/middleware_test.go
Normal file
251
internal/app/api/core/middleware/csrf/middleware_test.go
Normal file
@ -0,0 +1,251 @@
|
||||
package csrf
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/request"
|
||||
)
|
||||
|
||||
func TestMiddleware_Handler(t *testing.T) {
|
||||
sessionToken := "stored-token"
|
||||
sessionReader := func(r *http.Request) string {
|
||||
return sessionToken
|
||||
}
|
||||
sessionWriter := func(r *http.Request, token string) {
|
||||
sessionToken = token
|
||||
}
|
||||
m := New(sessionReader, sessionWriter)
|
||||
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{"ValidToken", "POST", "stored-token", http.StatusOK},
|
||||
{"ValidToken2", "PUT", "stored-token", http.StatusOK},
|
||||
{"ValidToken3", "GET", "stored-token", http.StatusOK},
|
||||
{"InvalidToken", "POST", "invalid-token", http.StatusForbidden},
|
||||
{"IgnoredMethod", "GET", "", http.StatusOK},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tt.method, "/", nil)
|
||||
req.Header.Set("X-CSRF-TOKEN", tt.token)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if status := rr.Code; status != tt.wantStatus {
|
||||
t.Errorf("Handler() status = %d, want %d", status, tt.wantStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_RefreshToken(t *testing.T) {
|
||||
sessionToken := ""
|
||||
sessionReader := func(r *http.Request) string {
|
||||
return sessionToken
|
||||
}
|
||||
sessionWriter := func(r *http.Request, token string) {
|
||||
sessionToken = token
|
||||
}
|
||||
m := New(sessionReader, sessionWriter)
|
||||
|
||||
handler := m.RefreshToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := GetToken(r.Context())
|
||||
if token == "" {
|
||||
t.Errorf("RefreshToken() did not set token in context")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("POST", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("RefreshToken() status = %d, want %d", status, http.StatusOK)
|
||||
}
|
||||
|
||||
if sessionToken == "" {
|
||||
t.Errorf("RefreshToken() did not set token in session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_RefreshToken_chained(t *testing.T) {
|
||||
sessionToken := ""
|
||||
tokenWrites := 0
|
||||
sessionReader := func(r *http.Request) string {
|
||||
return sessionToken
|
||||
}
|
||||
sessionWriter := func(r *http.Request, token string) {
|
||||
sessionToken = token
|
||||
tokenWrites++
|
||||
}
|
||||
m := New(sessionReader, sessionWriter)
|
||||
|
||||
handler := m.RefreshToken(m.RefreshToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := GetToken(r.Context())
|
||||
if token == "" {
|
||||
t.Errorf("RefreshToken() did not set token in context")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})))
|
||||
|
||||
req := httptest.NewRequest("POST", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("RefreshToken() status = %d, want %d", status, http.StatusOK)
|
||||
}
|
||||
|
||||
if sessionToken == "" {
|
||||
t.Errorf("RefreshToken() did not set token in session")
|
||||
}
|
||||
|
||||
if tokenWrites != 1 {
|
||||
t.Errorf("RefreshToken() wrote token to session more than once: %d", tokenWrites)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_RefreshToken_Handler(t *testing.T) {
|
||||
sessionToken := ""
|
||||
sessionReader := func(r *http.Request) string {
|
||||
return sessionToken
|
||||
}
|
||||
sessionWriter := func(r *http.Request, token string) {
|
||||
sessionToken = token
|
||||
}
|
||||
m := New(sessionReader, sessionWriter)
|
||||
|
||||
// simulate two requests: first one GET request with the RefreshToken handler, the next one is a PUT request with
|
||||
// the token from the first request added as X-CSRF-TOKEN header
|
||||
|
||||
// first request
|
||||
retrievedToken := ""
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler := m.RefreshToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
retrievedToken = GetToken(r.Context())
|
||||
if retrievedToken == "" {
|
||||
t.Errorf("RefreshToken() did not set token in context")
|
||||
}
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
}))
|
||||
handler.ServeHTTP(rr, req)
|
||||
if status := rr.Code; status != http.StatusAccepted {
|
||||
t.Errorf("Handler() status = %d, want %d", status, http.StatusAccepted)
|
||||
}
|
||||
if retrievedToken == "" {
|
||||
t.Errorf("no token retrieved")
|
||||
}
|
||||
if retrievedToken != sessionToken {
|
||||
t.Errorf("token in context does not match token in session")
|
||||
}
|
||||
|
||||
// second request
|
||||
req = httptest.NewRequest("PUT", "/", nil)
|
||||
req.Header.Set("X-CSRF-TOKEN", retrievedToken)
|
||||
rr = httptest.NewRecorder()
|
||||
handler = m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
handler.ServeHTTP(rr, req)
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("Handler() status = %d, want %d", status, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_FormBody(t *testing.T) {
|
||||
sessionToken := "stored-token"
|
||||
sessionReader := func(r *http.Request) string {
|
||||
return sessionToken
|
||||
}
|
||||
sessionWriter := func(r *http.Request, token string) {
|
||||
sessionToken = token
|
||||
}
|
||||
m := New(sessionReader, sessionWriter)
|
||||
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
bodyData, err := request.BodyString(r)
|
||||
if err != nil {
|
||||
t.Errorf("Handler() error = %v, want nil", err)
|
||||
}
|
||||
// ensure that the body is empty - ParseForm() should have been called before by the CSRF middleware
|
||||
if bodyData != "" {
|
||||
t.Errorf("Handler() bodyData = %s, want empty", bodyData)
|
||||
}
|
||||
|
||||
if r.FormValue("_csrf") != "stored-token" {
|
||||
t.Errorf("Handler() _csrf = %s, want %s", r.FormValue("_csrf"), "stored-token")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("POST", "/", nil)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Form = make(map[string][]string)
|
||||
req.Form.Add("_csrf", "stored-token")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("Handler() status = %d, want %d", status, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_FormBodyAvailable(t *testing.T) {
|
||||
sessionToken := "stored-token"
|
||||
sessionReader := func(r *http.Request) string {
|
||||
return sessionToken
|
||||
}
|
||||
sessionWriter := func(r *http.Request, token string) {
|
||||
sessionToken = token
|
||||
}
|
||||
m := New(sessionReader, sessionWriter)
|
||||
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
bodyData, err := request.BodyString(r)
|
||||
if err != nil {
|
||||
t.Errorf("Handler() error = %v, want nil", err)
|
||||
}
|
||||
// ensure that the body is not empty, as the CSRF middleware should not have read the body
|
||||
if bodyData != "the original body" {
|
||||
t.Errorf("Handler() bodyData = %s, want %s", bodyData, "the original body")
|
||||
}
|
||||
|
||||
// check if the token is available in the form values (from query parameters)
|
||||
if r.FormValue("_csrf") != "stored-token" {
|
||||
t.Errorf("Handler() _csrf = %s, want %s", r.FormValue("_csrf"), "stored-token")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("POST", "/?_csrf=stored-token", nil)
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
req.Body = io.NopCloser(strings.NewReader("the original body"))
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("Handler() status = %d, want %d", status, http.StatusOK)
|
||||
}
|
||||
}
|
88
internal/app/api/core/middleware/csrf/options.go
Normal file
88
internal/app/api/core/middleware/csrf/options.go
Normal file
@ -0,0 +1,88 @@
|
||||
package csrf
|
||||
|
||||
import "net/http"
|
||||
|
||||
type SessionReader func(r *http.Request) string
|
||||
type SessionWriter func(r *http.Request, token string)
|
||||
|
||||
// options is a struct that contains options for the CSRF middleware.
|
||||
// It uses the functional options pattern for flexible configuration.
|
||||
type options struct {
|
||||
tokenLength int
|
||||
ignoreMethods []string
|
||||
|
||||
errCallbackOverride bool
|
||||
errCallback func(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
tokenGetterOverride bool
|
||||
tokenGetter func(r *http.Request) string
|
||||
|
||||
sessionGetter SessionReader
|
||||
sessionWriter SessionWriter
|
||||
}
|
||||
|
||||
// Option is a type that is used to set options for the CSRF middleware.
|
||||
// It implements the functional options pattern.
|
||||
type Option func(*options)
|
||||
|
||||
// WithTokenLength is a method that sets the token length for the CSRF middleware.
|
||||
// The default value is 32.
|
||||
func WithTokenLength(length int) Option {
|
||||
return func(o *options) {
|
||||
o.tokenLength = length
|
||||
}
|
||||
}
|
||||
|
||||
// WithErrorCallback is a method that sets the error callback function for the CSRF middleware.
|
||||
// The error callback function is called when the CSRF token is invalid.
|
||||
// The default behavior is to write a 403 Forbidden response.
|
||||
func WithErrorCallback(fn func(w http.ResponseWriter, r *http.Request)) Option {
|
||||
return func(o *options) {
|
||||
o.errCallback = fn
|
||||
o.errCallbackOverride = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithTokenGetter is a method that sets the token getter function for the CSRF middleware.
|
||||
// The token getter function is called to get the CSRF token from the request.
|
||||
// The default behavior is to get the token from the "X-CSRF-Token" header.
|
||||
func WithTokenGetter(fn func(r *http.Request) string) Option {
|
||||
return func(o *options) {
|
||||
o.tokenGetter = fn
|
||||
o.tokenGetterOverride = true
|
||||
}
|
||||
}
|
||||
|
||||
// withSessionReader is a method that sets the session reader function for the CSRF middleware.
|
||||
// The session reader function is called to get the CSRF token from the session.
|
||||
func withSessionReader(fn SessionReader) Option {
|
||||
return func(o *options) {
|
||||
o.sessionGetter = fn
|
||||
}
|
||||
}
|
||||
|
||||
// withSessionWriter is a method that sets the session writer function for the CSRF middleware.
|
||||
// The session writer function is called to write the CSRF token to the session.
|
||||
func withSessionWriter(fn SessionWriter) Option {
|
||||
return func(o *options) {
|
||||
o.sessionWriter = fn
|
||||
}
|
||||
}
|
||||
|
||||
// newOptions is a function that returns a new options struct with sane default values.
|
||||
func newOptions(opts ...Option) options {
|
||||
o := options{
|
||||
tokenLength: 32,
|
||||
ignoreMethods: []string{"GET", "HEAD", "OPTIONS"},
|
||||
errCallbackOverride: false,
|
||||
errCallback: defaultErrorHandler,
|
||||
tokenGetterOverride: false,
|
||||
tokenGetter: defaultTokenGetter,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&o)
|
||||
}
|
||||
|
||||
return o
|
||||
}
|
75
internal/app/api/core/middleware/csrf/options_test.go
Normal file
75
internal/app/api/core/middleware/csrf/options_test.go
Normal file
@ -0,0 +1,75 @@
|
||||
package csrf
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWithTokenLength(t *testing.T) {
|
||||
o := newOptions(WithTokenLength(64))
|
||||
if o.tokenLength != 64 {
|
||||
t.Errorf("WithTokenLength() = %d, want %d", o.tokenLength, 64)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithErrorCallback(t *testing.T) {
|
||||
callback := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
}
|
||||
o := newOptions(WithErrorCallback(callback))
|
||||
if !o.errCallbackOverride {
|
||||
t.Errorf("WithErrorCallback() did not set errCallbackOverride to true")
|
||||
}
|
||||
if o.errCallback == nil {
|
||||
t.Errorf("WithErrorCallback() did not set errCallback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTokenGetter(t *testing.T) {
|
||||
getter := func(r *http.Request) string {
|
||||
return "test-token"
|
||||
}
|
||||
o := newOptions(WithTokenGetter(getter))
|
||||
if !o.tokenGetterOverride {
|
||||
t.Errorf("WithTokenGetter() did not set tokenGetterOverride to true")
|
||||
}
|
||||
if o.tokenGetter == nil {
|
||||
t.Errorf("WithTokenGetter() did not set tokenGetter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithSessionReader(t *testing.T) {
|
||||
reader := func(r *http.Request) string {
|
||||
return "session-token"
|
||||
}
|
||||
o := newOptions(withSessionReader(reader))
|
||||
if o.sessionGetter == nil {
|
||||
t.Errorf("withSessionReader() did not set sessionGetter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithSessionWriter(t *testing.T) {
|
||||
writer := func(r *http.Request, token string) {
|
||||
// do nothing
|
||||
}
|
||||
o := newOptions(withSessionWriter(writer))
|
||||
if o.sessionWriter == nil {
|
||||
t.Errorf("withSessionWriter() did not set sessionWriter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewOptionsDefaults(t *testing.T) {
|
||||
o := newOptions()
|
||||
if o.tokenLength != 32 {
|
||||
t.Errorf("newOptions() default tokenLength = %d, want %d", o.tokenLength, 32)
|
||||
}
|
||||
if len(o.ignoreMethods) != 3 {
|
||||
t.Errorf("newOptions() default ignoreMethods length = %d, want %d", len(o.ignoreMethods), 3)
|
||||
}
|
||||
if o.errCallback == nil {
|
||||
t.Errorf("newOptions() default errCallback is nil")
|
||||
}
|
||||
if o.tokenGetter == nil {
|
||||
t.Errorf("newOptions() default tokenGetter is nil")
|
||||
}
|
||||
}
|
90
internal/app/api/core/middleware/csrf/token.go
Normal file
90
internal/app/api/core/middleware/csrf/token.go
Normal file
@ -0,0 +1,90 @@
|
||||
package csrf
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
)
|
||||
|
||||
// checkForPRNG is a function that checks if a cryptographically secure PRNG is available.
|
||||
// If it is not available, the function panics.
|
||||
func checkForPRNG() {
|
||||
buf := make([]byte, 1)
|
||||
_, err := io.ReadFull(rand.Reader, buf)
|
||||
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand is unavailable: Read() failed with %#v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// generateToken is a function that generates a secure random CSRF token.
|
||||
func generateToken(length int) []byte {
|
||||
bytes := make([]byte, length)
|
||||
|
||||
if _, err := io.ReadFull(rand.Reader, bytes); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return bytes
|
||||
}
|
||||
|
||||
// encodeToken is a function that encodes a token to a base64 string.
|
||||
func encodeToken(token []byte) string {
|
||||
return base64.URLEncoding.EncodeToString(token)
|
||||
}
|
||||
|
||||
// decodeToken is a function that decodes a base64 string to a token.
|
||||
func decodeToken(token string) ([]byte, error) {
|
||||
return base64.URLEncoding.DecodeString(token)
|
||||
}
|
||||
|
||||
// maskToken is a function that masks a token with a given key.
|
||||
// The returned byte slice contains the key + the masked token.
|
||||
// The key needs to have the same length as the token, otherwise the function panics.
|
||||
// So the resulting slice has a length of len(token) * 2.
|
||||
func maskToken(token, key []byte) []byte {
|
||||
if len(token) != len(key) {
|
||||
panic("token and key must have the same length")
|
||||
}
|
||||
|
||||
// masked contains the key in the first half and the XOR masked token in the second half
|
||||
tokenLength := len(token)
|
||||
masked := make([]byte, tokenLength*2)
|
||||
for i := 0; i < len(token); i++ {
|
||||
masked[i] = key[i]
|
||||
masked[i+tokenLength] = token[i] ^ key[i] // XOR mask
|
||||
}
|
||||
|
||||
return masked
|
||||
}
|
||||
|
||||
// unmaskToken is a function that unmask a token which contains the key in the first half.
|
||||
// The returned byte slice contains the unmasked token, it has exactly half the length of the input slice.
|
||||
func unmaskToken(masked []byte) []byte {
|
||||
tokenLength := len(masked) / 2
|
||||
token := make([]byte, tokenLength)
|
||||
for i := 0; i < tokenLength; i++ {
|
||||
token[i] = masked[i] ^ masked[i+tokenLength] // XOR unmask
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
// tokenEqual is a function that compares two tokens for equality.
|
||||
func tokenEqual(a, b string) bool {
|
||||
decodedA, err := decodeToken(a)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
decodedB, err := decodeToken(b)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
unmaskedA := unmaskToken(decodedA)
|
||||
unmaskedB := unmaskToken(decodedB)
|
||||
|
||||
return slices.Equal(unmaskedA, unmaskedB)
|
||||
}
|
81
internal/app/api/core/middleware/csrf/token_test.go
Normal file
81
internal/app/api/core/middleware/csrf/token_test.go
Normal file
@ -0,0 +1,81 @@
|
||||
package csrf
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCheckForPRNG(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("checkForPRNG() panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
checkForPRNG()
|
||||
}
|
||||
|
||||
func TestGenerateToken(t *testing.T) {
|
||||
length := 32
|
||||
token := generateToken(length)
|
||||
if len(token) != length {
|
||||
t.Errorf("generateToken() returned token of length %d, expected %d", len(token), length)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeToken(t *testing.T) {
|
||||
token := []byte("testtoken")
|
||||
encoded := encodeToken(token)
|
||||
expected := base64.URLEncoding.EncodeToString(token)
|
||||
if encoded != expected {
|
||||
t.Errorf("encodeToken() = %v, want %v", encoded, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToken(t *testing.T) {
|
||||
token := "dGVzdHRva2Vu"
|
||||
expected := []byte("testtoken")
|
||||
decoded, err := decodeToken(token)
|
||||
if err != nil {
|
||||
t.Errorf("decodeToken() error = %v", err)
|
||||
}
|
||||
if string(decoded) != string(expected) {
|
||||
t.Errorf("decodeToken() = %v, want %v", decoded, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaskToken(t *testing.T) {
|
||||
token := []byte("testtoken")
|
||||
key := []byte("keykeykey")
|
||||
masked := maskToken(token, key)
|
||||
if len(masked) != len(token)*2 {
|
||||
t.Errorf("maskToken() returned masked token of length %d, expected %d", len(masked), len(token)*2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmaskToken(t *testing.T) {
|
||||
token := []byte("testtoken")
|
||||
key := []byte("keykeykey")
|
||||
masked := maskToken(token, key)
|
||||
unmasked := unmaskToken(masked)
|
||||
if string(unmasked) != string(token) {
|
||||
t.Errorf("unmaskToken() = %v, want %v", unmasked, token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenEqual(t *testing.T) {
|
||||
tokenA := encodeToken(maskToken([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}))
|
||||
tokenB := encodeToken(maskToken([]byte{0x01, 0x02, 0x03}, []byte{0x04, 0x05, 0x06}))
|
||||
if !tokenEqual(tokenA, tokenB) {
|
||||
t.Errorf("tokenEqual() = false, want true")
|
||||
}
|
||||
|
||||
tokenC := encodeToken(maskToken([]byte{0x01, 0x02, 0x03}, []byte{0x07, 0x08, 0x09}))
|
||||
if !tokenEqual(tokenA, tokenC) {
|
||||
t.Errorf("tokenEqual() = false, want true")
|
||||
}
|
||||
|
||||
tokenD := encodeToken(maskToken([]byte{0x09, 0x02, 0x03}, []byte{0x04, 0x05, 0x06}))
|
||||
if tokenEqual(tokenA, tokenD) {
|
||||
t.Errorf("tokenEqual() = true, want false")
|
||||
}
|
||||
}
|
199
internal/app/api/core/middleware/logging/middleware.go
Normal file
199
internal/app/api/core/middleware/logging/middleware.go
Normal file
@ -0,0 +1,199 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LogLevel is an enumeration of the different log levels.
|
||||
type LogLevel int
|
||||
|
||||
const (
|
||||
LogLevelDebug LogLevel = iota
|
||||
LogLevelInfo
|
||||
LogLevelWarn
|
||||
LogLevelError
|
||||
)
|
||||
|
||||
// Logger is an interface that defines the methods that a logger must implement.
|
||||
// This allows the logging middleware to be used with different logging libraries.
|
||||
type Logger interface {
|
||||
// Debugf logs a message at debug level.
|
||||
Debugf(format string, args ...any)
|
||||
// Infof logs a message at info level.
|
||||
Infof(format string, args ...any)
|
||||
// Warnf logs a message at warn level.
|
||||
Warnf(format string, args ...any)
|
||||
// Errorf logs a message at error level.
|
||||
Errorf(format string, args ...any)
|
||||
}
|
||||
|
||||
// Middleware is a type that creates a new logging middleware. The logging middleware
|
||||
// logs information about each request.
|
||||
type Middleware struct {
|
||||
o options
|
||||
}
|
||||
|
||||
// New returns a new logging middleware with the provided options.
|
||||
func New(opts ...Option) *Middleware {
|
||||
o := newOptions(opts...)
|
||||
|
||||
m := &Middleware{
|
||||
o: o,
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Handler returns the logging middleware handler.
|
||||
func (m *Middleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ww := newWriterWrapper(w)
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
info := m.extractInfoMap(r, start, ww)
|
||||
|
||||
if m.o.logger == nil {
|
||||
msg, args := m.buildSlogMessageAndArguments(info)
|
||||
m.logMsg(msg, args...)
|
||||
} else {
|
||||
msg := m.buildNormalLogMessage(info)
|
||||
m.logMsg(msg)
|
||||
}
|
||||
}()
|
||||
|
||||
next.ServeHTTP(ww, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Middleware) extractInfoMap(r *http.Request, start time.Time, ww *writerWrapper) map[string]any {
|
||||
info := make(map[string]any)
|
||||
|
||||
info["method"] = r.Method
|
||||
info["path"] = r.URL.Path
|
||||
info["protocol"] = r.Proto
|
||||
info["clientIP"] = r.Header.Get("X-Forwarded-For")
|
||||
if info["clientIP"] == "" {
|
||||
// If the X-Forwarded-For header is not set, use the remote address without the port number.
|
||||
lastColonIndex := strings.LastIndex(r.RemoteAddr, ":")
|
||||
switch lastColonIndex {
|
||||
case -1:
|
||||
info["clientIP"] = r.RemoteAddr
|
||||
default:
|
||||
info["clientIP"] = r.RemoteAddr[:lastColonIndex]
|
||||
}
|
||||
}
|
||||
info["userAgent"] = r.UserAgent()
|
||||
info["referer"] = r.Header.Get("Referer")
|
||||
info["duration"] = time.Since(start).String()
|
||||
info["status"] = ww.StatusCode
|
||||
info["dataLength"] = ww.WrittenBytes
|
||||
|
||||
if m.o.headerRequestIdKey != "" {
|
||||
info["headerRequestId"] = r.Header.Get(m.o.headerRequestIdKey)
|
||||
}
|
||||
if m.o.contextRequestIdKey != "" {
|
||||
info["contextRequestId"], _ = r.Context().Value(m.o.contextRequestIdKey).(string)
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
func (m *Middleware) buildNormalLogMessage(info map[string]any) string {
|
||||
switch {
|
||||
case info["headerRequestId"] != nil && info["contextRequestId"] != nil:
|
||||
return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s - rid=%s ctx=%s",
|
||||
info["method"], info["path"], info["protocol"],
|
||||
info["status"], info["dataLength"],
|
||||
info["duration"],
|
||||
info["clientIP"], info["userAgent"], info["referer"],
|
||||
info["headerRequestId"], info["contextRequestId"])
|
||||
case info["headerRequestId"] != nil:
|
||||
return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s - rid=%s",
|
||||
info["method"], info["path"], info["protocol"],
|
||||
info["status"], info["dataLength"],
|
||||
info["duration"],
|
||||
info["clientIP"], info["userAgent"], info["referer"],
|
||||
info["headerRequestId"])
|
||||
case info["contextRequestId"] != nil:
|
||||
return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s - ctx=%s",
|
||||
info["method"], info["path"], info["protocol"],
|
||||
info["status"], info["dataLength"],
|
||||
info["duration"],
|
||||
info["clientIP"], info["userAgent"], info["referer"],
|
||||
info["contextRequestId"])
|
||||
default:
|
||||
return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s",
|
||||
info["method"], info["path"], info["protocol"],
|
||||
info["status"], info["dataLength"],
|
||||
info["duration"],
|
||||
info["clientIP"], info["userAgent"], info["referer"])
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) buildSlogMessageAndArguments(info map[string]any) (message string, args []any) {
|
||||
message = fmt.Sprintf("%s %s", info["method"], info["path"])
|
||||
|
||||
// Use a fixed order for the keys, so that the message is always the same.
|
||||
// Skip method and path as they are already in the message.
|
||||
keys := []string{
|
||||
"protocol",
|
||||
"status",
|
||||
"dataLength",
|
||||
"duration",
|
||||
"clientIP",
|
||||
"userAgent",
|
||||
"referer",
|
||||
"headerRequestId",
|
||||
"contextRequestId",
|
||||
}
|
||||
for _, k := range keys {
|
||||
if v, ok := info[k]; ok {
|
||||
args = append(args, k, v) // only add key, value if it exists
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Middleware) addPrefix(message string) string {
|
||||
if m.o.prefix != "" {
|
||||
return m.o.prefix + " " + message
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
func (m *Middleware) logMsg(message string, args ...any) {
|
||||
message = m.addPrefix(message)
|
||||
|
||||
if m.o.logger != nil {
|
||||
switch m.o.logLevel {
|
||||
case LogLevelDebug:
|
||||
m.o.logger.Debugf(message, args...)
|
||||
case LogLevelInfo:
|
||||
m.o.logger.Infof(message, args...)
|
||||
case LogLevelWarn:
|
||||
m.o.logger.Warnf(message, args...)
|
||||
case LogLevelError:
|
||||
m.o.logger.Errorf(message, args...)
|
||||
default:
|
||||
m.o.logger.Infof(message, args...)
|
||||
}
|
||||
} else {
|
||||
switch m.o.logLevel {
|
||||
case LogLevelDebug:
|
||||
slog.Debug(message, args...)
|
||||
case LogLevelInfo:
|
||||
slog.Info(message, args...)
|
||||
case LogLevelWarn:
|
||||
slog.Warn(message, args...)
|
||||
case LogLevelError:
|
||||
slog.Error(message, args...)
|
||||
default:
|
||||
slog.Info(message, args...)
|
||||
}
|
||||
}
|
||||
}
|
148
internal/app/api/core/middleware/logging/middleware_test.go
Normal file
148
internal/app/api/core/middleware/logging/middleware_test.go
Normal file
@ -0,0 +1,148 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockLogger struct {
|
||||
messages []string
|
||||
}
|
||||
|
||||
func (m *mockLogger) Debugf(format string, _ ...any) {
|
||||
m.messages = append(m.messages, "DEBUG: "+format)
|
||||
}
|
||||
func (m *mockLogger) Infof(format string, _ ...any) {
|
||||
m.messages = append(m.messages, "INFO: "+format)
|
||||
}
|
||||
func (m *mockLogger) Warnf(format string, _ ...any) {
|
||||
m.messages = append(m.messages, "WARN: "+format)
|
||||
}
|
||||
func (m *mockLogger) Errorf(format string, _ ...any) {
|
||||
m.messages = append(m.messages, "ERROR: "+format)
|
||||
}
|
||||
|
||||
func TestMiddleware_Normal(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
_, _ = w.Write([]byte("Hello, World!"))
|
||||
})
|
||||
|
||||
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(rr, req)
|
||||
|
||||
if status := rr.Code; status != http.StatusTeapot {
|
||||
t.Errorf("expected status code to be %v, got %v", http.StatusTeapot, status)
|
||||
}
|
||||
|
||||
expected := "Hello, World!"
|
||||
if rr.Body.String() != expected {
|
||||
t.Errorf("expected response body to be %v, got %v", expected, rr.Body.String())
|
||||
}
|
||||
|
||||
if len(logger.messages) == 0 {
|
||||
t.Errorf("expected log messages, got none")
|
||||
}
|
||||
|
||||
if len(logger.messages) != 0 && !strings.Contains(logger.messages[0], "ERROR: GET /foo") {
|
||||
t.Errorf("expected log message to contain request info, got %v", logger.messages[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_Extended(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
_, _ = w.Write([]byte("Hello, World!"))
|
||||
})
|
||||
|
||||
middleware := New(WithContextRequestIdKey("requestId"), WithHeaderRequestIdKey("X-Request-Id")).
|
||||
Handler(handler)
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(rr, req)
|
||||
|
||||
if status := rr.Code; status != http.StatusTeapot {
|
||||
t.Errorf("expected status code to be %v, got %v", http.StatusTeapot, status)
|
||||
}
|
||||
|
||||
expected := "Hello, World!"
|
||||
if rr.Body.String() != expected {
|
||||
t.Errorf("expected response body to be %v, got %v", expected, rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_Logger_remoteAddr(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
_, _ = w.Write([]byte("Hello, World!"))
|
||||
})
|
||||
|
||||
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
req.RemoteAddr = "xhamster.com:1234"
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(rr, req)
|
||||
|
||||
}
|
||||
|
||||
func TestMiddleware_Logger_remoteAddrNoPort(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
_, _ = w.Write([]byte("Hello, World!"))
|
||||
})
|
||||
|
||||
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
req.RemoteAddr = "xhamster.com"
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(rr, req)
|
||||
|
||||
}
|
||||
|
||||
func TestMiddleware_Logger_remoteAddrV6(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
_, _ = w.Write([]byte("Hello, World!"))
|
||||
})
|
||||
|
||||
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
req.RemoteAddr = "[::1]:4711"
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(rr, req)
|
||||
|
||||
}
|
||||
|
||||
func TestMiddleware_Logger_remoteAddrV6NoPort(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
_, _ = w.Write([]byte("Hello, World!"))
|
||||
})
|
||||
|
||||
middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler)
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
req.RemoteAddr = "[::1]"
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(rr, req)
|
||||
|
||||
}
|
80
internal/app/api/core/middleware/logging/options.go
Normal file
80
internal/app/api/core/middleware/logging/options.go
Normal file
@ -0,0 +1,80 @@
|
||||
package logging
|
||||
|
||||
// options is a struct that contains options for the logging middleware.
|
||||
// It uses the functional options pattern for flexible configuration.
|
||||
type options struct {
|
||||
logLevel LogLevel
|
||||
logger Logger
|
||||
prefix string
|
||||
|
||||
contextRequestIdKey string
|
||||
headerRequestIdKey string
|
||||
}
|
||||
|
||||
// Option is a type that is used to set options for the logging middleware.
|
||||
// It implements the functional options pattern.
|
||||
type Option func(*options)
|
||||
|
||||
// WithLevel is a method that sets the log level for the logging middleware.
|
||||
// Possible values are LogLevelDebug, LogLevelInfo, LogLevelWarn, and LogLevelError.
|
||||
// The default value is LogLevelInfo.
|
||||
func WithLevel(level LogLevel) Option {
|
||||
return func(o *options) {
|
||||
o.logLevel = level
|
||||
}
|
||||
}
|
||||
|
||||
// WithPrefix is a method that sets the prefix for the logging middleware.
|
||||
// If a prefix is set, it will be prepended to each log message. A space will
|
||||
// be added between the prefix and the log message.
|
||||
// The default value is an empty string.
|
||||
func WithPrefix(prefix string) Option {
|
||||
return func(o *options) {
|
||||
o.prefix = prefix
|
||||
}
|
||||
}
|
||||
|
||||
// WithContextRequestIdKey is a method that sets the key for the request ID in the
|
||||
// request context. If a key is set, the logging middleware will use this key to
|
||||
// retrieve the request ID from the request context.
|
||||
// The default value is an empty string, meaning the request ID will not be logged.
|
||||
func WithContextRequestIdKey(key string) Option {
|
||||
return func(o *options) {
|
||||
o.contextRequestIdKey = key
|
||||
}
|
||||
}
|
||||
|
||||
// WithHeaderRequestIdKey is a method that sets the key for the request ID in the
|
||||
// request headers. If a key is set, the logging middleware will use this key to
|
||||
// retrieve the request ID from the request headers.
|
||||
// The default value is an empty string, meaning the request ID will not be logged.
|
||||
func WithHeaderRequestIdKey(key string) Option {
|
||||
return func(o *options) {
|
||||
o.headerRequestIdKey = key
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger is a method that sets the logger for the logging middleware.
|
||||
// If a logger is set, the logging middleware will use this logger to log messages.
|
||||
// The default logger is the structured slog logger.
|
||||
func WithLogger(logger Logger) Option {
|
||||
return func(o *options) {
|
||||
o.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// newOptions is a function that returns a new options struct with sane default values.
|
||||
func newOptions(opts ...Option) options {
|
||||
o := options{
|
||||
logLevel: LogLevelInfo,
|
||||
logger: nil,
|
||||
prefix: "",
|
||||
contextRequestIdKey: "",
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&o)
|
||||
}
|
||||
|
||||
return o
|
||||
}
|
88
internal/app/api/core/middleware/logging/options_test.go
Normal file
88
internal/app/api/core/middleware/logging/options_test.go
Normal file
@ -0,0 +1,88 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWithLevel(t *testing.T) {
|
||||
// table test to check all possible log levels
|
||||
levels := []LogLevel{
|
||||
LogLevelDebug,
|
||||
LogLevelInfo,
|
||||
LogLevelWarn,
|
||||
LogLevelError,
|
||||
}
|
||||
|
||||
for _, level := range levels {
|
||||
opt := WithLevel(level)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.logLevel != level {
|
||||
t.Errorf("expected log level to be %v, got %v", level, o.logLevel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithPrefix(t *testing.T) {
|
||||
prefix := "TEST"
|
||||
opt := WithPrefix(prefix)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.prefix != prefix {
|
||||
t.Errorf("expected prefix to be %v, got %v", prefix, o.prefix)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithContextRequestIdKey(t *testing.T) {
|
||||
key := "contextKey"
|
||||
opt := WithContextRequestIdKey(key)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.contextRequestIdKey != key {
|
||||
t.Errorf("expected contextRequestIdKey to be %v, got %v", key, o.contextRequestIdKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithHeaderRequestIdKey(t *testing.T) {
|
||||
key := "headerKey"
|
||||
opt := WithHeaderRequestIdKey(key)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.headerRequestIdKey != key {
|
||||
t.Errorf("expected headerRequestIdKey to be %v, got %v", key, o.headerRequestIdKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithLogger(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
opt := WithLogger(logger)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.logger != logger {
|
||||
t.Errorf("expected logger to be %v, got %v", logger, o.logger)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaults(t *testing.T) {
|
||||
o := newOptions()
|
||||
|
||||
if o.logLevel != LogLevelInfo {
|
||||
t.Errorf("expected log level to be %v, got %v", LogLevelInfo, o.logLevel)
|
||||
}
|
||||
|
||||
if o.logger != nil {
|
||||
t.Errorf("expected logger to be nil, got %v", o.logger)
|
||||
}
|
||||
|
||||
if o.prefix != "" {
|
||||
t.Errorf("expected prefix to be empty, got %v", o.prefix)
|
||||
}
|
||||
|
||||
if o.contextRequestIdKey != "" {
|
||||
t.Errorf("expected contextRequestIdKey to be empty, got %v", o.contextRequestIdKey)
|
||||
}
|
||||
|
||||
if o.headerRequestIdKey != "" {
|
||||
t.Errorf("expected headerRequestIdKey to be empty, got %v", o.headerRequestIdKey)
|
||||
}
|
||||
}
|
45
internal/app/api/core/middleware/logging/writer.go
Normal file
45
internal/app/api/core/middleware/logging/writer.go
Normal file
@ -0,0 +1,45 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// writerWrapper wraps a http.ResponseWriter and tracks the number of bytes written to it.
|
||||
// It also tracks the http response code passed to the WriteHeader func of
|
||||
// the ResponseWriter.
|
||||
type writerWrapper struct {
|
||||
http.ResponseWriter
|
||||
|
||||
// StatusCode is the last http response code passed to the WriteHeader func of
|
||||
// the ResponseWriter. If no such call is made, a default code of http.StatusOK
|
||||
// is assumed instead.
|
||||
StatusCode int
|
||||
|
||||
// WrittenBytes is the number of bytes successfully written by the Write or
|
||||
// ReadFrom function of the ResponseWriter. ResponseWriters may also write
|
||||
// data to their underlaying connection directly (e.g. headers), but those
|
||||
// are not tracked. Therefor the number of Written bytes will usually match
|
||||
// the size of the response body.
|
||||
WrittenBytes int64
|
||||
}
|
||||
|
||||
// WriteHeader wraps the WriteHeader method of the ResponseWriter and tracks the
|
||||
// http response code passed to it.
|
||||
func (w *writerWrapper) WriteHeader(code int) {
|
||||
w.StatusCode = code
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// Write wraps the Write method of the ResponseWriter and tracks the number of bytes
|
||||
// written to it.
|
||||
func (w *writerWrapper) Write(data []byte) (int, error) {
|
||||
n, err := w.ResponseWriter.Write(data)
|
||||
w.WrittenBytes += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// newWriterWrapper returns a new writerWrapper that wraps the given http.ResponseWriter.
|
||||
// It initializes the StatusCode to http.StatusOK.
|
||||
func newWriterWrapper(w http.ResponseWriter) *writerWrapper {
|
||||
return &writerWrapper{ResponseWriter: w, StatusCode: http.StatusOK}
|
||||
}
|
85
internal/app/api/core/middleware/logging/writer_test.go
Normal file
85
internal/app/api/core/middleware/logging/writer_test.go
Normal file
@ -0,0 +1,85 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWriterWrapper_WriteHeader(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
ww := newWriterWrapper(rr)
|
||||
|
||||
ww.WriteHeader(http.StatusNotFound)
|
||||
|
||||
if ww.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("expected status code to be %v, got %v", http.StatusNotFound, ww.StatusCode)
|
||||
}
|
||||
if rr.Code != http.StatusNotFound {
|
||||
t.Errorf("expected recorder status code to be %v, got %v", http.StatusNotFound, rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriterWrapper_Write(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
ww := newWriterWrapper(rr)
|
||||
|
||||
data := []byte("Hello, World!")
|
||||
n, err := ww.Write(data)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("expected written bytes to be %v, got %v", len(data), n)
|
||||
}
|
||||
if ww.WrittenBytes != int64(len(data)) {
|
||||
t.Errorf("expected WrittenBytes to be %v, got %v", len(data), ww.WrittenBytes)
|
||||
}
|
||||
if rr.Body.String() != string(data) {
|
||||
t.Errorf("expected response body to be %v, got %v", string(data), rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriterWrapper_WriteWithHeaders(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
ww := newWriterWrapper(rr)
|
||||
|
||||
data := []byte("Hello, World!")
|
||||
n, err := ww.Write(data)
|
||||
|
||||
ww.Header().Set("Content-Type", "text/plain")
|
||||
ww.Header().Set("X-Some-Header", "some-value")
|
||||
ww.WriteHeader(http.StatusTeapot)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("expected written bytes to be %v, got %v", len(data), n)
|
||||
}
|
||||
if ww.WrittenBytes != int64(len(data)) {
|
||||
t.Errorf("expected WrittenBytes to be %v, got %v", len(data), ww.WrittenBytes)
|
||||
}
|
||||
if rr.Body.String() != string(data) {
|
||||
t.Errorf("expected response body to be %v, got %v", string(data), rr.Body.String())
|
||||
}
|
||||
if ww.StatusCode != http.StatusTeapot {
|
||||
t.Errorf("expected status code to be %v, got %v", http.StatusTeapot, ww.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWriterWrapper(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
ww := newWriterWrapper(rr)
|
||||
|
||||
if ww.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected initial status code to be %v, got %v", http.StatusOK, ww.StatusCode)
|
||||
}
|
||||
if ww.WrittenBytes != 0 {
|
||||
t.Errorf("expected initial WrittenBytes to be %v, got %v", 0, ww.WrittenBytes)
|
||||
}
|
||||
if ww.ResponseWriter != rr {
|
||||
t.Errorf("expected ResponseWriter to be %v, got %v", rr, ww.ResponseWriter)
|
||||
}
|
||||
}
|
133
internal/app/api/core/middleware/recovery/middleware.go
Normal file
133
internal/app/api/core/middleware/recovery/middleware.go
Normal file
@ -0,0 +1,133 @@
|
||||
package recovery
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Logger is an interface that defines the methods that a logger must implement.
|
||||
// This allows the logging middleware to be used with different logging libraries.
|
||||
type Logger interface {
|
||||
// Errorf logs a message at error level.
|
||||
Errorf(format string, args ...any)
|
||||
}
|
||||
|
||||
// Middleware is a type that creates a new recovery middleware. The recovery middleware
|
||||
// recovers from panics and returns an Internal Server Error response. This middleware should
|
||||
// be the first middleware in the middleware chain, so that it can recover from panics in other
|
||||
// middlewares.
|
||||
type Middleware struct {
|
||||
o options
|
||||
}
|
||||
|
||||
// New returns a new recovery middleware with the provided options.
|
||||
func New(opts ...Option) *Middleware {
|
||||
o := newOptions(opts...)
|
||||
|
||||
m := &Middleware{
|
||||
o: o,
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Handler returns the recovery middleware handler.
|
||||
func (m *Middleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
stack := debug.Stack()
|
||||
|
||||
realErr, ok := err.(error)
|
||||
if !ok {
|
||||
realErr = fmt.Errorf("%v", err)
|
||||
}
|
||||
|
||||
// Check for a broken connection, as it is not really a
|
||||
// condition that warrants a panic stack trace.
|
||||
brokenPipe := isBrokenPipeError(realErr)
|
||||
|
||||
// Log the error and stack trace
|
||||
if m.o.logCallback != nil {
|
||||
m.o.logCallback(realErr, stack, brokenPipe)
|
||||
}
|
||||
|
||||
switch {
|
||||
case brokenPipe && m.o.brokenPipeCallback != nil:
|
||||
m.o.brokenPipeCallback(realErr, stack, w, r)
|
||||
case !brokenPipe && m.o.errCallback != nil:
|
||||
m.o.errCallback(realErr, stack, w, r)
|
||||
default:
|
||||
// no callback set, simply recover and do nothing...
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func addPrefix(o options, message string) string {
|
||||
if o.defaultLogPrefix != "" {
|
||||
return o.defaultLogPrefix + " " + message
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
// defaultErrCallback is the default error callback function for the recovery middleware.
|
||||
// It writes a JSON response with an Internal Server Error status code. If the exposeStackTrace option is
|
||||
// enabled, the stack trace is included in the response.
|
||||
func getDefaultErrCallback(o options) func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {
|
||||
return func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {
|
||||
responseBody := map[string]interface{}{
|
||||
"error": "Internal Server Error",
|
||||
}
|
||||
if o.exposeStackTrace && len(stack) > 0 {
|
||||
responseBody["stack"] = string(stack)
|
||||
}
|
||||
|
||||
jsonBody, _ := json.Marshal(responseBody)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write(jsonBody)
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultLogCallback is the default log callback function for the recovery middleware.
|
||||
// It logs the error and stack trace using the structured slog logger or the provided logger in Error level.
|
||||
func getDefaultLogCallback(o options) func(error, []byte, bool) {
|
||||
return func(err error, stack []byte, brokenPipe bool) {
|
||||
if brokenPipe {
|
||||
return // by default, ignore broken pipe errors
|
||||
}
|
||||
|
||||
switch {
|
||||
case o.useSlog:
|
||||
slog.Error(addPrefix(o, err.Error()), "stack", string(stack))
|
||||
case o.logger != nil:
|
||||
o.logger.Errorf(fmt.Sprintf("%s; stacktrace=%s", addPrefix(o, err.Error()), string(stack)))
|
||||
default:
|
||||
// no logger set, do nothing...
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isBrokenPipeError(err error) bool {
|
||||
var syscallErr *os.SyscallError
|
||||
if errors.As(err, &syscallErr) {
|
||||
errMsg := strings.ToLower(syscallErr.Err.Error())
|
||||
if strings.Contains(errMsg, "broken pipe") ||
|
||||
strings.Contains(errMsg, "connection reset by peer") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
149
internal/app/api/core/middleware/recovery/middleware_test.go
Normal file
149
internal/app/api/core/middleware/recovery/middleware_test.go
Normal file
@ -0,0 +1,149 @@
|
||||
package recovery
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockLogger struct{}
|
||||
|
||||
func (m *mockLogger) Errorf(_ string, _ ...any) {}
|
||||
|
||||
func TestMiddleware(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
options []Option
|
||||
panicSimulator func()
|
||||
expectedStatus int
|
||||
expectedBody string
|
||||
expectStack bool
|
||||
}{
|
||||
{
|
||||
name: "default behavior",
|
||||
options: []Option{},
|
||||
panicSimulator: func() {
|
||||
panic(errors.New("test panic"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: `{"error":"Internal Server Error"}`,
|
||||
},
|
||||
{
|
||||
name: "custom error callback",
|
||||
options: []Option{
|
||||
WithErrCallback(func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
w.Write([]byte("custom error"))
|
||||
}),
|
||||
},
|
||||
panicSimulator: func() {
|
||||
panic(errors.New("test panic"))
|
||||
},
|
||||
expectedStatus: http.StatusTeapot,
|
||||
expectedBody: "custom error",
|
||||
},
|
||||
{
|
||||
name: "broken pipe error",
|
||||
options: []Option{
|
||||
WithBrokenPipeCallback(func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
w.Write([]byte("broken pipe"))
|
||||
}),
|
||||
},
|
||||
panicSimulator: func() {
|
||||
panic(&os.SyscallError{Err: errors.New("broken pipe")})
|
||||
},
|
||||
expectedStatus: http.StatusServiceUnavailable,
|
||||
expectedBody: "broken pipe",
|
||||
},
|
||||
{
|
||||
name: "default callback broken pipe error",
|
||||
options: nil,
|
||||
panicSimulator: func() {
|
||||
panic(&os.SyscallError{Err: errors.New("broken pipe")})
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: "",
|
||||
},
|
||||
{
|
||||
name: "default callback normal error",
|
||||
options: nil,
|
||||
panicSimulator: func() {
|
||||
panic("something went wrong")
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: "{\"error\":\"Internal Server Error\"}",
|
||||
},
|
||||
{
|
||||
name: "default callback with stack trace",
|
||||
options: []Option{
|
||||
WithExposeStackTrace(true),
|
||||
},
|
||||
panicSimulator: func() {
|
||||
panic("something went wrong")
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: "\"stack\":",
|
||||
expectStack: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
handler := New(tt.options...).Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tt.panicSimulator()
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != tt.expectedStatus {
|
||||
t.Errorf("expected status %v, got %v", tt.expectedStatus, rr.Code)
|
||||
}
|
||||
if !tt.expectStack && rr.Body.String() != tt.expectedBody {
|
||||
t.Errorf("expected body %v, got %v", tt.expectedBody, rr.Body.String())
|
||||
}
|
||||
if tt.expectStack && !strings.Contains(rr.Body.String(), tt.expectedBody) {
|
||||
t.Errorf("expected body to contain %v, got %v", tt.expectedBody, rr.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBrokenPipeError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "broken pipe error",
|
||||
err: &os.SyscallError{Err: errors.New("broken pipe")},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "connection reset by peer error",
|
||||
err: &os.SyscallError{Err: errors.New("connection reset by peer")},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
err: errors.New("other error"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isBrokenPipeError(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
129
internal/app/api/core/middleware/recovery/options.go
Normal file
129
internal/app/api/core/middleware/recovery/options.go
Normal file
@ -0,0 +1,129 @@
|
||||
package recovery
|
||||
|
||||
import "net/http"
|
||||
|
||||
// options is a struct that contains options for the recovery middleware.
|
||||
// It uses the functional options pattern for flexible configuration.
|
||||
type options struct {
|
||||
logger Logger
|
||||
useSlog bool
|
||||
|
||||
errCallbackOverride bool
|
||||
errCallback func(err error, stack []byte, w http.ResponseWriter, r *http.Request)
|
||||
brokenPipeCallbackOverride bool
|
||||
brokenPipeCallback func(err error, stack []byte, w http.ResponseWriter, r *http.Request)
|
||||
|
||||
exposeStackTrace bool
|
||||
defaultLogPrefix string
|
||||
logCallbackOverride bool
|
||||
logCallback func(err error, stack []byte, brokenPipe bool)
|
||||
}
|
||||
|
||||
// Option is a type that is used to set options for the recovery middleware.
|
||||
// It implements the functional options pattern.
|
||||
type Option func(*options)
|
||||
|
||||
// WithErrCallback sets the error callback function for the recovery middleware.
|
||||
// The error callback function is called when a panic is recovered by the middleware.
|
||||
// This function completely overrides the default behavior of the middleware. It is the
|
||||
// responsibility of the user to handle the error and write a response to the client.
|
||||
//
|
||||
// Ensure that this function does not panic, as it will be called in a deferred function!
|
||||
func WithErrCallback(fn func(err error, stack []byte, w http.ResponseWriter, r *http.Request)) Option {
|
||||
return func(o *options) {
|
||||
o.errCallback = fn
|
||||
o.errCallbackOverride = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithBrokenPipeCallback sets the broken pipe callback function for the recovery middleware.
|
||||
// The broken pipe callback function is called when a broken pipe error is recovered by the middleware.
|
||||
// This function completely overrides the default behavior of the middleware. It is the responsibility
|
||||
// of the user to handle the error and write a response to the client.
|
||||
//
|
||||
// Ensure that this function does not panic, as it will be called in a deferred function!
|
||||
func WithBrokenPipeCallback(fn func(err error, stack []byte, w http.ResponseWriter, r *http.Request)) Option {
|
||||
return func(o *options) {
|
||||
o.brokenPipeCallback = fn
|
||||
o.brokenPipeCallbackOverride = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogCallback sets the log callback function for the recovery middleware.
|
||||
// The log callback function is called when a panic is recovered by the middleware.
|
||||
// This function allows the user to log the error and stack trace. The default behavior is to log
|
||||
// the error and stack trace in Error level.
|
||||
// This function completely overrides the default behavior of the middleware.
|
||||
//
|
||||
// Ensure that this function does not panic, as it will be called in a deferred function!
|
||||
func WithLogCallback(fn func(err error, stack []byte, brokenPipe bool)) Option {
|
||||
return func(o *options) {
|
||||
o.logCallback = fn
|
||||
o.logCallbackOverride = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger is a method that sets the logger for the logging middleware.
|
||||
// If a logger is set, the logging middleware will use this logger to log messages.
|
||||
// The default logger is the structured slog logger, see WithSlog.
|
||||
func WithLogger(logger Logger) Option {
|
||||
return func(o *options) {
|
||||
o.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// WithSlog is a method that sets whether the recovery middleware should use the structured slog logger.
|
||||
// If set to true, the middleware will use the structured slog logger. If set to false, the middleware
|
||||
// will not use any logger unless one is explicitly set with the WithLogger option.
|
||||
// The default value is true.
|
||||
func WithSlog(useSlog bool) Option {
|
||||
return func(o *options) {
|
||||
o.useSlog = useSlog
|
||||
}
|
||||
}
|
||||
|
||||
// WithDefaultLogPrefix is a method that sets the default log prefix for the recovery middleware.
|
||||
// If a default log prefix is set and the default log callback is used, the prefix will be prepended
|
||||
// to each log message. A space will be added between the prefix and the log message.
|
||||
// The default value is an empty string.
|
||||
func WithDefaultLogPrefix(defaultLogPrefix string) Option {
|
||||
return func(o *options) {
|
||||
o.defaultLogPrefix = defaultLogPrefix
|
||||
}
|
||||
}
|
||||
|
||||
// WithExposeStackTrace is a method that sets whether the stack trace should be exposed in the response.
|
||||
// If set to true, the stack trace will be included in the response body. If set to false, the stack trace
|
||||
// will not be included in the response body. This only applies to the default error callback.
|
||||
// The default value is false.
|
||||
func WithExposeStackTrace(exposeStackTrace bool) Option {
|
||||
return func(o *options) {
|
||||
o.exposeStackTrace = exposeStackTrace
|
||||
}
|
||||
}
|
||||
|
||||
// newOptions is a function that returns a new options struct with sane default values.
|
||||
func newOptions(opts ...Option) options {
|
||||
o := options{
|
||||
logger: nil,
|
||||
useSlog: true,
|
||||
errCallback: nil,
|
||||
brokenPipeCallback: nil, // by default, ignore broken pipe errors
|
||||
exposeStackTrace: false,
|
||||
defaultLogPrefix: "",
|
||||
logCallback: nil,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&o)
|
||||
}
|
||||
|
||||
if o.errCallback == nil && !o.errCallbackOverride {
|
||||
o.errCallback = getDefaultErrCallback(o)
|
||||
}
|
||||
if o.logCallback == nil && !o.logCallbackOverride {
|
||||
o.logCallback = getDefaultLogCallback(o)
|
||||
}
|
||||
|
||||
return o
|
||||
}
|
100
internal/app/api/core/middleware/recovery/options_test.go
Normal file
100
internal/app/api/core/middleware/recovery/options_test.go
Normal file
@ -0,0 +1,100 @@
|
||||
package recovery
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWithErrCallback(t *testing.T) {
|
||||
callback := func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {}
|
||||
opt := WithErrCallback(callback)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.errCallback == nil {
|
||||
t.Errorf("expected errCallback to be set, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithBrokenPipeCallback(t *testing.T) {
|
||||
callback := func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {}
|
||||
opt := WithBrokenPipeCallback(callback)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.brokenPipeCallback == nil {
|
||||
t.Errorf("expected brokenPipeCallback to be set, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithLogCallback(t *testing.T) {
|
||||
callback := func(err error, stack []byte, brokenPipe bool) {}
|
||||
opt := WithLogCallback(callback)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.logCallback == nil {
|
||||
t.Errorf("expected logCallback to be set, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithLogger(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
opt := WithLogger(logger)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.logger != logger {
|
||||
t.Errorf("expected logger to be %v, got %v", logger, o.logger)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithSlog(t *testing.T) {
|
||||
opt := WithSlog(false)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.useSlog != false {
|
||||
t.Errorf("expected useSlog to be false, got %v", o.useSlog)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithDefaultLogPrefix(t *testing.T) {
|
||||
prefix := "PREFIX"
|
||||
opt := WithDefaultLogPrefix(prefix)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.defaultLogPrefix != prefix {
|
||||
t.Errorf("expected defaultLogPrefix to be %v, got %v", prefix, o.defaultLogPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithExposeStackTrace(t *testing.T) {
|
||||
opt := WithExposeStackTrace(true)
|
||||
o := newOptions(opt)
|
||||
|
||||
if o.exposeStackTrace != true {
|
||||
t.Errorf("expected exposeStackTrace to be true, got %v", o.exposeStackTrace)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewOptionsDefaults(t *testing.T) {
|
||||
o := newOptions()
|
||||
|
||||
if o.logger != nil {
|
||||
t.Errorf("expected logger to be nil, got %v", o.logger)
|
||||
}
|
||||
if o.useSlog != true {
|
||||
t.Errorf("expected useSlog to be true, got %v", o.useSlog)
|
||||
}
|
||||
if o.errCallback == nil {
|
||||
t.Errorf("expected errCallback to be set, got nil")
|
||||
}
|
||||
if o.brokenPipeCallback != nil {
|
||||
t.Errorf("expected brokenPipeCallback to be nil, got %T", o.brokenPipeCallback)
|
||||
}
|
||||
if o.exposeStackTrace != false {
|
||||
t.Errorf("expected exposeStackTrace to be false, got %T", o.exposeStackTrace)
|
||||
}
|
||||
if o.defaultLogPrefix != "" {
|
||||
t.Errorf("expected defaultLogPrefix to be empty, got %T", o.defaultLogPrefix)
|
||||
}
|
||||
if o.logCallback == nil {
|
||||
t.Errorf("expected logCallback to be set, got nil")
|
||||
}
|
||||
}
|
69
internal/app/api/core/middleware/tracing/middleware.go
Normal file
69
internal/app/api/core/middleware/tracing/middleware.go
Normal file
@ -0,0 +1,69 @@
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Middleware is a type that creates a new tracing middleware. The tracing middleware
|
||||
// can be used to trace requests based on a request ID header or parameter.
|
||||
type Middleware struct {
|
||||
o options
|
||||
|
||||
seededRand *rand.Rand
|
||||
}
|
||||
|
||||
// New returns a new CORS middleware with the provided options.
|
||||
func New(opts ...Option) *Middleware {
|
||||
o := newOptions(opts...)
|
||||
|
||||
m := &Middleware{
|
||||
o: o,
|
||||
seededRand: rand.New(rand.NewSource(o.generateSeed)),
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Handler returns the tracing middleware handler.
|
||||
func (m *Middleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var reqId string
|
||||
|
||||
// read upstream header und re-use it
|
||||
if m.o.upstreamReqIdHeader != "" {
|
||||
reqId = r.Header.Get(m.o.upstreamReqIdHeader)
|
||||
}
|
||||
|
||||
// generate new id
|
||||
if reqId == "" && m.o.generateLength > 0 {
|
||||
reqId = m.generateRandomId()
|
||||
}
|
||||
|
||||
// set response header
|
||||
if m.o.headerIdentifier != "" {
|
||||
w.Header().Set(m.o.headerIdentifier, reqId)
|
||||
}
|
||||
|
||||
// set context value
|
||||
if m.o.contextIdentifier != "" {
|
||||
ctx := context.WithValue(r.Context(), m.o.contextIdentifier, reqId)
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r) // execute the next handler
|
||||
})
|
||||
}
|
||||
|
||||
// region internal-helpers
|
||||
|
||||
func (m *Middleware) generateRandomId() string {
|
||||
b := make([]byte, m.o.generateLength)
|
||||
for i := range b {
|
||||
b[i] = m.o.generateCharset[m.seededRand.Intn(len(m.o.generateCharset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// endregion internal-helpers
|
118
internal/app/api/core/middleware/tracing/middleware_test.go
Normal file
118
internal/app/api/core/middleware/tracing/middleware_test.go
Normal file
@ -0,0 +1,118 @@
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const defaultLength = 8
|
||||
const upstreamHeaderValue = "upstream-id"
|
||||
|
||||
func TestMiddleware_Handler_WithUpstreamHeader(t *testing.T) {
|
||||
m := New(WithUpstreamHeader("X-Upstream-Id"))
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqId := r.Header.Get("X-Upstream-Id")
|
||||
if reqId != upstreamHeaderValue {
|
||||
t.Errorf("expected upstream request id to be 'upstream-id', got %s", reqId)
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Upstream-Id", upstreamHeaderValue)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Header().Get("X-Request-Id") != upstreamHeaderValue {
|
||||
t.Errorf("expected X-Request-Id header to be set in the response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_GenerateNewId(t *testing.T) {
|
||||
idLen := 18
|
||||
m := New(WithIdLength(idLen))
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqId := w.Header().Get("X-Request-Id")
|
||||
if len(reqId) != 18 {
|
||||
t.Errorf("expected generated request id length to be %d, got %d", idLen, len(reqId))
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Header().Get("X-Request-Id") == "" || len(rr.Header().Get("X-Request-Id")) != idLen {
|
||||
t.Errorf("expected X-Request-Id header to be set in the response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_SetContextValue(t *testing.T) {
|
||||
m := New()
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqId := r.Context().Value("RequestId").(string)
|
||||
if reqId == "" || len(reqId) != defaultLength {
|
||||
t.Errorf("expected context request id to be set, got empty string")
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_SetCustomContextValue(t *testing.T) {
|
||||
m := New(WithContextIdentifier("Custom-Id"))
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqId := r.Context().Value("Custom-Id").(string)
|
||||
if reqId == "" || len(reqId) != defaultLength {
|
||||
t.Errorf("expected context request id to be set, got empty string")
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_NoIdGenerated(t *testing.T) {
|
||||
m := New(WithIdLength(0))
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqId := w.Header().Get("X-Request-Id")
|
||||
if reqId != "" {
|
||||
t.Errorf("expected no request id to be generated, got %s", reqId)
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_NoIdHeaderSet(t *testing.T) {
|
||||
m := New(WithHeaderIdentifier(""))
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqId := w.Header().Get("X-Request-Id")
|
||||
if reqId != "" {
|
||||
t.Errorf("expected no request id to be generated, got %s", reqId)
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
}
|
||||
|
||||
func TestMiddleware_Handler_NoIdContextSet(t *testing.T) {
|
||||
m := New(WithHeaderIdentifier(""))
|
||||
handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqId := r.Context().Value("Request-Id")
|
||||
if reqId != nil {
|
||||
t.Errorf("expected no context request id to be set, got %v", reqId)
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
}
|
85
internal/app/api/core/middleware/tracing/options.go
Normal file
85
internal/app/api/core/middleware/tracing/options.go
Normal file
@ -0,0 +1,85 @@
|
||||
package tracing
|
||||
|
||||
import "time"
|
||||
|
||||
// options is a struct that contains options for the tracing middleware.
|
||||
// It uses the functional options pattern for flexible configuration.
|
||||
type options struct {
|
||||
upstreamReqIdHeader string
|
||||
headerIdentifier string
|
||||
contextIdentifier string
|
||||
generateLength int
|
||||
generateCharset string
|
||||
generateSeed int64
|
||||
}
|
||||
|
||||
// Option is a type that is used to set options for the tracing middleware.
|
||||
// It implements the functional options pattern.
|
||||
type Option func(*options)
|
||||
|
||||
// WithIdSeed sets the seed for the random request id.
|
||||
// If no seed is provided, the current timestamp is used.
|
||||
func WithIdSeed(seed int64) Option {
|
||||
return func(o *options) {
|
||||
o.generateSeed = seed
|
||||
}
|
||||
}
|
||||
|
||||
// WithIdCharset sets the charset that is used to generate a random request id.
|
||||
// By default, upper-case letters and numbers are used.
|
||||
func WithIdCharset(charset string) Option {
|
||||
return func(o *options) {
|
||||
o.generateCharset = charset
|
||||
}
|
||||
}
|
||||
|
||||
// WithIdLength specifies the length of generated random ids.
|
||||
// By default, a length of 8 is used. If the length is 0, no request id will be generated.
|
||||
func WithIdLength(len int) Option {
|
||||
return func(o *options) {
|
||||
o.generateLength = len
|
||||
}
|
||||
}
|
||||
|
||||
// WithHeaderIdentifier specifies the header name for the request id that is added to the response headers.
|
||||
// If the identifier is empty, the request id will not be added to the response headers.
|
||||
func WithHeaderIdentifier(identifier string) Option {
|
||||
return func(o *options) {
|
||||
o.headerIdentifier = identifier
|
||||
}
|
||||
}
|
||||
|
||||
// WithUpstreamHeader sets the upstream header name, that should be used to fetch the request id.
|
||||
// If no upstream header is found, a random id will be generated if the id-length parameter is set to a value > 0.
|
||||
func WithUpstreamHeader(header string) Option {
|
||||
return func(o *options) {
|
||||
o.upstreamReqIdHeader = header
|
||||
}
|
||||
}
|
||||
|
||||
// WithContextIdentifier specifies the value-key for the request id that is added to the request context.
|
||||
// If the identifier is empty, the request id will not be added to the context.
|
||||
// If the request id is added to the context, it can be retrieved with:
|
||||
// `id := r.Context().Value(THE-IDENTIFIER).(string)`
|
||||
func WithContextIdentifier(identifier string) Option {
|
||||
return func(o *options) {
|
||||
o.contextIdentifier = identifier
|
||||
}
|
||||
}
|
||||
|
||||
// newOptions is a function that returns a new options struct with sane default values.
|
||||
func newOptions(opts ...Option) options {
|
||||
o := options{
|
||||
headerIdentifier: "X-Request-Id",
|
||||
contextIdentifier: "RequestId",
|
||||
generateSeed: time.Now().UnixNano(),
|
||||
generateCharset: "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789",
|
||||
generateLength: 8,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&o)
|
||||
}
|
||||
|
||||
return o
|
||||
}
|
75
internal/app/api/core/middleware/tracing/options_test.go
Normal file
75
internal/app/api/core/middleware/tracing/options_test.go
Normal file
@ -0,0 +1,75 @@
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWithIdSeed(t *testing.T) {
|
||||
o := newOptions(WithIdSeed(12345))
|
||||
if o.generateSeed != 12345 {
|
||||
t.Errorf("expected generateSeed to be 12345, got %d", o.generateSeed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithIdCharset(t *testing.T) {
|
||||
o := newOptions(WithIdCharset("abc123"))
|
||||
if o.generateCharset != "abc123" {
|
||||
t.Errorf("expected generateCharset to be 'abc123', got %s", o.generateCharset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithIdLength(t *testing.T) {
|
||||
o := newOptions(WithIdLength(16))
|
||||
if o.generateLength != 16 {
|
||||
t.Errorf("expected generateLength to be 16, got %d", o.generateLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithHeaderIdentifier(t *testing.T) {
|
||||
o := newOptions(WithHeaderIdentifier("X-Custom-Id"))
|
||||
if o.headerIdentifier != "X-Custom-Id" {
|
||||
t.Errorf("expected headerIdentifier to be 'X-Custom-Id', got %s", o.headerIdentifier)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithUpstreamHeader(t *testing.T) {
|
||||
o := newOptions(WithUpstreamHeader("X-Upstream-Id"))
|
||||
if o.upstreamReqIdHeader != "X-Upstream-Id" {
|
||||
t.Errorf("expected upstreamReqIdHeader to be 'X-Upstream-Id', got %s", o.upstreamReqIdHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithContextIdentifier(t *testing.T) {
|
||||
o := newOptions(WithContextIdentifier("Request-Id"))
|
||||
if o.contextIdentifier != "Request-Id" {
|
||||
t.Errorf("expected contextIdentifier to be 'Request-Id', got %s", o.contextIdentifier)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaults(t *testing.T) {
|
||||
o := newOptions()
|
||||
|
||||
if o.generateLength != 8 {
|
||||
t.Errorf("expected generateLength to be 8, got %d", o.generateLength)
|
||||
}
|
||||
|
||||
if o.generateCharset != "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" {
|
||||
t.Errorf("expected generateCharset to be 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', got %s", o.generateCharset)
|
||||
}
|
||||
|
||||
if o.generateSeed == 0 {
|
||||
t.Errorf("expected generateSeed to be non-zero")
|
||||
}
|
||||
|
||||
if o.headerIdentifier != "X-Request-Id" {
|
||||
t.Errorf("expected headerIdentifier to be 'X-Request-Id', got %s", o.headerIdentifier)
|
||||
}
|
||||
|
||||
if o.upstreamReqIdHeader != "" {
|
||||
t.Errorf("expected upstreamReqIdHeader to be empty, got %s", o.upstreamReqIdHeader)
|
||||
}
|
||||
|
||||
if o.contextIdentifier != "RequestId" {
|
||||
t.Errorf("expected contextIdentifier to be 'RequestId', got %s", o.contextIdentifier)
|
||||
}
|
||||
}
|
259
internal/app/api/core/request/basic.go
Normal file
259
internal/app/api/core/request/basic.go
Normal file
@ -0,0 +1,259 @@
|
||||
// Package request provides functions to extract parameters from the request.
|
||||
package request
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const CheckPrivateProxy = "PRIVATE"
|
||||
|
||||
// PathRaw returns the value of the named path parameter.
|
||||
func PathRaw(r *http.Request, name string) string {
|
||||
return r.PathValue(name)
|
||||
}
|
||||
|
||||
// Path returns the value of the named path parameter.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func Path(r *http.Request, name string) string {
|
||||
return strings.TrimSpace(PathRaw(r, name))
|
||||
}
|
||||
|
||||
// PathDefault returns the value of the named path parameter.
|
||||
// If the parameter is empty, it returns the default value.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func PathDefault(r *http.Request, name string, defaultValue string) string {
|
||||
value := r.PathValue(name)
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return Path(r, name)
|
||||
}
|
||||
|
||||
// QueryRaw returns the value of the named query parameter.
|
||||
func QueryRaw(r *http.Request, name string) string {
|
||||
return r.URL.Query().Get(name)
|
||||
}
|
||||
|
||||
// Query returns the value of the named query parameter.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func Query(r *http.Request, name string) string {
|
||||
return strings.TrimSpace(QueryRaw(r, name))
|
||||
}
|
||||
|
||||
// QueryDefault returns the value of the named query parameter.
|
||||
// If the parameter is empty, it returns the default value.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func QueryDefault(r *http.Request, name string, defaultValue string) string {
|
||||
if !r.URL.Query().Has(name) {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return Query(r, name)
|
||||
}
|
||||
|
||||
// QuerySlice returns the value of the named query parameter.
|
||||
// All slice values are trimmed of leading and trailing whitespace.
|
||||
func QuerySlice(r *http.Request, name string) []string {
|
||||
values, ok := r.URL.Query()[name]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]string, len(values))
|
||||
for i, value := range values {
|
||||
result[i] = strings.TrimSpace(value)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// QuerySliceDefault returns the value of the named query parameter.
|
||||
// If the parameter is empty, it returns the default value.
|
||||
// All slice values are trimmed of leading and trailing whitespace.
|
||||
func QuerySliceDefault(r *http.Request, name string, defaultValue []string) []string {
|
||||
if !r.URL.Query().Has(name) {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return QuerySlice(r, name)
|
||||
}
|
||||
|
||||
// FragmentRaw returns the value of the named fragment parameter.
|
||||
func FragmentRaw(r *http.Request) string {
|
||||
return r.URL.Fragment
|
||||
}
|
||||
|
||||
// Fragment returns the value of the named fragment parameter.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func Fragment(r *http.Request) string {
|
||||
return strings.TrimSpace(FragmentRaw(r))
|
||||
}
|
||||
|
||||
// FragmentDefault returns the value of the named fragment parameter.
|
||||
// If the parameter is empty, it returns the default value.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func FragmentDefault(r *http.Request, defaultValue string) string {
|
||||
if r.URL.Fragment == "" {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return Fragment(r)
|
||||
}
|
||||
|
||||
// FormRaw returns the value of the named form parameter.
|
||||
func FormRaw(r *http.Request, name string) string {
|
||||
return r.FormValue(name)
|
||||
}
|
||||
|
||||
// Form returns the value of the named form parameter.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func Form(r *http.Request, name string) string {
|
||||
return strings.TrimSpace(FormRaw(r, name))
|
||||
}
|
||||
|
||||
// DefaultForm returns the value of the named form parameter.
|
||||
// If the parameter is not set, it returns the default value.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func DefaultForm(r *http.Request, name, defaultValue string) string {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
if !r.Form.Has(name) {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return Form(r, name)
|
||||
}
|
||||
|
||||
// HeaderRaw returns the value of the named header.
|
||||
func HeaderRaw(r *http.Request, name string) string {
|
||||
return r.Header.Get(name)
|
||||
}
|
||||
|
||||
// Header returns the value of the named header.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func Header(r *http.Request, name string) string {
|
||||
return strings.TrimSpace(HeaderRaw(r, name))
|
||||
}
|
||||
|
||||
// HeaderDefault returns the value of the named header.
|
||||
// If the header is not set, it returns the default value.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func HeaderDefault(r *http.Request, name, defaultValue string) string {
|
||||
if _, ok := textproto.MIMEHeader(r.Header)[name]; !ok {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return Header(r, name)
|
||||
}
|
||||
|
||||
// Cookie returns the value of the named cookie.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func Cookie(r *http.Request, name string) string {
|
||||
cookie, err := r.Cookie(name)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.TrimSpace(cookie.Value)
|
||||
}
|
||||
|
||||
// CookieDefault returns the value of the named cookie.
|
||||
// If the cookie is not set, it returns the default value.
|
||||
// The return value is trimmed of leading and trailing whitespace.
|
||||
func CookieDefault(r *http.Request, name, defaultValue string) string {
|
||||
cookie, err := r.Cookie(name)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return strings.TrimSpace(cookie.Value)
|
||||
}
|
||||
|
||||
// ClientIp returns the client IP address.
|
||||
//
|
||||
// As the request may come from a proxy, the function checks the
|
||||
// X-Real-Ip and X-Forwarded-For headers to get the real client IP
|
||||
// if the request IP matches one of the allowed proxy IPs.
|
||||
// If the special proxy value CheckPrivateProxy ("PRIVATE") is passed, the function will
|
||||
// also check the header if the request IP is a private IP address.
|
||||
func ClientIp(r *http.Request, allowedProxyIp ...string) string {
|
||||
ipStr, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
|
||||
switch {
|
||||
case err != nil && strings.Contains(err.Error(), "missing port in address"):
|
||||
ipStr = strings.TrimSpace(r.RemoteAddr)
|
||||
case err != nil:
|
||||
ipStr = ""
|
||||
}
|
||||
IP := net.ParseIP(ipStr)
|
||||
if IP == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
isProxiedRequest := false
|
||||
if len(allowedProxyIp) > 0 {
|
||||
if slices.Contains(allowedProxyIp, IP.String()) {
|
||||
isProxiedRequest = true
|
||||
}
|
||||
if IP.IsPrivate() && slices.Contains(allowedProxyIp, CheckPrivateProxy) {
|
||||
isProxiedRequest = true
|
||||
}
|
||||
}
|
||||
|
||||
if isProxiedRequest {
|
||||
realClientIP := r.Header.Get("X-Real-Ip")
|
||||
if realClientIP == "" {
|
||||
realClientIP = r.Header.Get("X-Forwarded-For")
|
||||
}
|
||||
if realClientIP != "" {
|
||||
realIpStr, _, err := net.SplitHostPort(strings.TrimSpace(realClientIP))
|
||||
switch {
|
||||
case err != nil && strings.Contains(err.Error(), "missing port in address"):
|
||||
realIpStr = realClientIP
|
||||
case err != nil:
|
||||
realIpStr = ipStr
|
||||
}
|
||||
realIP := net.ParseIP(realIpStr)
|
||||
if realIP == nil {
|
||||
return IP.String()
|
||||
}
|
||||
return realIP.String()
|
||||
}
|
||||
}
|
||||
|
||||
return IP.String()
|
||||
}
|
||||
|
||||
// BodyJson decodes the JSON value from the request body into the target.
|
||||
// The target must be a pointer to a struct or slice.
|
||||
// The function returns an error if the JSON value could not be decoded.
|
||||
// The body reader is closed after reading.
|
||||
func BodyJson(r *http.Request, target any) error {
|
||||
defer func() {
|
||||
_ = r.Body.Close()
|
||||
}()
|
||||
return json.NewDecoder(r.Body).Decode(target)
|
||||
}
|
||||
|
||||
// BodyString returns the request body as a string.
|
||||
// The content is read and returned as is, without any processing.
|
||||
// The body is assumed to be UTF-8 encoded.
|
||||
func BodyString(r *http.Request) (string, error) {
|
||||
defer func() {
|
||||
_ = r.Body.Close()
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(bodyBytes), nil
|
||||
}
|
221
internal/app/api/core/request/basic_test.go
Normal file
221
internal/app/api/core/request/basic_test.go
Normal file
@ -0,0 +1,221 @@
|
||||
package request
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPath(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{Path: "/test/sample"}}
|
||||
r.SetPathValue("first", "test")
|
||||
if got := Path(r, "first"); got != "test" {
|
||||
t.Errorf("Path() = %v, want %v", got, "test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultPath(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{Path: "/"}}
|
||||
if got := PathDefault(r, "test", "default"); got != "default" {
|
||||
t.Errorf("PathDefault() = %v, want %v", got, "default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultPath_noDefault(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{Path: "/"}}
|
||||
r.SetPathValue("first", "test")
|
||||
if got := PathDefault(r, "first", "test"); got != "test" {
|
||||
t.Errorf("PathDefault() = %v, want %v", got, "test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuery(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{RawQuery: "name=value"}}
|
||||
if got := Query(r, "name"); got != "value" {
|
||||
t.Errorf("Query() = %v, want %v", got, "value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultQuery(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{RawQuery: ""}}
|
||||
if got := QueryDefault(r, "name", "default"); got != "default" {
|
||||
t.Errorf("QueryDefault() = %v, want %v", got, "default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuerySlice(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{RawQuery: "name=value1 &name=value2"}}
|
||||
expected := []string{"value1", "value2"}
|
||||
if got := QuerySlice(r, "name"); !slices.Equal(got, expected) {
|
||||
t.Errorf("QuerySlice() = %v, want %v", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuerySlice_empty(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{RawQuery: "name=value1&name=value2"}}
|
||||
if got := QuerySlice(r, "nix"); !slices.Equal(got, nil) {
|
||||
t.Errorf("QuerySlice() = %v, want %v", got, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultQuerySlice(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{RawQuery: ""}}
|
||||
defaultValue := []string{"default1", "default2"}
|
||||
if got := QuerySliceDefault(r, "name", defaultValue); !slices.Equal(got, defaultValue) {
|
||||
t.Errorf("QuerySliceDefault() = %v, want %v", got, defaultValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFragment(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{Fragment: "section"}}
|
||||
if got := Fragment(r); got != "section" {
|
||||
t.Errorf("Fragment() = %v, want %v", got, "section")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultFragment(t *testing.T) {
|
||||
r := &http.Request{URL: &url.URL{Fragment: ""}}
|
||||
if got := FragmentDefault(r, "default"); got != "default" {
|
||||
t.Errorf("FragmentDefault() = %v, want %v", got, "default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForm(t *testing.T) {
|
||||
r := &http.Request{Form: url.Values{"name": {"value"}}}
|
||||
if got := Form(r, "name"); got != "value" {
|
||||
t.Errorf("Form() = %v, want %v", got, "value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultForm(t *testing.T) {
|
||||
r := &http.Request{Form: url.Values{}}
|
||||
if got := DefaultForm(r, "name", "default"); got != "default" {
|
||||
t.Errorf("DefaultForm() = %v, want %v", got, "default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeader(t *testing.T) {
|
||||
r := &http.Request{Header: http.Header{"X-Test-Header": {"value"}}}
|
||||
if got := Header(r, "X-Test-Header"); got != "value" {
|
||||
t.Errorf("Header() = %v, want %v", got, "value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultHeader(t *testing.T) {
|
||||
r := &http.Request{Header: http.Header{}}
|
||||
if got := HeaderDefault(r, "X-Test-Header", "default"); got != "default" {
|
||||
t.Errorf("HeaderDefault() = %v, want %v", got, "default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCookie(t *testing.T) {
|
||||
r := &http.Request{Header: http.Header{"Cookie": {"name=value"}}}
|
||||
if got := Cookie(r, "name"); got != "value" {
|
||||
t.Errorf("Cookie() = %v, want %v", got, "value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCookie(t *testing.T) {
|
||||
r := &http.Request{Header: http.Header{}}
|
||||
if got := CookieDefault(r, "name", "default"); got != "default" {
|
||||
t.Errorf("CookieDefault() = %v, want %v", got, "default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIp(t *testing.T) {
|
||||
r := &http.Request{RemoteAddr: "192.168.1.1:12345"}
|
||||
if got := ClientIp(r); got != "192.168.1.1" {
|
||||
t.Errorf("ClientIp() = %v, want %v", got, "192.168.1.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIp_invalid(t *testing.T) {
|
||||
r := &http.Request{RemoteAddr: "was_isn_des"}
|
||||
if got := ClientIp(r); got != "" {
|
||||
t.Errorf("ClientIp() = %v, want %v", got, "")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIp_ignore_header(t *testing.T) {
|
||||
r := &http.Request{RemoteAddr: "192.168.1.1:12345", Header: http.Header{"X-Forwarded-For": {"123.45.67.1"}}}
|
||||
if got := ClientIp(r); got != "192.168.1.1" {
|
||||
t.Errorf("ClientIp() = %v, want %v", got, "192.168.1.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIp_header1(t *testing.T) {
|
||||
r := &http.Request{RemoteAddr: "192.168.1.1:12345", Header: http.Header{"X-Forwarded-For": {"123.45.67.1"}}}
|
||||
if got := ClientIp(r, CheckPrivateProxy); got != "123.45.67.1" {
|
||||
t.Errorf("ClientIp() = %v, want %v", got, "123.45.67.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIp_header2(t *testing.T) {
|
||||
r := &http.Request{RemoteAddr: "192.168.1.1:12345", Header: http.Header{"X-Real-Ip": {"123.45.67.1"}}}
|
||||
if got := ClientIp(r, CheckPrivateProxy); got != "123.45.67.1" {
|
||||
t.Errorf("ClientIp() = %v, want %v", got, "123.45.67.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIp_header3(t *testing.T) {
|
||||
r := &http.Request{RemoteAddr: "1.1.1.1:12345", Header: http.Header{"X-Real-Ip": {"123.45.67.1"}}}
|
||||
if got := ClientIp(r, "1.1.1.1"); got != "123.45.67.1" {
|
||||
t.Errorf("ClientIp() = %v, want %v", got, "123.45.67.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIp_header4(t *testing.T) {
|
||||
r := &http.Request{RemoteAddr: "8.8.8.8:12345", Header: http.Header{"X-Real-Ip": {"123.45.67.1"}}}
|
||||
if got := ClientIp(r, "1.1.1.1"); got != "8.8.8.8" {
|
||||
t.Errorf("ClientIp() = %v, want %v", got, "8.8.8.8")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIp_header_invalid(t *testing.T) {
|
||||
r := &http.Request{RemoteAddr: "1.1.1.1:12345", Header: http.Header{"X-Real-Ip": {"so-sicher-nit"}}}
|
||||
if got := ClientIp(r, "1.1.1.1"); got != "1.1.1.1" {
|
||||
t.Errorf("ClientIp() = %v, want %v", got, "1.1.1.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBodyJson(t *testing.T) {
|
||||
type TestStruct struct {
|
||||
Name string `json:"name"`
|
||||
Value int `json:"value"`
|
||||
}
|
||||
|
||||
jsonStr := `{"name": "test", "value": 123}`
|
||||
r := &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader(jsonStr)),
|
||||
}
|
||||
|
||||
var result TestStruct
|
||||
err := BodyJson(r, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("BodyJson() error = %v", err)
|
||||
}
|
||||
|
||||
expected := TestStruct{Name: "test", Value: 123}
|
||||
if result != expected {
|
||||
t.Errorf("BodyJson() = %v, want %v", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBodyString(t *testing.T) {
|
||||
bodyStr := "test body content"
|
||||
r := &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader(bodyStr)),
|
||||
}
|
||||
|
||||
result, err := BodyString(r)
|
||||
if err != nil {
|
||||
t.Fatalf("BodyString() error = %v", err)
|
||||
}
|
||||
|
||||
if result != bodyStr {
|
||||
t.Errorf("BodyString() = %v, want %v", result, bodyStr)
|
||||
}
|
||||
}
|
100
internal/app/api/core/respond/basic.go
Normal file
100
internal/app/api/core/respond/basic.go
Normal file
@ -0,0 +1,100 @@
|
||||
// Package respond provides a set of utility functions to help with the HTTP response handling.
|
||||
package respond
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Status writes a response with the given status code.
|
||||
// The response will not contain any data.
|
||||
func Status(w http.ResponseWriter, code int) {
|
||||
w.WriteHeader(code)
|
||||
}
|
||||
|
||||
// String writes a plain text response with the given status code and data.
|
||||
// The Content-Type header is set to text/plain with a charset of utf-8.
|
||||
func String(w http.ResponseWriter, code int, data string) {
|
||||
w.Header().Set("Content-Type", "text/plain;charset=utf-8")
|
||||
w.WriteHeader(code)
|
||||
|
||||
_, _ = w.Write([]byte(data))
|
||||
}
|
||||
|
||||
// JSON writes a JSON response with the given status code and data.
|
||||
// If data is nil, the response will null. The status code is set to the given code.
|
||||
// The Content-Type header is set to application/json.
|
||||
// If the given data is not JSON serializable, the response will not contain any data.
|
||||
// All encoding errors are silently ignored.
|
||||
func JSON(w http.ResponseWriter, code int, data any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// if no data was given, simply return null
|
||||
if data == nil {
|
||||
w.WriteHeader(code)
|
||||
_, _ = w.Write([]byte("null"))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
|
||||
_ = json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
// Data writes a response with the given status code, content type, and data.
|
||||
// If no content type is provided, it is detected from the data.
|
||||
func Data(w http.ResponseWriter, code int, contentType string, data []byte) {
|
||||
if contentType == "" {
|
||||
contentType = http.DetectContentType(data) // ensure content type is set
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
|
||||
w.WriteHeader(code)
|
||||
|
||||
_, _ = w.Write(data)
|
||||
}
|
||||
|
||||
// Reader writes a response with the given status code, content type, and data.
|
||||
// The content length is optional, it is only set if the given length is greater than 0.
|
||||
func Reader(w http.ResponseWriter, code int, contentType string, contentLength int, data io.Reader) {
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
if contentLength > 0 {
|
||||
w.Header().Set("Content-Length", strconv.Itoa(contentLength))
|
||||
}
|
||||
w.WriteHeader(code)
|
||||
|
||||
_, _ = io.Copy(w, data)
|
||||
}
|
||||
|
||||
// Attachment writes a response with the given status code, content type, filename, and data.
|
||||
// If no content type is provided, it is detected from the data.
|
||||
func Attachment(w http.ResponseWriter, code int, filename, contentType string, data []byte) {
|
||||
w.Header().Set("Content-Disposition", "attachment; filename="+filename)
|
||||
|
||||
Data(w, code, contentType, data)
|
||||
}
|
||||
|
||||
// AttachmentReader writes a response with the given status code, content type, filename, content length, and data.
|
||||
// The content length is optional, it is only set if the given length is greater than 0.
|
||||
func AttachmentReader(
|
||||
w http.ResponseWriter,
|
||||
code int,
|
||||
filename, contentType string,
|
||||
contentLength int,
|
||||
data io.Reader,
|
||||
) {
|
||||
w.Header().Set("Content-Disposition", "attachment; filename="+filename)
|
||||
|
||||
Reader(w, code, contentType, contentLength, data)
|
||||
}
|
||||
|
||||
// Redirect writes a response with the given status code and redirects to the given URL.
|
||||
// The redirect url will always be an absolute URL. If the given URL is relative,
|
||||
// the original request URL is used as the base.
|
||||
func Redirect(w http.ResponseWriter, r *http.Request, code int, url string) {
|
||||
http.Redirect(w, r, url, code)
|
||||
}
|
273
internal/app/api/core/respond/basic_test.go
Normal file
273
internal/app/api/core/respond/basic_test.go
Normal file
@ -0,0 +1,273 @@
|
||||
package respond
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStatus(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
Status(rec, http.StatusNoContent)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
t.Errorf("expected status %d, got %d", http.StatusNoContent, res.StatusCode)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if len(body) != 0 {
|
||||
t.Errorf("expected no body, got %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestString(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
String(rec, http.StatusOK, "Hello, World!")
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain;charset=utf-8" {
|
||||
t.Errorf("expected content type %s, got %s", "text/plain;charset=utf-8", contentType)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if string(body) != "Hello, World!" {
|
||||
t.Errorf("expected body %s, got %s", "Hello, World!", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSON(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
data := map[string]string{"hello": "world"}
|
||||
JSON(rec, http.StatusOK, data)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentType := res.Header.Get("Content-Type"); contentType != "application/json" {
|
||||
t.Errorf("expected content type %s, got %s", "application/json", contentType)
|
||||
}
|
||||
|
||||
var body map[string]string
|
||||
_ = json.NewDecoder(res.Body).Decode(&body)
|
||||
if body["hello"] != "world" {
|
||||
t.Errorf("expected body %v, got %v", data, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSON_empty(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
JSON(rec, http.StatusOK, nil)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentType := res.Header.Get("Content-Type"); contentType != "application/json" {
|
||||
t.Errorf("expected content type %s, got %s", "application/json", contentType)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if string(body) != "null" {
|
||||
t.Errorf("expected body %s, got %s", "null", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestData(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
data := []byte("Hello, World!")
|
||||
Data(rec, http.StatusOK, "text/plain", data)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" {
|
||||
t.Errorf("expected content type %s, got %s", "text/plain", contentType)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if !bytes.Equal(body, data) {
|
||||
t.Errorf("expected body %s, got %s", data, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestData_noContentType(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
data := []byte{0x1, 0x2, 0x3, 0x4, 0x5}
|
||||
Data(rec, http.StatusOK, "", data)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentType := res.Header.Get("Content-Type"); contentType != "application/octet-stream" {
|
||||
t.Errorf("expected content type %s, got %s", "application/octet-stream", contentType)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if !bytes.Equal(body, data) {
|
||||
t.Errorf("expected body %s, got %s", data, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
data := []byte("Hello, World!")
|
||||
reader := bytes.NewBufferString(string(data))
|
||||
Reader(rec, http.StatusOK, "text/plain", len(data), reader)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" {
|
||||
t.Errorf("expected content type %s, got %s", "text/plain", contentType)
|
||||
}
|
||||
|
||||
if contentLength := res.Header.Get("Content-Length"); contentLength != strconv.Itoa(len(data)) {
|
||||
t.Errorf("expected content length %d, got %s", len(data), contentLength)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if string(body) != "Hello, World!" {
|
||||
t.Errorf("expected body %s, got %s", "Hello, World!", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_unknownLength(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
data := bytes.NewBufferString("Hello, World!")
|
||||
Reader(rec, http.StatusOK, "text/plain", 0, data)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" {
|
||||
t.Errorf("expected content type %s, got %s", "text/plain", contentType)
|
||||
}
|
||||
|
||||
if contentLength := res.Header.Get("Content-Length"); contentLength != "" {
|
||||
t.Errorf("expected no content length, got %s", contentLength)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if string(body) != "Hello, World!" {
|
||||
t.Errorf("expected body %s, got %s", "Hello, World!", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttachment(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
data := []byte("Hello, World!")
|
||||
Attachment(rec, http.StatusOK, "example.txt", "text/plain", data)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentDisposition := res.Header.Get("Content-Disposition"); contentDisposition != "attachment; filename=example.txt" {
|
||||
t.Errorf("expected content disposition %s, got %s", "attachment; filename=example.txt", contentDisposition)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if !bytes.Equal(body, data) {
|
||||
t.Errorf("expected body %s, got %s", data, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttachmentReader(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
data := bytes.NewBufferString("Hello, World!")
|
||||
AttachmentReader(rec, http.StatusOK, "example.txt", "text/plain", data.Len(), data)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentDisposition := res.Header.Get("Content-Disposition"); contentDisposition != "attachment; filename=example.txt" {
|
||||
t.Errorf("expected content disposition %s, got %s", "attachment; filename=example.txt", contentDisposition)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if string(body) != "Hello, World!" {
|
||||
t.Errorf("expected body %s, got %s", "Hello, World!", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirect(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/old", nil)
|
||||
url := "http://example.com/new"
|
||||
|
||||
Redirect(rec, req, http.StatusMovedPermanently, url)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusMovedPermanently {
|
||||
t.Errorf("expected status %d, got %d", http.StatusMovedPermanently, res.StatusCode)
|
||||
}
|
||||
|
||||
if location := res.Header.Get("Location"); location != url {
|
||||
t.Errorf("expected location %s, got %s", url, location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirect_relative(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/old/dir", nil)
|
||||
url := "newlocation/sub"
|
||||
want := "/old/newlocation/sub"
|
||||
|
||||
Redirect(rec, req, http.StatusMovedPermanently, url)
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusMovedPermanently {
|
||||
t.Errorf("expected status %d, got %d", http.StatusMovedPermanently, res.StatusCode)
|
||||
}
|
||||
|
||||
if location := res.Header.Get("Location"); location != want {
|
||||
t.Errorf("expected location %s, got %s", want, location)
|
||||
}
|
||||
}
|
46
internal/app/api/core/respond/template.go
Normal file
46
internal/app/api/core/respond/template.go
Normal file
@ -0,0 +1,46 @@
|
||||
package respond
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// TplData is a map of template data. This is a convenience type for passing data to templates.
|
||||
type TplData map[string]any
|
||||
|
||||
// TemplateInstance is an interface that wraps the ExecuteTemplate method.
|
||||
// It is implemented by the html/template and text/template packages.
|
||||
type TemplateInstance interface {
|
||||
// ExecuteTemplate executes a template with the given name and data.
|
||||
ExecuteTemplate(wr io.Writer, name string, data any) error
|
||||
}
|
||||
|
||||
// TemplateRenderer is a renderer that uses a template instance to render HTML or Text templates.
|
||||
type TemplateRenderer struct {
|
||||
t TemplateInstance
|
||||
}
|
||||
|
||||
// NewTemplateRenderer creates a new HTML or Text template renderer with the given template instance.
|
||||
func NewTemplateRenderer(t TemplateInstance) *TemplateRenderer {
|
||||
return &TemplateRenderer{t: t}
|
||||
}
|
||||
|
||||
// Render renders a template with the given name and data.
|
||||
// If rendering fails, it will panic with an error.
|
||||
func (r *TemplateRenderer) Render(w http.ResponseWriter, code int, name, contentType string, data any) {
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
w.WriteHeader(code)
|
||||
|
||||
err := r.t.ExecuteTemplate(w, name, data)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("error rendering template %s: %v", name, err))
|
||||
}
|
||||
}
|
||||
|
||||
// HTML renders a template with the given name and data. It is a convenience method for Render.
|
||||
// The content type is set to "text/html" and the encoding to "utf-8".
|
||||
// If rendering fails, it will panic with an error.
|
||||
func (r *TemplateRenderer) HTML(w http.ResponseWriter, code int, name string, data any) {
|
||||
r.Render(w, code, name, "text/html;charset=utf-8", data)
|
||||
}
|
67
internal/app/api/core/respond/template_test.go
Normal file
67
internal/app/api/core/respond/template_test.go
Normal file
@ -0,0 +1,67 @@
|
||||
package respond
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockTemplate struct {
|
||||
tmpl *template.Template
|
||||
}
|
||||
|
||||
func (m *mockTemplate) ExecuteTemplate(wr io.Writer, name string, data any) error {
|
||||
return m.tmpl.ExecuteTemplate(wr, name, data)
|
||||
}
|
||||
|
||||
func TestTemplateRenderer_Render(t *testing.T) {
|
||||
tmpl := template.Must(template.New("test").Parse(`{{define "test"}}Hello, {{.}}!{{end}}`))
|
||||
renderer := NewTemplateRenderer(&mockTemplate{tmpl: tmpl})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
renderer.Render(rec, http.StatusOK, "test", "text/plain", "World")
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" {
|
||||
t.Errorf("expected content type %s, got %s", "text/plain", contentType)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
expectedBody := "Hello, World!"
|
||||
if string(body) != expectedBody {
|
||||
t.Errorf("expected body %s, got %s", expectedBody, string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateRenderer_HTML(t *testing.T) {
|
||||
tmpl := template.Must(template.New("test").Parse(`{{define "test"}}<p>Hello, {{.}}!</p>{{end}}`))
|
||||
renderer := NewTemplateRenderer(&mockTemplate{tmpl: tmpl})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
renderer.HTML(rec, http.StatusOK, "test", "World")
|
||||
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
if contentType := res.Header.Get("Content-Type"); contentType != "text/html;charset=utf-8" {
|
||||
t.Errorf("expected content type %s, got %s", "text/html;charset=utf-8", contentType)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
expectedBody := "<p>Hello, World!</p>"
|
||||
if string(body) != expectedBody {
|
||||
t.Errorf("expected body %s, got %s", expectedBody, string(body))
|
||||
}
|
||||
}
|
@ -2,27 +2,25 @@ package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-pkgz/routegroup"
|
||||
|
||||
"github.com/h44z/wg-portal/internal"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/middleware/cors"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/middleware/logging"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/middleware/recovery"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/middleware/tracing"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/respond"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
)
|
||||
|
||||
var (
|
||||
random = rand.New(rand.NewSource(time.Now().UTC().UnixNano()))
|
||||
)
|
||||
|
||||
const (
|
||||
RequestIDKey = "X-Request-ID"
|
||||
)
|
||||
@ -30,19 +28,21 @@ const (
|
||||
type ApiVersion string
|
||||
type HandlerName string
|
||||
|
||||
type GroupSetupFn func(group *gin.RouterGroup)
|
||||
type GroupSetupFn func(group *routegroup.Bundle)
|
||||
|
||||
type ApiEndpointSetupFunc func() (ApiVersion, GroupSetupFn)
|
||||
|
||||
type Server struct {
|
||||
cfg *config.Config
|
||||
server *gin.Engine
|
||||
versions map[ApiVersion]*gin.RouterGroup
|
||||
server *routegroup.Bundle
|
||||
tpl *respond.TemplateRenderer
|
||||
versions map[ApiVersion]*routegroup.Bundle
|
||||
}
|
||||
|
||||
func NewServer(cfg *config.Config, endpoints ...ApiEndpointSetupFunc) (*Server, error) {
|
||||
s := &Server{
|
||||
cfg: cfg,
|
||||
cfg: cfg,
|
||||
server: routegroup.New(http.NewServeMux()),
|
||||
}
|
||||
|
||||
hostname, err := os.Hostname()
|
||||
@ -51,69 +51,39 @@ func NewServer(cfg *config.Config, endpoints ...ApiEndpointSetupFunc) (*Server,
|
||||
}
|
||||
hostname += ", version " + internal.Version
|
||||
|
||||
// Setup http server
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
gin.DefaultWriter = io.Discard
|
||||
s.server = gin.New()
|
||||
|
||||
s.server.Use(recovery.New().Handler)
|
||||
if cfg.Web.RequestLogging {
|
||||
if cfg.Advanced.LogLevel == "trace" {
|
||||
gin.SetMode(gin.DebugMode)
|
||||
}
|
||||
s.server.Use(func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
raw := c.Request.URL.RawQuery
|
||||
s.server.Use(logging.New(logging.WithLevel(logging.LogLevelDebug)).Handler)
|
||||
|
||||
c.Next()
|
||||
|
||||
if raw != "" {
|
||||
path = path + "?" + raw
|
||||
}
|
||||
|
||||
latency := time.Since(start)
|
||||
status := c.Writer.Status()
|
||||
clientIP := c.ClientIP()
|
||||
method := c.Request.Method
|
||||
errorMsg := c.Errors.ByType(gin.ErrorTypePrivate).String()
|
||||
|
||||
slog.Debug("HTTP Request",
|
||||
"status", status,
|
||||
"latency", latency,
|
||||
"client", clientIP,
|
||||
"method", method,
|
||||
"path", path,
|
||||
"error", errorMsg,
|
||||
)
|
||||
}
|
||||
s.server.Use(cors.New().Handler)
|
||||
s.server.Use(tracing.New(
|
||||
tracing.WithContextIdentifier(RequestIDKey),
|
||||
tracing.WithHeaderIdentifier(RequestIDKey),
|
||||
).Handler)
|
||||
if cfg.Web.ExposeHostInfo {
|
||||
s.server.Use(func(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Served-By", hostname)
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
s.server.Use(gin.Recovery()).Use(func(c *gin.Context) {
|
||||
c.Writer.Header().Set("X-Served-By", hostname)
|
||||
c.Next()
|
||||
}).Use(func(c *gin.Context) {
|
||||
xRequestID := uuid(16)
|
||||
|
||||
c.Request.Header.Set(RequestIDKey, xRequestID)
|
||||
c.Set(RequestIDKey, xRequestID)
|
||||
c.Next()
|
||||
})
|
||||
|
||||
// Setup templates
|
||||
templates := template.Must(template.New("").Funcs(s.server.FuncMap).ParseFS(apiTemplates, "assets/tpl/*.gohtml"))
|
||||
s.server.SetHTMLTemplate(templates)
|
||||
s.tpl = respond.NewTemplateRenderer(
|
||||
template.Must(template.New("").ParseFS(apiTemplates, "assets/tpl/*.gohtml")),
|
||||
)
|
||||
|
||||
// Serve static files
|
||||
imgFs := http.FS(fsMust(fs.Sub(apiStatics, "assets/img")))
|
||||
s.server.StaticFS("/css", http.FS(fsMust(fs.Sub(apiStatics, "assets/css"))))
|
||||
s.server.StaticFS("/js", http.FS(fsMust(fs.Sub(apiStatics, "assets/js"))))
|
||||
s.server.StaticFS("/img", imgFs)
|
||||
s.server.StaticFS("/fonts", http.FS(fsMust(fs.Sub(apiStatics, "assets/fonts"))))
|
||||
s.server.StaticFS("/doc", http.FS(fsMust(fs.Sub(apiStatics, "assets/doc"))))
|
||||
s.server.HandleFiles("/css", http.FS(fsMust(fs.Sub(apiStatics, "assets/css"))))
|
||||
s.server.HandleFiles("/js", http.FS(fsMust(fs.Sub(apiStatics, "assets/js"))))
|
||||
s.server.HandleFiles("/img", imgFs)
|
||||
s.server.HandleFiles("/fonts", http.FS(fsMust(fs.Sub(apiStatics, "assets/fonts"))))
|
||||
s.server.HandleFiles("/doc", http.FS(fsMust(fs.Sub(apiStatics, "assets/doc"))))
|
||||
|
||||
// Setup routes
|
||||
s.server.UseRawPath = true
|
||||
s.server.UnescapePathValues = true
|
||||
s.setupRoutes(endpoints...)
|
||||
s.setupFrontendRoutes()
|
||||
|
||||
@ -136,9 +106,7 @@ func (s *Server) Run(ctx context.Context, listenAddress string) {
|
||||
err = srv.ListenAndServe()
|
||||
}
|
||||
if err != nil {
|
||||
slog.Info("web service exited",
|
||||
"address", listenAddress,
|
||||
"error", err)
|
||||
slog.Info("web service exited", "address", listenAddress, "error", err)
|
||||
cancelFn()
|
||||
}
|
||||
}()
|
||||
@ -157,18 +125,18 @@ func (s *Server) Run(ctx context.Context, listenAddress string) {
|
||||
}
|
||||
|
||||
func (s *Server) setupRoutes(endpoints ...ApiEndpointSetupFunc) {
|
||||
s.server.GET("/api", s.landingPage)
|
||||
s.versions = make(map[ApiVersion]*gin.RouterGroup)
|
||||
s.server.HandleFunc("GET /api", s.landingPage)
|
||||
s.versions = make(map[ApiVersion]*routegroup.Bundle)
|
||||
|
||||
for _, setupFunc := range endpoints {
|
||||
version, groupSetupFn := setupFunc()
|
||||
|
||||
if _, ok := s.versions[version]; !ok {
|
||||
s.versions[version] = s.server.Group(fmt.Sprintf("/api/%s", version))
|
||||
s.versions[version] = s.server.Mount(fmt.Sprintf("/api/%s", version))
|
||||
|
||||
// OpenAPI documentation (via RapiDoc)
|
||||
s.versions[version].GET("/swagger/index.html", s.rapiDocHandler(version)) // Deprecated: old link
|
||||
s.versions[version].GET("/doc.html", s.rapiDocHandler(version))
|
||||
s.versions[version].HandleFunc("GET /swagger/index.html", s.rapiDocHandler(version)) // Deprecated: old link
|
||||
s.versions[version].HandleFunc("GET /doc.html", s.rapiDocHandler(version))
|
||||
|
||||
groupSetupFn(s.versions[version])
|
||||
}
|
||||
@ -177,25 +145,27 @@ func (s *Server) setupRoutes(endpoints ...ApiEndpointSetupFunc) {
|
||||
|
||||
func (s *Server) setupFrontendRoutes() {
|
||||
// Serve static files
|
||||
s.server.GET("/", func(c *gin.Context) {
|
||||
c.Redirect(http.StatusMovedPermanently, "/app")
|
||||
s.server.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
respond.Redirect(w, r, http.StatusMovedPermanently, "/app")
|
||||
})
|
||||
s.server.GET("/favicon.ico", func(c *gin.Context) {
|
||||
c.Redirect(http.StatusMovedPermanently, "/app/favicon.ico")
|
||||
|
||||
s.server.HandleFunc("/favicon.ico", func(w http.ResponseWriter, r *http.Request) {
|
||||
respond.Redirect(w, r, http.StatusMovedPermanently, "/app/favicon.ico")
|
||||
})
|
||||
s.server.StaticFS("/app", http.FS(fsMust(fs.Sub(frontendStatics, "frontend-dist"))))
|
||||
|
||||
s.server.HandleFiles("/app", http.FS(fsMust(fs.Sub(frontendStatics, "frontend-dist"))))
|
||||
}
|
||||
|
||||
func (s *Server) landingPage(c *gin.Context) {
|
||||
c.HTML(http.StatusOK, "index.gohtml", gin.H{
|
||||
func (s *Server) landingPage(w http.ResponseWriter, _ *http.Request) {
|
||||
s.tpl.HTML(w, http.StatusOK, "index.gohtml", respond.TplData{
|
||||
"Version": internal.Version,
|
||||
"Year": time.Now().Year(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) rapiDocHandler(version ApiVersion) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.HTML(http.StatusOK, "rapidoc.gohtml", gin.H{
|
||||
func (s *Server) rapiDocHandler(version ApiVersion) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
s.tpl.HTML(w, http.StatusOK, "rapidoc.gohtml", respond.TplData{
|
||||
"RapiDocSource": "/js/rapidoc-min.js",
|
||||
"ApiSpecUrl": fmt.Sprintf("/doc/%s_swagger.yaml", version),
|
||||
"Version": internal.Version,
|
||||
@ -210,9 +180,3 @@ func fsMust(f fs.FS, err error) fs.FS {
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func uuid(len int) string {
|
||||
bytes := make([]byte, len)
|
||||
random.Read(bytes)
|
||||
return base64.StdEncoding.EncodeToString(bytes)[:len]
|
||||
}
|
||||
|
@ -1,24 +1,46 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-contrib/sessions/memstore"
|
||||
"github.com/gin-gonic/gin"
|
||||
csrf "github.com/utrack/gin-csrf"
|
||||
"github.com/go-pkgz/routegroup"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core"
|
||||
"github.com/h44z/wg-portal/internal/app/api/v0/model"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/middleware/cors"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/middleware/csrf"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/respond"
|
||||
)
|
||||
|
||||
type handler interface {
|
||||
type SessionMiddleware interface {
|
||||
// SetData sets the session data for the given context.
|
||||
SetData(ctx context.Context, val SessionData)
|
||||
// GetData returns the session data for the given context. If no data is found, the default session data is returned.
|
||||
GetData(ctx context.Context) SessionData
|
||||
// DestroyData destroys the session data for the given context.
|
||||
DestroyData(ctx context.Context)
|
||||
|
||||
// GetString returns the string value for the given key. If no value is found, an empty string is returned.
|
||||
GetString(ctx context.Context, key string) string
|
||||
// Put sets the value for the given key.
|
||||
Put(ctx context.Context, key string, value any)
|
||||
// LoadAndSave is a middleware that loads the session data for the given request and saves it after the request is
|
||||
// finished.
|
||||
LoadAndSave(next http.Handler) http.Handler
|
||||
}
|
||||
|
||||
type Handler interface {
|
||||
// GetName returns the name of the handler.
|
||||
GetName() string
|
||||
RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler)
|
||||
// RegisterRoutes registers the routes for the handler. The session manager is passed to the handler.
|
||||
RegisterRoutes(g *routegroup.Bundle)
|
||||
}
|
||||
|
||||
type Authenticator interface {
|
||||
// LoggedIn checks if a user is logged in. If scopes are given, they are validated as well.
|
||||
LoggedIn(scopes ...Scope) func(next http.Handler) http.Handler
|
||||
// UserIdMatch checks if the user id in the session matches the user id in the request. If not, the request is aborted.
|
||||
UserIdMatch(idParameter string) func(next http.Handler) http.Handler
|
||||
}
|
||||
|
||||
// To compile the API documentation use the
|
||||
@ -35,54 +57,33 @@ type handler interface {
|
||||
// @BasePath /api/v0
|
||||
// @query.collection.format multi
|
||||
|
||||
func NewRestApi(cfg *config.Config, app *app.App) core.ApiEndpointSetupFunc {
|
||||
authenticator := &authenticationHandler{
|
||||
app: app,
|
||||
Session: GinSessionStore{sessionIdentifier: cfg.Web.SessionIdentifier},
|
||||
}
|
||||
|
||||
handlers := make([]handler, 0, 1)
|
||||
handlers = append(handlers, testEndpoint{})
|
||||
handlers = append(handlers, userEndpoint{app: app, authenticator: authenticator})
|
||||
handlers = append(handlers, newConfigEndpoint(app, authenticator))
|
||||
handlers = append(handlers, authEndpoint{app: app, authenticator: authenticator})
|
||||
handlers = append(handlers, interfaceEndpoint{app: app, authenticator: authenticator})
|
||||
handlers = append(handlers, peerEndpoint{app: app, authenticator: authenticator})
|
||||
|
||||
func NewRestApi(
|
||||
session SessionMiddleware,
|
||||
handlers ...Handler,
|
||||
) core.ApiEndpointSetupFunc {
|
||||
return func() (core.ApiVersion, core.GroupSetupFn) {
|
||||
return "v0", func(group *gin.RouterGroup) {
|
||||
cookieStore := memstore.NewStore([]byte(cfg.Web.SessionSecret))
|
||||
cookieStore.Options(sessions.Options{
|
||||
Path: "/",
|
||||
MaxAge: 86400, // auth session is valid for 1 day
|
||||
Secure: strings.HasPrefix(cfg.Web.ExternalUrl, "https"),
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
return "v0", func(group *routegroup.Bundle) {
|
||||
csrfMiddleware := csrf.New(func(r *http.Request) string {
|
||||
return session.GetString(r.Context(), "csrf_token")
|
||||
}, func(r *http.Request, token string) {
|
||||
session.Put(r.Context(), "csrf_token", token)
|
||||
})
|
||||
group.Use(sessions.Sessions(cfg.Web.SessionIdentifier, cookieStore))
|
||||
group.Use(cors.Default())
|
||||
group.Use(csrf.Middleware(csrf.Options{
|
||||
Secret: cfg.Web.CsrfSecret,
|
||||
ErrorFunc: func(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, model.Error{
|
||||
Code: http.StatusBadRequest,
|
||||
Message: "CSRF token mismatch",
|
||||
})
|
||||
c.Abort()
|
||||
},
|
||||
}))
|
||||
|
||||
group.GET("/csrf", handleCsrfGet())
|
||||
group.Use(session.LoadAndSave)
|
||||
group.Use(csrfMiddleware.Handler)
|
||||
group.Use(cors.New().Handler)
|
||||
|
||||
group.With(csrfMiddleware.RefreshToken).HandleFunc("GET /csrf", handleCsrfGet())
|
||||
|
||||
// Handler functions
|
||||
for _, h := range handlers {
|
||||
h.RegisterRoutes(group, authenticator)
|
||||
h.RegisterRoutes(group)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleCsrfGet returns a gorm handler function.
|
||||
// handleCsrfGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID base_handleCsrfGet
|
||||
// @Tags Security
|
||||
@ -90,8 +91,12 @@ func NewRestApi(cfg *config.Config, app *app.App) core.ApiEndpointSetupFunc {
|
||||
// @Produce json
|
||||
// @Success 200 {object} string
|
||||
// @Router /csrf [get]
|
||||
func handleCsrfGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, csrf.GetToken(c))
|
||||
func handleCsrfGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
respond.JSON(w, http.StatusOK, csrf.GetToken(r.Context()))
|
||||
}
|
||||
}
|
||||
|
||||
// region session wrapper
|
||||
|
||||
// endregion session wrapper
|
||||
|
@ -8,36 +8,62 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-pkgz/routegroup"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/request"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/respond"
|
||||
"github.com/h44z/wg-portal/internal/app/api/v0/model"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type authEndpoint struct {
|
||||
app *app.App
|
||||
authenticator *authenticationHandler
|
||||
type Session interface {
|
||||
// SetData sets the session data for the given context.
|
||||
SetData(ctx context.Context, val SessionData)
|
||||
// GetData returns the session data for the given context. If no data is found, the default session data is returned.
|
||||
GetData(ctx context.Context) SessionData
|
||||
// DestroyData destroys the session data for the given context.
|
||||
DestroyData(ctx context.Context)
|
||||
}
|
||||
|
||||
func (e authEndpoint) GetName() string {
|
||||
type Validator interface {
|
||||
Struct(s interface{}) error
|
||||
}
|
||||
|
||||
type AuthEndpoint struct {
|
||||
app *app.App
|
||||
authenticator Authenticator
|
||||
session Session
|
||||
validate Validator
|
||||
}
|
||||
|
||||
func NewAuthEndpoint(app *app.App, authenticator Authenticator, session Session, validator Validator) AuthEndpoint {
|
||||
return AuthEndpoint{
|
||||
app: app,
|
||||
authenticator: authenticator,
|
||||
session: session,
|
||||
validate: validator,
|
||||
}
|
||||
}
|
||||
|
||||
func (e AuthEndpoint) GetName() string {
|
||||
return "AuthEndpoint"
|
||||
}
|
||||
|
||||
func (e authEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
|
||||
apiGroup := g.Group("/auth")
|
||||
func (e AuthEndpoint) RegisterRoutes(g *routegroup.Bundle) {
|
||||
apiGroup := g.Mount("/auth")
|
||||
|
||||
apiGroup.GET("/providers", e.handleExternalLoginProvidersGet())
|
||||
apiGroup.GET("/session", e.handleSessionInfoGet())
|
||||
apiGroup.HandleFunc("GET /providers", e.handleExternalLoginProvidersGet())
|
||||
apiGroup.HandleFunc("GET /session", e.handleSessionInfoGet())
|
||||
|
||||
apiGroup.GET("/login/:provider/init", e.handleOauthInitiateGet())
|
||||
apiGroup.GET("/login/:provider/callback", e.handleOauthCallbackGet())
|
||||
apiGroup.HandleFunc("GET /login/{provider}/init", e.handleOauthInitiateGet())
|
||||
apiGroup.HandleFunc("GET /login/{provider}/callback", e.handleOauthCallbackGet())
|
||||
|
||||
apiGroup.POST("/login", e.handleLoginPost())
|
||||
apiGroup.POST("/logout", authenticator.LoggedIn(), e.handleLogoutPost())
|
||||
apiGroup.HandleFunc("POST /login", e.handleLoginPost())
|
||||
apiGroup.With(e.authenticator.LoggedIn()).HandleFunc("POST /logout", e.handleLogoutPost())
|
||||
}
|
||||
|
||||
// handleExternalLoginProvidersGet returns a gorm handler function.
|
||||
// handleExternalLoginProvidersGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID auth_handleExternalLoginProvidersGet
|
||||
// @Tags Authentication
|
||||
@ -45,16 +71,15 @@ func (e authEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenti
|
||||
// @Produce json
|
||||
// @Success 200 {object} []model.LoginProviderInfo
|
||||
// @Router /auth/providers [get]
|
||||
func (e authEndpoint) handleExternalLoginProvidersGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
providers := e.app.Authenticator.GetExternalLoginProviders(ctx)
|
||||
func (e AuthEndpoint) handleExternalLoginProvidersGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
providers := e.app.Authenticator.GetExternalLoginProviders(r.Context())
|
||||
|
||||
c.JSON(http.StatusOK, model.NewLoginProviderInfos(providers))
|
||||
respond.JSON(w, http.StatusOK, model.NewLoginProviderInfos(providers))
|
||||
}
|
||||
}
|
||||
|
||||
// handleSessionInfoGet returns a gorm handler function.
|
||||
// handleSessionInfoGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID auth_handleSessionInfoGet
|
||||
// @Tags Authentication
|
||||
@ -63,9 +88,9 @@ func (e authEndpoint) handleExternalLoginProvidersGet() gin.HandlerFunc {
|
||||
// @Success 200 {object} []model.SessionInfo
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /auth/session [get]
|
||||
func (e authEndpoint) handleSessionInfoGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
currentSession := e.authenticator.Session.GetData(c)
|
||||
func (e AuthEndpoint) handleSessionInfoGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
currentSession := e.session.GetData(r.Context())
|
||||
|
||||
var loggedInUid *string
|
||||
var firstname *string
|
||||
@ -83,7 +108,7 @@ func (e authEndpoint) handleSessionInfoGet() gin.HandlerFunc {
|
||||
email = &e
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.SessionInfo{
|
||||
respond.JSON(w, http.StatusOK, model.SessionInfo{
|
||||
LoggedIn: currentSession.LoggedIn,
|
||||
IsAdmin: currentSession.IsAdmin,
|
||||
UserIdentifier: loggedInUid,
|
||||
@ -94,7 +119,7 @@ func (e authEndpoint) handleSessionInfoGet() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// handleOauthInitiateGet returns a gorm handler function.
|
||||
// handleOauthInitiateGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID auth_handleOauthInitiateGet
|
||||
// @Tags Authentication
|
||||
@ -102,23 +127,24 @@ func (e authEndpoint) handleSessionInfoGet() gin.HandlerFunc {
|
||||
// @Produce json
|
||||
// @Success 200 {object} []model.LoginProviderInfo
|
||||
// @Router /auth/{provider}/init [get]
|
||||
func (e authEndpoint) handleOauthInitiateGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
currentSession := e.authenticator.Session.GetData(c)
|
||||
func (e AuthEndpoint) handleOauthInitiateGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
currentSession := e.session.GetData(r.Context())
|
||||
|
||||
autoRedirect, _ := strconv.ParseBool(c.DefaultQuery("redirect", "false"))
|
||||
returnTo := c.Query("return")
|
||||
provider := c.Param("provider")
|
||||
autoRedirect, _ := strconv.ParseBool(request.QueryDefault(r, "redirect", "false"))
|
||||
returnTo := request.Query(r, "return")
|
||||
provider := request.Path(r, "provider")
|
||||
|
||||
var returnUrl *url.URL
|
||||
var returnParams string
|
||||
redirectToReturn := func() {
|
||||
c.Redirect(http.StatusFound, returnUrl.String()+"?"+returnParams)
|
||||
respond.Redirect(w, r, http.StatusFound, returnUrl.String()+"?"+returnParams)
|
||||
}
|
||||
|
||||
if returnTo != "" {
|
||||
if !e.isValidReturnUrl(returnTo) {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "invalid return URL"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "invalid return URL"})
|
||||
return
|
||||
}
|
||||
if u, err := url.Parse(returnTo); err == nil {
|
||||
@ -137,34 +163,34 @@ func (e authEndpoint) handleOauthInitiateGet() gin.HandlerFunc {
|
||||
returnParams = queryParams.Encode()
|
||||
redirectToReturn()
|
||||
} else {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "already logged in"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "already logged in"})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
authCodeUrl, state, nonce, err := e.app.Authenticator.OauthLoginStep1(ctx, provider)
|
||||
authCodeUrl, state, nonce, err := e.app.Authenticator.OauthLoginStep1(context.Background(), provider)
|
||||
if err != nil {
|
||||
if autoRedirect && e.isValidReturnUrl(returnTo) {
|
||||
redirectToReturn()
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
authSession := e.authenticator.Session.DefaultSessionData()
|
||||
authSession := e.session.GetData(r.Context())
|
||||
authSession.OauthState = state
|
||||
authSession.OauthNonce = nonce
|
||||
authSession.OauthProvider = provider
|
||||
authSession.OauthReturnTo = returnTo
|
||||
e.authenticator.Session.SetData(c, authSession)
|
||||
e.session.SetData(r.Context(), authSession)
|
||||
|
||||
if autoRedirect {
|
||||
c.Redirect(http.StatusFound, authCodeUrl)
|
||||
respond.Redirect(w, r, http.StatusFound, authCodeUrl)
|
||||
} else {
|
||||
c.JSON(http.StatusOK, model.OauthInitiationResponse{
|
||||
respond.JSON(w, http.StatusOK, model.OauthInitiationResponse{
|
||||
RedirectUrl: authCodeUrl,
|
||||
State: state,
|
||||
})
|
||||
@ -172,7 +198,7 @@ func (e authEndpoint) handleOauthInitiateGet() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// handleOauthCallbackGet returns a gorm handler function.
|
||||
// handleOauthCallbackGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID auth_handleOauthCallbackGet
|
||||
// @Tags Authentication
|
||||
@ -180,14 +206,14 @@ func (e authEndpoint) handleOauthInitiateGet() gin.HandlerFunc {
|
||||
// @Produce json
|
||||
// @Success 200 {object} []model.LoginProviderInfo
|
||||
// @Router /auth/{provider}/callback [get]
|
||||
func (e authEndpoint) handleOauthCallbackGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
currentSession := e.authenticator.Session.GetData(c)
|
||||
func (e AuthEndpoint) handleOauthCallbackGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
currentSession := e.session.GetData(r.Context())
|
||||
|
||||
var returnUrl *url.URL
|
||||
var returnParams string
|
||||
redirectToReturn := func() {
|
||||
c.Redirect(http.StatusFound, returnUrl.String()+"?"+returnParams)
|
||||
respond.Redirect(w, r, http.StatusFound, returnUrl.String()+"?"+returnParams)
|
||||
}
|
||||
|
||||
if currentSession.OauthReturnTo != "" {
|
||||
@ -207,20 +233,20 @@ func (e authEndpoint) handleOauthCallbackGet() gin.HandlerFunc {
|
||||
returnParams = queryParams.Encode()
|
||||
redirectToReturn()
|
||||
} else {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Message: "already logged in"})
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Message: "already logged in"})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
provider := c.Param("provider")
|
||||
oauthCode := c.Query("code")
|
||||
oauthState := c.Query("state")
|
||||
provider := request.Path(r, "provider")
|
||||
oauthCode := request.Query(r, "code")
|
||||
oauthState := request.Query(r, "state")
|
||||
|
||||
if provider != currentSession.OauthProvider {
|
||||
if returnUrl != nil && e.isValidReturnUrl(returnUrl.String()) {
|
||||
redirectToReturn()
|
||||
} else {
|
||||
c.JSON(http.StatusBadRequest,
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "invalid oauth provider"})
|
||||
}
|
||||
return
|
||||
@ -229,7 +255,8 @@ func (e authEndpoint) handleOauthCallbackGet() gin.HandlerFunc {
|
||||
if returnUrl != nil && e.isValidReturnUrl(returnUrl.String()) {
|
||||
redirectToReturn()
|
||||
} else {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "invalid oauth state"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "invalid oauth state"})
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -241,12 +268,13 @@ func (e authEndpoint) handleOauthCallbackGet() gin.HandlerFunc {
|
||||
if returnUrl != nil && e.isValidReturnUrl(returnUrl.String()) {
|
||||
redirectToReturn()
|
||||
} else {
|
||||
c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: err.Error()})
|
||||
respond.JSON(w, http.StatusUnauthorized,
|
||||
model.Error{Code: http.StatusUnauthorized, Message: err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
e.setAuthenticatedUser(c, user)
|
||||
e.setAuthenticatedUser(r, user)
|
||||
|
||||
if returnUrl != nil && e.isValidReturnUrl(returnUrl.String()) {
|
||||
queryParams := returnUrl.Query()
|
||||
@ -254,13 +282,13 @@ func (e authEndpoint) handleOauthCallbackGet() gin.HandlerFunc {
|
||||
returnParams = queryParams.Encode()
|
||||
redirectToReturn()
|
||||
} else {
|
||||
c.JSON(http.StatusOK, user)
|
||||
respond.JSON(w, http.StatusOK, user)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e authEndpoint) setAuthenticatedUser(c *gin.Context, user *domain.User) {
|
||||
currentSession := e.authenticator.Session.GetData(c)
|
||||
func (e AuthEndpoint) setAuthenticatedUser(r *http.Request, user *domain.User) {
|
||||
currentSession := e.session.GetData(r.Context())
|
||||
|
||||
currentSession.LoggedIn = true
|
||||
currentSession.IsAdmin = user.IsAdmin
|
||||
@ -274,10 +302,10 @@ func (e authEndpoint) setAuthenticatedUser(c *gin.Context, user *domain.User) {
|
||||
currentSession.OauthProvider = ""
|
||||
currentSession.OauthReturnTo = ""
|
||||
|
||||
e.authenticator.Session.SetData(c, currentSession)
|
||||
e.session.SetData(r.Context(), currentSession)
|
||||
}
|
||||
|
||||
// handleLoginPost returns a gorm handler function.
|
||||
// handleLoginPost returns a gorm Handler function.
|
||||
//
|
||||
// @ID auth_handleLoginPost
|
||||
// @Tags Authentication
|
||||
@ -285,11 +313,11 @@ func (e authEndpoint) setAuthenticatedUser(c *gin.Context, user *domain.User) {
|
||||
// @Produce json
|
||||
// @Success 200 {object} []model.LoginProviderInfo
|
||||
// @Router /auth/login [post]
|
||||
func (e authEndpoint) handleLoginPost() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
currentSession := e.authenticator.Session.GetData(c)
|
||||
func (e AuthEndpoint) handleLoginPost() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
currentSession := e.session.GetData(r.Context())
|
||||
if currentSession.LoggedIn {
|
||||
c.JSON(http.StatusOK, model.Error{Code: http.StatusOK, Message: "already logged in"})
|
||||
respond.JSON(w, http.StatusOK, model.Error{Code: http.StatusOK, Message: "already logged in"})
|
||||
return
|
||||
}
|
||||
|
||||
@ -298,25 +326,29 @@ func (e authEndpoint) handleLoginPost() gin.HandlerFunc {
|
||||
Password string `json:"password" binding:"required,min=4"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&loginData); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
if err := request.BodyJson(r, &loginData); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := e.validate.Struct(loginData); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
user, err := e.app.Authenticator.PlainLogin(ctx, loginData.Username, loginData.Password)
|
||||
user, err := e.app.Authenticator.PlainLogin(context.Background(), loginData.Username, loginData.Password)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "login failed"})
|
||||
respond.JSON(w, http.StatusUnauthorized,
|
||||
model.Error{Code: http.StatusUnauthorized, Message: "login failed"})
|
||||
return
|
||||
}
|
||||
|
||||
e.setAuthenticatedUser(c, user)
|
||||
e.setAuthenticatedUser(r, user)
|
||||
|
||||
c.JSON(http.StatusOK, user)
|
||||
respond.JSON(w, http.StatusOK, user)
|
||||
}
|
||||
}
|
||||
|
||||
// handleLogoutPost returns a gorm handler function.
|
||||
// handleLogoutPost returns a gorm Handler function.
|
||||
//
|
||||
// @ID auth_handleLogoutGet
|
||||
// @Tags Authentication
|
||||
@ -324,22 +356,22 @@ func (e authEndpoint) handleLoginPost() gin.HandlerFunc {
|
||||
// @Produce json
|
||||
// @Success 200 {object} []model.LoginProviderInfo
|
||||
// @Router /auth/logout [get]
|
||||
func (e authEndpoint) handleLogoutPost() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
currentSession := e.authenticator.Session.GetData(c)
|
||||
func (e AuthEndpoint) handleLogoutPost() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
currentSession := e.session.GetData(r.Context())
|
||||
|
||||
if !currentSession.LoggedIn { // Not logged in
|
||||
c.JSON(http.StatusOK, model.Error{Code: http.StatusOK, Message: "not logged in"})
|
||||
respond.JSON(w, http.StatusOK, model.Error{Code: http.StatusOK, Message: "not logged in"})
|
||||
return
|
||||
}
|
||||
|
||||
e.authenticator.Session.DestroyData(c)
|
||||
c.JSON(http.StatusOK, model.Error{Code: http.StatusOK, Message: "logout ok"})
|
||||
e.session.DestroyData(r.Context())
|
||||
respond.JSON(w, http.StatusOK, model.Error{Code: http.StatusOK, Message: "logout ok"})
|
||||
}
|
||||
}
|
||||
|
||||
// isValidReturnUrl checks if the given return URL matches the configured external URL of the application.
|
||||
func (e authEndpoint) isValidReturnUrl(returnUrl string) bool {
|
||||
func (e AuthEndpoint) isValidReturnUrl(returnUrl string) bool {
|
||||
if !strings.HasPrefix(returnUrl, e.app.Config.Web.ExternalUrl) {
|
||||
return false
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"embed"
|
||||
"fmt"
|
||||
"html/template"
|
||||
@ -9,57 +8,61 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-pkgz/routegroup"
|
||||
|
||||
"github.com/h44z/wg-portal/internal"
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/request"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/respond"
|
||||
"github.com/h44z/wg-portal/internal/app/api/v0/model"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
)
|
||||
|
||||
//go:embed frontend_config.js.gotpl
|
||||
var frontendJs embed.FS
|
||||
|
||||
type configEndpoint struct {
|
||||
app *app.App
|
||||
authenticator *authenticationHandler
|
||||
type ConfigEndpoint struct {
|
||||
cfg *config.Config
|
||||
authenticator Authenticator
|
||||
|
||||
tpl *template.Template
|
||||
tpl *respond.TemplateRenderer
|
||||
}
|
||||
|
||||
func newConfigEndpoint(app *app.App, authenticator *authenticationHandler) configEndpoint {
|
||||
ep := configEndpoint{
|
||||
app: app,
|
||||
func NewConfigEndpoint(cfg *config.Config, authenticator Authenticator) ConfigEndpoint {
|
||||
ep := ConfigEndpoint{
|
||||
cfg: cfg,
|
||||
authenticator: authenticator,
|
||||
tpl: template.Must(template.ParseFS(frontendJs, "frontend_config.js.gotpl")),
|
||||
tpl: respond.NewTemplateRenderer(template.Must(template.ParseFS(frontendJs,
|
||||
"frontend_config.js.gotpl"))),
|
||||
}
|
||||
|
||||
return ep
|
||||
}
|
||||
|
||||
func (e configEndpoint) GetName() string {
|
||||
func (e ConfigEndpoint) GetName() string {
|
||||
return "ConfigEndpoint"
|
||||
}
|
||||
|
||||
func (e configEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandler) {
|
||||
apiGroup := g.Group("/config")
|
||||
func (e ConfigEndpoint) RegisterRoutes(g *routegroup.Bundle) {
|
||||
apiGroup := g.Mount("/config")
|
||||
|
||||
apiGroup.GET("/frontend.js", e.handleConfigJsGet())
|
||||
apiGroup.GET("/settings", e.authenticator.LoggedIn(), e.handleSettingsGet())
|
||||
apiGroup.HandleFunc("GET /frontend.js", e.handleConfigJsGet())
|
||||
apiGroup.With(e.authenticator.LoggedIn()).HandleFunc("GET /settings", e.handleSettingsGet())
|
||||
}
|
||||
|
||||
// handleConfigJsGet returns a gorm handler function.
|
||||
// handleConfigJsGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID config_handleConfigJsGet
|
||||
// @Tags Configuration
|
||||
// @Summary Get the dynamic frontend configuration javascript.
|
||||
// @Produce text/javascript
|
||||
// @Success 200 string javascript "The JavaScript contents"
|
||||
// @Failure 500
|
||||
// @Router /config/frontend.js [get]
|
||||
func (e configEndpoint) handleConfigJsGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
backendUrl := fmt.Sprintf("%s/api/v0", e.app.Config.Web.ExternalUrl)
|
||||
if c.GetHeader("x-wg-dev") != "" {
|
||||
referer := c.Request.Header.Get("Referer")
|
||||
func (e ConfigEndpoint) handleConfigJsGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
backendUrl := fmt.Sprintf("%s/api/v0", e.cfg.Web.ExternalUrl)
|
||||
if request.Header(r, "x-wg-dev") != "" {
|
||||
referer := request.Header(r, "Referer")
|
||||
host := "localhost"
|
||||
port := "5000"
|
||||
parsedReferer, err := url.Parse(referer)
|
||||
@ -69,23 +72,17 @@ func (e configEndpoint) handleConfigJsGet() gin.HandlerFunc {
|
||||
backendUrl = fmt.Sprintf("http://%s:%s/api/v0", host,
|
||||
port) // override if request comes from frontend started with npm run dev
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
err := e.tpl.ExecuteTemplate(buf, "frontend_config.js.gotpl", gin.H{
|
||||
|
||||
e.tpl.Render(w, http.StatusOK, "frontend_config.js.gotpl", "text/javascript", map[string]any{
|
||||
"BackendUrl": backendUrl,
|
||||
"Version": internal.Version,
|
||||
"SiteTitle": e.app.Config.Web.SiteTitle,
|
||||
"SiteCompanyName": e.app.Config.Web.SiteCompanyName,
|
||||
"SiteTitle": e.cfg.Web.SiteTitle,
|
||||
"SiteCompanyName": e.cfg.Web.SiteCompanyName,
|
||||
})
|
||||
if err != nil {
|
||||
c.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "application/javascript", buf.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
// handleSettingsGet returns a gorm handler function.
|
||||
// handleSettingsGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID config_handleSettingsGet
|
||||
// @Tags Configuration
|
||||
@ -94,13 +91,13 @@ func (e configEndpoint) handleConfigJsGet() gin.HandlerFunc {
|
||||
// @Success 200 {object} model.Settings
|
||||
// @Success 200 string javascript "The JavaScript contents"
|
||||
// @Router /config/settings [get]
|
||||
func (e configEndpoint) handleSettingsGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, model.Settings{
|
||||
MailLinkOnly: e.app.Config.Mail.LinkOnly,
|
||||
PersistentConfigSupported: e.app.Config.Advanced.ConfigStoragePath != "",
|
||||
SelfProvisioning: e.app.Config.Core.SelfProvisioningAllowed,
|
||||
ApiAdminOnly: e.app.Config.Advanced.ApiAdminOnly,
|
||||
func (e ConfigEndpoint) handleSettingsGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
respond.JSON(w, http.StatusOK, model.Settings{
|
||||
MailLinkOnly: e.cfg.Mail.LinkOnly,
|
||||
PersistentConfigSupported: e.cfg.Advanced.ConfigStoragePath != "",
|
||||
SelfProvisioning: e.cfg.Core.SelfProvisioningAllowed,
|
||||
ApiAdminOnly: e.cfg.Advanced.ApiAdminOnly,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -4,39 +4,51 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-pkgz/routegroup"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/request"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/respond"
|
||||
"github.com/h44z/wg-portal/internal/app/api/v0/model"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type interfaceEndpoint struct {
|
||||
type InterfaceEndpoint struct {
|
||||
app *app.App
|
||||
authenticator *authenticationHandler
|
||||
authenticator Authenticator
|
||||
validator Validator
|
||||
}
|
||||
|
||||
func (e interfaceEndpoint) GetName() string {
|
||||
func NewInterfaceEndpoint(app *app.App, authenticator Authenticator, validator Validator) InterfaceEndpoint {
|
||||
return InterfaceEndpoint{
|
||||
app: app,
|
||||
authenticator: authenticator,
|
||||
validator: validator,
|
||||
}
|
||||
}
|
||||
|
||||
func (e InterfaceEndpoint) GetName() string {
|
||||
return "InterfaceEndpoint"
|
||||
}
|
||||
|
||||
func (e interfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandler) {
|
||||
apiGroup := g.Group("/interface", e.authenticator.LoggedIn(ScopeAdmin))
|
||||
func (e InterfaceEndpoint) RegisterRoutes(g *routegroup.Bundle) {
|
||||
apiGroup := g.Mount("/interface")
|
||||
apiGroup.Use(e.authenticator.LoggedIn(ScopeAdmin))
|
||||
|
||||
apiGroup.GET("/prepare", e.handlePrepareGet())
|
||||
apiGroup.GET("/all", e.handleAllGet())
|
||||
apiGroup.GET("/get/:id", e.handleSingleGet())
|
||||
apiGroup.PUT("/:id", e.handleUpdatePut())
|
||||
apiGroup.DELETE("/:id", e.handleDelete())
|
||||
apiGroup.POST("/new", e.handleCreatePost())
|
||||
apiGroup.GET("/config/:id", e.handleConfigGet())
|
||||
apiGroup.POST("/:id/save-config", e.handleSaveConfigPost())
|
||||
apiGroup.POST("/:id/apply-peer-defaults", e.handleApplyPeerDefaultsPost())
|
||||
apiGroup.HandleFunc("GET /prepare", e.handlePrepareGet())
|
||||
apiGroup.HandleFunc("GET /all", e.handleAllGet())
|
||||
apiGroup.HandleFunc("GET /get/{id}", e.handleSingleGet())
|
||||
apiGroup.HandleFunc("PUT /{id}", e.handleUpdatePut())
|
||||
apiGroup.HandleFunc("DELETE /{id}", e.handleDelete())
|
||||
apiGroup.HandleFunc("POST /new", e.handleCreatePost())
|
||||
apiGroup.HandleFunc("GET /config/{id}", e.handleConfigGet())
|
||||
apiGroup.HandleFunc("POST /{id}/save-config", e.handleSaveConfigPost())
|
||||
apiGroup.HandleFunc("POST /{id}/apply-peer-defaults", e.handleApplyPeerDefaultsPost())
|
||||
|
||||
apiGroup.GET("/peers/:id", e.handlePeersGet())
|
||||
apiGroup.HandleFunc("GET /peers/{id}", e.handlePeersGet())
|
||||
}
|
||||
|
||||
// handlePrepareGet returns a gorm handler function.
|
||||
// handlePrepareGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID interfaces_handlePrepareGet
|
||||
// @Tags Interface
|
||||
@ -45,22 +57,21 @@ func (e interfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationH
|
||||
// @Success 200 {object} model.Interface
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /interface/prepare [get]
|
||||
func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
in, err := e.app.PrepareInterface(ctx)
|
||||
func (e InterfaceEndpoint) handlePrepareGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
in, err := e.app.PrepareInterface(r.Context())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
respond.JSON(w, http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewInterface(in, nil))
|
||||
respond.JSON(w, http.StatusOK, model.NewInterface(in, nil))
|
||||
}
|
||||
}
|
||||
|
||||
// handleAllGet returns a gorm handler function.
|
||||
// handleAllGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID interfaces_handleAllGet
|
||||
// @Tags Interface
|
||||
@ -69,22 +80,21 @@ func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc {
|
||||
// @Success 200 {object} []model.Interface
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /interface/all [get]
|
||||
func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
interfaces, peers, err := e.app.GetAllInterfacesAndPeers(ctx)
|
||||
func (e InterfaceEndpoint) handleAllGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
interfaces, peers, err := e.app.GetAllInterfacesAndPeers(r.Context())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
respond.JSON(w, http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewInterfaces(interfaces, peers))
|
||||
respond.JSON(w, http.StatusOK, model.NewInterfaces(interfaces, peers))
|
||||
}
|
||||
}
|
||||
|
||||
// handleSingleGet returns a gorm handler function.
|
||||
// handleSingleGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID interfaces_handleSingleGet
|
||||
// @Tags Interface
|
||||
@ -94,30 +104,29 @@ func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /interface/get/{id} [get]
|
||||
func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
id := Base64UrlDecode(c.Param("id"))
|
||||
func (e InterfaceEndpoint) handleSingleGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := Base64UrlDecode(request.Path(r, "id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: "missing id parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
iface, peers, err := e.app.GetInterfaceAndPeers(ctx, domain.InterfaceIdentifier(id))
|
||||
iface, peers, err := e.app.GetInterfaceAndPeers(r.Context(), domain.InterfaceIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
respond.JSON(w, http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewInterface(iface, peers))
|
||||
respond.JSON(w, http.StatusOK, model.NewInterface(iface, peers))
|
||||
}
|
||||
}
|
||||
|
||||
// handleConfigGet returns a gorm handler function.
|
||||
// handleConfigGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID interfaces_handleConfigGet
|
||||
// @Tags Interface
|
||||
@ -127,20 +136,19 @@ func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /interface/config/{id} [get]
|
||||
func (e interfaceEndpoint) handleConfigGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
id := Base64UrlDecode(c.Param("id"))
|
||||
func (e InterfaceEndpoint) handleConfigGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := Base64UrlDecode(request.Path(r, "id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: "missing id parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
config, err := e.app.GetInterfaceConfig(ctx, domain.InterfaceIdentifier(id))
|
||||
config, err := e.app.GetInterfaceConfig(r.Context(), domain.InterfaceIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
respond.JSON(w, http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
})
|
||||
return
|
||||
@ -148,17 +156,17 @@ func (e interfaceEndpoint) handleConfigGet() gin.HandlerFunc {
|
||||
|
||||
configString, err := io.ReadAll(config)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
respond.JSON(w, http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, string(configString))
|
||||
respond.JSON(w, http.StatusOK, string(configString))
|
||||
}
|
||||
}
|
||||
|
||||
// handleUpdatePut returns a gorm handler function.
|
||||
// handleUpdatePut returns a gorm Handler function.
|
||||
//
|
||||
// @ID interfaces_handleUpdatePut
|
||||
// @Tags Interface
|
||||
@ -170,41 +178,44 @@ func (e interfaceEndpoint) handleConfigGet() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /interface/{id} [put]
|
||||
func (e interfaceEndpoint) handleUpdatePut() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := Base64UrlDecode(c.Param("id"))
|
||||
func (e InterfaceEndpoint) handleUpdatePut() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := Base64UrlDecode(request.Path(r, "id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
return
|
||||
}
|
||||
|
||||
var in model.Interface
|
||||
err := c.BindJSON(&in)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
if err := request.BodyJson(r, &in); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := e.validator.Struct(in); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if id != in.Identifier {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"})
|
||||
return
|
||||
}
|
||||
|
||||
updatedInterface, peers, err := e.app.UpdateInterface(ctx, model.NewDomainInterface(&in))
|
||||
updatedInterface, peers, err := e.app.UpdateInterface(r.Context(), model.NewDomainInterface(&in))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
respond.JSON(w, http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewInterface(updatedInterface, peers))
|
||||
respond.JSON(w, http.StatusOK, model.NewInterface(updatedInterface, peers))
|
||||
}
|
||||
}
|
||||
|
||||
// handleCreatePost returns a gorm handler function.
|
||||
// handleCreatePost returns a gorm Handler function.
|
||||
//
|
||||
// @ID interfaces_handleCreatePost
|
||||
// @Tags Interface
|
||||
@ -215,30 +226,31 @@ func (e interfaceEndpoint) handleUpdatePut() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /interface/new [post]
|
||||
func (e interfaceEndpoint) handleCreatePost() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
func (e InterfaceEndpoint) handleCreatePost() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var in model.Interface
|
||||
err := c.BindJSON(&in)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
if err := request.BodyJson(r, &in); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := e.validator.Struct(in); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
newInterface, err := e.app.CreateInterface(ctx, model.NewDomainInterface(&in))
|
||||
newInterface, err := e.app.CreateInterface(r.Context(), model.NewDomainInterface(&in))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
respond.JSON(w, http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewInterface(newInterface, nil))
|
||||
respond.JSON(w, http.StatusOK, model.NewInterface(newInterface, nil))
|
||||
}
|
||||
}
|
||||
|
||||
// handlePeersGet returns a gorm handler function.
|
||||
// handlePeersGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID interfaces_handlePeersGet
|
||||
// @Tags Interface
|
||||
@ -247,31 +259,29 @@ func (e interfaceEndpoint) handleCreatePost() gin.HandlerFunc {
|
||||
// @Success 200 {object} []model.Peer
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /interface/peers/{id} [get]
|
||||
func (e interfaceEndpoint) handlePeersGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := Base64UrlDecode(c.Param("id"))
|
||||
func (e InterfaceEndpoint) handlePeersGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := Base64UrlDecode(request.Path(r, "id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: "missing id parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
_, peers, err := e.app.GetInterfaceAndPeers(ctx, domain.InterfaceIdentifier(id))
|
||||
_, peers, err := e.app.GetInterfaceAndPeers(r.Context(), domain.InterfaceIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
respond.JSON(w, http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewPeers(peers))
|
||||
respond.JSON(w, http.StatusOK, model.NewPeers(peers))
|
||||
}
|
||||
}
|
||||
|
||||
// handleDelete returns a gorm handler function.
|
||||
// handleDelete returns a gorm Handler function.
|
||||
//
|
||||
// @ID interfaces_handleDelete
|
||||
// @Tags Interface
|
||||
@ -282,29 +292,28 @@ func (e interfaceEndpoint) handlePeersGet() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /interface/{id} [delete]
|
||||
func (e interfaceEndpoint) handleDelete() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := Base64UrlDecode(c.Param("id"))
|
||||
func (e InterfaceEndpoint) handleDelete() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := Base64UrlDecode(request.Path(r, "id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
return
|
||||
}
|
||||
|
||||
err := e.app.DeleteInterface(ctx, domain.InterfaceIdentifier(id))
|
||||
err := e.app.DeleteInterface(r.Context(), domain.InterfaceIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
respond.JSON(w, http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
respond.Status(w, http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
// handleSaveConfigPost returns a gorm handler function.
|
||||
// handleSaveConfigPost returns a gorm Handler function.
|
||||
//
|
||||
// @ID interfaces_handleSaveConfigPost
|
||||
// @Tags Interface
|
||||
@ -315,29 +324,28 @@ func (e interfaceEndpoint) handleDelete() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /interface/{id}/save-config [post]
|
||||
func (e interfaceEndpoint) handleSaveConfigPost() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := Base64UrlDecode(c.Param("id"))
|
||||
func (e InterfaceEndpoint) handleSaveConfigPost() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := Base64UrlDecode(request.Path(r, "id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
return
|
||||
}
|
||||
|
||||
err := e.app.PersistInterfaceConfig(ctx, domain.InterfaceIdentifier(id))
|
||||
err := e.app.PersistInterfaceConfig(r.Context(), domain.InterfaceIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
respond.JSON(w, http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
respond.Status(w, http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
// handleApplyPeerDefaultsPost returns a gorm handler function.
|
||||
// handleApplyPeerDefaultsPost returns a gorm Handler function.
|
||||
//
|
||||
// @ID interfaces_handleApplyPeerDefaultsPost
|
||||
// @Tags Interface
|
||||
@ -349,36 +357,38 @@ func (e interfaceEndpoint) handleSaveConfigPost() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /interface/{id}/apply-peer-defaults [post]
|
||||
func (e interfaceEndpoint) handleApplyPeerDefaultsPost() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := Base64UrlDecode(c.Param("id"))
|
||||
func (e InterfaceEndpoint) handleApplyPeerDefaultsPost() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := Base64UrlDecode(request.Path(r, "id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
return
|
||||
}
|
||||
|
||||
var in model.Interface
|
||||
err := c.BindJSON(&in)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
if err := request.BodyJson(r, &in); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := e.validator.Struct(in); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if id != in.Identifier {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"})
|
||||
return
|
||||
}
|
||||
|
||||
err = e.app.ApplyPeerDefaults(ctx, model.NewDomainInterface(&in))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
if err := e.app.ApplyPeerDefaults(r.Context(), model.NewDomainInterface(&in)); err != nil {
|
||||
respond.JSON(w, http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
respond.Status(w, http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
@ -4,39 +4,52 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-pkgz/routegroup"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/request"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/respond"
|
||||
"github.com/h44z/wg-portal/internal/app/api/v0/model"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type peerEndpoint struct {
|
||||
type PeerEndpoint struct {
|
||||
app *app.App
|
||||
authenticator *authenticationHandler
|
||||
authenticator Authenticator
|
||||
validator Validator
|
||||
}
|
||||
|
||||
func (e peerEndpoint) GetName() string {
|
||||
func NewPeerEndpoint(app *app.App, authenticator Authenticator, validator Validator) PeerEndpoint {
|
||||
return PeerEndpoint{
|
||||
app: app,
|
||||
authenticator: authenticator,
|
||||
validator: validator,
|
||||
}
|
||||
}
|
||||
|
||||
func (e PeerEndpoint) GetName() string {
|
||||
return "PeerEndpoint"
|
||||
}
|
||||
|
||||
func (e peerEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandler) {
|
||||
apiGroup := g.Group("/peer", e.authenticator.LoggedIn())
|
||||
func (e PeerEndpoint) RegisterRoutes(g *routegroup.Bundle) {
|
||||
apiGroup := g.Mount("/peer")
|
||||
apiGroup.Use(e.authenticator.LoggedIn())
|
||||
|
||||
apiGroup.GET("/iface/:iface/all", e.authenticator.LoggedIn(ScopeAdmin), e.handleAllGet())
|
||||
apiGroup.GET("/iface/:iface/stats", e.authenticator.LoggedIn(ScopeAdmin), e.handleStatsGet())
|
||||
apiGroup.GET("/iface/:iface/prepare", e.authenticator.LoggedIn(), e.handlePrepareGet())
|
||||
apiGroup.POST("/iface/:iface/new", e.authenticator.LoggedIn(), e.handleCreatePost())
|
||||
apiGroup.POST("/iface/:iface/multiplenew", e.authenticator.LoggedIn(ScopeAdmin), e.handleCreateMultiplePost())
|
||||
apiGroup.GET("/config-qr/:id", e.handleQrCodeGet())
|
||||
apiGroup.POST("/config-mail", e.handleEmailPost())
|
||||
apiGroup.GET("/config/:id", e.handleConfigGet())
|
||||
apiGroup.GET("/:id", e.handleSingleGet())
|
||||
apiGroup.PUT("/:id", e.handleUpdatePut())
|
||||
apiGroup.DELETE("/:id", e.handleDelete())
|
||||
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /iface/{iface}/all", e.handleAllGet())
|
||||
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /iface/{iface}/stats", e.handleStatsGet())
|
||||
apiGroup.HandleFunc("GET /iface/{iface}/prepare", e.handlePrepareGet())
|
||||
apiGroup.HandleFunc("POST /iface/{iface}/new", e.handleCreatePost())
|
||||
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("POST /iface/{iface}/multiplenew",
|
||||
e.handleCreateMultiplePost())
|
||||
apiGroup.HandleFunc("GET /config-qr/{id}", e.handleQrCodeGet())
|
||||
apiGroup.HandleFunc("POST /config-mail", e.handleEmailPost())
|
||||
apiGroup.HandleFunc("GET /config/{id}", e.handleConfigGet())
|
||||
apiGroup.HandleFunc("GET /{id}", e.handleSingleGet())
|
||||
apiGroup.HandleFunc("PUT /{id}", e.handleUpdatePut())
|
||||
apiGroup.HandleFunc("DELETE /{id}", e.handleDelete())
|
||||
}
|
||||
|
||||
// handleAllGet returns a gorm handler function.
|
||||
// handleAllGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID peers_handleAllGet
|
||||
// @Tags Peer
|
||||
@ -47,28 +60,27 @@ func (e peerEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandle
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /peer/iface/{iface}/all [get]
|
||||
func (e peerEndpoint) handleAllGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
interfaceId := Base64UrlDecode(c.Param("iface"))
|
||||
func (e PeerEndpoint) handleAllGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
interfaceId := Base64UrlDecode(request.Path(r, "iface"))
|
||||
if interfaceId == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
_, peers, err := e.app.GetInterfaceAndPeers(ctx, domain.InterfaceIdentifier(interfaceId))
|
||||
_, peers, err := e.app.GetInterfaceAndPeers(r.Context(), domain.InterfaceIdentifier(interfaceId))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewPeers(peers))
|
||||
respond.JSON(w, http.StatusOK, model.NewPeers(peers))
|
||||
}
|
||||
}
|
||||
|
||||
// handleSingleGet returns a gorm handler function.
|
||||
// handleSingleGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID peers_handleSingleGet
|
||||
// @Tags Peer
|
||||
@ -79,28 +91,27 @@ func (e peerEndpoint) handleAllGet() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /peer/{id} [get]
|
||||
func (e peerEndpoint) handleSingleGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
peerId := Base64UrlDecode(c.Param("id"))
|
||||
func (e PeerEndpoint) handleSingleGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
peerId := Base64UrlDecode(request.Path(r, "id"))
|
||||
if peerId == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing id parameter"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "missing id parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
peer, err := e.app.GetPeer(ctx, domain.PeerIdentifier(peerId))
|
||||
peer, err := e.app.GetPeer(r.Context(), domain.PeerIdentifier(peerId))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewPeer(peer))
|
||||
respond.JSON(w, http.StatusOK, model.NewPeer(peer))
|
||||
}
|
||||
}
|
||||
|
||||
// handlePrepareGet returns a gorm handler function.
|
||||
// handlePrepareGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID peers_handlePrepareGet
|
||||
// @Tags Peer
|
||||
@ -111,28 +122,27 @@ func (e peerEndpoint) handleSingleGet() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /peer/iface/{iface}/prepare [get]
|
||||
func (e peerEndpoint) handlePrepareGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
interfaceId := Base64UrlDecode(c.Param("iface"))
|
||||
func (e PeerEndpoint) handlePrepareGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
interfaceId := Base64UrlDecode(request.Path(r, "iface"))
|
||||
if interfaceId == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
peer, err := e.app.PreparePeer(ctx, domain.InterfaceIdentifier(interfaceId))
|
||||
peer, err := e.app.PreparePeer(r.Context(), domain.InterfaceIdentifier(interfaceId))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewPeer(peer))
|
||||
respond.JSON(w, http.StatusOK, model.NewPeer(peer))
|
||||
}
|
||||
}
|
||||
|
||||
// handleCreatePost returns a gorm handler function.
|
||||
// handleCreatePost returns a gorm Handler function.
|
||||
//
|
||||
// @ID peers_handleCreatePost
|
||||
// @Tags Peer
|
||||
@ -144,40 +154,43 @@ func (e peerEndpoint) handlePrepareGet() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /peer/iface/{iface}/new [post]
|
||||
func (e peerEndpoint) handleCreatePost() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
interfaceId := Base64UrlDecode(c.Param("iface"))
|
||||
func (e PeerEndpoint) handleCreatePost() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
interfaceId := Base64UrlDecode(request.Path(r, "iface"))
|
||||
if interfaceId == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
var p model.Peer
|
||||
err := c.BindJSON(&p)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
if err := request.BodyJson(r, &p); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := e.validator.Struct(p); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if p.InterfaceIdentifier != interfaceId {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"})
|
||||
return
|
||||
}
|
||||
|
||||
newPeer, err := e.app.CreatePeer(ctx, model.NewDomainPeer(&p))
|
||||
newPeer, err := e.app.CreatePeer(r.Context(), model.NewDomainPeer(&p))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewPeer(newPeer))
|
||||
respond.JSON(w, http.StatusOK, model.NewPeer(newPeer))
|
||||
}
|
||||
}
|
||||
|
||||
// handleCreateMultiplePost returns a gorm handler function.
|
||||
// handleCreateMultiplePost returns a gorm Handler function.
|
||||
//
|
||||
// @ID peers_handleCreateMultiplePost
|
||||
// @Tags Peer
|
||||
@ -189,36 +202,38 @@ func (e peerEndpoint) handleCreatePost() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /peer/iface/{iface}/multiplenew [post]
|
||||
func (e peerEndpoint) handleCreateMultiplePost() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
interfaceId := Base64UrlDecode(c.Param("iface"))
|
||||
func (e PeerEndpoint) handleCreateMultiplePost() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
interfaceId := Base64UrlDecode(request.Path(r, "iface"))
|
||||
if interfaceId == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
var req model.MultiPeerRequest
|
||||
err := c.BindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
if err := request.BodyJson(r, &req); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := e.validator.Struct(req); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
newPeers, err := e.app.CreateMultiplePeers(ctx, domain.InterfaceIdentifier(interfaceId),
|
||||
newPeers, err := e.app.CreateMultiplePeers(r.Context(), domain.InterfaceIdentifier(interfaceId),
|
||||
model.NewDomainPeerCreationRequest(&req))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewPeers(newPeers))
|
||||
respond.JSON(w, http.StatusOK, model.NewPeers(newPeers))
|
||||
}
|
||||
}
|
||||
|
||||
// handleUpdatePut returns a gorm handler function.
|
||||
// handleUpdatePut returns a gorm Handler function.
|
||||
//
|
||||
// @ID peers_handleUpdatePut
|
||||
// @Tags Peer
|
||||
@ -230,40 +245,43 @@ func (e peerEndpoint) handleCreateMultiplePost() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /peer/{id} [put]
|
||||
func (e peerEndpoint) handleUpdatePut() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
peerId := Base64UrlDecode(c.Param("id"))
|
||||
func (e PeerEndpoint) handleUpdatePut() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
peerId := Base64UrlDecode(request.Path(r, "id"))
|
||||
if peerId == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing id parameter"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "missing id parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
var p model.Peer
|
||||
err := c.BindJSON(&p)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
if err := request.BodyJson(r, &p); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := e.validator.Struct(p); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if p.Identifier != peerId {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "peer id mismatch"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "peer id mismatch"})
|
||||
return
|
||||
}
|
||||
|
||||
updatedPeer, err := e.app.UpdatePeer(ctx, model.NewDomainPeer(&p))
|
||||
updatedPeer, err := e.app.UpdatePeer(r.Context(), model.NewDomainPeer(&p))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewPeer(updatedPeer))
|
||||
respond.JSON(w, http.StatusOK, model.NewPeer(updatedPeer))
|
||||
}
|
||||
}
|
||||
|
||||
// handleDelete returns a gorm handler function.
|
||||
// handleDelete returns a gorm Handler function.
|
||||
//
|
||||
// @ID peers_handleDelete
|
||||
// @Tags Peer
|
||||
@ -274,28 +292,26 @@ func (e peerEndpoint) handleUpdatePut() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /peer/{id} [delete]
|
||||
func (e peerEndpoint) handleDelete() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := Base64UrlDecode(c.Param("id"))
|
||||
func (e PeerEndpoint) handleDelete() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := Base64UrlDecode(request.Path(r, "id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
|
||||
return
|
||||
}
|
||||
|
||||
err := e.app.DeletePeer(ctx, domain.PeerIdentifier(id))
|
||||
err := e.app.DeletePeer(r.Context(), domain.PeerIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
respond.Status(w, http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConfigGet returns a gorm handler function.
|
||||
// handleConfigGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID peers_handleConfigGet
|
||||
// @Tags Peer
|
||||
@ -306,21 +322,19 @@ func (e peerEndpoint) handleDelete() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /peer/config/{id} [get]
|
||||
func (e peerEndpoint) handleConfigGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := Base64UrlDecode(c.Param("id"))
|
||||
func (e PeerEndpoint) handleConfigGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := Base64UrlDecode(request.Path(r, "id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: "missing id parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
config, err := e.app.GetPeerConfig(ctx, domain.PeerIdentifier(id))
|
||||
config, err := e.app.GetPeerConfig(r.Context(), domain.PeerIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
respond.JSON(w, http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
})
|
||||
return
|
||||
@ -328,17 +342,17 @@ func (e peerEndpoint) handleConfigGet() gin.HandlerFunc {
|
||||
|
||||
configString, err := io.ReadAll(config)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
respond.JSON(w, http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, string(configString))
|
||||
respond.JSON(w, http.StatusOK, string(configString))
|
||||
}
|
||||
}
|
||||
|
||||
// handleQrCodeGet returns a gorm handler function.
|
||||
// handleQrCodeGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID peers_handleQrCodeGet
|
||||
// @Tags Peer
|
||||
@ -350,20 +364,19 @@ func (e peerEndpoint) handleConfigGet() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /peer/config-qr/{id} [get]
|
||||
func (e peerEndpoint) handleQrCodeGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
id := Base64UrlDecode(c.Param("id"))
|
||||
func (e PeerEndpoint) handleQrCodeGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := Base64UrlDecode(request.Path(r, "id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: "missing id parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
config, err := e.app.GetPeerConfigQrCode(ctx, domain.PeerIdentifier(id))
|
||||
config, err := e.app.GetPeerConfigQrCode(r.Context(), domain.PeerIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
respond.JSON(w, http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
})
|
||||
return
|
||||
@ -371,17 +384,17 @@ func (e peerEndpoint) handleQrCodeGet() gin.HandlerFunc {
|
||||
|
||||
configData, err := io.ReadAll(config)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
respond.JSON(w, http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "image/png", configData)
|
||||
respond.Data(w, http.StatusOK, "image/png", configData)
|
||||
}
|
||||
}
|
||||
|
||||
// handleEmailPost returns a gorm handler function.
|
||||
// handleEmailPost returns a gorm Handler function.
|
||||
//
|
||||
// @ID peers_handleEmailPost
|
||||
// @Tags Peer
|
||||
@ -392,38 +405,39 @@ func (e peerEndpoint) handleQrCodeGet() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /peer/config-mail [post]
|
||||
func (e peerEndpoint) handleEmailPost() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
func (e PeerEndpoint) handleEmailPost() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var req model.PeerMailRequest
|
||||
err := c.BindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
if err := request.BodyJson(r, &req); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := e.validator.Struct(req); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Identifiers) == 0 {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing peer identifiers"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "missing peer identifiers"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
peerIds := make([]domain.PeerIdentifier, len(req.Identifiers))
|
||||
for i := range req.Identifiers {
|
||||
peerIds[i] = domain.PeerIdentifier(req.Identifiers[i])
|
||||
}
|
||||
err = e.app.SendPeerEmail(ctx, req.LinkOnly, peerIds...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
if err := e.app.SendPeerEmail(r.Context(), req.LinkOnly, peerIds...); err != nil {
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
respond.Status(w, http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
// handleStatsGet returns a gorm handler function.
|
||||
// handleStatsGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID peers_handleStatsGet
|
||||
// @Tags Peer
|
||||
@ -434,23 +448,22 @@ func (e peerEndpoint) handleEmailPost() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /peer/iface/{iface}/stats [get]
|
||||
func (e peerEndpoint) handleStatsGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
interfaceId := Base64UrlDecode(c.Param("iface"))
|
||||
func (e PeerEndpoint) handleStatsGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
interfaceId := Base64UrlDecode(request.Path(r, "iface"))
|
||||
if interfaceId == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := e.app.GetPeerStats(ctx, domain.InterfaceIdentifier(interfaceId))
|
||||
stats, err := e.app.GetPeerStats(r.Context(), domain.InterfaceIdentifier(interfaceId))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewPeerStats(e.app.Config.Statistics.CollectPeerData, stats))
|
||||
respond.JSON(w, http.StatusOK, model.NewPeerStats(e.app.Config.Statistics.CollectPeerData, stats))
|
||||
}
|
||||
}
|
||||
|
@ -5,20 +5,29 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-pkgz/routegroup"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/respond"
|
||||
"github.com/h44z/wg-portal/internal/app/api/v0/model"
|
||||
)
|
||||
|
||||
type testEndpoint struct{}
|
||||
type TestEndpoint struct {
|
||||
authenticator Authenticator
|
||||
}
|
||||
|
||||
func (e testEndpoint) GetName() string {
|
||||
func NewTestEndpoint(authenticator Authenticator) TestEndpoint {
|
||||
return TestEndpoint{
|
||||
authenticator: authenticator,
|
||||
}
|
||||
}
|
||||
|
||||
func (e TestEndpoint) GetName() string {
|
||||
return "TestEndpoint"
|
||||
}
|
||||
|
||||
func (e testEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandler) {
|
||||
g.GET("/now", e.handleCurrentTimeGet())
|
||||
g.GET("/hostname", e.handleHostnameGet())
|
||||
func (e TestEndpoint) RegisterRoutes(g *routegroup.Bundle) {
|
||||
g.HandleFunc("GET /now", e.handleCurrentTimeGet())
|
||||
g.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /hostname", e.handleHostnameGet())
|
||||
}
|
||||
|
||||
// handleCurrentTimeGet represents the GET endpoint that responds the current time
|
||||
@ -31,15 +40,15 @@ func (e testEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandle
|
||||
// @Success 200 {object} string
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /now [get]
|
||||
func (e testEndpoint) handleCurrentTimeGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
func (e TestEndpoint) handleCurrentTimeGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if time.Now().Second() == 0 {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
respond.JSON(w, http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError,
|
||||
Message: "invalid time",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, time.Now().String())
|
||||
respond.JSON(w, http.StatusOK, time.Now().String())
|
||||
}
|
||||
}
|
||||
|
||||
@ -53,15 +62,15 @@ func (e testEndpoint) handleCurrentTimeGet() gin.HandlerFunc {
|
||||
// @Success 200 {object} string
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /hostname [get]
|
||||
func (e testEndpoint) handleHostnameGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
func (e TestEndpoint) handleHostnameGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
respond.JSON(w, http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError,
|
||||
Message: err.Error(),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, hostname)
|
||||
respond.JSON(w, http.StatusOK, hostname)
|
||||
}
|
||||
}
|
||||
|
@ -3,38 +3,50 @@ package handlers
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-pkgz/routegroup"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/request"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/respond"
|
||||
"github.com/h44z/wg-portal/internal/app/api/v0/model"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type userEndpoint struct {
|
||||
type UserEndpoint struct {
|
||||
app *app.App
|
||||
authenticator *authenticationHandler
|
||||
authenticator Authenticator
|
||||
validator Validator
|
||||
}
|
||||
|
||||
func (e userEndpoint) GetName() string {
|
||||
func NewUserEndpoint(app *app.App, authenticator Authenticator, validator Validator) UserEndpoint {
|
||||
return UserEndpoint{
|
||||
app: app,
|
||||
authenticator: authenticator,
|
||||
validator: validator,
|
||||
}
|
||||
}
|
||||
|
||||
func (e UserEndpoint) GetName() string {
|
||||
return "UserEndpoint"
|
||||
}
|
||||
|
||||
func (e userEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandler) {
|
||||
apiGroup := g.Group("/user", e.authenticator.LoggedIn())
|
||||
func (e UserEndpoint) RegisterRoutes(g *routegroup.Bundle) {
|
||||
apiGroup := g.Mount("/user")
|
||||
apiGroup.Use(e.authenticator.LoggedIn())
|
||||
|
||||
apiGroup.GET("/all", e.authenticator.LoggedIn(ScopeAdmin), e.handleAllGet())
|
||||
apiGroup.GET("/:id", e.authenticator.UserIdMatch("id"), e.handleSingleGet())
|
||||
apiGroup.PUT("/:id", e.authenticator.UserIdMatch("id"), e.handleUpdatePut())
|
||||
apiGroup.DELETE("/:id", e.authenticator.UserIdMatch("id"), e.handleDelete())
|
||||
apiGroup.POST("/new", e.authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost())
|
||||
apiGroup.GET("/:id/peers", e.authenticator.UserIdMatch("id"), e.handlePeersGet())
|
||||
apiGroup.GET("/:id/stats", e.authenticator.UserIdMatch("id"), e.handleStatsGet())
|
||||
apiGroup.GET("/:id/interfaces", e.authenticator.UserIdMatch("id"), e.handleInterfacesGet())
|
||||
apiGroup.POST("/:id/api/enable", e.authenticator.UserIdMatch("id"), e.handleApiEnablePost())
|
||||
apiGroup.POST("/:id/api/disable", e.authenticator.UserIdMatch("id"), e.handleApiDisablePost())
|
||||
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /all", e.handleAllGet())
|
||||
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("GET /{id}", e.handleSingleGet())
|
||||
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("PUT /{id}", e.handleUpdatePut())
|
||||
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("DELETE /{id}", e.handleDelete())
|
||||
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("POST /new", e.handleCreatePost())
|
||||
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("GET /{id}/peers", e.handlePeersGet())
|
||||
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("GET /{id}/stats", e.handleStatsGet())
|
||||
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("GET /{id}/interfaces", e.handleInterfacesGet())
|
||||
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("POST /{id}/api/enable", e.handleApiEnablePost())
|
||||
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("POST /{id}/api/disable", e.handleApiDisablePost())
|
||||
}
|
||||
|
||||
// handleAllGet returns a gorm handler function.
|
||||
// handleAllGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID users_handleAllGet
|
||||
// @Tags Users
|
||||
@ -43,22 +55,20 @@ func (e userEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandle
|
||||
// @Success 200 {object} []model.User
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /user/all [get]
|
||||
func (e userEndpoint) handleAllGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
users, err := e.app.GetAllUsers(ctx)
|
||||
func (e UserEndpoint) handleAllGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
users, err := e.app.GetAllUsers(r.Context())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewUsers(users))
|
||||
respond.JSON(w, http.StatusOK, model.NewUsers(users))
|
||||
}
|
||||
}
|
||||
|
||||
// handleSingleGet returns a gorm handler function.
|
||||
// handleSingleGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID users_handleSingleGet
|
||||
// @Tags Users
|
||||
@ -68,28 +78,26 @@ func (e userEndpoint) handleAllGet() gin.HandlerFunc {
|
||||
// @Success 200 {object} model.User
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /user/{id} [get]
|
||||
func (e userEndpoint) handleSingleGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := Base64UrlDecode(c.Param("id"))
|
||||
func (e UserEndpoint) handleSingleGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := Base64UrlDecode(request.Path(r, "id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing user id"})
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing user id"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := e.app.GetUser(ctx, domain.UserIdentifier(id))
|
||||
user, err := e.app.GetUser(r.Context(), domain.UserIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewUser(user, true))
|
||||
respond.JSON(w, http.StatusOK, model.NewUser(user, true))
|
||||
}
|
||||
}
|
||||
|
||||
// handleUpdatePut returns a gorm handler function.
|
||||
// handleUpdatePut returns a gorm Handler function.
|
||||
//
|
||||
// @ID users_handleUpdatePut
|
||||
// @Tags Users
|
||||
@ -101,40 +109,42 @@ func (e userEndpoint) handleSingleGet() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /user/{id} [put]
|
||||
func (e userEndpoint) handleUpdatePut() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := Base64UrlDecode(c.Param("id"))
|
||||
func (e UserEndpoint) handleUpdatePut() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := Base64UrlDecode(request.Path(r, "id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing user id"})
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing user id"})
|
||||
return
|
||||
}
|
||||
|
||||
var user model.User
|
||||
err := c.BindJSON(&user)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
if err := request.BodyJson(r, &user); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := e.validator.Struct(user); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if id != user.Identifier {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "user id mismatch"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusBadRequest, Message: "user id mismatch"})
|
||||
return
|
||||
}
|
||||
|
||||
updateUser, err := e.app.UpdateUser(ctx, model.NewDomainUser(&user))
|
||||
updateUser, err := e.app.UpdateUser(r.Context(), model.NewDomainUser(&user))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewUser(updateUser, false))
|
||||
respond.JSON(w, http.StatusOK, model.NewUser(updateUser, false))
|
||||
}
|
||||
}
|
||||
|
||||
// handleCreatePost returns a gorm handler function.
|
||||
// handleCreatePost returns a gorm Handler function.
|
||||
//
|
||||
// @ID users_handleCreatePost
|
||||
// @Tags Users
|
||||
@ -145,29 +155,30 @@ func (e userEndpoint) handleUpdatePut() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /user/new [post]
|
||||
func (e userEndpoint) handleCreatePost() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
func (e UserEndpoint) handleCreatePost() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var user model.User
|
||||
err := c.BindJSON(&user)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
if err := request.BodyJson(r, &user); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := e.validator.Struct(user); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
newUser, err := e.app.CreateUser(ctx, model.NewDomainUser(&user))
|
||||
newUser, err := e.app.CreateUser(r.Context(), model.NewDomainUser(&user))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewUser(newUser, false))
|
||||
respond.JSON(w, http.StatusOK, model.NewUser(newUser, false))
|
||||
}
|
||||
}
|
||||
|
||||
// handlePeersGet returns a gorm handler function.
|
||||
// handlePeersGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID users_handlePeersGet
|
||||
// @Tags Users
|
||||
@ -178,29 +189,27 @@ func (e userEndpoint) handleCreatePost() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /user/{id}/peers [get]
|
||||
func (e userEndpoint) handlePeersGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
userId := Base64UrlDecode(c.Param("id"))
|
||||
func (e UserEndpoint) handlePeersGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userId := Base64UrlDecode(request.Path(r, "id"))
|
||||
if userId == "" {
|
||||
c.JSON(http.StatusBadRequest,
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: "missing id parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
peers, err := e.app.GetUserPeers(ctx, domain.UserIdentifier(userId))
|
||||
peers, err := e.app.GetUserPeers(r.Context(), domain.UserIdentifier(userId))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewPeers(peers))
|
||||
respond.JSON(w, http.StatusOK, model.NewPeers(peers))
|
||||
}
|
||||
}
|
||||
|
||||
// handleStatsGet returns a gorm handler function.
|
||||
// handleStatsGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID users_handleStatsGet
|
||||
// @Tags Users
|
||||
@ -211,29 +220,27 @@ func (e userEndpoint) handlePeersGet() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /user/{id}/stats [get]
|
||||
func (e userEndpoint) handleStatsGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
userId := Base64UrlDecode(c.Param("id"))
|
||||
func (e UserEndpoint) handleStatsGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userId := Base64UrlDecode(request.Path(r, "id"))
|
||||
if userId == "" {
|
||||
c.JSON(http.StatusBadRequest,
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: "missing id parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := e.app.GetUserPeerStats(ctx, domain.UserIdentifier(userId))
|
||||
stats, err := e.app.GetUserPeerStats(r.Context(), domain.UserIdentifier(userId))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewPeerStats(e.app.Config.Statistics.CollectPeerData, stats))
|
||||
respond.JSON(w, http.StatusOK, model.NewPeerStats(e.app.Config.Statistics.CollectPeerData, stats))
|
||||
}
|
||||
}
|
||||
|
||||
// handleInterfacesGet returns a gorm handler function.
|
||||
// handleInterfacesGet returns a gorm Handler function.
|
||||
//
|
||||
// @ID users_handleInterfacesGet
|
||||
// @Tags Users
|
||||
@ -244,29 +251,27 @@ func (e userEndpoint) handleStatsGet() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /user/{id}/interfaces [get]
|
||||
func (e userEndpoint) handleInterfacesGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
userId := Base64UrlDecode(c.Param("id"))
|
||||
func (e UserEndpoint) handleInterfacesGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userId := Base64UrlDecode(request.Path(r, "id"))
|
||||
if userId == "" {
|
||||
c.JSON(http.StatusBadRequest,
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: "missing id parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
peers, err := e.app.GetUserInterfaces(ctx, domain.UserIdentifier(userId))
|
||||
peers, err := e.app.GetUserInterfaces(r.Context(), domain.UserIdentifier(userId))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewInterfaces(peers, nil))
|
||||
respond.JSON(w, http.StatusOK, model.NewInterfaces(peers, nil))
|
||||
}
|
||||
}
|
||||
|
||||
// handleDelete returns a gorm handler function.
|
||||
// handleDelete returns a gorm Handler function.
|
||||
//
|
||||
// @ID users_handleDelete
|
||||
// @Tags Users
|
||||
@ -277,28 +282,26 @@ func (e userEndpoint) handleInterfacesGet() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /user/{id} [delete]
|
||||
func (e userEndpoint) handleDelete() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := Base64UrlDecode(c.Param("id"))
|
||||
func (e UserEndpoint) handleDelete() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := Base64UrlDecode(request.Path(r, "id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing user id"})
|
||||
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing user id"})
|
||||
return
|
||||
}
|
||||
|
||||
err := e.app.DeleteUser(ctx, domain.UserIdentifier(id))
|
||||
err := e.app.DeleteUser(r.Context(), domain.UserIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
respond.Status(w, http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
// handleApiEnablePost returns a gorm handler function.
|
||||
// handleApiEnablePost returns a gorm Handler function.
|
||||
//
|
||||
// @ID users_handleApiEnablePost
|
||||
// @Tags Users
|
||||
@ -308,29 +311,27 @@ func (e userEndpoint) handleDelete() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /user/{id}/api/enable [post]
|
||||
func (e userEndpoint) handleApiEnablePost() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
userId := Base64UrlDecode(c.Param("id"))
|
||||
func (e UserEndpoint) handleApiEnablePost() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userId := Base64UrlDecode(request.Path(r, "id"))
|
||||
if userId == "" {
|
||||
c.JSON(http.StatusBadRequest,
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: "missing id parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := e.app.ActivateApi(ctx, domain.UserIdentifier(userId))
|
||||
user, err := e.app.ActivateApi(r.Context(), domain.UserIdentifier(userId))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewUser(user, true))
|
||||
respond.JSON(w, http.StatusOK, model.NewUser(user, true))
|
||||
}
|
||||
}
|
||||
|
||||
// handleApiDisablePost returns a gorm handler function.
|
||||
// handleApiDisablePost returns a gorm Handler function.
|
||||
//
|
||||
// @ID users_handleApiDisablePost
|
||||
// @Tags Users
|
||||
@ -340,24 +341,22 @@ func (e userEndpoint) handleApiEnablePost() gin.HandlerFunc {
|
||||
// @Failure 400 {object} model.Error
|
||||
// @Failure 500 {object} model.Error
|
||||
// @Router /user/{id}/api/disable [post]
|
||||
func (e userEndpoint) handleApiDisablePost() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
userId := Base64UrlDecode(c.Param("id"))
|
||||
func (e UserEndpoint) handleApiDisablePost() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userId := Base64UrlDecode(request.Path(r, "id"))
|
||||
if userId == "" {
|
||||
c.JSON(http.StatusBadRequest,
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: "missing id parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := e.app.DeactivateApi(ctx, domain.UserIdentifier(userId))
|
||||
user, err := e.app.DeactivateApi(r.Context(), domain.UserIdentifier(userId))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError,
|
||||
respond.JSON(w, http.StatusInternalServerError,
|
||||
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewUser(user, false))
|
||||
respond.JSON(w, http.StatusOK, model.NewUser(user, false))
|
||||
}
|
||||
}
|
||||
|
@ -1,111 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/app/api/v0/model"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type Scope string
|
||||
|
||||
const (
|
||||
ScopeAdmin Scope = "ADMIN" // Admin scope contains all other scopes
|
||||
)
|
||||
|
||||
type authenticationHandler struct {
|
||||
app *app.App
|
||||
Session SessionStore
|
||||
}
|
||||
|
||||
// LoggedIn checks if a user is logged in. If scopes are given, they are validated as well.
|
||||
func (h authenticationHandler) LoggedIn(scopes ...Scope) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
session := h.Session.GetData(c)
|
||||
|
||||
if !session.LoggedIn {
|
||||
// Abort the request with the appropriate error code
|
||||
c.Abort()
|
||||
c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "not logged in"})
|
||||
return
|
||||
}
|
||||
|
||||
if !UserHasScopes(session, scopes...) {
|
||||
// Abort the request with the appropriate error code
|
||||
c.Abort()
|
||||
c.JSON(http.StatusForbidden, model.Error{Code: http.StatusForbidden, Message: "not enough permissions"})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if logged-in user is still valid
|
||||
if !h.app.Authenticator.IsUserValid(c.Request.Context(), domain.UserIdentifier(session.UserIdentifier)) {
|
||||
h.Session.DestroyData(c)
|
||||
c.Abort()
|
||||
c.JSON(http.StatusUnauthorized,
|
||||
model.Error{Code: http.StatusUnauthorized, Message: "session no longer available"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(domain.CtxUserInfo, &domain.ContextUserInfo{
|
||||
Id: domain.UserIdentifier(session.UserIdentifier),
|
||||
IsAdmin: session.IsAdmin,
|
||||
})
|
||||
|
||||
// Continue down the chain to handler etc
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// UserIdMatch checks if the user id in the session matches the user id in the request. If not, the request is aborted.
|
||||
func (h authenticationHandler) UserIdMatch(idParameter string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
session := h.Session.GetData(c)
|
||||
|
||||
if session.IsAdmin {
|
||||
c.Next() // Admins can do everything
|
||||
return
|
||||
}
|
||||
|
||||
sessionUserId := domain.UserIdentifier(session.UserIdentifier)
|
||||
requestUserId := domain.UserIdentifier(Base64UrlDecode(c.Param(idParameter)))
|
||||
|
||||
if sessionUserId != requestUserId {
|
||||
// Abort the request with the appropriate error code
|
||||
c.Abort()
|
||||
c.JSON(http.StatusForbidden, model.Error{Code: http.StatusForbidden, Message: "not enough permissions"})
|
||||
return
|
||||
}
|
||||
|
||||
// Continue down the chain to handler etc
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func UserHasScopes(session SessionData, scopes ...Scope) bool {
|
||||
// No scopes give, so the check should succeed
|
||||
if len(scopes) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// check if user has admin scope
|
||||
if session.IsAdmin {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if admin scope is required
|
||||
for _, scope := range scopes {
|
||||
if scope == ScopeAdmin {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// For all other scopes, a logged-in user is sufficient (for now)
|
||||
if session.LoggedIn {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
@ -1,92 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gob.Register(SessionData{})
|
||||
}
|
||||
|
||||
type SessionData struct {
|
||||
LoggedIn bool
|
||||
IsAdmin bool
|
||||
|
||||
UserIdentifier string
|
||||
|
||||
Firstname string
|
||||
Lastname string
|
||||
Email string
|
||||
|
||||
OauthState string
|
||||
OauthNonce string
|
||||
OauthProvider string
|
||||
OauthReturnTo string
|
||||
}
|
||||
|
||||
type SessionStore interface {
|
||||
DefaultSessionData() SessionData
|
||||
|
||||
GetData(c *gin.Context) SessionData
|
||||
SetData(c *gin.Context, data SessionData)
|
||||
|
||||
DestroyData(c *gin.Context)
|
||||
}
|
||||
|
||||
type GinSessionStore struct {
|
||||
sessionIdentifier string
|
||||
}
|
||||
|
||||
func (g GinSessionStore) GetData(c *gin.Context) SessionData {
|
||||
session := sessions.Default(c)
|
||||
rawSessionData := session.Get(g.sessionIdentifier)
|
||||
|
||||
var sessionData SessionData
|
||||
if rawSessionData != nil {
|
||||
sessionData = rawSessionData.(SessionData)
|
||||
} else {
|
||||
// init a new default session
|
||||
sessionData = g.DefaultSessionData()
|
||||
session.Set(g.sessionIdentifier, sessionData)
|
||||
if err := session.Save(); err != nil {
|
||||
panic(fmt.Sprintf("failed to store session: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
return sessionData
|
||||
}
|
||||
|
||||
func (g GinSessionStore) DefaultSessionData() SessionData {
|
||||
return SessionData{
|
||||
LoggedIn: false,
|
||||
IsAdmin: false,
|
||||
UserIdentifier: "",
|
||||
Firstname: "",
|
||||
Lastname: "",
|
||||
Email: "",
|
||||
OauthState: "",
|
||||
OauthNonce: "",
|
||||
OauthProvider: "",
|
||||
OauthReturnTo: "",
|
||||
}
|
||||
}
|
||||
|
||||
func (g GinSessionStore) SetData(c *gin.Context, data SessionData) {
|
||||
session := sessions.Default(c)
|
||||
session.Set(g.sessionIdentifier, data)
|
||||
if err := session.Save(); err != nil {
|
||||
panic(fmt.Sprintf("failed to store session: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
func (g GinSessionStore) DestroyData(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
session.Delete(g.sessionIdentifier)
|
||||
if err := session.Save(); err != nil {
|
||||
panic(fmt.Sprintf("failed to store session: %v", err))
|
||||
}
|
||||
}
|
126
internal/app/api/v0/handlers/web_authentication.go
Normal file
126
internal/app/api/v0/handlers/web_authentication.go
Normal file
@ -0,0 +1,126 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/request"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/respond"
|
||||
"github.com/h44z/wg-portal/internal/app/api/v0/model"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type Scope string
|
||||
|
||||
const (
|
||||
ScopeAdmin Scope = "ADMIN" // Admin scope contains all other scopes
|
||||
)
|
||||
|
||||
type UserAuthenticator interface {
|
||||
IsUserValid(ctx context.Context, id domain.UserIdentifier) bool
|
||||
}
|
||||
|
||||
type AuthenticationHandler struct {
|
||||
authenticator UserAuthenticator
|
||||
session Session
|
||||
}
|
||||
|
||||
func NewAuthenticationHandler(authenticator UserAuthenticator, session Session) AuthenticationHandler {
|
||||
return AuthenticationHandler{
|
||||
authenticator: authenticator,
|
||||
session: session,
|
||||
}
|
||||
}
|
||||
|
||||
// LoggedIn checks if a user is logged in. If scopes are given, they are validated as well.
|
||||
func (h AuthenticationHandler) LoggedIn(scopes ...Scope) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
session := h.session.GetData(r.Context())
|
||||
|
||||
if !session.LoggedIn {
|
||||
// Abort the request with the appropriate error code
|
||||
respond.JSON(w, http.StatusUnauthorized,
|
||||
model.Error{Code: http.StatusUnauthorized, Message: "not logged in"})
|
||||
return
|
||||
}
|
||||
|
||||
if !UserHasScopes(session, scopes...) {
|
||||
// Abort the request with the appropriate error code
|
||||
respond.JSON(w, http.StatusForbidden,
|
||||
model.Error{Code: http.StatusForbidden, Message: "not enough permissions"})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if logged-in user is still valid
|
||||
if !h.authenticator.IsUserValid(r.Context(), domain.UserIdentifier(session.UserIdentifier)) {
|
||||
h.session.DestroyData(r.Context())
|
||||
respond.JSON(w, http.StatusUnauthorized,
|
||||
model.Error{Code: http.StatusUnauthorized, Message: "session no longer available"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), domain.CtxUserInfo, &domain.ContextUserInfo{
|
||||
Id: domain.UserIdentifier(session.UserIdentifier),
|
||||
IsAdmin: session.IsAdmin,
|
||||
})
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// Continue down the chain to Handler etc
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// UserIdMatch checks if the user id in the session matches the user id in the request. If not, the request is aborted.
|
||||
func (h AuthenticationHandler) UserIdMatch(idParameter string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
session := h.session.GetData(r.Context())
|
||||
|
||||
if session.IsAdmin {
|
||||
next.ServeHTTP(w, r) // Admins can do everything
|
||||
return
|
||||
}
|
||||
|
||||
sessionUserId := domain.UserIdentifier(session.UserIdentifier)
|
||||
requestUserId := domain.UserIdentifier(Base64UrlDecode(request.Path(r, idParameter)))
|
||||
|
||||
if sessionUserId != requestUserId {
|
||||
// Abort the request with the appropriate error code
|
||||
respond.JSON(w, http.StatusForbidden,
|
||||
model.Error{Code: http.StatusForbidden, Message: "not enough permissions"})
|
||||
return
|
||||
}
|
||||
|
||||
// Continue down the chain to Handler etc
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func UserHasScopes(session SessionData, scopes ...Scope) bool {
|
||||
// No scopes give, so the check should succeed
|
||||
if len(scopes) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// check if user has admin scope
|
||||
if session.IsAdmin {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if admin scope is required
|
||||
for _, scope := range scopes {
|
||||
if scope == ScopeAdmin {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// For all other scopes, a logged-in user is sufficient (for now)
|
||||
if session.LoggedIn {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
88
internal/app/api/v0/handlers/web_session.go
Normal file
88
internal/app/api/v0/handlers/web_session.go
Normal file
@ -0,0 +1,88 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alexedwards/scs/v2"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gob.Register(SessionData{})
|
||||
}
|
||||
|
||||
type SessionData struct {
|
||||
LoggedIn bool
|
||||
IsAdmin bool
|
||||
|
||||
UserIdentifier string
|
||||
|
||||
Firstname string
|
||||
Lastname string
|
||||
Email string
|
||||
|
||||
OauthState string
|
||||
OauthNonce string
|
||||
OauthProvider string
|
||||
OauthReturnTo string
|
||||
|
||||
CsrfToken string
|
||||
}
|
||||
|
||||
const sessionApiV0Key = "session_api_v0"
|
||||
|
||||
type SessionWrapper struct {
|
||||
*scs.SessionManager
|
||||
}
|
||||
|
||||
func NewSessionWrapper(cfg *config.Config) *SessionWrapper {
|
||||
sessionManager := scs.New()
|
||||
sessionManager.Lifetime = 24 * time.Hour
|
||||
sessionManager.IdleTimeout = 1 * time.Hour
|
||||
sessionManager.Cookie.Name = cfg.Web.SessionIdentifier
|
||||
sessionManager.Cookie.Secure = strings.HasPrefix(cfg.Web.ExternalUrl, "https")
|
||||
sessionManager.Cookie.HttpOnly = true
|
||||
sessionManager.Cookie.SameSite = http.SameSiteLaxMode
|
||||
sessionManager.Cookie.Path = "/"
|
||||
sessionManager.Cookie.Persist = false
|
||||
|
||||
wrappedSessionManager := &SessionWrapper{sessionManager}
|
||||
|
||||
return wrappedSessionManager
|
||||
}
|
||||
|
||||
func (s *SessionWrapper) SetData(ctx context.Context, value SessionData) {
|
||||
s.SessionManager.Put(ctx, sessionApiV0Key, value)
|
||||
}
|
||||
|
||||
func (s *SessionWrapper) GetData(ctx context.Context) SessionData {
|
||||
sessionData, ok := s.SessionManager.Get(ctx, sessionApiV0Key).(SessionData)
|
||||
if !ok {
|
||||
return s.defaultSessionData()
|
||||
}
|
||||
return sessionData
|
||||
}
|
||||
|
||||
func (s *SessionWrapper) DestroyData(ctx context.Context) {
|
||||
_ = s.SessionManager.Destroy(ctx)
|
||||
}
|
||||
|
||||
func (s *SessionWrapper) defaultSessionData() SessionData {
|
||||
return SessionData{
|
||||
LoggedIn: false,
|
||||
IsAdmin: false,
|
||||
UserIdentifier: "",
|
||||
Firstname: "",
|
||||
Lastname: "",
|
||||
Email: "",
|
||||
OauthState: "",
|
||||
OauthNonce: "",
|
||||
OauthProvider: "",
|
||||
OauthReturnTo: "",
|
||||
}
|
||||
}
|
@ -4,17 +4,19 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-pkgz/routegroup"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app/api/core"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/middleware/cors"
|
||||
"github.com/h44z/wg-portal/internal/app/api/v1/models"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
// GetName returns the name of the handler.
|
||||
GetName() string
|
||||
RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler)
|
||||
// RegisterRoutes registers the routes for the handler. The session manager is passed to the handler.
|
||||
RegisterRoutes(g *routegroup.Bundle)
|
||||
}
|
||||
|
||||
// To compile the API documentation use the
|
||||
@ -38,18 +40,14 @@ type Handler interface {
|
||||
// @BasePath /api/v1
|
||||
// @query.collection.format multi
|
||||
|
||||
func NewRestApi(userSource UserSource, handlers ...Handler) core.ApiEndpointSetupFunc {
|
||||
authenticator := &authenticationHandler{
|
||||
userSource: userSource,
|
||||
}
|
||||
|
||||
func NewRestApi(handlers ...Handler) core.ApiEndpointSetupFunc {
|
||||
return func() (core.ApiVersion, core.GroupSetupFn) {
|
||||
return "v1", func(group *gin.RouterGroup) {
|
||||
group.Use(cors.Default())
|
||||
return "v1", func(group *routegroup.Bundle) {
|
||||
group.Use(cors.New().Handler)
|
||||
|
||||
// Handler functions
|
||||
for _, h := range handlers {
|
||||
h.RegisterRoutes(group, authenticator)
|
||||
h.RegisterRoutes(group)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -80,3 +78,12 @@ func ParseServiceError(err error) (int, models.Error) {
|
||||
Message: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
type Authenticator interface {
|
||||
// LoggedIn checks if a user is logged in. If scopes are given, they are validated as well.
|
||||
LoggedIn(scopes ...Scope) func(next http.Handler) http.Handler
|
||||
}
|
||||
|
||||
type Validator interface {
|
||||
Struct(s interface{}) error
|
||||
}
|
||||
|
@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-pkgz/routegroup"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/request"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/respond"
|
||||
"github.com/h44z/wg-portal/internal/app/api/v1/models"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
@ -19,12 +21,20 @@ type InterfaceEndpointInterfaceService interface {
|
||||
}
|
||||
|
||||
type InterfaceEndpoint struct {
|
||||
interfaces InterfaceEndpointInterfaceService
|
||||
interfaces InterfaceEndpointInterfaceService
|
||||
authenticator Authenticator
|
||||
validator Validator
|
||||
}
|
||||
|
||||
func NewInterfaceEndpoint(interfaceService InterfaceEndpointInterfaceService) *InterfaceEndpoint {
|
||||
func NewInterfaceEndpoint(
|
||||
authenticator Authenticator,
|
||||
validator Validator,
|
||||
interfaceService InterfaceEndpointInterfaceService,
|
||||
) *InterfaceEndpoint {
|
||||
return &InterfaceEndpoint{
|
||||
interfaces: interfaceService,
|
||||
authenticator: authenticator,
|
||||
validator: validator,
|
||||
interfaces: interfaceService,
|
||||
}
|
||||
}
|
||||
|
||||
@ -32,15 +42,16 @@ func (e InterfaceEndpoint) GetName() string {
|
||||
return "InterfaceEndpoint"
|
||||
}
|
||||
|
||||
func (e InterfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
|
||||
apiGroup := g.Group("/interface", authenticator.LoggedIn())
|
||||
func (e InterfaceEndpoint) RegisterRoutes(g *routegroup.Bundle) {
|
||||
apiGroup := g.Mount("/interface")
|
||||
apiGroup.Use(e.authenticator.LoggedIn(ScopeAdmin))
|
||||
|
||||
apiGroup.GET("/all", authenticator.LoggedIn(ScopeAdmin), e.handleAllGet())
|
||||
apiGroup.GET("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleByIdGet())
|
||||
apiGroup.HandleFunc("GET /all", e.handleAllGet())
|
||||
apiGroup.HandleFunc("GET /by-id/{id}", e.handleByIdGet())
|
||||
|
||||
apiGroup.POST("/new", authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost())
|
||||
apiGroup.PUT("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleUpdatePut())
|
||||
apiGroup.DELETE("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleDelete())
|
||||
apiGroup.HandleFunc("POST /new", e.handleCreatePost())
|
||||
apiGroup.HandleFunc("PUT /by-id/{id}", e.handleUpdatePut())
|
||||
apiGroup.HandleFunc("DELETE /by-id/{id}", e.handleDelete())
|
||||
}
|
||||
|
||||
// handleAllGet returns a gorm Handler function.
|
||||
@ -54,17 +65,16 @@ func (e InterfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *aut
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /interface/all [get]
|
||||
// @Security BasicAuth
|
||||
func (e InterfaceEndpoint) handleAllGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
allInterfaces, allPeersPerInterface, err := e.interfaces.GetAll(ctx)
|
||||
func (e InterfaceEndpoint) handleAllGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
allInterfaces, allPeersPerInterface, err := e.interfaces.GetAll(r.Context())
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.NewInterfaces(allInterfaces, allPeersPerInterface))
|
||||
respond.JSON(w, http.StatusOK, models.NewInterfaces(allInterfaces, allPeersPerInterface))
|
||||
}
|
||||
}
|
||||
|
||||
@ -82,23 +92,23 @@ func (e InterfaceEndpoint) handleAllGet() gin.HandlerFunc {
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /interface/by-id/{id} [get]
|
||||
// @Security BasicAuth
|
||||
func (e InterfaceEndpoint) handleByIdGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := c.Param("id")
|
||||
func (e InterfaceEndpoint) handleByIdGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := request.Path(r, "id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
return
|
||||
}
|
||||
|
||||
iface, interfacePeers, err := e.interfaces.GetById(ctx, domain.InterfaceIdentifier(id))
|
||||
iface, interfacePeers, err := e.interfaces.GetById(r.Context(), domain.InterfaceIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.NewInterface(iface, interfacePeers))
|
||||
respond.JSON(w, http.StatusOK, models.NewInterface(iface, interfacePeers))
|
||||
}
|
||||
}
|
||||
|
||||
@ -117,24 +127,26 @@ func (e InterfaceEndpoint) handleByIdGet() gin.HandlerFunc {
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /interface/new [post]
|
||||
// @Security BasicAuth
|
||||
func (e InterfaceEndpoint) handleCreatePost() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
func (e InterfaceEndpoint) handleCreatePost() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var iface models.Interface
|
||||
err := c.BindJSON(&iface)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
if err := request.BodyJson(r, &iface); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := e.validator.Struct(iface); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
newInterface, err := e.interfaces.Create(ctx, models.NewDomainInterface(&iface))
|
||||
newInterface, err := e.interfaces.Create(r.Context(), models.NewDomainInterface(&iface))
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.NewInterface(newInterface, nil))
|
||||
respond.JSON(w, http.StatusOK, models.NewInterface(newInterface, nil))
|
||||
}
|
||||
}
|
||||
|
||||
@ -154,34 +166,43 @@ func (e InterfaceEndpoint) handleCreatePost() gin.HandlerFunc {
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /interface/by-id/{id} [put]
|
||||
// @Security BasicAuth
|
||||
func (e InterfaceEndpoint) handleUpdatePut() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := c.Param("id")
|
||||
func (e InterfaceEndpoint) handleUpdatePut() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := request.Path(r, "id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
return
|
||||
}
|
||||
|
||||
var iface models.Interface
|
||||
err := c.BindJSON(&iface)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
if err := request.BodyJson(r, &iface); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := e.validator.Struct(iface); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if id != iface.Identifier {
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
models.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"})
|
||||
return
|
||||
}
|
||||
|
||||
updatedInterface, updatedInterfacePeers, err := e.interfaces.Update(
|
||||
ctx,
|
||||
r.Context(),
|
||||
domain.InterfaceIdentifier(id),
|
||||
models.NewDomainInterface(&iface),
|
||||
)
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.NewInterface(updatedInterface, updatedInterfacePeers))
|
||||
respond.JSON(w, http.StatusOK, models.NewInterface(updatedInterface, updatedInterfacePeers))
|
||||
}
|
||||
}
|
||||
|
||||
@ -200,22 +221,22 @@ func (e InterfaceEndpoint) handleUpdatePut() gin.HandlerFunc {
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /interface/by-id/{id} [delete]
|
||||
// @Security BasicAuth
|
||||
func (e InterfaceEndpoint) handleDelete() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := c.Param("id")
|
||||
func (e InterfaceEndpoint) handleDelete() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := request.Path(r, "id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
return
|
||||
}
|
||||
|
||||
err := e.interfaces.Delete(ctx, domain.InterfaceIdentifier(id))
|
||||
err := e.interfaces.Delete(r.Context(), domain.InterfaceIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
respond.Status(w, http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-pkgz/routegroup"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/request"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/respond"
|
||||
"github.com/h44z/wg-portal/internal/app/api/v1/models"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
@ -17,12 +19,20 @@ type MetricsEndpointStatisticsService interface {
|
||||
}
|
||||
|
||||
type MetricsEndpoint struct {
|
||||
metrics MetricsEndpointStatisticsService
|
||||
metrics MetricsEndpointStatisticsService
|
||||
authenticator Authenticator
|
||||
validator Validator
|
||||
}
|
||||
|
||||
func NewMetricsEndpoint(metrics MetricsEndpointStatisticsService) *MetricsEndpoint {
|
||||
func NewMetricsEndpoint(
|
||||
authenticator Authenticator,
|
||||
validator Validator,
|
||||
metrics MetricsEndpointStatisticsService,
|
||||
) *MetricsEndpoint {
|
||||
return &MetricsEndpoint{
|
||||
metrics: metrics,
|
||||
authenticator: authenticator,
|
||||
validator: validator,
|
||||
metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
@ -30,12 +40,14 @@ func (e MetricsEndpoint) GetName() string {
|
||||
return "MetricsEndpoint"
|
||||
}
|
||||
|
||||
func (e MetricsEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
|
||||
apiGroup := g.Group("/metrics", authenticator.LoggedIn())
|
||||
func (e MetricsEndpoint) RegisterRoutes(g *routegroup.Bundle) {
|
||||
apiGroup := g.Mount("/metrics")
|
||||
apiGroup.Use(e.authenticator.LoggedIn())
|
||||
|
||||
apiGroup.GET("/by-interface/:id", authenticator.LoggedIn(ScopeAdmin), e.handleMetricsForInterfaceGet())
|
||||
apiGroup.GET("/by-user/:id", authenticator.LoggedIn(), e.handleMetricsForUserGet())
|
||||
apiGroup.GET("/by-peer/:id", authenticator.LoggedIn(), e.handleMetricsForPeerGet())
|
||||
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /by-interface/{id}",
|
||||
e.handleMetricsForInterfaceGet())
|
||||
apiGroup.HandleFunc("GET /by-user/{id}", e.handleMetricsForUserGet())
|
||||
apiGroup.HandleFunc("GET /by-peer/{id}", e.handleMetricsForPeerGet())
|
||||
}
|
||||
|
||||
// handleMetricsForInterfaceGet returns a gorm Handler function.
|
||||
@ -52,23 +64,23 @@ func (e MetricsEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authe
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /metrics/by-interface/{id} [get]
|
||||
// @Security BasicAuth
|
||||
func (e MetricsEndpoint) handleMetricsForInterfaceGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := c.Param("id")
|
||||
func (e MetricsEndpoint) handleMetricsForInterfaceGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := request.Path(r, "id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
return
|
||||
}
|
||||
|
||||
interfaceMetrics, err := e.metrics.GetForInterface(ctx, domain.InterfaceIdentifier(id))
|
||||
interfaceMetrics, err := e.metrics.GetForInterface(r.Context(), domain.InterfaceIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.NewInterfaceMetrics(interfaceMetrics))
|
||||
respond.JSON(w, http.StatusOK, models.NewInterfaceMetrics(interfaceMetrics))
|
||||
}
|
||||
}
|
||||
|
||||
@ -86,23 +98,23 @@ func (e MetricsEndpoint) handleMetricsForInterfaceGet() gin.HandlerFunc {
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /metrics/by-user/{id} [get]
|
||||
// @Security BasicAuth
|
||||
func (e MetricsEndpoint) handleMetricsForUserGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := c.Param("id")
|
||||
func (e MetricsEndpoint) handleMetricsForUserGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := request.Path(r, "id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
return
|
||||
}
|
||||
|
||||
user, userMetrics, err := e.metrics.GetForUser(ctx, domain.UserIdentifier(id))
|
||||
user, userMetrics, err := e.metrics.GetForUser(r.Context(), domain.UserIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.NewUserMetrics(user, userMetrics))
|
||||
respond.JSON(w, http.StatusOK, models.NewUserMetrics(user, userMetrics))
|
||||
}
|
||||
}
|
||||
|
||||
@ -120,22 +132,22 @@ func (e MetricsEndpoint) handleMetricsForUserGet() gin.HandlerFunc {
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /metrics/by-peer/{id} [get]
|
||||
// @Security BasicAuth
|
||||
func (e MetricsEndpoint) handleMetricsForPeerGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := c.Param("id")
|
||||
func (e MetricsEndpoint) handleMetricsForPeerGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := request.Path(r, "id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
|
||||
return
|
||||
}
|
||||
|
||||
peerMetrics, err := e.metrics.GetForPeer(ctx, domain.PeerIdentifier(id))
|
||||
peerMetrics, err := e.metrics.GetForPeer(r.Context(), domain.PeerIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.NewPeerMetrics(peerMetrics))
|
||||
respond.JSON(w, http.StatusOK, models.NewPeerMetrics(peerMetrics))
|
||||
}
|
||||
}
|
||||
|
@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-pkgz/routegroup"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/request"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/respond"
|
||||
"github.com/h44z/wg-portal/internal/app/api/v1/models"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
@ -20,12 +22,19 @@ type PeerService interface {
|
||||
}
|
||||
|
||||
type PeerEndpoint struct {
|
||||
peers PeerService
|
||||
peers PeerService
|
||||
authenticator Authenticator
|
||||
validator Validator
|
||||
}
|
||||
|
||||
func NewPeerEndpoint(peerService PeerService) *PeerEndpoint {
|
||||
func NewPeerEndpoint(
|
||||
authenticator Authenticator,
|
||||
validator Validator, peerService PeerService,
|
||||
) *PeerEndpoint {
|
||||
return &PeerEndpoint{
|
||||
peers: peerService,
|
||||
authenticator: authenticator,
|
||||
validator: validator,
|
||||
peers: peerService,
|
||||
}
|
||||
}
|
||||
|
||||
@ -33,16 +42,18 @@ func (e PeerEndpoint) GetName() string {
|
||||
return "PeerEndpoint"
|
||||
}
|
||||
|
||||
func (e PeerEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
|
||||
apiGroup := g.Group("/peer", authenticator.LoggedIn())
|
||||
func (e PeerEndpoint) RegisterRoutes(g *routegroup.Bundle) {
|
||||
apiGroup := g.Mount("/peer")
|
||||
apiGroup.Use(e.authenticator.LoggedIn())
|
||||
|
||||
apiGroup.GET("/by-interface/:id", authenticator.LoggedIn(ScopeAdmin), e.handleAllForInterfaceGet())
|
||||
apiGroup.GET("/by-user/:id", authenticator.LoggedIn(), e.handleAllForUserGet())
|
||||
apiGroup.GET("/by-id/:id", authenticator.LoggedIn(), e.handleByIdGet())
|
||||
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /by-interface/{id}",
|
||||
e.handleAllForInterfaceGet())
|
||||
apiGroup.HandleFunc("GET /by-user/{id}", e.handleAllForUserGet())
|
||||
apiGroup.HandleFunc("GET /by-id/{id}", e.handleByIdGet())
|
||||
|
||||
apiGroup.POST("/new", authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost())
|
||||
apiGroup.PUT("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleUpdatePut())
|
||||
apiGroup.DELETE("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleDelete())
|
||||
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("POST /new", e.handleCreatePost())
|
||||
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("PUT /by-id/{id}", e.handleUpdatePut())
|
||||
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("DELETE /by-id/{id}", e.handleDelete())
|
||||
}
|
||||
|
||||
// handleAllForInterfaceGet returns a gorm Handler function.
|
||||
@ -57,23 +68,23 @@ func (e PeerEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenti
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /peer/by-interface/{id} [get]
|
||||
// @Security BasicAuth
|
||||
func (e PeerEndpoint) handleAllForInterfaceGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := c.Param("id")
|
||||
func (e PeerEndpoint) handleAllForInterfaceGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := request.Path(r, "id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
models.Error{Code: http.StatusBadRequest, Message: "missing interface id"})
|
||||
return
|
||||
}
|
||||
|
||||
interfacePeers, err := e.peers.GetForInterface(ctx, domain.InterfaceIdentifier(id))
|
||||
interfacePeers, err := e.peers.GetForInterface(r.Context(), domain.InterfaceIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.NewPeers(interfacePeers))
|
||||
respond.JSON(w, http.StatusOK, models.NewPeers(interfacePeers))
|
||||
}
|
||||
}
|
||||
|
||||
@ -90,23 +101,23 @@ func (e PeerEndpoint) handleAllForInterfaceGet() gin.HandlerFunc {
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /peer/by-user/{id} [get]
|
||||
// @Security BasicAuth
|
||||
func (e PeerEndpoint) handleAllForUserGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := c.Param("id")
|
||||
func (e PeerEndpoint) handleAllForUserGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := request.Path(r, "id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing user id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
models.Error{Code: http.StatusBadRequest, Message: "missing user id"})
|
||||
return
|
||||
}
|
||||
|
||||
interfacePeers, err := e.peers.GetForUser(ctx, domain.UserIdentifier(id))
|
||||
interfacePeers, err := e.peers.GetForUser(r.Context(), domain.UserIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.NewPeers(interfacePeers))
|
||||
respond.JSON(w, http.StatusOK, models.NewPeers(interfacePeers))
|
||||
}
|
||||
}
|
||||
|
||||
@ -125,23 +136,23 @@ func (e PeerEndpoint) handleAllForUserGet() gin.HandlerFunc {
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /peer/by-id/{id} [get]
|
||||
// @Security BasicAuth
|
||||
func (e PeerEndpoint) handleByIdGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := c.Param("id")
|
||||
func (e PeerEndpoint) handleByIdGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := request.Path(r, "id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
|
||||
return
|
||||
}
|
||||
|
||||
peer, err := e.peers.GetById(ctx, domain.PeerIdentifier(id))
|
||||
peer, err := e.peers.GetById(r.Context(), domain.PeerIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.NewPeer(peer))
|
||||
respond.JSON(w, http.StatusOK, models.NewPeer(peer))
|
||||
}
|
||||
}
|
||||
|
||||
@ -161,24 +172,26 @@ func (e PeerEndpoint) handleByIdGet() gin.HandlerFunc {
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /peer/new [post]
|
||||
// @Security BasicAuth
|
||||
func (e PeerEndpoint) handleCreatePost() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
func (e PeerEndpoint) handleCreatePost() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var peer models.Peer
|
||||
err := c.BindJSON(&peer)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
if err := request.BodyJson(r, &peer); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := e.validator.Struct(peer); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
newPeer, err := e.peers.Create(ctx, models.NewDomainPeer(&peer))
|
||||
newPeer, err := e.peers.Create(r.Context(), models.NewDomainPeer(&peer))
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.NewPeer(newPeer))
|
||||
respond.JSON(w, http.StatusOK, models.NewPeer(newPeer))
|
||||
}
|
||||
}
|
||||
|
||||
@ -199,30 +212,33 @@ func (e PeerEndpoint) handleCreatePost() gin.HandlerFunc {
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /peer/by-id/{id} [put]
|
||||
// @Security BasicAuth
|
||||
func (e PeerEndpoint) handleUpdatePut() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := c.Param("id")
|
||||
func (e PeerEndpoint) handleUpdatePut() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := request.Path(r, "id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
|
||||
return
|
||||
}
|
||||
|
||||
var peer models.Peer
|
||||
err := c.BindJSON(&peer)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
if err := request.BodyJson(r, &peer); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := e.validator.Struct(peer); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
updatedPeer, err := e.peers.Update(ctx, domain.PeerIdentifier(id), models.NewDomainPeer(&peer))
|
||||
updatedPeer, err := e.peers.Update(r.Context(), domain.PeerIdentifier(id), models.NewDomainPeer(&peer))
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.NewPeer(updatedPeer))
|
||||
respond.JSON(w, http.StatusOK, models.NewPeer(updatedPeer))
|
||||
}
|
||||
}
|
||||
|
||||
@ -241,22 +257,22 @@ func (e PeerEndpoint) handleUpdatePut() gin.HandlerFunc {
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /peer/by-id/{id} [delete]
|
||||
// @Security BasicAuth
|
||||
func (e PeerEndpoint) handleDelete() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := c.Param("id")
|
||||
func (e PeerEndpoint) handleDelete() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := request.Path(r, "id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
|
||||
return
|
||||
}
|
||||
|
||||
err := e.peers.Delete(ctx, domain.PeerIdentifier(id))
|
||||
err := e.peers.Delete(r.Context(), domain.PeerIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
respond.Status(w, http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
@ -5,8 +5,10 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-pkgz/routegroup"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/request"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/respond"
|
||||
"github.com/h44z/wg-portal/internal/app/api/v1/models"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
@ -23,12 +25,20 @@ type ProvisioningEndpointProvisioningService interface {
|
||||
}
|
||||
|
||||
type ProvisioningEndpoint struct {
|
||||
provisioning ProvisioningEndpointProvisioningService
|
||||
provisioning ProvisioningEndpointProvisioningService
|
||||
authenticator Authenticator
|
||||
validator Validator
|
||||
}
|
||||
|
||||
func NewProvisioningEndpoint(provisioning ProvisioningEndpointProvisioningService) *ProvisioningEndpoint {
|
||||
func NewProvisioningEndpoint(
|
||||
authenticator Authenticator,
|
||||
validator Validator,
|
||||
provisioning ProvisioningEndpointProvisioningService,
|
||||
) *ProvisioningEndpoint {
|
||||
return &ProvisioningEndpoint{
|
||||
provisioning: provisioning,
|
||||
authenticator: authenticator,
|
||||
validator: validator,
|
||||
provisioning: provisioning,
|
||||
}
|
||||
}
|
||||
|
||||
@ -36,14 +46,15 @@ func (e ProvisioningEndpoint) GetName() string {
|
||||
return "ProvisioningEndpoint"
|
||||
}
|
||||
|
||||
func (e ProvisioningEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
|
||||
apiGroup := g.Group("/provisioning", authenticator.LoggedIn())
|
||||
func (e ProvisioningEndpoint) RegisterRoutes(g *routegroup.Bundle) {
|
||||
apiGroup := g.Mount("/provisioning")
|
||||
apiGroup.Use(e.authenticator.LoggedIn())
|
||||
|
||||
apiGroup.GET("/data/user-info", authenticator.LoggedIn(), e.handleUserInfoGet())
|
||||
apiGroup.GET("/data/peer-config", authenticator.LoggedIn(), e.handlePeerConfigGet())
|
||||
apiGroup.GET("/data/peer-qr", authenticator.LoggedIn(), e.handlePeerQrGet())
|
||||
apiGroup.HandleFunc("GET /data/user-info", e.handleUserInfoGet())
|
||||
apiGroup.HandleFunc("GET /data/peer-config", e.handlePeerConfigGet())
|
||||
apiGroup.HandleFunc("GET /data/peer-qr", e.handlePeerQrGet())
|
||||
|
||||
apiGroup.POST("/new-peer", authenticator.LoggedIn(), e.handleNewPeerPost())
|
||||
apiGroup.HandleFunc("POST /new-peer", e.handleNewPeerPost())
|
||||
}
|
||||
|
||||
// handleUserInfoGet returns a gorm Handler function.
|
||||
@ -63,24 +74,23 @@ func (e ProvisioningEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /provisioning/data/user-info [get]
|
||||
// @Security BasicAuth
|
||||
func (e ProvisioningEndpoint) handleUserInfoGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := strings.TrimSpace(c.Query("UserId"))
|
||||
email := strings.TrimSpace(c.Query("Email"))
|
||||
func (e ProvisioningEndpoint) handleUserInfoGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := strings.TrimSpace(request.Query(r, "UserId"))
|
||||
email := strings.TrimSpace(request.Query(r, "Email"))
|
||||
|
||||
if id == "" && email == "" {
|
||||
id = string(domain.GetUserInfo(ctx).Id)
|
||||
id = string(domain.GetUserInfo(r.Context()).Id)
|
||||
}
|
||||
|
||||
user, peers, err := e.provisioning.GetUserAndPeers(ctx, domain.UserIdentifier(id), email)
|
||||
user, peers, err := e.provisioning.GetUserAndPeers(r.Context(), domain.UserIdentifier(id), email)
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.NewUserInformation(user, peers))
|
||||
respond.JSON(w, http.StatusOK, models.NewUserInformation(user, peers))
|
||||
}
|
||||
}
|
||||
|
||||
@ -101,23 +111,23 @@ func (e ProvisioningEndpoint) handleUserInfoGet() gin.HandlerFunc {
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /provisioning/data/peer-config [get]
|
||||
// @Security BasicAuth
|
||||
func (e ProvisioningEndpoint) handlePeerConfigGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := strings.TrimSpace(c.Query("PeerId"))
|
||||
func (e ProvisioningEndpoint) handlePeerConfigGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := strings.TrimSpace(request.Query(r, "PeerId"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
|
||||
return
|
||||
}
|
||||
|
||||
peerConfig, err := e.provisioning.GetPeerConfig(ctx, domain.PeerIdentifier(id))
|
||||
peerConfig, err := e.provisioning.GetPeerConfig(r.Context(), domain.PeerIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "text/plain", peerConfig)
|
||||
respond.Data(w, http.StatusOK, "text/plain", peerConfig)
|
||||
}
|
||||
}
|
||||
|
||||
@ -138,23 +148,23 @@ func (e ProvisioningEndpoint) handlePeerConfigGet() gin.HandlerFunc {
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /provisioning/data/peer-qr [get]
|
||||
// @Security BasicAuth
|
||||
func (e ProvisioningEndpoint) handlePeerQrGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := strings.TrimSpace(c.Query("PeerId"))
|
||||
func (e ProvisioningEndpoint) handlePeerQrGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := strings.TrimSpace(request.Query(r, "PeerId"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
models.Error{Code: http.StatusBadRequest, Message: "missing peer id"})
|
||||
return
|
||||
}
|
||||
|
||||
peerConfigQrCode, err := e.provisioning.GetPeerQrPng(ctx, domain.PeerIdentifier(id))
|
||||
peerConfigQrCode, err := e.provisioning.GetPeerQrPng(r.Context(), domain.PeerIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "image/png", peerConfigQrCode)
|
||||
respond.Data(w, http.StatusOK, "image/png", peerConfigQrCode)
|
||||
}
|
||||
}
|
||||
|
||||
@ -174,23 +184,25 @@ func (e ProvisioningEndpoint) handlePeerQrGet() gin.HandlerFunc {
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /provisioning/new-peer [post]
|
||||
// @Security BasicAuth
|
||||
func (e ProvisioningEndpoint) handleNewPeerPost() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
func (e ProvisioningEndpoint) handleNewPeerPost() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var req models.ProvisioningRequest
|
||||
err := c.BindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
if err := request.BodyJson(r, &req); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := e.validator.Struct(req); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
peer, err := e.provisioning.NewPeer(ctx, req)
|
||||
peer, err := e.provisioning.NewPeer(r.Context(), req)
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.NewPeer(peer))
|
||||
respond.JSON(w, http.StatusOK, models.NewPeer(peer))
|
||||
}
|
||||
}
|
||||
|
@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-pkgz/routegroup"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/request"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/respond"
|
||||
"github.com/h44z/wg-portal/internal/app/api/v1/models"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
@ -19,12 +21,20 @@ type UserService interface {
|
||||
}
|
||||
|
||||
type UserEndpoint struct {
|
||||
users UserService
|
||||
users UserService
|
||||
authenticator Authenticator
|
||||
validator Validator
|
||||
}
|
||||
|
||||
func NewUserEndpoint(userService UserService) *UserEndpoint {
|
||||
func NewUserEndpoint(
|
||||
authenticator Authenticator,
|
||||
validator Validator,
|
||||
userService UserService,
|
||||
) *UserEndpoint {
|
||||
return &UserEndpoint{
|
||||
users: userService,
|
||||
authenticator: authenticator,
|
||||
validator: validator,
|
||||
users: userService,
|
||||
}
|
||||
}
|
||||
|
||||
@ -32,14 +42,15 @@ func (e UserEndpoint) GetName() string {
|
||||
return "UserEndpoint"
|
||||
}
|
||||
|
||||
func (e UserEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
|
||||
apiGroup := g.Group("/user", authenticator.LoggedIn())
|
||||
func (e UserEndpoint) RegisterRoutes(g *routegroup.Bundle) {
|
||||
apiGroup := g.Mount("/user")
|
||||
apiGroup.Use(e.authenticator.LoggedIn())
|
||||
|
||||
apiGroup.GET("/all", authenticator.LoggedIn(ScopeAdmin), e.handleAllGet())
|
||||
apiGroup.GET("/by-id/:id", authenticator.LoggedIn(), e.handleByIdGet())
|
||||
apiGroup.POST("/new", authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost())
|
||||
apiGroup.PUT("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleUpdatePut())
|
||||
apiGroup.DELETE("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleDelete())
|
||||
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /all", e.handleAllGet())
|
||||
apiGroup.HandleFunc("GET /by-id/{id}", e.handleByIdGet())
|
||||
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("POST /new", e.handleCreatePost())
|
||||
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("PUT /by-id/{id}", e.handleUpdatePut())
|
||||
apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("DELETE /by-id/{id}", e.handleDelete())
|
||||
}
|
||||
|
||||
// handleAllGet returns a gorm Handler function.
|
||||
@ -53,17 +64,16 @@ func (e UserEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenti
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /user/all [get]
|
||||
// @Security BasicAuth
|
||||
func (e UserEndpoint) handleAllGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
users, err := e.users.GetAll(ctx)
|
||||
func (e UserEndpoint) handleAllGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
users, err := e.users.GetAll(r.Context())
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.NewUsers(users))
|
||||
respond.JSON(w, http.StatusOK, models.NewUsers(users))
|
||||
}
|
||||
}
|
||||
|
||||
@ -82,23 +92,23 @@ func (e UserEndpoint) handleAllGet() gin.HandlerFunc {
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /user/by-id/{id} [get]
|
||||
// @Security BasicAuth
|
||||
func (e UserEndpoint) handleByIdGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := c.Param("id")
|
||||
func (e UserEndpoint) handleByIdGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := request.Path(r, "id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing user id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
models.Error{Code: http.StatusBadRequest, Message: "missing user id"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := e.users.GetById(ctx, domain.UserIdentifier(id))
|
||||
user, err := e.users.GetById(r.Context(), domain.UserIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.NewUser(user, true))
|
||||
respond.JSON(w, http.StatusOK, models.NewUser(user, true))
|
||||
}
|
||||
}
|
||||
|
||||
@ -118,24 +128,26 @@ func (e UserEndpoint) handleByIdGet() gin.HandlerFunc {
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /user/new [post]
|
||||
// @Security BasicAuth
|
||||
func (e UserEndpoint) handleCreatePost() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
func (e UserEndpoint) handleCreatePost() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var user models.User
|
||||
err := c.BindJSON(&user)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
if err := request.BodyJson(r, &user); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := e.validator.Struct(user); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
newUser, err := e.users.Create(ctx, models.NewDomainUser(&user))
|
||||
newUser, err := e.users.Create(r.Context(), models.NewDomainUser(&user))
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.NewUser(newUser, true))
|
||||
respond.JSON(w, http.StatusOK, models.NewUser(newUser, true))
|
||||
}
|
||||
}
|
||||
|
||||
@ -156,30 +168,33 @@ func (e UserEndpoint) handleCreatePost() gin.HandlerFunc {
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /user/by-id/{id} [put]
|
||||
// @Security BasicAuth
|
||||
func (e UserEndpoint) handleUpdatePut() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := c.Param("id")
|
||||
func (e UserEndpoint) handleUpdatePut() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := request.Path(r, "id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing user id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
models.Error{Code: http.StatusBadRequest, Message: "missing user id"})
|
||||
return
|
||||
}
|
||||
|
||||
var user models.User
|
||||
err := c.BindJSON(&user)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
if err := request.BodyJson(r, &user); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := e.validator.Struct(user); err != nil {
|
||||
respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
updateUser, err := e.users.Update(ctx, domain.UserIdentifier(id), models.NewDomainUser(&user))
|
||||
updateUser, err := e.users.Update(r.Context(), domain.UserIdentifier(id), models.NewDomainUser(&user))
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.NewUser(updateUser, true))
|
||||
respond.JSON(w, http.StatusOK, models.NewUser(updateUser, true))
|
||||
}
|
||||
}
|
||||
|
||||
@ -198,22 +213,22 @@ func (e UserEndpoint) handleUpdatePut() gin.HandlerFunc {
|
||||
// @Failure 500 {object} models.Error
|
||||
// @Router /user/by-id/{id} [delete]
|
||||
// @Security BasicAuth
|
||||
func (e UserEndpoint) handleDelete() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := c.Param("id")
|
||||
func (e UserEndpoint) handleDelete() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := request.Path(r, "id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing user id"})
|
||||
respond.JSON(w, http.StatusBadRequest,
|
||||
models.Error{Code: http.StatusBadRequest, Message: "missing user id"})
|
||||
return
|
||||
}
|
||||
|
||||
err := e.users.Delete(ctx, domain.UserIdentifier(id))
|
||||
err := e.users.Delete(r.Context(), domain.UserIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(ParseServiceError(err))
|
||||
status, model := ParseServiceError(err)
|
||||
respond.JSON(w, status, model)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
respond.Status(w, http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
@ -1,93 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app/api/v0/model"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type Scope string
|
||||
|
||||
const (
|
||||
ScopeAdmin Scope = "ADMIN" // Admin scope contains all other scopes
|
||||
)
|
||||
|
||||
type UserSource interface {
|
||||
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||
}
|
||||
|
||||
type authenticationHandler struct {
|
||||
userSource UserSource
|
||||
}
|
||||
|
||||
// LoggedIn checks if a user is logged in. If scopes are given, they are validated as well.
|
||||
func (h authenticationHandler) LoggedIn(scopes ...Scope) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
username, password, ok := c.Request.BasicAuth()
|
||||
if !ok || username == "" || password == "" {
|
||||
// Abort the request with the appropriate error code
|
||||
c.Abort()
|
||||
c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "missing credentials"})
|
||||
return
|
||||
}
|
||||
|
||||
// check if user exists in DB
|
||||
|
||||
ctx := domain.SetUserInfo(c.Request.Context(), domain.SystemAdminContextUserInfo())
|
||||
user, err := h.userSource.GetUser(ctx, domain.UserIdentifier(username))
|
||||
if err != nil {
|
||||
// Abort the request with the appropriate error code
|
||||
c.Abort()
|
||||
c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "invalid credentials"})
|
||||
return
|
||||
}
|
||||
|
||||
// validate API token
|
||||
if err := user.CheckApiToken(password); err != nil {
|
||||
// Abort the request with the appropriate error code
|
||||
c.Abort()
|
||||
c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "invalid credentials"})
|
||||
return
|
||||
}
|
||||
|
||||
if !UserHasScopes(user, scopes...) {
|
||||
// Abort the request with the appropriate error code
|
||||
c.Abort()
|
||||
c.JSON(http.StatusForbidden, model.Error{Code: http.StatusForbidden, Message: "not enough permissions"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(domain.CtxUserInfo, &domain.ContextUserInfo{
|
||||
Id: user.Identifier,
|
||||
IsAdmin: user.IsAdmin,
|
||||
})
|
||||
|
||||
// Continue down the chain to Handler etc
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func UserHasScopes(user *domain.User, scopes ...Scope) bool {
|
||||
// No scopes give, so the check should succeed
|
||||
if len(scopes) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// check if user has admin scope
|
||||
if user.IsAdmin {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if admin scope is required
|
||||
for _, scope := range scopes {
|
||||
if scope == ScopeAdmin {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
101
internal/app/api/v1/handlers/web_authentication.go
Normal file
101
internal/app/api/v1/handlers/web_authentication.go
Normal file
@ -0,0 +1,101 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app/api/core/respond"
|
||||
"github.com/h44z/wg-portal/internal/app/api/v0/model"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type Scope string
|
||||
|
||||
const (
|
||||
ScopeAdmin Scope = "ADMIN" // Admin scope contains all other scopes
|
||||
)
|
||||
|
||||
type UserAuthenticator interface {
|
||||
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||
}
|
||||
|
||||
type AuthenticationHandler struct {
|
||||
authenticator UserAuthenticator
|
||||
}
|
||||
|
||||
func NewAuthenticationHandler(authenticator UserAuthenticator) AuthenticationHandler {
|
||||
return AuthenticationHandler{
|
||||
authenticator: authenticator,
|
||||
}
|
||||
}
|
||||
|
||||
// LoggedIn checks if a user is logged in. If scopes are given, they are validated as well.
|
||||
func (h AuthenticationHandler) LoggedIn(scopes ...Scope) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok || username == "" || password == "" {
|
||||
// Abort the request with the appropriate error code
|
||||
respond.JSON(w, http.StatusUnauthorized,
|
||||
model.Error{Code: http.StatusUnauthorized, Message: "missing credentials"})
|
||||
return
|
||||
}
|
||||
|
||||
// check if user exists in DB
|
||||
|
||||
ctx := domain.SetUserInfo(r.Context(), domain.SystemAdminContextUserInfo())
|
||||
user, err := h.authenticator.GetUser(ctx, domain.UserIdentifier(username))
|
||||
if err != nil {
|
||||
// Abort the request with the appropriate error code
|
||||
respond.JSON(w, http.StatusUnauthorized,
|
||||
model.Error{Code: http.StatusUnauthorized, Message: "invalid credentials"})
|
||||
return
|
||||
}
|
||||
|
||||
// validate API token
|
||||
if err := user.CheckApiToken(password); err != nil {
|
||||
// Abort the request with the appropriate error code
|
||||
respond.JSON(w, http.StatusUnauthorized,
|
||||
model.Error{Code: http.StatusUnauthorized, Message: "invalid credentials"})
|
||||
return
|
||||
}
|
||||
|
||||
if !UserHasScopes(user, scopes...) {
|
||||
// Abort the request with the appropriate error code
|
||||
respond.JSON(w, http.StatusForbidden,
|
||||
model.Error{Code: http.StatusForbidden, Message: "not enough permissions"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx = context.WithValue(r.Context(), domain.CtxUserInfo, &domain.ContextUserInfo{
|
||||
Id: user.Identifier,
|
||||
IsAdmin: user.IsAdmin,
|
||||
})
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// Continue down the chain to Handler etc
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func UserHasScopes(user *domain.User, scopes ...Scope) bool {
|
||||
// No scopes give, so the check should succeed
|
||||
if len(scopes) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// check if user has admin scope
|
||||
if user.IsAdmin {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if admin scope is required
|
||||
for _, scope := range scopes {
|
||||
if scope == ScopeAdmin {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
@ -3,6 +3,8 @@ package config
|
||||
type WebConfig struct {
|
||||
// RequestLogging enables logging of all HTTP requests.
|
||||
RequestLogging bool `yaml:"request_logging"`
|
||||
// ExposeHostInfo sets whether the host information should be exposed in a response header.
|
||||
ExposeHostInfo bool `yaml:"expose_host_info"`
|
||||
// ExternalUrl is the URL where a client can access WireGuard Portal.
|
||||
// This is used for the callback URL of the OAuth providers.
|
||||
ExternalUrl string `yaml:"external_url"`
|
||||
|
@ -4,8 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const CtxUserInfo = "userInfo"
|
||||
@ -47,21 +45,6 @@ func SystemAdminContextUserInfo() *ContextUserInfo {
|
||||
}
|
||||
}
|
||||
|
||||
// SetUserInfoFromGin sets the user info from the gin context to the request context.
|
||||
func SetUserInfoFromGin(c *gin.Context) context.Context {
|
||||
ginUserInfo, exists := c.Get(CtxUserInfo)
|
||||
|
||||
info := DefaultContextUserInfo()
|
||||
if exists {
|
||||
if ginInfo, ok := ginUserInfo.(*ContextUserInfo); ok {
|
||||
info = ginInfo
|
||||
}
|
||||
}
|
||||
|
||||
ctx := SetUserInfo(c.Request.Context(), info)
|
||||
return ctx
|
||||
}
|
||||
|
||||
// SetUserInfo sets the user info in the context.
|
||||
func SetUserInfo(ctx context.Context, info *ContextUserInfo) context.Context {
|
||||
ctx = context.WithValue(ctx, CtxUserInfo, info)
|
||||
|
Loading…
x
Reference in New Issue
Block a user