mirror of
https://github.com/h44z/wg-portal.git
synced 2025-09-13 14:31:15 +00:00
initial commit
This commit is contained in:
99
internal/common/configuration.go
Normal file
99
internal/common/configuration.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"reflect"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/wireguard"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/ldap"
|
||||
"github.com/kelseyhightower/envconfig"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var ErrInvalidSpecification = errors.New("specification must be a struct pointer")
|
||||
|
||||
// LoadConfigFile parses yaml files. It uses to yaml annotation to store the data in a struct.
|
||||
func loadConfigFile(cfg interface{}, filename string) error {
|
||||
s := reflect.ValueOf(cfg)
|
||||
|
||||
if s.Kind() != reflect.Ptr {
|
||||
return ErrInvalidSpecification
|
||||
}
|
||||
s = s.Elem()
|
||||
if s.Kind() != reflect.Struct {
|
||||
return ErrInvalidSpecification
|
||||
}
|
||||
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
decoder := yaml.NewDecoder(f)
|
||||
err = decoder.Decode(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadConfigEnv processes envconfig annotations and loads environment variables to the given configuration struct.
|
||||
func loadConfigEnv(cfg interface{}) error {
|
||||
err := envconfig.Process("", cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Core struct {
|
||||
ListeningAddress string `yaml:"listeningAddress" envconfig:"LISTENING_ADDRESS"`
|
||||
Title string `yaml:"title" envconfig:"WEBSITE_TITLE"`
|
||||
} `yaml:"core"`
|
||||
|
||||
LDAP ldap.Config `yaml:"ldap"`
|
||||
WG wireguard.Config `yaml:"wg"`
|
||||
AdminLdapGroup string `yaml:"adminLdapGroup" envconfig:"ADMIN_LDAP_GROUP"`
|
||||
LogoutRedirectPath string `yaml:"logoutRedirectPath" envconfig:"LOGOUT_REDIRECT_PATH"`
|
||||
AuthRoutePrefix string `yaml:"authRoutePrefix" envconfig:"AUTH_ROUTE_PREFIX"`
|
||||
}
|
||||
|
||||
func NewConfig() *Config {
|
||||
cfg := &Config{}
|
||||
|
||||
// Default config
|
||||
cfg.Core.ListeningAddress = ":8080"
|
||||
cfg.Core.Title = "WireGuard VPN"
|
||||
cfg.LDAP.URL = "ldap://srv-ad01.company.local:389"
|
||||
cfg.LDAP.BaseDN = "DC=COMPANY,DC=LOCAL"
|
||||
cfg.LDAP.StartTLS = true
|
||||
cfg.LDAP.BindUser = "company\\ldap_wireguard"
|
||||
cfg.LDAP.BindPass = "SuperSecret"
|
||||
cfg.WG.DeviceName = "wg0"
|
||||
cfg.AdminLdapGroup = "CN=WireGuardAdmins,OU=_O_IT,DC=COMPANY,DC=LOCAL"
|
||||
cfg.LogoutRedirectPath = "/"
|
||||
cfg.AuthRoutePrefix = "/auth"
|
||||
|
||||
// Load config from file and environment
|
||||
cfgFile, ok := os.LookupEnv("CONFIG_FILE")
|
||||
if !ok {
|
||||
cfgFile = "config.yml" // Default config file
|
||||
}
|
||||
err := loadConfigFile(cfg, cfgFile)
|
||||
if err != nil {
|
||||
log.Warnf("unable to load config.yml file: %v, using default configuration...", err)
|
||||
}
|
||||
err = loadConfigEnv(cfg)
|
||||
if err != nil {
|
||||
log.Warnf("unable to load environment config: %v", err)
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
37
internal/common/iputil.go
Normal file
37
internal/common/iputil.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package common
|
||||
|
||||
import "net"
|
||||
|
||||
// BroadcastAddr returns the last address in the given network, or the broadcast address.
|
||||
func BroadcastAddr(n *net.IPNet) net.IP {
|
||||
// The golang net package doesn't make it easy to calculate the broadcast address. :(
|
||||
var broadcast net.IP
|
||||
if len(n.IP) == 4 {
|
||||
broadcast = net.ParseIP("0.0.0.0").To4()
|
||||
} else {
|
||||
broadcast = net.ParseIP("::")
|
||||
}
|
||||
for i := 0; i < len(n.IP); i++ {
|
||||
broadcast[i] = n.IP[i] | ^n.Mask[i]
|
||||
}
|
||||
return broadcast
|
||||
}
|
||||
|
||||
// http://play.golang.org/p/m8TNTtygK0
|
||||
func IncreaseIP(ip net.IP) {
|
||||
for j := len(ip) - 1; j >= 0; j-- {
|
||||
ip[j]++
|
||||
if ip[j] > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsIPv6 check if given ip is IPv6
|
||||
func IsIPv6(address string) bool {
|
||||
ip := net.ParseIP(address)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
return ip.To4() == nil
|
||||
}
|
88
internal/ldap/authentication.go
Normal file
88
internal/ldap/authentication.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
)
|
||||
|
||||
type Authentication struct {
|
||||
Cfg *Config
|
||||
}
|
||||
|
||||
func NewAuthentication(config Config) Authentication {
|
||||
a := Authentication{
|
||||
Cfg: &config,
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func (a Authentication) open() (*ldap.Conn, error) {
|
||||
conn, err := ldap.DialURL(a.Cfg.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if a.Cfg.StartTLS {
|
||||
// Reconnect with TLS
|
||||
err = conn.StartTLS(&tls.Config{InsecureSkipVerify: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = conn.Bind(a.Cfg.BindUser, a.Cfg.BindPass)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (a Authentication) close(conn *ldap.Conn) {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (a Authentication) CheckLogin(username, password string) bool {
|
||||
return a.CheckCustomLogin("sAMAccountName", username, password)
|
||||
}
|
||||
|
||||
func (a Authentication) CheckCustomLogin(userIdentifier, username, password string) bool {
|
||||
client, err := a.open()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer a.close(client)
|
||||
|
||||
// Search for the given username
|
||||
searchRequest := ldap.NewSearchRequest(
|
||||
a.Cfg.BaseDN,
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
|
||||
fmt.Sprintf("(&(objectClass=organizationalPerson)(%s=%s))", userIdentifier, username),
|
||||
[]string{"dn"},
|
||||
nil,
|
||||
)
|
||||
|
||||
sr, err := client.Search(searchRequest)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(sr.Entries) != 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
userDN := sr.Entries[0].DN
|
||||
|
||||
// Bind as the user to verify their password
|
||||
err = client.Bind(userDN, password)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
9
internal/ldap/ldap.go
Normal file
9
internal/ldap/ldap.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package ldap
|
||||
|
||||
type Config struct {
|
||||
URL string `yaml:"url" envconfig:"LDAP_URL"`
|
||||
StartTLS bool `yaml:"startTLS" envconfig:"LDAP_STARTTLS"`
|
||||
BaseDN string `yaml:"dn" envconfig:"LDAP_BASEDN"`
|
||||
BindUser string `yaml:"user" envconfig:"LDAP_USER"`
|
||||
BindPass string `yaml:"pass" envconfig:"LDAP_PASSWORD"`
|
||||
}
|
455
internal/ldap/usercache.go
Normal file
455
internal/ldap/usercache.go
Normal file
@@ -0,0 +1,455 @@
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var Fields = []string{"givenName", "sn", "mail", "department", "memberOf", "sAMAccountName", "telephoneNumber",
|
||||
"mobile", "displayName", "cn", "title", "company", "manager", "streetAddress", "employeeID", "memberOf", "l",
|
||||
"st", "postalCode", "co", "facsimileTelephoneNumber", "pager", "thumbnailPhoto", "otherMobile",
|
||||
"extensionAttribute2", "distinguishedName", "userAccountControl"}
|
||||
|
||||
var ModifiableFields = []string{"department", "telephoneNumber", "mobile", "displayName", "title", "company",
|
||||
"manager", "streetAddress", "employeeID", "l", "st", "postalCode", "co", "thumbnailPhoto"}
|
||||
|
||||
// --------------------------------------------------------------------------------------------------------------------
|
||||
// Cache Data Store
|
||||
// --------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
type UserCacheHolder interface {
|
||||
Clear()
|
||||
SetAllUsers(users []RawLdapData)
|
||||
SetUser(data RawLdapData)
|
||||
GetUser(dn string) *RawLdapData
|
||||
GetUsers() []*RawLdapData
|
||||
}
|
||||
|
||||
type RawLdapData struct {
|
||||
DN string
|
||||
Attributes map[string]string
|
||||
RawAttributes map[string][][]byte
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------------------------------------------------
|
||||
// Sample Cache Data store
|
||||
// --------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
type UserCacheHolderEntry struct {
|
||||
RawLdapData
|
||||
Username string
|
||||
Mail string
|
||||
Firstname string
|
||||
Lastname string
|
||||
Groups []string
|
||||
}
|
||||
|
||||
func (e *UserCacheHolderEntry) CalcFieldsFromAttributes() {
|
||||
e.Username = strings.ToLower(e.Attributes["sAMAccountName"])
|
||||
e.Mail = e.Attributes["mail"]
|
||||
e.Firstname = e.Attributes["givenName"]
|
||||
e.Lastname = e.Attributes["sn"]
|
||||
e.Groups = make([]string, len(e.RawAttributes["memberOf"]))
|
||||
for i, group := range e.RawAttributes["memberOf"] {
|
||||
e.Groups[i] = string(group)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *UserCacheHolderEntry) GetUID() string {
|
||||
return fmt.Sprintf("u%x", md5.Sum([]byte(e.Attributes["distinguishedName"])))
|
||||
}
|
||||
|
||||
type SynchronizedUserCacheHolder struct {
|
||||
users map[string]*UserCacheHolderEntry
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
func (h *SynchronizedUserCacheHolder) Init() {
|
||||
h.users = make(map[string]*UserCacheHolderEntry)
|
||||
}
|
||||
|
||||
func (h *SynchronizedUserCacheHolder) Clear() {
|
||||
h.mux.Lock()
|
||||
defer h.mux.Unlock()
|
||||
|
||||
h.users = make(map[string]*UserCacheHolderEntry)
|
||||
}
|
||||
|
||||
func (h *SynchronizedUserCacheHolder) SetAllUsers(users []RawLdapData) {
|
||||
h.mux.Lock()
|
||||
defer h.mux.Unlock()
|
||||
|
||||
h.users = make(map[string]*UserCacheHolderEntry)
|
||||
|
||||
for i := range users {
|
||||
h.users[users[i].DN] = &UserCacheHolderEntry{RawLdapData: users[i]}
|
||||
h.users[users[i].DN].CalcFieldsFromAttributes()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *SynchronizedUserCacheHolder) SetUser(user RawLdapData) {
|
||||
h.mux.Lock()
|
||||
defer h.mux.Unlock()
|
||||
|
||||
h.users[user.DN] = &UserCacheHolderEntry{RawLdapData: user}
|
||||
h.users[user.DN].CalcFieldsFromAttributes()
|
||||
}
|
||||
|
||||
func (h *SynchronizedUserCacheHolder) GetUser(dn string) *RawLdapData {
|
||||
h.mux.RLock()
|
||||
defer h.mux.RUnlock()
|
||||
|
||||
return &h.users[dn].RawLdapData
|
||||
}
|
||||
|
||||
func (h *SynchronizedUserCacheHolder) GetUserData(dn string) *UserCacheHolderEntry {
|
||||
h.mux.RLock()
|
||||
defer h.mux.RUnlock()
|
||||
|
||||
return h.users[dn]
|
||||
}
|
||||
|
||||
func (h *SynchronizedUserCacheHolder) GetUsers() []*RawLdapData {
|
||||
h.mux.RLock()
|
||||
defer h.mux.RUnlock()
|
||||
|
||||
users := make([]*RawLdapData, 0, len(h.users))
|
||||
for _, user := range h.users {
|
||||
users = append(users, &user.RawLdapData)
|
||||
}
|
||||
|
||||
return users
|
||||
}
|
||||
|
||||
func (h *SynchronizedUserCacheHolder) GetSortedUsers(sortKey string, sortDirection string) []*UserCacheHolderEntry {
|
||||
h.mux.RLock()
|
||||
defer h.mux.RUnlock()
|
||||
|
||||
sortedUsers := make([]*UserCacheHolderEntry, 0, len(h.users))
|
||||
|
||||
for _, user := range h.users {
|
||||
sortedUsers = append(sortedUsers, user)
|
||||
}
|
||||
|
||||
sort.Slice(sortedUsers, func(i, j int) bool {
|
||||
if sortDirection == "asc" {
|
||||
return sortedUsers[i].Attributes[sortKey] < sortedUsers[j].Attributes[sortKey]
|
||||
} else {
|
||||
return sortedUsers[i].Attributes[sortKey] > sortedUsers[j].Attributes[sortKey]
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
return sortedUsers
|
||||
|
||||
}
|
||||
|
||||
func (h *SynchronizedUserCacheHolder) GetFilteredUsers(sortKey string, sortDirection string, search, searchDepartment string) []*UserCacheHolderEntry {
|
||||
sortedUsers := h.GetSortedUsers(sortKey, sortDirection)
|
||||
if search == "" && searchDepartment == "" {
|
||||
return sortedUsers // skip filtering
|
||||
}
|
||||
|
||||
filteredUsers := make([]*UserCacheHolderEntry, 0, len(sortedUsers))
|
||||
for _, user := range sortedUsers {
|
||||
if searchDepartment != "" && user.Attributes["department"] != searchDepartment {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(user.Attributes["sn"], search) ||
|
||||
strings.Contains(user.Attributes["givenName"], search) ||
|
||||
strings.Contains(user.Mail, search) ||
|
||||
strings.Contains(user.Attributes["department"], search) ||
|
||||
strings.Contains(user.Attributes["telephoneNumber"], search) ||
|
||||
strings.Contains(user.Attributes["mobile"], search) {
|
||||
filteredUsers = append(filteredUsers, user)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredUsers
|
||||
}
|
||||
|
||||
func (h *SynchronizedUserCacheHolder) IsInGroup(username, gid string) bool {
|
||||
userDN := h.GetUserDN(username)
|
||||
if userDN == "" {
|
||||
return false // user not found -> not in group
|
||||
}
|
||||
|
||||
user := h.GetUserData(userDN)
|
||||
if user == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, group := range user.Groups {
|
||||
if group == gid {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *SynchronizedUserCacheHolder) UserExists(username string) bool {
|
||||
userDN := h.GetUserDN(username)
|
||||
if userDN == "" {
|
||||
return false // user not found
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *SynchronizedUserCacheHolder) GetUserDN(username string) string {
|
||||
userDN := ""
|
||||
for dn, user := range h.users {
|
||||
accName := strings.ToLower(user.Attributes["sAMAccountName"])
|
||||
if accName == username {
|
||||
userDN = dn
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return userDN
|
||||
}
|
||||
|
||||
func (h *SynchronizedUserCacheHolder) GetUserDNByMail(mail string) string {
|
||||
userDN := ""
|
||||
for dn, user := range h.users {
|
||||
accMail := strings.ToLower(user.Attributes["mail"])
|
||||
if accMail == mail {
|
||||
userDN = dn
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return userDN
|
||||
}
|
||||
|
||||
func (h *SynchronizedUserCacheHolder) GetTeamLeaders() []*UserCacheHolderEntry {
|
||||
|
||||
sortedUsers := h.GetSortedUsers("sn", "asc")
|
||||
teamLeaders := make([]*UserCacheHolderEntry, 0, len(sortedUsers))
|
||||
for _, user := range sortedUsers {
|
||||
if user.Attributes["extensionAttribute2"] != "Teamleiter" {
|
||||
continue
|
||||
}
|
||||
|
||||
teamLeaders = append(teamLeaders, user)
|
||||
}
|
||||
|
||||
return teamLeaders
|
||||
}
|
||||
|
||||
func (h *SynchronizedUserCacheHolder) GetDepartments() []string {
|
||||
h.mux.RLock()
|
||||
defer h.mux.RUnlock()
|
||||
|
||||
departmentSet := make(map[string]struct{})
|
||||
for _, user := range h.users {
|
||||
if user.Attributes["department"] == "" {
|
||||
continue
|
||||
}
|
||||
departmentSet[user.Attributes["department"]] = struct{}{}
|
||||
}
|
||||
|
||||
departments := make([]string, len(departmentSet))
|
||||
i := 0
|
||||
for department := range departmentSet {
|
||||
departments[i] = department
|
||||
i++
|
||||
}
|
||||
|
||||
sort.Strings(departments)
|
||||
|
||||
return departments
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------------------------------------------------
|
||||
// Cache Handler, LDAP interaction
|
||||
// --------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
type UserCache struct {
|
||||
Cfg *Config
|
||||
LastError error
|
||||
UpdatedAt time.Time
|
||||
userData UserCacheHolder
|
||||
}
|
||||
|
||||
func NewUserCache(config Config, store UserCacheHolder) *UserCache {
|
||||
uc := &UserCache{
|
||||
Cfg: &config,
|
||||
UpdatedAt: time.Now(),
|
||||
userData: store,
|
||||
}
|
||||
|
||||
log.Infof("Filling user cache...")
|
||||
err := uc.Update(true)
|
||||
log.Infof("User cache filled!")
|
||||
uc.LastError = err
|
||||
|
||||
return uc
|
||||
}
|
||||
|
||||
func (u UserCache) open() (*ldap.Conn, error) {
|
||||
conn, err := ldap.DialURL(u.Cfg.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if u.Cfg.StartTLS {
|
||||
// Reconnect with TLS
|
||||
err = conn.StartTLS(&tls.Config{InsecureSkipVerify: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = conn.Bind(u.Cfg.BindUser, u.Cfg.BindPass)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (u UserCache) close(conn *ldap.Conn) {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Update updates the user cache in background, minimal locking will happen
|
||||
func (u *UserCache) Update(filter bool) error {
|
||||
log.Debugf("Updating ldap cache...")
|
||||
client, err := u.open()
|
||||
if err != nil {
|
||||
u.LastError = err
|
||||
return err
|
||||
}
|
||||
defer u.close(client)
|
||||
|
||||
// Search for the given username
|
||||
searchRequest := ldap.NewSearchRequest(
|
||||
u.Cfg.BaseDN,
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
|
||||
"(objectClass=organizationalPerson)",
|
||||
Fields,
|
||||
nil,
|
||||
)
|
||||
|
||||
sr, err := client.Search(searchRequest)
|
||||
if err != nil {
|
||||
u.LastError = err
|
||||
return err
|
||||
}
|
||||
|
||||
tmpData := make([]RawLdapData, 0, len(sr.Entries))
|
||||
|
||||
for _, entry := range sr.Entries {
|
||||
if filter {
|
||||
usernameAttr := strings.ToLower(entry.GetAttributeValue("sAMAccountName"))
|
||||
firstNameAttr := entry.GetAttributeValue("givenName")
|
||||
lastNameAttr := entry.GetAttributeValue("sn")
|
||||
mailAttr := entry.GetAttributeValue("mail")
|
||||
userAccountControl := entry.GetAttributeValue("userAccountControl")
|
||||
employeeID := entry.GetAttributeValue("employeeID")
|
||||
dn := entry.GetAttributeValue("distinguishedName")
|
||||
|
||||
if usernameAttr == "" || firstNameAttr == "" || lastNameAttr == "" || mailAttr == "" || employeeID == "" {
|
||||
continue // prefilter...
|
||||
}
|
||||
|
||||
if userAccountControl == "" || userAccountControl == "514" {
|
||||
continue // 514 means account is disabled
|
||||
}
|
||||
|
||||
if entry.DN != dn {
|
||||
log.Errorf("LDAP inconsistent: '%s' != '%s'", entry.DN, dn)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
tmp := RawLdapData{
|
||||
DN: entry.DN,
|
||||
Attributes: make(map[string]string, len(Fields)),
|
||||
RawAttributes: make(map[string][][]byte, len(Fields)),
|
||||
}
|
||||
|
||||
for _, field := range Fields {
|
||||
tmp.Attributes[field] = entry.GetAttributeValue(field)
|
||||
tmp.RawAttributes[field] = entry.GetRawAttributeValues(field)
|
||||
}
|
||||
|
||||
tmpData = append(tmpData, tmp)
|
||||
}
|
||||
|
||||
// Copy to userdata
|
||||
u.userData.SetAllUsers(tmpData)
|
||||
u.UpdatedAt = time.Now()
|
||||
u.LastError = nil
|
||||
|
||||
log.Debug("Ldap cache updated...")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UserCache) ModifyUserData(dn string, newData RawLdapData, fields []string) error {
|
||||
if fields == nil {
|
||||
fields = ModifiableFields // default
|
||||
}
|
||||
|
||||
existingUserData := u.userData.GetUser(dn)
|
||||
if existingUserData == nil {
|
||||
return fmt.Errorf("user with dn %s not found", dn)
|
||||
}
|
||||
|
||||
modify := ldap.NewModifyRequest(dn, nil)
|
||||
|
||||
for _, ldapAttribute := range fields {
|
||||
if existingUserData.Attributes[ldapAttribute] == newData.Attributes[ldapAttribute] {
|
||||
continue // do not update unchanged fields
|
||||
}
|
||||
|
||||
if len(existingUserData.RawAttributes[ldapAttribute]) == 0 && newData.Attributes[ldapAttribute] != "" {
|
||||
modify.Add(ldapAttribute, []string{newData.Attributes[ldapAttribute]})
|
||||
newData.RawAttributes[ldapAttribute] = [][]byte{
|
||||
[]byte(newData.Attributes[ldapAttribute]),
|
||||
}
|
||||
}
|
||||
if len(existingUserData.RawAttributes[ldapAttribute]) != 0 && newData.Attributes[ldapAttribute] != "" {
|
||||
modify.Replace(ldapAttribute, []string{newData.Attributes[ldapAttribute]})
|
||||
newData.RawAttributes[ldapAttribute][0] = []byte(newData.Attributes[ldapAttribute])
|
||||
}
|
||||
if len(existingUserData.RawAttributes[ldapAttribute]) != 0 && newData.Attributes[ldapAttribute] == "" {
|
||||
modify.Delete(ldapAttribute, []string{})
|
||||
newData.RawAttributes[ldapAttribute] = [][]byte{} // clear list
|
||||
}
|
||||
}
|
||||
|
||||
if len(modify.Changes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
client, err := u.open()
|
||||
if err != nil {
|
||||
u.LastError = err
|
||||
return err
|
||||
}
|
||||
defer u.close(client)
|
||||
|
||||
err = client.Modify(modify)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Once written to ldap, update the local cache
|
||||
u.userData.SetUser(newData)
|
||||
|
||||
return nil
|
||||
}
|
242
internal/server/core.go
Normal file
242
internal/server/core.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/wireguard"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/common"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/ldap"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/gin-gonic/contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const SessionIdentifier = "wgPortalSession"
|
||||
const CacheRefreshDuration = 5 * time.Minute
|
||||
|
||||
func init() {
|
||||
gob.Register(SessionData{})
|
||||
}
|
||||
|
||||
type SessionData struct {
|
||||
LoggedIn bool
|
||||
IsAdmin bool
|
||||
UID string
|
||||
UserName string
|
||||
Firstname string
|
||||
Lastname string
|
||||
SortedBy string
|
||||
SortDirection string
|
||||
Search string
|
||||
AlertData string
|
||||
AlertType string
|
||||
}
|
||||
|
||||
type AlertData struct {
|
||||
HasAlert bool
|
||||
Message string
|
||||
Type string
|
||||
}
|
||||
|
||||
type StaticData struct {
|
||||
WebsiteTitle string
|
||||
WebsiteLogo string
|
||||
LoginURL string
|
||||
LogoutURL string
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
// Core components
|
||||
config *common.Config
|
||||
server *gin.Engine
|
||||
users *UserManager
|
||||
|
||||
// WireGuard stuff
|
||||
wg *wireguard.Manager
|
||||
|
||||
// LDAP stuff
|
||||
ldapAuth ldap.Authentication
|
||||
ldapUsers *ldap.SynchronizedUserCacheHolder
|
||||
ldapCacheUpdater *ldap.UserCache
|
||||
}
|
||||
|
||||
func (s *Server) Setup() error {
|
||||
// Init rand
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
s.config = common.NewConfig()
|
||||
|
||||
// Setup LDAP stuff
|
||||
s.ldapAuth = ldap.NewAuthentication(s.config.LDAP)
|
||||
s.ldapUsers = &ldap.SynchronizedUserCacheHolder{}
|
||||
s.ldapUsers.Init()
|
||||
s.ldapCacheUpdater = ldap.NewUserCache(s.config.LDAP, s.ldapUsers)
|
||||
if s.ldapCacheUpdater.LastError != nil {
|
||||
return s.ldapCacheUpdater.LastError
|
||||
}
|
||||
|
||||
// Setup WireGuard stuff
|
||||
s.wg = &wireguard.Manager{Cfg: &s.config.WG}
|
||||
if err := s.wg.Init(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Setup user manager
|
||||
s.users = NewUserManager()
|
||||
if s.users == nil {
|
||||
return errors.New("unable to setup user manager")
|
||||
}
|
||||
s.users.InitWithDevice(s.wg.GetDeviceInfo())
|
||||
s.users.InitWithPeers(s.wg.GetPeerList())
|
||||
|
||||
dir := s.getExecutableDirectory()
|
||||
rDir, _ := filepath.Abs(filepath.Dir(os.Args[0]))
|
||||
log.Infof("Real working directory: %s", rDir)
|
||||
log.Infof("Current working directory: %s", dir)
|
||||
|
||||
// Setup http server
|
||||
s.server = gin.Default()
|
||||
|
||||
// Setup templates
|
||||
log.Infof("Loading templates from: %s", filepath.Join(dir, "/assets/tpl/*.html"))
|
||||
s.server.LoadHTMLGlob(filepath.Join(dir, "/assets/tpl/*.html"))
|
||||
s.server.Use(sessions.Sessions("authsession", sessions.NewCookieStore([]byte("secret"))))
|
||||
|
||||
// Serve static files
|
||||
s.server.Static("/css", filepath.Join(dir, "/assets/css"))
|
||||
s.server.Static("/js", filepath.Join(dir, "/assets/js"))
|
||||
s.server.Static("/img", filepath.Join(dir, "/assets/img"))
|
||||
s.server.Static("/fonts", filepath.Join(dir, "/assets/fonts"))
|
||||
|
||||
// Setup all routes
|
||||
SetupRoutes(s)
|
||||
|
||||
log.Infof("Setup of service completed!")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Run() {
|
||||
// Start ldap group watcher
|
||||
go func(s *Server) {
|
||||
for {
|
||||
time.Sleep(CacheRefreshDuration)
|
||||
if err := s.ldapCacheUpdater.Update(true); err != nil {
|
||||
log.Warnf("Failed to update ldap group cache: %v", err)
|
||||
}
|
||||
log.Debugf("Refreshed LDAP permissions!")
|
||||
}
|
||||
}(s)
|
||||
|
||||
// Run web service
|
||||
err := s.server.Run(s.config.Core.ListeningAddress)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to listen and serve on %s: %v", s.config.Core.ListeningAddress, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) getExecutableDirectory() string {
|
||||
dir, err := filepath.Abs(filepath.Dir(os.Args[0]))
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get executable directory: %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(dir, "assets")); os.IsNotExist(err) {
|
||||
return "." // assets directory not found -> we are developing in goland =)
|
||||
}
|
||||
|
||||
return dir
|
||||
}
|
||||
|
||||
func (s *Server) getSessionData(c *gin.Context) SessionData {
|
||||
session := sessions.Default(c)
|
||||
rawSessionData := session.Get(SessionIdentifier)
|
||||
|
||||
var sessionData SessionData
|
||||
if rawSessionData != nil {
|
||||
sessionData = rawSessionData.(SessionData)
|
||||
} else {
|
||||
sessionData = SessionData{
|
||||
SortedBy: "sn",
|
||||
SortDirection: "asc",
|
||||
Firstname: "",
|
||||
Lastname: "",
|
||||
IsAdmin: false,
|
||||
LoggedIn: false,
|
||||
}
|
||||
session.Set(SessionIdentifier, sessionData)
|
||||
if err := session.Save(); err != nil {
|
||||
log.Errorf("Failed to store session: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return sessionData
|
||||
}
|
||||
|
||||
func (s *Server) getAlertData(c *gin.Context) AlertData {
|
||||
currentSession := s.getSessionData(c)
|
||||
alertData := AlertData{
|
||||
HasAlert: currentSession.AlertData != "",
|
||||
Message: currentSession.AlertData,
|
||||
Type: currentSession.AlertType,
|
||||
}
|
||||
// Reset alerts
|
||||
_ = s.setAlert(c, "", "")
|
||||
|
||||
return alertData
|
||||
}
|
||||
|
||||
func (s *Server) updateSessionData(c *gin.Context, data SessionData) error {
|
||||
session := sessions.Default(c)
|
||||
session.Set(SessionIdentifier, data)
|
||||
if err := session.Save(); err != nil {
|
||||
log.Errorf("Failed to store session: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) destroySessionData(c *gin.Context) error {
|
||||
session := sessions.Default(c)
|
||||
session.Delete(SessionIdentifier)
|
||||
if err := session.Save(); err != nil {
|
||||
log.Errorf("Failed to destroy session: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) getStaticData() StaticData {
|
||||
return StaticData{
|
||||
WebsiteTitle: s.config.Core.Title,
|
||||
LoginURL: s.config.AuthRoutePrefix + "/login",
|
||||
LogoutURL: s.config.AuthRoutePrefix + "/logout",
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) setAlert(c *gin.Context, message, typ string) SessionData {
|
||||
currentSession := s.getSessionData(c)
|
||||
currentSession.AlertData = message
|
||||
currentSession.AlertType = typ
|
||||
_ = s.updateSessionData(c, currentSession)
|
||||
|
||||
return currentSession
|
||||
}
|
||||
|
||||
func (s SessionData) GetSortIcon(field string) string {
|
||||
if s.SortedBy != field {
|
||||
return "fa-sort"
|
||||
}
|
||||
if s.SortDirection == "asc" {
|
||||
return "fa-sort-alpha-down"
|
||||
} else {
|
||||
return "fa-sort-alpha-up"
|
||||
}
|
||||
}
|
57
internal/server/handlers.go
Normal file
57
internal/server/handlers.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (s *Server) GetIndex(c *gin.Context) {
|
||||
c.HTML(http.StatusOK, "index.html", gin.H{
|
||||
"route": c.Request.URL.Path,
|
||||
"session": s.getSessionData(c),
|
||||
"static": s.getStaticData(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) HandleError(c *gin.Context, code int, message, details string) {
|
||||
// TODO: if json
|
||||
//c.JSON(code, gin.H{"error": message, "details": details})
|
||||
|
||||
c.HTML(code, "error.html", gin.H{
|
||||
"data": gin.H{
|
||||
"Code": strconv.Itoa(code),
|
||||
"Message": message,
|
||||
"Details": details,
|
||||
},
|
||||
"route": c.Request.URL.Path,
|
||||
"session": s.getSessionData(c),
|
||||
"static": s.getStaticData(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) GetAdminIndex(c *gin.Context) {
|
||||
dev, err := s.wg.GetDeviceInfo()
|
||||
if err != nil {
|
||||
s.HandleError(c, http.StatusInternalServerError, "WireGuard error", err.Error())
|
||||
return
|
||||
}
|
||||
peers, err := s.wg.GetPeerList()
|
||||
if err != nil {
|
||||
s.HandleError(c, http.StatusInternalServerError, "WireGuard error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
users := make([]User, len(peers))
|
||||
for i, peer := range peers {
|
||||
users[i] = s.users.GetOrCreateUserForPeer(peer)
|
||||
}
|
||||
c.HTML(http.StatusOK, "admin_index.html", gin.H{
|
||||
"route": c.Request.URL.Path,
|
||||
"session": s.getSessionData(c),
|
||||
"static": s.getStaticData(),
|
||||
"peers": users,
|
||||
"interface": dev,
|
||||
})
|
||||
}
|
96
internal/server/handlers_auth.go
Normal file
96
internal/server/handlers_auth.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (s *Server) GetLogin(c *gin.Context) {
|
||||
currentSession := s.getSessionData(c)
|
||||
if currentSession.LoggedIn {
|
||||
c.Redirect(http.StatusSeeOther, "/") // already logged in
|
||||
}
|
||||
|
||||
authError := c.DefaultQuery("err", "")
|
||||
errMsg := "Unknown error occurred, try again!"
|
||||
switch authError {
|
||||
case "missingdata":
|
||||
errMsg = "Invalid login data retrieved, please fill out all fields and try again!"
|
||||
case "authfail":
|
||||
errMsg = "Authentication failed!"
|
||||
case "loginreq":
|
||||
errMsg = "Login required!"
|
||||
}
|
||||
|
||||
c.HTML(http.StatusOK, "login.html", gin.H{
|
||||
"error": authError != "",
|
||||
"message": errMsg,
|
||||
"static": s.getStaticData(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) PostLogin(c *gin.Context) {
|
||||
currentSession := s.getSessionData(c)
|
||||
if currentSession.LoggedIn {
|
||||
// already logged in
|
||||
c.Redirect(http.StatusSeeOther, "/")
|
||||
return
|
||||
}
|
||||
|
||||
username := strings.ToLower(c.PostForm("username"))
|
||||
password := c.PostForm("password")
|
||||
|
||||
// Validate form input
|
||||
if strings.Trim(username, " ") == "" || strings.Trim(password, " ") == "" {
|
||||
c.Redirect(http.StatusSeeOther, s.config.AuthRoutePrefix+"/login?err=missingdata")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user is in cache, avoid unnecessary ldap requests
|
||||
if !s.ldapUsers.UserExists(username) {
|
||||
c.Redirect(http.StatusSeeOther, s.config.AuthRoutePrefix+"/login?err=authfail")
|
||||
}
|
||||
|
||||
// Check if username and password match
|
||||
if !s.ldapAuth.CheckLogin(username, password) {
|
||||
c.Redirect(http.StatusSeeOther, s.config.AuthRoutePrefix+"/login?err=authfail")
|
||||
return
|
||||
}
|
||||
|
||||
dn := s.ldapUsers.GetUserDN(username)
|
||||
userData := s.ldapUsers.GetUserData(dn)
|
||||
sessionData := SessionData{
|
||||
LoggedIn: true,
|
||||
IsAdmin: s.ldapUsers.IsInGroup(username, s.config.AdminLdapGroup),
|
||||
UID: userData.GetUID(),
|
||||
UserName: username,
|
||||
Firstname: userData.Firstname,
|
||||
Lastname: userData.Lastname,
|
||||
SortedBy: "sn",
|
||||
SortDirection: "asc",
|
||||
Search: "",
|
||||
}
|
||||
|
||||
if err := s.updateSessionData(c, sessionData); err != nil {
|
||||
s.HandleError(c, http.StatusInternalServerError, "login error", "failed to save session")
|
||||
return
|
||||
}
|
||||
c.Redirect(http.StatusSeeOther, "/")
|
||||
}
|
||||
|
||||
func (s *Server) GetLogout(c *gin.Context) {
|
||||
currentSession := s.getSessionData(c)
|
||||
|
||||
if !currentSession.LoggedIn { // Not logged in
|
||||
c.Redirect(http.StatusSeeOther, s.config.LogoutRedirectPath)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.destroySessionData(c); err != nil {
|
||||
s.HandleError(c, http.StatusInternalServerError, "logout error", "failed to destroy session")
|
||||
return
|
||||
}
|
||||
c.Redirect(http.StatusSeeOther, s.config.LogoutRedirectPath)
|
||||
}
|
51
internal/server/routes.go
Normal file
51
internal/server/routes.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func SetupRoutes(s *Server) {
|
||||
// Startpage
|
||||
s.server.GET("/", s.GetIndex)
|
||||
|
||||
// Auth routes
|
||||
auth := s.server.Group("/auth")
|
||||
auth.GET("/login", s.GetLogin)
|
||||
auth.POST("/login", s.PostLogin)
|
||||
auth.GET("/logout", s.GetLogout)
|
||||
|
||||
// Admin routes
|
||||
admin := s.server.Group("/admin")
|
||||
admin.Use(s.RequireAuthentication(s.config.AdminLdapGroup))
|
||||
admin.GET("/", s.GetAdminIndex)
|
||||
|
||||
// User routes
|
||||
user := s.server.Group("/user")
|
||||
user.Use(s.RequireAuthentication("")) // empty scope = all logged in users
|
||||
}
|
||||
|
||||
func (s *Server) RequireAuthentication(scope string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
session := s.getSessionData(c)
|
||||
|
||||
if !session.LoggedIn {
|
||||
// Abort the request with the appropriate error code
|
||||
c.Abort()
|
||||
c.Redirect(http.StatusSeeOther, s.config.AuthRoutePrefix+"/login?err=loginreq")
|
||||
return
|
||||
}
|
||||
|
||||
if scope != "" && !s.ldapUsers.IsInGroup(session.UserName, s.config.AdminLdapGroup) && // admins always have access
|
||||
!s.ldapUsers.IsInGroup(session.UserName, scope) {
|
||||
// Abort the request with the appropriate error code
|
||||
c.Abort()
|
||||
s.HandleError(c, http.StatusUnauthorized, "unauthorized", "not enough permissions")
|
||||
return
|
||||
}
|
||||
|
||||
// Continue down the chain to handler etc
|
||||
c.Next()
|
||||
}
|
||||
}
|
322
internal/server/usermanager.go
Normal file
322
internal/server/usermanager.go
Normal file
@@ -0,0 +1,322 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/common"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/ldap"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Peer wgtypes.Peer `gorm:"-"`
|
||||
User *ldap.UserCacheHolderEntry `gorm:"-"` // optional, it is still possible to have users without ldap
|
||||
|
||||
UID string // uid for html identification
|
||||
IsOnline bool `gorm:"-"`
|
||||
Identifier string // Identifier AND Email make a WireGuard peer unique
|
||||
Email string `gorm:"index"`
|
||||
|
||||
IgnorePersistentKeepalive bool
|
||||
PresharedKey string
|
||||
AllowedIPsStr string
|
||||
IPsStr string
|
||||
AllowedIPs []string `gorm:"-"` // IPs that are used in the client config file
|
||||
IPs []string `gorm:"-"` // The IPs of the client
|
||||
PrivateKey string
|
||||
PublicKey string `gorm:"primaryKey"`
|
||||
|
||||
DeactivatedAt *time.Time
|
||||
CreatedBy string
|
||||
UpdatedBy string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func (u *User) GetPeerConfig() wgtypes.PeerConfig {
|
||||
publicKey, _ := wgtypes.ParseKey(u.PublicKey)
|
||||
var presharedKey *wgtypes.Key
|
||||
if u.PresharedKey != "" {
|
||||
presharedKeyTmp, _ := wgtypes.ParseKey(u.PresharedKey)
|
||||
presharedKey = &presharedKeyTmp
|
||||
}
|
||||
|
||||
cfg := wgtypes.PeerConfig{
|
||||
PublicKey: publicKey,
|
||||
Remove: false,
|
||||
UpdateOnly: false,
|
||||
PresharedKey: presharedKey,
|
||||
Endpoint: nil,
|
||||
PersistentKeepaliveInterval: nil,
|
||||
ReplaceAllowedIPs: true,
|
||||
AllowedIPs: make([]net.IPNet, len(u.IPs)),
|
||||
}
|
||||
for i, ip := range u.IPs {
|
||||
_, ipNet, err := net.ParseCIDR(ip)
|
||||
if err == nil {
|
||||
cfg.AllowedIPs[i] = *ipNet
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
type Device struct {
|
||||
DeviceName string `gorm:"primaryKey"`
|
||||
PrivateKey string
|
||||
PublicKey string
|
||||
PersistentKeepalive int
|
||||
ListenPort int
|
||||
Mtu int
|
||||
Endpoint string
|
||||
AllowedIPsStr string
|
||||
IPsStr string
|
||||
AllowedIPs []string `gorm:"-"` // IPs that are used in the client config file
|
||||
IPs []string `gorm:"-"` // The IPs of the client
|
||||
DNSStr string
|
||||
DNS []string `gorm:"-"` // The DNS servers of the client
|
||||
PreUp string
|
||||
PostUp string
|
||||
PreDown string
|
||||
PostDown string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func (d *Device) IsValid() bool {
|
||||
if len(d.IPs) == 0 {
|
||||
return false
|
||||
}
|
||||
if d.Endpoint == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
type UserManager struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUserManager() *UserManager {
|
||||
um := &UserManager{}
|
||||
var err error
|
||||
um.db, err = gorm.Open(sqlite.Open("wg_portal.db"), &gorm.Config{})
|
||||
if err != nil {
|
||||
log.Errorf("failed to open sqlite database: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = um.db.AutoMigrate(&User{}, &Device{})
|
||||
if err != nil {
|
||||
log.Errorf("failed to migrate sqlite database: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return um
|
||||
}
|
||||
|
||||
func (u *UserManager) InitWithPeers(peers []wgtypes.Peer, err error) {
|
||||
if err != nil {
|
||||
log.Errorf("failed to init user-manager from peers: %v", err)
|
||||
return
|
||||
}
|
||||
for _, peer := range peers {
|
||||
u.GetOrCreateUserForPeer(peer)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UserManager) InitWithDevice(dev *wgtypes.Device, err error) {
|
||||
if err != nil {
|
||||
log.Errorf("failed to init user-manager from device: %v", err)
|
||||
return
|
||||
}
|
||||
u.GetOrCreateDevice(*dev)
|
||||
}
|
||||
|
||||
func (u *UserManager) GetAllUsers() []User {
|
||||
users := make([]User, 0)
|
||||
u.db.Find(&users)
|
||||
|
||||
for i := range users {
|
||||
users[i].AllowedIPs = strings.Split(users[i].AllowedIPsStr, ", ")
|
||||
users[i].IPs = strings.Split(users[i].IPsStr, ", ")
|
||||
}
|
||||
|
||||
return users
|
||||
}
|
||||
|
||||
func (u *UserManager) GetAllDevices() []Device {
|
||||
devices := make([]Device, 0)
|
||||
u.db.Find(&devices)
|
||||
|
||||
for i := range devices {
|
||||
devices[i].AllowedIPs = strings.Split(devices[i].AllowedIPsStr, ", ")
|
||||
devices[i].IPs = strings.Split(devices[i].IPsStr, ", ")
|
||||
devices[i].DNS = strings.Split(devices[i].DNSStr, ", ")
|
||||
}
|
||||
|
||||
return devices
|
||||
}
|
||||
|
||||
func (u *UserManager) GetOrCreateUserForPeer(peer wgtypes.Peer) User {
|
||||
user := User{}
|
||||
u.db.Where("public_key = ?", peer.PublicKey.String()).FirstOrInit(&user)
|
||||
|
||||
if user.PublicKey == "" { // user not found, create
|
||||
user.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey.String())))
|
||||
user.PublicKey = peer.PublicKey.String()
|
||||
user.PrivateKey = "" // UNKNOWN
|
||||
if peer.PresharedKey != (wgtypes.Key{}) {
|
||||
user.PresharedKey = peer.PresharedKey.String()
|
||||
}
|
||||
user.Email = "autodetected@example.com"
|
||||
user.Identifier = "Autodetected (" + user.PublicKey[0:8] + ")"
|
||||
user.UpdatedAt = time.Now()
|
||||
user.CreatedAt = time.Now()
|
||||
user.AllowedIPs = make([]string, 0) // UNKNOWN
|
||||
user.IPs = make([]string, len(peer.AllowedIPs))
|
||||
for i, ip := range peer.AllowedIPs {
|
||||
user.IPs[i] = ip.String()
|
||||
}
|
||||
user.AllowedIPsStr = strings.Join(user.AllowedIPs, ", ")
|
||||
user.IPsStr = strings.Join(user.IPs, ", ")
|
||||
|
||||
res := u.db.Create(&user)
|
||||
if res.Error != nil {
|
||||
log.Errorf("failed to create autodetected peer: %v", res.Error)
|
||||
}
|
||||
}
|
||||
|
||||
user.IPs = strings.Split(user.IPsStr, ", ")
|
||||
user.AllowedIPs = strings.Split(user.AllowedIPsStr, ", ")
|
||||
|
||||
return user
|
||||
}
|
||||
|
||||
func (u *UserManager) CreateUser(user User) error {
|
||||
user.UID = fmt.Sprintf("u%x", md5.Sum([]byte(user.PublicKey)))
|
||||
user.UpdatedAt = time.Now()
|
||||
user.CreatedAt = time.Now()
|
||||
user.AllowedIPsStr = strings.Join(user.AllowedIPs, ", ")
|
||||
user.IPsStr = strings.Join(user.IPs, ", ")
|
||||
|
||||
res := u.db.Create(&user)
|
||||
if res.Error != nil {
|
||||
log.Errorf("failed to create user: %v", res.Error)
|
||||
return res.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UserManager) UpdateUser(user User) error {
|
||||
user.UpdatedAt = time.Now()
|
||||
user.AllowedIPsStr = strings.Join(user.AllowedIPs, ", ")
|
||||
user.IPsStr = strings.Join(user.IPs, ", ")
|
||||
|
||||
res := u.db.Save(&user)
|
||||
if res.Error != nil {
|
||||
log.Errorf("failed to update user: %v", res.Error)
|
||||
return res.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UserManager) GetAllReservedIps() ([]string, error) {
|
||||
reservedIps := make([]string, 0)
|
||||
users := u.GetAllUsers()
|
||||
for _, user := range users {
|
||||
for _, cidr := range user.IPs {
|
||||
ip, _, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"err": err,
|
||||
"cidr": cidr,
|
||||
}).Error("failed to ip from cidr")
|
||||
} else {
|
||||
reservedIps = append(reservedIps, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
devices := u.GetAllDevices()
|
||||
for _, device := range devices {
|
||||
for _, cidr := range device.IPs {
|
||||
ip, _, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"err": err,
|
||||
"cidr": cidr,
|
||||
}).Error("failed to ip from cidr")
|
||||
} else {
|
||||
reservedIps = append(reservedIps, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return reservedIps, nil
|
||||
}
|
||||
|
||||
// GetAvailableIp search for an available ip in cidr against a list of reserved ips
|
||||
func (u *UserManager) GetAvailableIp(cidr string, reserved []string) (string, error) {
|
||||
ip, ipnet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// this two addresses are not usable
|
||||
broadcastAddr := common.BroadcastAddr(ipnet).String()
|
||||
networkAddr := ipnet.IP.String()
|
||||
|
||||
for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); common.IncreaseIP(ip) {
|
||||
ok := true
|
||||
address := ip.String()
|
||||
for _, r := range reserved {
|
||||
if address == r {
|
||||
ok = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if ok && address != networkAddr && address != broadcastAddr {
|
||||
return address, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("no more available address from cidr")
|
||||
}
|
||||
|
||||
func (u *UserManager) GetOrCreateDevice(dev wgtypes.Device) Device {
|
||||
device := Device{}
|
||||
u.db.Where("device_name = ?", dev.Name).FirstOrInit(&device)
|
||||
|
||||
if device.PublicKey == "" { // device not found, create
|
||||
device.PublicKey = dev.PublicKey.String()
|
||||
device.PrivateKey = dev.PrivateKey.String()
|
||||
device.DeviceName = dev.Name
|
||||
device.ListenPort = dev.ListenPort
|
||||
device.Mtu = 0
|
||||
device.PersistentKeepalive = 16 // Default
|
||||
|
||||
res := u.db.Create(&device)
|
||||
if res.Error != nil {
|
||||
log.Errorf("failed to create autodetected device: %v", res.Error)
|
||||
}
|
||||
}
|
||||
|
||||
device.IPs = strings.Split(device.IPsStr, ", ")
|
||||
device.AllowedIPs = strings.Split(device.AllowedIPsStr, ", ")
|
||||
device.DNS = strings.Split(device.DNSStr, ", ")
|
||||
|
||||
return device
|
||||
}
|
5
internal/wireguard/config.go
Normal file
5
internal/wireguard/config.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package wireguard
|
||||
|
||||
type Config struct {
|
||||
DeviceName string `yaml:"device" envconfig:"WG_DEVICE"`
|
||||
}
|
103
internal/wireguard/manager.go
Normal file
103
internal/wireguard/manager.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
Cfg *Config
|
||||
wg *wgctrl.Client
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
func (m *Manager) Init() error {
|
||||
var err error
|
||||
m.wg, err = wgctrl.New()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create WireGuard client: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetDeviceInfo() (*wgtypes.Device, error) {
|
||||
dev, err := m.wg.Device(m.Cfg.DeviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not get WireGuard device: %w", err)
|
||||
}
|
||||
|
||||
return dev, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetPeerList() ([]wgtypes.Peer, error) {
|
||||
m.mux.RLock()
|
||||
defer m.mux.RUnlock()
|
||||
|
||||
dev, err := m.wg.Device(m.Cfg.DeviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not get WireGuard device: %w", err)
|
||||
}
|
||||
|
||||
return dev.Peers, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetPeer(pubKey string) (*wgtypes.Peer, error) {
|
||||
m.mux.RLock()
|
||||
defer m.mux.RUnlock()
|
||||
|
||||
publicKey, err := wgtypes.ParseKey(pubKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid public key: %w", err)
|
||||
}
|
||||
|
||||
peers, err := m.GetPeerList()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not get WireGuard peers: %w", err)
|
||||
}
|
||||
|
||||
for _, peer := range peers {
|
||||
if peer.PublicKey == publicKey {
|
||||
return &peer, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("could not find WireGuard peer: %s", pubKey)
|
||||
}
|
||||
|
||||
func (m *Manager) AddPeer(cfg wgtypes.PeerConfig) error {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
err := m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}})
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not configure WireGuard device: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) RemovePeer(pubKey string) error {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
publicKey, err := wgtypes.ParseKey(pubKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid public key: %w", err)
|
||||
}
|
||||
|
||||
peer := wgtypes.PeerConfig{
|
||||
PublicKey: publicKey,
|
||||
Remove: true,
|
||||
}
|
||||
|
||||
err = m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{peer}})
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not configure WireGuard device: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
Reference in New Issue
Block a user