initial commit

This commit is contained in:
Christoph Haas
2020-11-05 19:37:51 +01:00
commit 93f7335b6e
70 changed files with 22081 additions and 0 deletions

View File

@@ -0,0 +1,99 @@
package common
import (
"errors"
"os"
"reflect"
"github.com/h44z/wg-portal/internal/wireguard"
"github.com/h44z/wg-portal/internal/ldap"
"github.com/kelseyhightower/envconfig"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
)
var ErrInvalidSpecification = errors.New("specification must be a struct pointer")
// LoadConfigFile parses yaml files. It uses to yaml annotation to store the data in a struct.
func loadConfigFile(cfg interface{}, filename string) error {
s := reflect.ValueOf(cfg)
if s.Kind() != reflect.Ptr {
return ErrInvalidSpecification
}
s = s.Elem()
if s.Kind() != reflect.Struct {
return ErrInvalidSpecification
}
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()
decoder := yaml.NewDecoder(f)
err = decoder.Decode(cfg)
if err != nil {
return err
}
return nil
}
// LoadConfigEnv processes envconfig annotations and loads environment variables to the given configuration struct.
func loadConfigEnv(cfg interface{}) error {
err := envconfig.Process("", cfg)
if err != nil {
return err
}
return nil
}
type Config struct {
Core struct {
ListeningAddress string `yaml:"listeningAddress" envconfig:"LISTENING_ADDRESS"`
Title string `yaml:"title" envconfig:"WEBSITE_TITLE"`
} `yaml:"core"`
LDAP ldap.Config `yaml:"ldap"`
WG wireguard.Config `yaml:"wg"`
AdminLdapGroup string `yaml:"adminLdapGroup" envconfig:"ADMIN_LDAP_GROUP"`
LogoutRedirectPath string `yaml:"logoutRedirectPath" envconfig:"LOGOUT_REDIRECT_PATH"`
AuthRoutePrefix string `yaml:"authRoutePrefix" envconfig:"AUTH_ROUTE_PREFIX"`
}
func NewConfig() *Config {
cfg := &Config{}
// Default config
cfg.Core.ListeningAddress = ":8080"
cfg.Core.Title = "WireGuard VPN"
cfg.LDAP.URL = "ldap://srv-ad01.company.local:389"
cfg.LDAP.BaseDN = "DC=COMPANY,DC=LOCAL"
cfg.LDAP.StartTLS = true
cfg.LDAP.BindUser = "company\\ldap_wireguard"
cfg.LDAP.BindPass = "SuperSecret"
cfg.WG.DeviceName = "wg0"
cfg.AdminLdapGroup = "CN=WireGuardAdmins,OU=_O_IT,DC=COMPANY,DC=LOCAL"
cfg.LogoutRedirectPath = "/"
cfg.AuthRoutePrefix = "/auth"
// Load config from file and environment
cfgFile, ok := os.LookupEnv("CONFIG_FILE")
if !ok {
cfgFile = "config.yml" // Default config file
}
err := loadConfigFile(cfg, cfgFile)
if err != nil {
log.Warnf("unable to load config.yml file: %v, using default configuration...", err)
}
err = loadConfigEnv(cfg)
if err != nil {
log.Warnf("unable to load environment config: %v", err)
}
return cfg
}

37
internal/common/iputil.go Normal file
View File

@@ -0,0 +1,37 @@
package common
import "net"
// BroadcastAddr returns the last address in the given network, or the broadcast address.
func BroadcastAddr(n *net.IPNet) net.IP {
// The golang net package doesn't make it easy to calculate the broadcast address. :(
var broadcast net.IP
if len(n.IP) == 4 {
broadcast = net.ParseIP("0.0.0.0").To4()
} else {
broadcast = net.ParseIP("::")
}
for i := 0; i < len(n.IP); i++ {
broadcast[i] = n.IP[i] | ^n.Mask[i]
}
return broadcast
}
// http://play.golang.org/p/m8TNTtygK0
func IncreaseIP(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
}
}
}
// IsIPv6 check if given ip is IPv6
func IsIPv6(address string) bool {
ip := net.ParseIP(address)
if ip == nil {
return false
}
return ip.To4() == nil
}

View File

@@ -0,0 +1,88 @@
package ldap
import (
"crypto/tls"
"fmt"
"github.com/go-ldap/ldap/v3"
)
type Authentication struct {
Cfg *Config
}
func NewAuthentication(config Config) Authentication {
a := Authentication{
Cfg: &config,
}
return a
}
func (a Authentication) open() (*ldap.Conn, error) {
conn, err := ldap.DialURL(a.Cfg.URL)
if err != nil {
return nil, err
}
if a.Cfg.StartTLS {
// Reconnect with TLS
err = conn.StartTLS(&tls.Config{InsecureSkipVerify: true})
if err != nil {
return nil, err
}
}
err = conn.Bind(a.Cfg.BindUser, a.Cfg.BindPass)
if err != nil {
return nil, err
}
return conn, nil
}
func (a Authentication) close(conn *ldap.Conn) {
if conn != nil {
conn.Close()
}
}
func (a Authentication) CheckLogin(username, password string) bool {
return a.CheckCustomLogin("sAMAccountName", username, password)
}
func (a Authentication) CheckCustomLogin(userIdentifier, username, password string) bool {
client, err := a.open()
if err != nil {
return false
}
defer a.close(client)
// Search for the given username
searchRequest := ldap.NewSearchRequest(
a.Cfg.BaseDN,
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
fmt.Sprintf("(&(objectClass=organizationalPerson)(%s=%s))", userIdentifier, username),
[]string{"dn"},
nil,
)
sr, err := client.Search(searchRequest)
if err != nil {
return false
}
if len(sr.Entries) != 1 {
return false
}
userDN := sr.Entries[0].DN
// Bind as the user to verify their password
err = client.Bind(userDN, password)
if err != nil {
return false
}
return true
}

9
internal/ldap/ldap.go Normal file
View File

@@ -0,0 +1,9 @@
package ldap
type Config struct {
URL string `yaml:"url" envconfig:"LDAP_URL"`
StartTLS bool `yaml:"startTLS" envconfig:"LDAP_STARTTLS"`
BaseDN string `yaml:"dn" envconfig:"LDAP_BASEDN"`
BindUser string `yaml:"user" envconfig:"LDAP_USER"`
BindPass string `yaml:"pass" envconfig:"LDAP_PASSWORD"`
}

455
internal/ldap/usercache.go Normal file
View File

@@ -0,0 +1,455 @@
package ldap
import (
"crypto/md5"
"crypto/tls"
"fmt"
"sort"
"strings"
"sync"
"time"
"github.com/go-ldap/ldap/v3"
log "github.com/sirupsen/logrus"
)
var Fields = []string{"givenName", "sn", "mail", "department", "memberOf", "sAMAccountName", "telephoneNumber",
"mobile", "displayName", "cn", "title", "company", "manager", "streetAddress", "employeeID", "memberOf", "l",
"st", "postalCode", "co", "facsimileTelephoneNumber", "pager", "thumbnailPhoto", "otherMobile",
"extensionAttribute2", "distinguishedName", "userAccountControl"}
var ModifiableFields = []string{"department", "telephoneNumber", "mobile", "displayName", "title", "company",
"manager", "streetAddress", "employeeID", "l", "st", "postalCode", "co", "thumbnailPhoto"}
// --------------------------------------------------------------------------------------------------------------------
// Cache Data Store
// --------------------------------------------------------------------------------------------------------------------
type UserCacheHolder interface {
Clear()
SetAllUsers(users []RawLdapData)
SetUser(data RawLdapData)
GetUser(dn string) *RawLdapData
GetUsers() []*RawLdapData
}
type RawLdapData struct {
DN string
Attributes map[string]string
RawAttributes map[string][][]byte
}
// --------------------------------------------------------------------------------------------------------------------
// Sample Cache Data store
// --------------------------------------------------------------------------------------------------------------------
type UserCacheHolderEntry struct {
RawLdapData
Username string
Mail string
Firstname string
Lastname string
Groups []string
}
func (e *UserCacheHolderEntry) CalcFieldsFromAttributes() {
e.Username = strings.ToLower(e.Attributes["sAMAccountName"])
e.Mail = e.Attributes["mail"]
e.Firstname = e.Attributes["givenName"]
e.Lastname = e.Attributes["sn"]
e.Groups = make([]string, len(e.RawAttributes["memberOf"]))
for i, group := range e.RawAttributes["memberOf"] {
e.Groups[i] = string(group)
}
}
func (e *UserCacheHolderEntry) GetUID() string {
return fmt.Sprintf("u%x", md5.Sum([]byte(e.Attributes["distinguishedName"])))
}
type SynchronizedUserCacheHolder struct {
users map[string]*UserCacheHolderEntry
mux sync.RWMutex
}
func (h *SynchronizedUserCacheHolder) Init() {
h.users = make(map[string]*UserCacheHolderEntry)
}
func (h *SynchronizedUserCacheHolder) Clear() {
h.mux.Lock()
defer h.mux.Unlock()
h.users = make(map[string]*UserCacheHolderEntry)
}
func (h *SynchronizedUserCacheHolder) SetAllUsers(users []RawLdapData) {
h.mux.Lock()
defer h.mux.Unlock()
h.users = make(map[string]*UserCacheHolderEntry)
for i := range users {
h.users[users[i].DN] = &UserCacheHolderEntry{RawLdapData: users[i]}
h.users[users[i].DN].CalcFieldsFromAttributes()
}
}
func (h *SynchronizedUserCacheHolder) SetUser(user RawLdapData) {
h.mux.Lock()
defer h.mux.Unlock()
h.users[user.DN] = &UserCacheHolderEntry{RawLdapData: user}
h.users[user.DN].CalcFieldsFromAttributes()
}
func (h *SynchronizedUserCacheHolder) GetUser(dn string) *RawLdapData {
h.mux.RLock()
defer h.mux.RUnlock()
return &h.users[dn].RawLdapData
}
func (h *SynchronizedUserCacheHolder) GetUserData(dn string) *UserCacheHolderEntry {
h.mux.RLock()
defer h.mux.RUnlock()
return h.users[dn]
}
func (h *SynchronizedUserCacheHolder) GetUsers() []*RawLdapData {
h.mux.RLock()
defer h.mux.RUnlock()
users := make([]*RawLdapData, 0, len(h.users))
for _, user := range h.users {
users = append(users, &user.RawLdapData)
}
return users
}
func (h *SynchronizedUserCacheHolder) GetSortedUsers(sortKey string, sortDirection string) []*UserCacheHolderEntry {
h.mux.RLock()
defer h.mux.RUnlock()
sortedUsers := make([]*UserCacheHolderEntry, 0, len(h.users))
for _, user := range h.users {
sortedUsers = append(sortedUsers, user)
}
sort.Slice(sortedUsers, func(i, j int) bool {
if sortDirection == "asc" {
return sortedUsers[i].Attributes[sortKey] < sortedUsers[j].Attributes[sortKey]
} else {
return sortedUsers[i].Attributes[sortKey] > sortedUsers[j].Attributes[sortKey]
}
})
return sortedUsers
}
func (h *SynchronizedUserCacheHolder) GetFilteredUsers(sortKey string, sortDirection string, search, searchDepartment string) []*UserCacheHolderEntry {
sortedUsers := h.GetSortedUsers(sortKey, sortDirection)
if search == "" && searchDepartment == "" {
return sortedUsers // skip filtering
}
filteredUsers := make([]*UserCacheHolderEntry, 0, len(sortedUsers))
for _, user := range sortedUsers {
if searchDepartment != "" && user.Attributes["department"] != searchDepartment {
continue
}
if strings.Contains(user.Attributes["sn"], search) ||
strings.Contains(user.Attributes["givenName"], search) ||
strings.Contains(user.Mail, search) ||
strings.Contains(user.Attributes["department"], search) ||
strings.Contains(user.Attributes["telephoneNumber"], search) ||
strings.Contains(user.Attributes["mobile"], search) {
filteredUsers = append(filteredUsers, user)
}
}
return filteredUsers
}
func (h *SynchronizedUserCacheHolder) IsInGroup(username, gid string) bool {
userDN := h.GetUserDN(username)
if userDN == "" {
return false // user not found -> not in group
}
user := h.GetUserData(userDN)
if user == nil {
return false
}
for _, group := range user.Groups {
if group == gid {
return true
}
}
return false
}
func (h *SynchronizedUserCacheHolder) UserExists(username string) bool {
userDN := h.GetUserDN(username)
if userDN == "" {
return false // user not found
}
return true
}
func (h *SynchronizedUserCacheHolder) GetUserDN(username string) string {
userDN := ""
for dn, user := range h.users {
accName := strings.ToLower(user.Attributes["sAMAccountName"])
if accName == username {
userDN = dn
break
}
}
return userDN
}
func (h *SynchronizedUserCacheHolder) GetUserDNByMail(mail string) string {
userDN := ""
for dn, user := range h.users {
accMail := strings.ToLower(user.Attributes["mail"])
if accMail == mail {
userDN = dn
break
}
}
return userDN
}
func (h *SynchronizedUserCacheHolder) GetTeamLeaders() []*UserCacheHolderEntry {
sortedUsers := h.GetSortedUsers("sn", "asc")
teamLeaders := make([]*UserCacheHolderEntry, 0, len(sortedUsers))
for _, user := range sortedUsers {
if user.Attributes["extensionAttribute2"] != "Teamleiter" {
continue
}
teamLeaders = append(teamLeaders, user)
}
return teamLeaders
}
func (h *SynchronizedUserCacheHolder) GetDepartments() []string {
h.mux.RLock()
defer h.mux.RUnlock()
departmentSet := make(map[string]struct{})
for _, user := range h.users {
if user.Attributes["department"] == "" {
continue
}
departmentSet[user.Attributes["department"]] = struct{}{}
}
departments := make([]string, len(departmentSet))
i := 0
for department := range departmentSet {
departments[i] = department
i++
}
sort.Strings(departments)
return departments
}
// --------------------------------------------------------------------------------------------------------------------
// Cache Handler, LDAP interaction
// --------------------------------------------------------------------------------------------------------------------
type UserCache struct {
Cfg *Config
LastError error
UpdatedAt time.Time
userData UserCacheHolder
}
func NewUserCache(config Config, store UserCacheHolder) *UserCache {
uc := &UserCache{
Cfg: &config,
UpdatedAt: time.Now(),
userData: store,
}
log.Infof("Filling user cache...")
err := uc.Update(true)
log.Infof("User cache filled!")
uc.LastError = err
return uc
}
func (u UserCache) open() (*ldap.Conn, error) {
conn, err := ldap.DialURL(u.Cfg.URL)
if err != nil {
return nil, err
}
if u.Cfg.StartTLS {
// Reconnect with TLS
err = conn.StartTLS(&tls.Config{InsecureSkipVerify: true})
if err != nil {
return nil, err
}
}
err = conn.Bind(u.Cfg.BindUser, u.Cfg.BindPass)
if err != nil {
return nil, err
}
return conn, nil
}
func (u UserCache) close(conn *ldap.Conn) {
if conn != nil {
conn.Close()
}
}
// Update updates the user cache in background, minimal locking will happen
func (u *UserCache) Update(filter bool) error {
log.Debugf("Updating ldap cache...")
client, err := u.open()
if err != nil {
u.LastError = err
return err
}
defer u.close(client)
// Search for the given username
searchRequest := ldap.NewSearchRequest(
u.Cfg.BaseDN,
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
"(objectClass=organizationalPerson)",
Fields,
nil,
)
sr, err := client.Search(searchRequest)
if err != nil {
u.LastError = err
return err
}
tmpData := make([]RawLdapData, 0, len(sr.Entries))
for _, entry := range sr.Entries {
if filter {
usernameAttr := strings.ToLower(entry.GetAttributeValue("sAMAccountName"))
firstNameAttr := entry.GetAttributeValue("givenName")
lastNameAttr := entry.GetAttributeValue("sn")
mailAttr := entry.GetAttributeValue("mail")
userAccountControl := entry.GetAttributeValue("userAccountControl")
employeeID := entry.GetAttributeValue("employeeID")
dn := entry.GetAttributeValue("distinguishedName")
if usernameAttr == "" || firstNameAttr == "" || lastNameAttr == "" || mailAttr == "" || employeeID == "" {
continue // prefilter...
}
if userAccountControl == "" || userAccountControl == "514" {
continue // 514 means account is disabled
}
if entry.DN != dn {
log.Errorf("LDAP inconsistent: '%s' != '%s'", entry.DN, dn)
continue
}
}
tmp := RawLdapData{
DN: entry.DN,
Attributes: make(map[string]string, len(Fields)),
RawAttributes: make(map[string][][]byte, len(Fields)),
}
for _, field := range Fields {
tmp.Attributes[field] = entry.GetAttributeValue(field)
tmp.RawAttributes[field] = entry.GetRawAttributeValues(field)
}
tmpData = append(tmpData, tmp)
}
// Copy to userdata
u.userData.SetAllUsers(tmpData)
u.UpdatedAt = time.Now()
u.LastError = nil
log.Debug("Ldap cache updated...")
return nil
}
func (u *UserCache) ModifyUserData(dn string, newData RawLdapData, fields []string) error {
if fields == nil {
fields = ModifiableFields // default
}
existingUserData := u.userData.GetUser(dn)
if existingUserData == nil {
return fmt.Errorf("user with dn %s not found", dn)
}
modify := ldap.NewModifyRequest(dn, nil)
for _, ldapAttribute := range fields {
if existingUserData.Attributes[ldapAttribute] == newData.Attributes[ldapAttribute] {
continue // do not update unchanged fields
}
if len(existingUserData.RawAttributes[ldapAttribute]) == 0 && newData.Attributes[ldapAttribute] != "" {
modify.Add(ldapAttribute, []string{newData.Attributes[ldapAttribute]})
newData.RawAttributes[ldapAttribute] = [][]byte{
[]byte(newData.Attributes[ldapAttribute]),
}
}
if len(existingUserData.RawAttributes[ldapAttribute]) != 0 && newData.Attributes[ldapAttribute] != "" {
modify.Replace(ldapAttribute, []string{newData.Attributes[ldapAttribute]})
newData.RawAttributes[ldapAttribute][0] = []byte(newData.Attributes[ldapAttribute])
}
if len(existingUserData.RawAttributes[ldapAttribute]) != 0 && newData.Attributes[ldapAttribute] == "" {
modify.Delete(ldapAttribute, []string{})
newData.RawAttributes[ldapAttribute] = [][]byte{} // clear list
}
}
if len(modify.Changes) == 0 {
return nil
}
client, err := u.open()
if err != nil {
u.LastError = err
return err
}
defer u.close(client)
err = client.Modify(modify)
if err != nil {
return err
}
// Once written to ldap, update the local cache
u.userData.SetUser(newData)
return nil
}

242
internal/server/core.go Normal file
View File

@@ -0,0 +1,242 @@
package server
import (
"encoding/gob"
"errors"
"math/rand"
"os"
"path/filepath"
"time"
"github.com/h44z/wg-portal/internal/wireguard"
"github.com/h44z/wg-portal/internal/common"
"github.com/h44z/wg-portal/internal/ldap"
log "github.com/sirupsen/logrus"
"github.com/gin-gonic/contrib/sessions"
"github.com/gin-gonic/gin"
)
const SessionIdentifier = "wgPortalSession"
const CacheRefreshDuration = 5 * time.Minute
func init() {
gob.Register(SessionData{})
}
type SessionData struct {
LoggedIn bool
IsAdmin bool
UID string
UserName string
Firstname string
Lastname string
SortedBy string
SortDirection string
Search string
AlertData string
AlertType string
}
type AlertData struct {
HasAlert bool
Message string
Type string
}
type StaticData struct {
WebsiteTitle string
WebsiteLogo string
LoginURL string
LogoutURL string
}
type Server struct {
// Core components
config *common.Config
server *gin.Engine
users *UserManager
// WireGuard stuff
wg *wireguard.Manager
// LDAP stuff
ldapAuth ldap.Authentication
ldapUsers *ldap.SynchronizedUserCacheHolder
ldapCacheUpdater *ldap.UserCache
}
func (s *Server) Setup() error {
// Init rand
rand.Seed(time.Now().UnixNano())
s.config = common.NewConfig()
// Setup LDAP stuff
s.ldapAuth = ldap.NewAuthentication(s.config.LDAP)
s.ldapUsers = &ldap.SynchronizedUserCacheHolder{}
s.ldapUsers.Init()
s.ldapCacheUpdater = ldap.NewUserCache(s.config.LDAP, s.ldapUsers)
if s.ldapCacheUpdater.LastError != nil {
return s.ldapCacheUpdater.LastError
}
// Setup WireGuard stuff
s.wg = &wireguard.Manager{Cfg: &s.config.WG}
if err := s.wg.Init(); err != nil {
return err
}
// Setup user manager
s.users = NewUserManager()
if s.users == nil {
return errors.New("unable to setup user manager")
}
s.users.InitWithDevice(s.wg.GetDeviceInfo())
s.users.InitWithPeers(s.wg.GetPeerList())
dir := s.getExecutableDirectory()
rDir, _ := filepath.Abs(filepath.Dir(os.Args[0]))
log.Infof("Real working directory: %s", rDir)
log.Infof("Current working directory: %s", dir)
// Setup http server
s.server = gin.Default()
// Setup templates
log.Infof("Loading templates from: %s", filepath.Join(dir, "/assets/tpl/*.html"))
s.server.LoadHTMLGlob(filepath.Join(dir, "/assets/tpl/*.html"))
s.server.Use(sessions.Sessions("authsession", sessions.NewCookieStore([]byte("secret"))))
// Serve static files
s.server.Static("/css", filepath.Join(dir, "/assets/css"))
s.server.Static("/js", filepath.Join(dir, "/assets/js"))
s.server.Static("/img", filepath.Join(dir, "/assets/img"))
s.server.Static("/fonts", filepath.Join(dir, "/assets/fonts"))
// Setup all routes
SetupRoutes(s)
log.Infof("Setup of service completed!")
return nil
}
func (s *Server) Run() {
// Start ldap group watcher
go func(s *Server) {
for {
time.Sleep(CacheRefreshDuration)
if err := s.ldapCacheUpdater.Update(true); err != nil {
log.Warnf("Failed to update ldap group cache: %v", err)
}
log.Debugf("Refreshed LDAP permissions!")
}
}(s)
// Run web service
err := s.server.Run(s.config.Core.ListeningAddress)
if err != nil {
log.Errorf("Failed to listen and serve on %s: %v", s.config.Core.ListeningAddress, err)
}
}
func (s *Server) getExecutableDirectory() string {
dir, err := filepath.Abs(filepath.Dir(os.Args[0]))
if err != nil {
log.Errorf("Failed to get executable directory: %v", err)
}
if _, err := os.Stat(filepath.Join(dir, "assets")); os.IsNotExist(err) {
return "." // assets directory not found -> we are developing in goland =)
}
return dir
}
func (s *Server) getSessionData(c *gin.Context) SessionData {
session := sessions.Default(c)
rawSessionData := session.Get(SessionIdentifier)
var sessionData SessionData
if rawSessionData != nil {
sessionData = rawSessionData.(SessionData)
} else {
sessionData = SessionData{
SortedBy: "sn",
SortDirection: "asc",
Firstname: "",
Lastname: "",
IsAdmin: false,
LoggedIn: false,
}
session.Set(SessionIdentifier, sessionData)
if err := session.Save(); err != nil {
log.Errorf("Failed to store session: %v", err)
}
}
return sessionData
}
func (s *Server) getAlertData(c *gin.Context) AlertData {
currentSession := s.getSessionData(c)
alertData := AlertData{
HasAlert: currentSession.AlertData != "",
Message: currentSession.AlertData,
Type: currentSession.AlertType,
}
// Reset alerts
_ = s.setAlert(c, "", "")
return alertData
}
func (s *Server) updateSessionData(c *gin.Context, data SessionData) error {
session := sessions.Default(c)
session.Set(SessionIdentifier, data)
if err := session.Save(); err != nil {
log.Errorf("Failed to store session: %v", err)
return err
}
return nil
}
func (s *Server) destroySessionData(c *gin.Context) error {
session := sessions.Default(c)
session.Delete(SessionIdentifier)
if err := session.Save(); err != nil {
log.Errorf("Failed to destroy session: %v", err)
return err
}
return nil
}
func (s *Server) getStaticData() StaticData {
return StaticData{
WebsiteTitle: s.config.Core.Title,
LoginURL: s.config.AuthRoutePrefix + "/login",
LogoutURL: s.config.AuthRoutePrefix + "/logout",
}
}
func (s *Server) setAlert(c *gin.Context, message, typ string) SessionData {
currentSession := s.getSessionData(c)
currentSession.AlertData = message
currentSession.AlertType = typ
_ = s.updateSessionData(c, currentSession)
return currentSession
}
func (s SessionData) GetSortIcon(field string) string {
if s.SortedBy != field {
return "fa-sort"
}
if s.SortDirection == "asc" {
return "fa-sort-alpha-down"
} else {
return "fa-sort-alpha-up"
}
}

View File

@@ -0,0 +1,57 @@
package server
import (
"net/http"
"strconv"
"github.com/gin-gonic/gin"
)
func (s *Server) GetIndex(c *gin.Context) {
c.HTML(http.StatusOK, "index.html", gin.H{
"route": c.Request.URL.Path,
"session": s.getSessionData(c),
"static": s.getStaticData(),
})
}
func (s *Server) HandleError(c *gin.Context, code int, message, details string) {
// TODO: if json
//c.JSON(code, gin.H{"error": message, "details": details})
c.HTML(code, "error.html", gin.H{
"data": gin.H{
"Code": strconv.Itoa(code),
"Message": message,
"Details": details,
},
"route": c.Request.URL.Path,
"session": s.getSessionData(c),
"static": s.getStaticData(),
})
}
func (s *Server) GetAdminIndex(c *gin.Context) {
dev, err := s.wg.GetDeviceInfo()
if err != nil {
s.HandleError(c, http.StatusInternalServerError, "WireGuard error", err.Error())
return
}
peers, err := s.wg.GetPeerList()
if err != nil {
s.HandleError(c, http.StatusInternalServerError, "WireGuard error", err.Error())
return
}
users := make([]User, len(peers))
for i, peer := range peers {
users[i] = s.users.GetOrCreateUserForPeer(peer)
}
c.HTML(http.StatusOK, "admin_index.html", gin.H{
"route": c.Request.URL.Path,
"session": s.getSessionData(c),
"static": s.getStaticData(),
"peers": users,
"interface": dev,
})
}

View File

@@ -0,0 +1,96 @@
package server
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
func (s *Server) GetLogin(c *gin.Context) {
currentSession := s.getSessionData(c)
if currentSession.LoggedIn {
c.Redirect(http.StatusSeeOther, "/") // already logged in
}
authError := c.DefaultQuery("err", "")
errMsg := "Unknown error occurred, try again!"
switch authError {
case "missingdata":
errMsg = "Invalid login data retrieved, please fill out all fields and try again!"
case "authfail":
errMsg = "Authentication failed!"
case "loginreq":
errMsg = "Login required!"
}
c.HTML(http.StatusOK, "login.html", gin.H{
"error": authError != "",
"message": errMsg,
"static": s.getStaticData(),
})
}
func (s *Server) PostLogin(c *gin.Context) {
currentSession := s.getSessionData(c)
if currentSession.LoggedIn {
// already logged in
c.Redirect(http.StatusSeeOther, "/")
return
}
username := strings.ToLower(c.PostForm("username"))
password := c.PostForm("password")
// Validate form input
if strings.Trim(username, " ") == "" || strings.Trim(password, " ") == "" {
c.Redirect(http.StatusSeeOther, s.config.AuthRoutePrefix+"/login?err=missingdata")
return
}
// Check if user is in cache, avoid unnecessary ldap requests
if !s.ldapUsers.UserExists(username) {
c.Redirect(http.StatusSeeOther, s.config.AuthRoutePrefix+"/login?err=authfail")
}
// Check if username and password match
if !s.ldapAuth.CheckLogin(username, password) {
c.Redirect(http.StatusSeeOther, s.config.AuthRoutePrefix+"/login?err=authfail")
return
}
dn := s.ldapUsers.GetUserDN(username)
userData := s.ldapUsers.GetUserData(dn)
sessionData := SessionData{
LoggedIn: true,
IsAdmin: s.ldapUsers.IsInGroup(username, s.config.AdminLdapGroup),
UID: userData.GetUID(),
UserName: username,
Firstname: userData.Firstname,
Lastname: userData.Lastname,
SortedBy: "sn",
SortDirection: "asc",
Search: "",
}
if err := s.updateSessionData(c, sessionData); err != nil {
s.HandleError(c, http.StatusInternalServerError, "login error", "failed to save session")
return
}
c.Redirect(http.StatusSeeOther, "/")
}
func (s *Server) GetLogout(c *gin.Context) {
currentSession := s.getSessionData(c)
if !currentSession.LoggedIn { // Not logged in
c.Redirect(http.StatusSeeOther, s.config.LogoutRedirectPath)
return
}
if err := s.destroySessionData(c); err != nil {
s.HandleError(c, http.StatusInternalServerError, "logout error", "failed to destroy session")
return
}
c.Redirect(http.StatusSeeOther, s.config.LogoutRedirectPath)
}

51
internal/server/routes.go Normal file
View File

@@ -0,0 +1,51 @@
package server
import (
"net/http"
"github.com/gin-gonic/gin"
)
func SetupRoutes(s *Server) {
// Startpage
s.server.GET("/", s.GetIndex)
// Auth routes
auth := s.server.Group("/auth")
auth.GET("/login", s.GetLogin)
auth.POST("/login", s.PostLogin)
auth.GET("/logout", s.GetLogout)
// Admin routes
admin := s.server.Group("/admin")
admin.Use(s.RequireAuthentication(s.config.AdminLdapGroup))
admin.GET("/", s.GetAdminIndex)
// User routes
user := s.server.Group("/user")
user.Use(s.RequireAuthentication("")) // empty scope = all logged in users
}
func (s *Server) RequireAuthentication(scope string) gin.HandlerFunc {
return func(c *gin.Context) {
session := s.getSessionData(c)
if !session.LoggedIn {
// Abort the request with the appropriate error code
c.Abort()
c.Redirect(http.StatusSeeOther, s.config.AuthRoutePrefix+"/login?err=loginreq")
return
}
if scope != "" && !s.ldapUsers.IsInGroup(session.UserName, s.config.AdminLdapGroup) && // admins always have access
!s.ldapUsers.IsInGroup(session.UserName, scope) {
// Abort the request with the appropriate error code
c.Abort()
s.HandleError(c, http.StatusUnauthorized, "unauthorized", "not enough permissions")
return
}
// Continue down the chain to handler etc
c.Next()
}
}

View File

@@ -0,0 +1,322 @@
package server
import (
"crypto/md5"
"errors"
"fmt"
"net"
"strings"
"time"
"github.com/h44z/wg-portal/internal/common"
"github.com/h44z/wg-portal/internal/ldap"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
type User struct {
Peer wgtypes.Peer `gorm:"-"`
User *ldap.UserCacheHolderEntry `gorm:"-"` // optional, it is still possible to have users without ldap
UID string // uid for html identification
IsOnline bool `gorm:"-"`
Identifier string // Identifier AND Email make a WireGuard peer unique
Email string `gorm:"index"`
IgnorePersistentKeepalive bool
PresharedKey string
AllowedIPsStr string
IPsStr string
AllowedIPs []string `gorm:"-"` // IPs that are used in the client config file
IPs []string `gorm:"-"` // The IPs of the client
PrivateKey string
PublicKey string `gorm:"primaryKey"`
DeactivatedAt *time.Time
CreatedBy string
UpdatedBy string
CreatedAt time.Time
UpdatedAt time.Time
}
func (u *User) GetPeerConfig() wgtypes.PeerConfig {
publicKey, _ := wgtypes.ParseKey(u.PublicKey)
var presharedKey *wgtypes.Key
if u.PresharedKey != "" {
presharedKeyTmp, _ := wgtypes.ParseKey(u.PresharedKey)
presharedKey = &presharedKeyTmp
}
cfg := wgtypes.PeerConfig{
PublicKey: publicKey,
Remove: false,
UpdateOnly: false,
PresharedKey: presharedKey,
Endpoint: nil,
PersistentKeepaliveInterval: nil,
ReplaceAllowedIPs: true,
AllowedIPs: make([]net.IPNet, len(u.IPs)),
}
for i, ip := range u.IPs {
_, ipNet, err := net.ParseCIDR(ip)
if err == nil {
cfg.AllowedIPs[i] = *ipNet
}
}
return cfg
}
type Device struct {
DeviceName string `gorm:"primaryKey"`
PrivateKey string
PublicKey string
PersistentKeepalive int
ListenPort int
Mtu int
Endpoint string
AllowedIPsStr string
IPsStr string
AllowedIPs []string `gorm:"-"` // IPs that are used in the client config file
IPs []string `gorm:"-"` // The IPs of the client
DNSStr string
DNS []string `gorm:"-"` // The DNS servers of the client
PreUp string
PostUp string
PreDown string
PostDown string
CreatedAt time.Time
UpdatedAt time.Time
}
func (d *Device) IsValid() bool {
if len(d.IPs) == 0 {
return false
}
if d.Endpoint == "" {
return false
}
return true
}
type UserManager struct {
db *gorm.DB
}
func NewUserManager() *UserManager {
um := &UserManager{}
var err error
um.db, err = gorm.Open(sqlite.Open("wg_portal.db"), &gorm.Config{})
if err != nil {
log.Errorf("failed to open sqlite database: %v", err)
return nil
}
err = um.db.AutoMigrate(&User{}, &Device{})
if err != nil {
log.Errorf("failed to migrate sqlite database: %v", err)
return nil
}
return um
}
func (u *UserManager) InitWithPeers(peers []wgtypes.Peer, err error) {
if err != nil {
log.Errorf("failed to init user-manager from peers: %v", err)
return
}
for _, peer := range peers {
u.GetOrCreateUserForPeer(peer)
}
}
func (u *UserManager) InitWithDevice(dev *wgtypes.Device, err error) {
if err != nil {
log.Errorf("failed to init user-manager from device: %v", err)
return
}
u.GetOrCreateDevice(*dev)
}
func (u *UserManager) GetAllUsers() []User {
users := make([]User, 0)
u.db.Find(&users)
for i := range users {
users[i].AllowedIPs = strings.Split(users[i].AllowedIPsStr, ", ")
users[i].IPs = strings.Split(users[i].IPsStr, ", ")
}
return users
}
func (u *UserManager) GetAllDevices() []Device {
devices := make([]Device, 0)
u.db.Find(&devices)
for i := range devices {
devices[i].AllowedIPs = strings.Split(devices[i].AllowedIPsStr, ", ")
devices[i].IPs = strings.Split(devices[i].IPsStr, ", ")
devices[i].DNS = strings.Split(devices[i].DNSStr, ", ")
}
return devices
}
func (u *UserManager) GetOrCreateUserForPeer(peer wgtypes.Peer) User {
user := User{}
u.db.Where("public_key = ?", peer.PublicKey.String()).FirstOrInit(&user)
if user.PublicKey == "" { // user not found, create
user.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey.String())))
user.PublicKey = peer.PublicKey.String()
user.PrivateKey = "" // UNKNOWN
if peer.PresharedKey != (wgtypes.Key{}) {
user.PresharedKey = peer.PresharedKey.String()
}
user.Email = "autodetected@example.com"
user.Identifier = "Autodetected (" + user.PublicKey[0:8] + ")"
user.UpdatedAt = time.Now()
user.CreatedAt = time.Now()
user.AllowedIPs = make([]string, 0) // UNKNOWN
user.IPs = make([]string, len(peer.AllowedIPs))
for i, ip := range peer.AllowedIPs {
user.IPs[i] = ip.String()
}
user.AllowedIPsStr = strings.Join(user.AllowedIPs, ", ")
user.IPsStr = strings.Join(user.IPs, ", ")
res := u.db.Create(&user)
if res.Error != nil {
log.Errorf("failed to create autodetected peer: %v", res.Error)
}
}
user.IPs = strings.Split(user.IPsStr, ", ")
user.AllowedIPs = strings.Split(user.AllowedIPsStr, ", ")
return user
}
func (u *UserManager) CreateUser(user User) error {
user.UID = fmt.Sprintf("u%x", md5.Sum([]byte(user.PublicKey)))
user.UpdatedAt = time.Now()
user.CreatedAt = time.Now()
user.AllowedIPsStr = strings.Join(user.AllowedIPs, ", ")
user.IPsStr = strings.Join(user.IPs, ", ")
res := u.db.Create(&user)
if res.Error != nil {
log.Errorf("failed to create user: %v", res.Error)
return res.Error
}
return nil
}
func (u *UserManager) UpdateUser(user User) error {
user.UpdatedAt = time.Now()
user.AllowedIPsStr = strings.Join(user.AllowedIPs, ", ")
user.IPsStr = strings.Join(user.IPs, ", ")
res := u.db.Save(&user)
if res.Error != nil {
log.Errorf("failed to update user: %v", res.Error)
return res.Error
}
return nil
}
func (u *UserManager) GetAllReservedIps() ([]string, error) {
reservedIps := make([]string, 0)
users := u.GetAllUsers()
for _, user := range users {
for _, cidr := range user.IPs {
ip, _, err := net.ParseCIDR(cidr)
if err != nil {
log.WithFields(log.Fields{
"err": err,
"cidr": cidr,
}).Error("failed to ip from cidr")
} else {
reservedIps = append(reservedIps, ip.String())
}
}
}
devices := u.GetAllDevices()
for _, device := range devices {
for _, cidr := range device.IPs {
ip, _, err := net.ParseCIDR(cidr)
if err != nil {
log.WithFields(log.Fields{
"err": err,
"cidr": cidr,
}).Error("failed to ip from cidr")
} else {
reservedIps = append(reservedIps, ip.String())
}
}
}
return reservedIps, nil
}
// GetAvailableIp search for an available ip in cidr against a list of reserved ips
func (u *UserManager) GetAvailableIp(cidr string, reserved []string) (string, error) {
ip, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
return "", err
}
// this two addresses are not usable
broadcastAddr := common.BroadcastAddr(ipnet).String()
networkAddr := ipnet.IP.String()
for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); common.IncreaseIP(ip) {
ok := true
address := ip.String()
for _, r := range reserved {
if address == r {
ok = false
break
}
}
if ok && address != networkAddr && address != broadcastAddr {
return address, nil
}
}
return "", errors.New("no more available address from cidr")
}
func (u *UserManager) GetOrCreateDevice(dev wgtypes.Device) Device {
device := Device{}
u.db.Where("device_name = ?", dev.Name).FirstOrInit(&device)
if device.PublicKey == "" { // device not found, create
device.PublicKey = dev.PublicKey.String()
device.PrivateKey = dev.PrivateKey.String()
device.DeviceName = dev.Name
device.ListenPort = dev.ListenPort
device.Mtu = 0
device.PersistentKeepalive = 16 // Default
res := u.db.Create(&device)
if res.Error != nil {
log.Errorf("failed to create autodetected device: %v", res.Error)
}
}
device.IPs = strings.Split(device.IPsStr, ", ")
device.AllowedIPs = strings.Split(device.AllowedIPsStr, ", ")
device.DNS = strings.Split(device.DNSStr, ", ")
return device
}

View File

@@ -0,0 +1,5 @@
package wireguard
type Config struct {
DeviceName string `yaml:"device" envconfig:"WG_DEVICE"`
}

View File

@@ -0,0 +1,103 @@
package wireguard
import (
"fmt"
"sync"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type Manager struct {
Cfg *Config
wg *wgctrl.Client
mux sync.RWMutex
}
func (m *Manager) Init() error {
var err error
m.wg, err = wgctrl.New()
if err != nil {
return fmt.Errorf("could not create WireGuard client: %w", err)
}
return nil
}
func (m *Manager) GetDeviceInfo() (*wgtypes.Device, error) {
dev, err := m.wg.Device(m.Cfg.DeviceName)
if err != nil {
return nil, fmt.Errorf("could not get WireGuard device: %w", err)
}
return dev, nil
}
func (m *Manager) GetPeerList() ([]wgtypes.Peer, error) {
m.mux.RLock()
defer m.mux.RUnlock()
dev, err := m.wg.Device(m.Cfg.DeviceName)
if err != nil {
return nil, fmt.Errorf("could not get WireGuard device: %w", err)
}
return dev.Peers, nil
}
func (m *Manager) GetPeer(pubKey string) (*wgtypes.Peer, error) {
m.mux.RLock()
defer m.mux.RUnlock()
publicKey, err := wgtypes.ParseKey(pubKey)
if err != nil {
return nil, fmt.Errorf("invalid public key: %w", err)
}
peers, err := m.GetPeerList()
if err != nil {
return nil, fmt.Errorf("could not get WireGuard peers: %w", err)
}
for _, peer := range peers {
if peer.PublicKey == publicKey {
return &peer, nil
}
}
return nil, fmt.Errorf("could not find WireGuard peer: %s", pubKey)
}
func (m *Manager) AddPeer(cfg wgtypes.PeerConfig) error {
m.mux.Lock()
defer m.mux.Unlock()
err := m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}})
if err != nil {
return fmt.Errorf("could not configure WireGuard device: %w", err)
}
return nil
}
func (m *Manager) RemovePeer(pubKey string) error {
m.mux.Lock()
defer m.mux.Unlock()
publicKey, err := wgtypes.ParseKey(pubKey)
if err != nil {
return fmt.Errorf("invalid public key: %w", err)
}
peer := wgtypes.PeerConfig{
PublicKey: publicKey,
Remove: true,
}
err = m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{peer}})
if err != nil {
return fmt.Errorf("could not configure WireGuard device: %w", err)
}
return nil
}