Compare commits

..

9 Commits

Author SHA1 Message Date
Christoph Haas
5017fb5759 update readme, fix default env file 2021-03-22 23:05:20 +01:00
Christoph Haas
29cd73aa46 fix TLS for email sending 2021-03-22 22:53:59 +01:00
Christoph Haas
6ece6e5be9 make ldap cert check configurable, fix CodeQL warnings 2021-03-22 22:52:08 +01:00
Christoph Haas
588f8c7c70 add csrf 2021-03-22 22:51:37 +01:00
Christoph Haas
68507c3bcd fix redirect after sending the peer email 2021-03-22 13:45:35 +01:00
Christoph Haas
1e9f845457 fix user_edit template 2021-03-22 13:42:28 +01:00
Christoph Haas
f95c692aed migrate peer database 2021-03-22 13:00:02 +01:00
Christoph Haas
f4edc55851 fix mail template (#3) and rename some variables, also change default ordering (latest handshake first) 2021-03-22 12:39:50 +01:00
Christoph Haas
6ab00ef567 WIP: support for multiple WireGuard devices (#2) 2021-03-21 12:36:11 +01:00
36 changed files with 646 additions and 565 deletions

View File

@@ -29,6 +29,7 @@ It also supports LDAP (Active Directory or OpenLDAP) as authentication provider.
* Responsive template * Responsive template
* One single binary * One single binary
* Can be used with existing WireGuard setups * Can be used with existing WireGuard setups
* Support for multiple WireGuard interfaces
![Screenshot](screenshot.png) ![Screenshot](screenshot.png)
@@ -54,14 +55,21 @@ services:
ports: ports:
- '8123:8123' - '8123:8123'
environment: environment:
# WireGuard Settings
- WG_DEVICES=wg0
- WG_DEFAULT_DEVICE=wg0
- WG_CONFIG_PATH=/etc/wireguard
# Core Settings
- EXTERNAL_URL=https://vpn.company.com - EXTERNAL_URL=https://vpn.company.com
- WEBSITE_TITLE=WireGuard VPN - WEBSITE_TITLE=WireGuard VPN
- COMPANY_NAME=Your Company Name - COMPANY_NAME=Your Company Name
- MAIL_FROM=WireGuard VPN <noreply+wireguard@company.com>
- ADMIN_USER=admin@domain.com - ADMIN_USER=admin@domain.com
- ADMIN_PASS=supersecret - ADMIN_PASS=supersecret
# Mail Settings
- MAIL_FROM=WireGuard VPN <noreply+wireguard@company.com>
- EMAIL_HOST=10.10.10.10 - EMAIL_HOST=10.10.10.10
- EMAIL_PORT=25 - EMAIL_PORT=25
# LDAP Settings
- LDAP_ENABLED=true - LDAP_ENABLED=true
- LDAP_URL=ldap://srv-ad01.company.local:389 - LDAP_URL=ldap://srv-ad01.company.local:389
- LDAP_BASEDN=DC=COMPANY,DC=LOCAL - LDAP_BASEDN=DC=COMPANY,DC=LOCAL
@@ -71,7 +79,7 @@ services:
``` ```
Please note that mapping ```/etc/wireguard``` to ```/etc/wireguard``` inside the docker, will erase your host's current configuration. Please note that mapping ```/etc/wireguard``` to ```/etc/wireguard``` inside the docker, will erase your host's current configuration.
If needed, please make sure to backup your files from ```/etc/wireguard```. If needed, please make sure to backup your files from ```/etc/wireguard```.
For a full list of configuration options take a look at the source file [internal/common/configuration.go](internal/common/configuration.go#L57). For a full list of configuration options take a look at the source file [internal/server/configuration.go](internal/server/configuration.go#L56).
### Standalone ### Standalone
For a standalone application, use the Makefile provided in the repository to build the application. For a standalone application, use the Makefile provided in the repository to build the application.
@@ -90,6 +98,7 @@ A detailed description for using this software with a raspberry pi can be found
* Generation or application of any `iptables` or `nftables` rules * Generation or application of any `iptables` or `nftables` rules
* Setting up or changing IP-addresses of the WireGuard interface on operating systems other than linux * Setting up or changing IP-addresses of the WireGuard interface on operating systems other than linux
* Importing private keys of an existing WireGuard setup
## Application stack ## Application stack

View File

@@ -25,6 +25,11 @@
} }
}); });
}); });
$(function() {
$('select.device-selector').change(function() {
this.form.submit();
});
});
})(jQuery); // End of use strict })(jQuery); // End of use strict

View File

@@ -20,6 +20,7 @@
<h2>Enter valid LDAP user email addresses to quickly create new accounts.</h2> <h2>Enter valid LDAP user email addresses to quickly create new accounts.</h2>
{{template "prt_flashes.html" .}} {{template "prt_flashes.html" .}}
<form method="post" enctype="multipart/form-data"> <form method="post" enctype="multipart/form-data">
<input type="hidden" name="_csrf" value="{{.Csrf}}">
<div class="form-row"> <div class="form-row">
<div class="form-group required col-md-12"> <div class="form-group required col-md-12">
<label for="inputEmail">Email Addresses</label> <label for="inputEmail">Email Addresses</label>

View File

@@ -22,6 +22,7 @@
{{template "prt_flashes.html" .}} {{template "prt_flashes.html" .}}
<form method="post" enctype="multipart/form-data"> <form method="post" enctype="multipart/form-data">
<input type="hidden" name="_csrf" value="{{.Csrf}}">
<input type="hidden" name="uid" value="{{.Peer.UID}}"> <input type="hidden" name="uid" value="{{.Peer.UID}}">
{{if .EditableKeys}} {{if .EditableKeys}}
<div class="form-row"> <div class="form-row">

View File

@@ -17,6 +17,7 @@
{{template "prt_flashes.html" .}} {{template "prt_flashes.html" .}}
<form method="post" enctype="multipart/form-data"> <form method="post" enctype="multipart/form-data">
<input type="hidden" name="_csrf" value="{{.Csrf}}">
<input type="hidden" name="device" value="{{.Device.DeviceName}}"> <input type="hidden" name="device" value="{{.Device.DeviceName}}">
<h3>Server's interface configuration</h3> <h3>Server's interface configuration</h3>
{{if .EditableKeys}} {{if .EditableKeys}}

View File

@@ -22,6 +22,7 @@
{{template "prt_flashes.html" .}} {{template "prt_flashes.html" .}}
<form method="post" enctype="multipart/form-data"> <form method="post" enctype="multipart/form-data">
<input type="hidden" name="_csrf" value="{{.Csrf}}">
{{if eq .User.CreatedAt .Epoch}} {{if eq .User.CreatedAt .Epoch}}
<div class="form-row"> <div class="form-row">
<div class="form-group required col-md-12"> <div class="form-group required col-md-12">

View File

@@ -106,6 +106,7 @@
</thead> </thead>
<tbody> <tbody>
{{range $i, $p :=.Peers}} {{range $i, $p :=.Peers}}
{{$peerUser:=(userForEmail $.Users $p.Email)}}
<tr id="user-pos-{{$i}}" {{if $p.DeactivatedAt}}class="disabled-peer"{{end}}> <tr id="user-pos-{{$i}}" {{if $p.DeactivatedAt}}class="disabled-peer"{{end}}>
<th scope="row" class="list-image-cell"> <th scope="row" class="list-image-cell">
<a href="#{{$p.UID}}" data-toggle="collapse" class="collapse-indicator collapsed"></a> <a href="#{{$p.UID}}" data-toggle="collapse" class="collapse-indicator collapsed"></a>
@@ -142,14 +143,14 @@
<div class="tab-content" id="tabContent{{$p.UID}}"> <div class="tab-content" id="tabContent{{$p.UID}}">
<div id="t1{{$p.UID}}" class="tab-pane fade active show"> <div id="t1{{$p.UID}}" class="tab-pane fade active show">
<h4>User details</h4> <h4>User details</h4>
{{if not $p.User}} {{if not $peerUser}}
<p>No user information available...</p> <p>No user information available...</p>
{{else}} {{else}}
<ul> <ul>
<li>Firstname: {{$p.User.Firstname}}</li> <li>Firstname: {{$peerUser.Firstname}}</li>
<li>Lastname: {{$p.User.Lastname}}</li> <li>Lastname: {{$peerUser.Lastname}}</li>
<li>Phone: {{$p.User.Phone}}</li> <li>Phone: {{$peerUser.Phone}}</li>
<li>Mail: {{$p.User.Email}}</li> <li>Mail: {{$peerUser.Email}}</li>
</ul> </ul>
{{end}} {{end}}
<h4>Connection / Traffic</h4> <h4>Connection / Traffic</h4>

View File

@@ -92,7 +92,7 @@
<th class="column-top" width="210" style="font-size:0pt; line-height:0pt; padding:0; margin:0; font-weight:normal; vertical-align:top;"> <th class="column-top" width="210" style="font-size:0pt; line-height:0pt; padding:0; margin:0; font-weight:normal; vertical-align:top;">
<table width="100%" border="0" cellspacing="0" cellpadding="0"> <table width="100%" border="0" cellspacing="0" cellpadding="0">
<tr> <tr>
<td class="fluid-img" style="font-size:0pt; line-height:0pt; text-align:left;"><img src="cid:{{.QrcodePngName}}" width="210" height="210" border="0" alt="" /></td> <td class="fluid-img" style="font-size:0pt; line-height:0pt; text-align:left;"><img src="cid:{{$.QrcodePngName}}" width="210" height="210" border="0" alt="" /></td>
</tr> </tr>
</table> </table>
</th> </th>
@@ -100,14 +100,14 @@
<th class="column-top" width="280" style="font-size:0pt; line-height:0pt; padding:0; margin:0; font-weight:normal; vertical-align:top;"> <th class="column-top" width="280" style="font-size:0pt; line-height:0pt; padding:0; margin:0; font-weight:normal; vertical-align:top;">
<table width="100%" border="0" cellspacing="0" cellpadding="0"> <table width="100%" border="0" cellspacing="0" cellpadding="0">
<tr> <tr>
{{if .Client.LdapUser}} {{if $.User}}
<td class="h4 pb20" style="color:#000000; font-family:'Muli', Arial,sans-serif; font-size:20px; line-height:28px; text-align:left; padding-bottom:20px;">Hello {{.Client.LdapUser.Firstname}} {{.Client.LdapUser.Lastname}}</td> <td class="h4 pb20" style="color:#000000; font-family:'Muli', Arial,sans-serif; font-size:20px; line-height:28px; text-align:left; padding-bottom:20px;">Hello {{$.User.Firstname}} {{$.User.Lastname}}</td>
{{else}} {{else}}
<td class="h4 pb20" style="color:#000000; font-family:'Muli', Arial,sans-serif; font-size:20px; line-height:28px; text-align:left; padding-bottom:20px;">Hello</td> <td class="h4 pb20" style="color:#000000; font-family:'Muli', Arial,sans-serif; font-size:20px; line-height:28px; text-align:left; padding-bottom:20px;">Hello</td>
{{end}} {{end}}
</tr> </tr>
<tr> <tr>
<td class="text pb20" style="color:#000000; font-family:Arial,sans-serif; font-size:14px; line-height:26px; text-align:left; padding-bottom:20px;">You or your administrator probably requested this VPN configuration. Scan the Qrcode or open the attached configuration file ({{.Client.GetConfigFileName}}) in the WireGuard VPN client to establish a secure VPN connection.</td> <td class="text pb20" style="color:#000000; font-family:Arial,sans-serif; font-size:14px; line-height:26px; text-align:left; padding-bottom:20px;">You or your administrator probably requested this VPN configuration. Scan the Qrcode or open the attached configuration file ({{$.Peer.GetConfigFileName}}) in the WireGuard VPN client to establish a secure VPN connection.</td>
</tr> </tr>
</table> </table>
</th> </th>
@@ -170,7 +170,7 @@
<td class="text-footer1 pb10" style="color:#000000; font-family:'Muli', Arial,sans-serif; font-size:16px; line-height:20px; text-align:center; padding-bottom:10px;">This mail was generated using WireGuard Portal.</td> <td class="text-footer1 pb10" style="color:#000000; font-family:'Muli', Arial,sans-serif; font-size:16px; line-height:20px; text-align:center; padding-bottom:10px;">This mail was generated using WireGuard Portal.</td>
</tr> </tr>
<tr> <tr>
<td class="text-footer2" style="color:#000000; font-family:'Muli', Arial,sans-serif; font-size:12px; line-height:26px; text-align:center;"><a href="{{.PortalUrl}}" target="_blank" rel="noopener noreferrer" class="link" style="color:#000000; text-decoration:none;"><span class="link" style="color:#000000; text-decoration:none;">Visit WireGuard Portal</span></a></td> <td class="text-footer2" style="color:#000000; font-family:'Muli', Arial,sans-serif; font-size:12px; line-height:26px; text-align:center;"><a href="{{$.PortalUrl}}" target="_blank" rel="noopener noreferrer" class="link" style="color:#000000; text-decoration:none;"><span class="link" style="color:#000000; text-decoration:none;">Visit WireGuard Portal</span></a></td>
</tr> </tr>
</table> </table>
</td> </td>

View File

@@ -19,6 +19,7 @@
<div class="card-header">Please sign in</div> <div class="card-header">Please sign in</div>
<div class="card-body"> <div class="card-body">
<form class="form-signin" method="post"> <form class="form-signin" method="post">
<input type="hidden" name="_csrf" value="{{.Csrf}}">
<div class="form-group"> <div class="form-group">
<label for="inputUsername">Email</label> <label for="inputUsername">Email</label>
<input type="text" name="username" class="form-control" id="inputUsername" aria-describedby="usernameHelp" placeholder="Enter email"> <input type="text" name="username" class="form-control" id="inputUsername" aria-describedby="usernameHelp" placeholder="Enter email">

View File

@@ -22,6 +22,19 @@
{{end}} {{end}}
{{end}}{{end}} {{end}}{{end}}
</ul> </ul>
{{with eq $.Session.LoggedIn true}}{{with eq $.Session.IsAdmin true}}
{{with startsWith $.Route "/admin/"}}
<form class="form-inline my-2 my-lg-0" method="get">
<div class="form-group mr-sm-2">
<select name="device" id="inputDevice" class="form-control device-selector">
{{range $i, $d :=$.DeviceNames}}
<option value="{{$d}}" {{if eq $d $.Session.DeviceName}}selected{{end}}>{{$d}}</option>
{{end}}
</select>
</div>
</form>
{{end}}
{{end}}{{end}}
{{if eq $.Session.LoggedIn true}} {{if eq $.Session.LoggedIn true}}
<div class="nav-item dropdown"> <div class="nav-item dropdown">
<a href="#" class="navbar-text dropdown-toggle" data-toggle="dropdown">{{$.Session.Firstname}} {{$.Session.Lastname}} <span class="caret"></span></a> <a href="#" class="navbar-text dropdown-toggle" data-toggle="dropdown">{{$.Session.Firstname}} {{$.Session.Lastname}} <span class="caret"></span></a>
@@ -43,6 +56,6 @@
</nav> </nav>
{{if not $.Device.IsValid}} {{if not $.Device.IsValid}}
<div class="container"> <div class="container">
<div class="alert alert-danger">Warning: WireGuard Interface is not fully configured! Configurations may be incomplete and non functional!</div> <div class="alert alert-danger">Warning: WireGuard Interface {{$.Device.DeviceName}} is not fully configured! Configurations may be incomplete and non functional!</div>
</div> </div>
{{end}} {{end}}

View File

@@ -30,6 +30,7 @@
</thead> </thead>
<tbody> <tbody>
{{range $i, $p :=.Peers}} {{range $i, $p :=.Peers}}
{{$peerUser:=(userForEmail $.Users $p.Email)}}
<tr id="user-pos-{{$i}}" {{if $p.DeactivatedAt}}class="disabled-peer"{{end}}> <tr id="user-pos-{{$i}}" {{if $p.DeactivatedAt}}class="disabled-peer"{{end}}>
<th scope="row" class="list-image-cell"> <th scope="row" class="list-image-cell">
<a href="#{{$p.UID}}" data-toggle="collapse" class="collapse-indicator collapsed"></a> <a href="#{{$p.UID}}" data-toggle="collapse" class="collapse-indicator collapsed"></a>
@@ -58,14 +59,14 @@
<div class="tab-content" id="tabContent{{$p.UID}}"> <div class="tab-content" id="tabContent{{$p.UID}}">
<div id="t1{{$p.UID}}" class="tab-pane fade active show"> <div id="t1{{$p.UID}}" class="tab-pane fade active show">
<h4>User details</h4> <h4>User details</h4>
{{if not $p.User}} {{if not $peerUser}}
<p>No user information available...</p> <p>No user information available...</p>
{{else}} {{else}}
<ul> <ul>
<li>Firstname: {{$p.User.Firstname}}</li> <li>Firstname: {{$peerUser.Firstname}}</li>
<li>Lastname: {{$p.User.Lastname}}</li> <li>Lastname: {{$peerUser.Lastname}}</li>
<li>Phone: {{$p.User.Phone}}</li> <li>Phone: {{$peerUser.Phone}}</li>
<li>Mail: {{$p.User.Email}}</li> <li>Mail: {{$peerUser.Email}}</li>
</ul> </ul>
{{end}} {{end}}
<h4>Traffic</h4> <h4>Traffic</h4>

3
go.mod
View File

@@ -11,13 +11,12 @@ require (
github.com/jordan-wright/email v4.0.1-0.20200917010138-e1c00e156980+incompatible github.com/jordan-wright/email v4.0.1-0.20200917010138-e1c00e156980+incompatible
github.com/kelseyhightower/envconfig v1.4.0 github.com/kelseyhightower/envconfig v1.4.0
github.com/milosgajdos/tenus v0.0.3 github.com/milosgajdos/tenus v0.0.3
github.com/mitchellh/gox v1.0.1 // indirect
github.com/necrose99/gox v0.4.0 // indirect
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/sirupsen/logrus v1.7.0 github.com/sirupsen/logrus v1.7.0
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/tatsushid/go-fastping v0.0.0-20160109021039-d7bb493dee3e github.com/tatsushid/go-fastping v0.0.0-20160109021039-d7bb493dee3e
github.com/toorop/gin-logrus v0.0.0-20200831135515-d2ee50d38dae github.com/toorop/gin-logrus v0.0.0-20200831135515-d2ee50d38dae
github.com/utrack/gin-csrf v0.0.0-20190424104817-40fb8d2c8fca
golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20200609130330-bd2cb7843e1b golang.zx2c4.com/wireguard/wgctrl v0.0.0-20200609130330-bd2cb7843e1b
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c

View File

@@ -182,7 +182,7 @@ func (provider Provider) open() (*ldap.Conn, error) {
if provider.config.StartTLS { if provider.config.StartTLS {
// Reconnect with TLS // Reconnect with TLS
err = conn.StartTLS(&tls.Config{InsecureSkipVerify: true}) err = conn.StartTLS(&tls.Config{InsecureSkipVerify: !provider.config.CertValidation})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -7,6 +7,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/h44z/wg-portal/internal/common"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/h44z/wg-portal/internal/authentication" "github.com/h44z/wg-portal/internal/authentication"
"github.com/h44z/wg-portal/internal/users" "github.com/h44z/wg-portal/internal/users"
@@ -22,11 +24,11 @@ type Provider struct {
db *gorm.DB db *gorm.DB
} }
func New(cfg *users.Config) (*Provider, error) { func New(cfg *common.DatabaseConfig) (*Provider, error) {
p := &Provider{} p := &Provider{}
var err error var err error
p.db, err = users.GetDatabaseForConfig(cfg) p.db, err = common.GetDatabaseForConfig(cfg)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "failed to setup authentication database %s", cfg.Database) return nil, errors.Wrapf(err, "failed to setup authentication database %s", cfg.Database)
} }

76
internal/common/db.go Normal file
View File

@@ -0,0 +1,76 @@
package common
import (
"fmt"
"os"
"path/filepath"
"time"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
type SupportedDatabase string
const (
SupportedDatabaseMySQL SupportedDatabase = "mysql"
SupportedDatabaseSQLite SupportedDatabase = "sqlite"
)
type DatabaseConfig struct {
Typ SupportedDatabase `yaml:"typ" envconfig:"DATABASE_TYPE"` //mysql or sqlite
Host string `yaml:"host" envconfig:"DATABASE_HOST"`
Port int `yaml:"port" envconfig:"DATABASE_PORT"`
Database string `yaml:"database" envconfig:"DATABASE_NAME"` // On SQLite: the database file-path, otherwise the database name
User string `yaml:"user" envconfig:"DATABASE_USERNAME"`
Password string `yaml:"password" envconfig:"DATABASE_PASSWORD"`
}
func GetDatabaseForConfig(cfg *DatabaseConfig) (db *gorm.DB, err error) {
switch cfg.Typ {
case SupportedDatabaseSQLite:
if _, err = os.Stat(filepath.Dir(cfg.Database)); os.IsNotExist(err) {
if err = os.MkdirAll(filepath.Dir(cfg.Database), 0700); err != nil {
return
}
}
db, err = gorm.Open(sqlite.Open(cfg.Database), &gorm.Config{})
if err != nil {
return
}
case SupportedDatabaseMySQL:
connectionString := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database)
db, err = gorm.Open(mysql.Open(connectionString), &gorm.Config{})
if err != nil {
return
}
sqlDB, _ := db.DB()
sqlDB.SetConnMaxLifetime(time.Minute * 5)
sqlDB.SetMaxIdleConns(2)
sqlDB.SetMaxOpenConns(10)
err = sqlDB.Ping() // This DOES open a connection if necessary. This makes sure the database is accessible
if err != nil {
return nil, errors.Wrap(err, "failed to ping mysql authentication database")
}
}
// Enable Logger (logrus)
logCfg := logger.Config{
SlowThreshold: time.Second, // all slower than one second
Colorful: false,
LogLevel: logger.Silent, // default: log nothing
}
if logrus.StandardLogger().GetLevel() == logrus.TraceLevel {
logCfg.LogLevel = logger.Info
logCfg.SlowThreshold = 500 * time.Millisecond // all slower than half a second
}
db.Config.Logger = logger.New(logrus.StandardLogger(), logCfg)
return
}

View File

@@ -71,9 +71,9 @@ func SendEmailWithAttachments(cfg MailConfig, sender, replyTo, subject, body str
} }
} }
if cfg.CertValidation { if cfg.TLS {
return e.Send(hostname, auth) return e.SendWithStartTLS(hostname, auth, &tls.Config{InsecureSkipVerify: !cfg.CertValidation})
} else { } else {
return e.SendWithStartTLS(hostname, auth, &tls.Config{InsecureSkipVerify: true}) return e.Send(hostname, auth)
} }
} }

View File

@@ -60,6 +60,16 @@ func ListToString(lst []string) string {
return strings.Join(lst, ", ") return strings.Join(lst, ", ")
} }
// ListContains checks if a needle exists in the given list.
func ListContains(lst []string, needle string) bool {
for _, entry := range lst {
if entry == needle {
return true
}
}
return false
}
// https://yourbasic.org/golang/formatting-byte-size-to-human-readable-format/ // https://yourbasic.org/golang/formatting-byte-size-to-human-readable-format/
func ByteCountSI(b int64) string { func ByteCountSI(b int64) string {
const unit = 1000 const unit = 1000

View File

@@ -10,6 +10,7 @@ const (
type Config struct { type Config struct {
URL string `yaml:"url" envconfig:"LDAP_URL"` URL string `yaml:"url" envconfig:"LDAP_URL"`
StartTLS bool `yaml:"startTLS" envconfig:"LDAP_STARTTLS"` StartTLS bool `yaml:"startTLS" envconfig:"LDAP_STARTTLS"`
CertValidation bool `yaml:"certcheck" envconfig:"LDAP_CERT_VALIDATION"`
BaseDN string `yaml:"dn" envconfig:"LDAP_BASEDN"` BaseDN string `yaml:"dn" envconfig:"LDAP_BASEDN"`
BindUser string `yaml:"user" envconfig:"LDAP_USER"` BindUser string `yaml:"user" envconfig:"LDAP_USER"`
BindPass string `yaml:"pass" envconfig:"LDAP_PASSWORD"` BindPass string `yaml:"pass" envconfig:"LDAP_PASSWORD"`

View File

@@ -23,7 +23,7 @@ func Open(cfg *Config) (*ldap.Conn, error) {
if cfg.StartTLS { if cfg.StartTLS {
// Reconnect with TLS // Reconnect with TLS
err = conn.StartTLS(&tls.Config{InsecureSkipVerify: true}) err = conn.StartTLS(&tls.Config{InsecureSkipVerify: !cfg.CertValidation})
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to star TLS on connection") return nil, errors.Wrap(err, "failed to star TLS on connection")
} }
@@ -92,7 +92,7 @@ func IsActiveDirectoryUserDisabled(userAccountControl string) bool {
return false return false
} }
uacInt, err := strconv.Atoi(userAccountControl) uacInt, err := strconv.ParseInt(userAccountControl, 10, 32)
if err != nil { if err != nil {
return true return true
} }

View File

@@ -1,12 +1,12 @@
package common package server
import ( import (
"os" "os"
"reflect" "reflect"
"runtime" "runtime"
"github.com/h44z/wg-portal/internal/common"
"github.com/h44z/wg-portal/internal/ldap" "github.com/h44z/wg-portal/internal/ldap"
"github.com/h44z/wg-portal/internal/users"
"github.com/h44z/wg-portal/internal/wireguard" "github.com/h44z/wg-portal/internal/wireguard"
"github.com/kelseyhightower/envconfig" "github.com/kelseyhightower/envconfig"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -65,9 +65,10 @@ type Config struct {
EditableKeys bool `yaml:"editableKeys" envconfig:"EDITABLE_KEYS"` EditableKeys bool `yaml:"editableKeys" envconfig:"EDITABLE_KEYS"`
CreateDefaultPeer bool `yaml:"createDefaultPeer" envconfig:"CREATE_DEFAULT_PEER"` CreateDefaultPeer bool `yaml:"createDefaultPeer" envconfig:"CREATE_DEFAULT_PEER"`
LdapEnabled bool `yaml:"ldapEnabled" envconfig:"LDAP_ENABLED"` LdapEnabled bool `yaml:"ldapEnabled" envconfig:"LDAP_ENABLED"`
SessionSecret string `yaml:"sessionSecret" envconfig:"SESSION_SECRET"`
} `yaml:"core"` } `yaml:"core"`
Database users.Config `yaml:"database"` Database common.DatabaseConfig `yaml:"database"`
Email MailConfig `yaml:"email"` Email common.MailConfig `yaml:"email"`
LDAP ldap.Config `yaml:"ldap"` LDAP ldap.Config `yaml:"ldap"`
WG wireguard.Config `yaml:"wg"` WG wireguard.Config `yaml:"wg"`
} }
@@ -84,6 +85,7 @@ func NewConfig() *Config {
cfg.Core.AdminUser = "admin@wgportal.local" cfg.Core.AdminUser = "admin@wgportal.local"
cfg.Core.AdminPassword = "wgportal" cfg.Core.AdminPassword = "wgportal"
cfg.Core.LdapEnabled = false cfg.Core.LdapEnabled = false
cfg.Core.SessionSecret = "secret"
cfg.Database.Typ = "sqlite" cfg.Database.Typ = "sqlite"
cfg.Database.Database = "data/wg_portal.db" cfg.Database.Database = "data/wg_portal.db"
@@ -103,8 +105,9 @@ func NewConfig() *Config {
cfg.LDAP.DisabledAttribute = "userAccountControl" cfg.LDAP.DisabledAttribute = "userAccountControl"
cfg.LDAP.AdminLdapGroup = "CN=WireGuardAdmins,OU=_O_IT,DC=COMPANY,DC=LOCAL" cfg.LDAP.AdminLdapGroup = "CN=WireGuardAdmins,OU=_O_IT,DC=COMPANY,DC=LOCAL"
cfg.WG.DeviceName = "wg0" cfg.WG.DeviceNames = []string{"wg0"}
cfg.WG.WireGuardConfig = "/etc/wireguard/wg0.conf" cfg.WG.DefaultDeviceName = "wg0"
cfg.WG.ConfigDirectoryPath = "/etc/wireguard"
cfg.WG.ManageIPAddresses = true cfg.WG.ManageIPAddresses = true
cfg.Email.Host = "127.0.0.1" cfg.Email.Host = "127.0.0.1"
cfg.Email.Port = 25 cfg.Email.Port = 25

View File

@@ -4,6 +4,8 @@ import (
"net/http" "net/http"
"strings" "strings"
csrf "github.com/utrack/gin-csrf"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/h44z/wg-portal/internal/authentication" "github.com/h44z/wg-portal/internal/authentication"
"github.com/h44z/wg-portal/internal/users" "github.com/h44z/wg-portal/internal/users"
@@ -31,6 +33,7 @@ func (s *Server) GetLogin(c *gin.Context) {
"error": authError != "", "error": authError != "",
"message": errMsg, "message": errMsg,
"static": s.getStaticData(), "static": s.getStaticData(),
"Csrf": csrf.GetToken(c),
}) })
} }
@@ -98,7 +101,7 @@ func (s *Server) PostLogin(c *gin.Context) {
Firstname: userData.Firstname, Firstname: userData.Firstname,
Lastname: userData.Lastname, Lastname: userData.Lastname,
Phone: userData.Phone, Phone: userData.Phone,
}); err != nil { }, s.wg.Cfg.GetDefaultDeviceName()); err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "login error", "failed to update user data") s.GetHandleError(c, http.StatusInternalServerError, "login error", "failed to update user data")
return return
} }
@@ -121,9 +124,10 @@ func (s *Server) PostLogin(c *gin.Context) {
sessionData.Email = user.Email sessionData.Email = user.Email
sessionData.Firstname = user.Firstname sessionData.Firstname = user.Firstname
sessionData.Lastname = user.Lastname sessionData.Lastname = user.Lastname
sessionData.DeviceName = s.wg.Cfg.DeviceNames[0]
// Check if user already has a peer setup, if not create one // Check if user already has a peer setup, if not create one
if err := s.CreateUserDefaultPeer(user.Email); err != nil { if err := s.CreateUserDefaultPeer(user.Email, s.wg.Cfg.GetDefaultDeviceName()); err != nil {
// Not a fatal error, just log it... // Not a fatal error, just log it...
logrus.Errorf("failed to automatically create vpn peer for %s: %v", sessionData.Email, err) logrus.Errorf("failed to automatically create vpn peer for %s: %v", sessionData.Email, err)
} }

View File

@@ -4,12 +4,15 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"github.com/pkg/errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/h44z/wg-portal/internal/common"
"github.com/h44z/wg-portal/internal/users"
"github.com/pkg/errors"
) )
func (s *Server) GetHandleError(c *gin.Context, code int, message, details string) { func (s *Server) GetHandleError(c *gin.Context, code int, message, details string) {
currentSession := GetSessionData(c)
c.HTML(code, "error.html", gin.H{ c.HTML(code, "error.html", gin.H{
"Data": gin.H{ "Data": gin.H{
"Code": strconv.Itoa(code), "Code": strconv.Itoa(code),
@@ -19,22 +22,21 @@ func (s *Server) GetHandleError(c *gin.Context, code int, message, details strin
"Route": c.Request.URL.Path, "Route": c.Request.URL.Path,
"Session": GetSessionData(c), "Session": GetSessionData(c),
"Static": s.getStaticData(), "Static": s.getStaticData(),
"Device": s.peers.GetDevice(currentSession.DeviceName),
"DeviceNames": s.wg.Cfg.DeviceNames,
}) })
} }
func (s *Server) GetIndex(c *gin.Context) { func (s *Server) GetIndex(c *gin.Context) {
c.HTML(http.StatusOK, "index.html", struct { currentSession := GetSessionData(c)
Route string
Alerts []FlashData c.HTML(http.StatusOK, "index.html", gin.H{
Session SessionData "Route": c.Request.URL.Path,
Static StaticData "Alerts": GetFlashes(c),
Device Device "Session": currentSession,
}{ "Static": s.getStaticData(),
Route: c.Request.URL.Path, "Device": s.peers.GetDevice(currentSession.DeviceName),
Alerts: GetFlashes(c), "DeviceNames": s.wg.Cfg.DeviceNames,
Session: GetSessionData(c),
Static: s.getStaticData(),
Device: s.peers.GetDevice(),
}) })
} }
@@ -74,25 +76,35 @@ func (s *Server) GetAdminIndex(c *gin.Context) {
return return
} }
device := s.peers.GetDevice() deviceName := c.Query("device")
users := s.peers.GetFilteredAndSortedPeers(currentSession.SortedBy["peers"], currentSession.SortDirection["peers"], currentSession.Search["peers"]) if deviceName != "" {
if !common.ListContains(s.wg.Cfg.DeviceNames, deviceName) {
s.GetHandleError(c, http.StatusInternalServerError, "device selection error", "no such device")
return
}
currentSession.DeviceName = deviceName
c.HTML(http.StatusOK, "admin_index.html", struct { if err := UpdateSessionData(c, currentSession); err != nil {
Route string s.GetHandleError(c, http.StatusInternalServerError, "device selection error", "failed to save session")
Alerts []FlashData return
Session SessionData }
Static StaticData c.Redirect(http.StatusSeeOther, "/admin/")
Peers []Peer return
TotalPeers int }
Device Device
}{ device := s.peers.GetDevice(currentSession.DeviceName)
Route: c.Request.URL.Path, users := s.peers.GetFilteredAndSortedPeers(currentSession.DeviceName, currentSession.SortedBy["peers"], currentSession.SortDirection["peers"], currentSession.Search["peers"])
Alerts: GetFlashes(c),
Session: currentSession, c.HTML(http.StatusOK, "admin_index.html", gin.H{
Static: s.getStaticData(), "Route": c.Request.URL.Path,
Peers: users, "Alerts": GetFlashes(c),
TotalPeers: len(s.peers.GetAllPeers()), "Session": currentSession,
Device: device, "Static": s.getStaticData(),
"Peers": users,
"TotalPeers": len(s.peers.GetAllPeers(currentSession.DeviceName)),
"Users": s.users.GetUsers(),
"Device": device,
"DeviceNames": s.wg.Cfg.DeviceNames,
}) })
} }
@@ -120,25 +132,18 @@ func (s *Server) GetUserIndex(c *gin.Context) {
return return
} }
device := s.peers.GetDevice() peers := s.peers.GetSortedPeersForEmail(currentSession.SortedBy["userpeers"], currentSession.SortDirection["userpeers"], currentSession.Email)
users := s.peers.GetSortedPeersForEmail(currentSession.SortedBy["userpeers"], currentSession.SortDirection["userpeers"], currentSession.Email)
c.HTML(http.StatusOK, "user_index.html", struct { c.HTML(http.StatusOK, "user_index.html", gin.H{
Route string "Route": c.Request.URL.Path,
Alerts []FlashData "Alerts": GetFlashes(c),
Session SessionData "Session": currentSession,
Static StaticData "Static": s.getStaticData(),
Peers []Peer "Peers": peers,
TotalPeers int "TotalPeers": len(peers),
Device Device "Users": []users.User{*s.users.GetUser(currentSession.Email)},
}{ "Device": s.peers.GetDevice(currentSession.DeviceName),
Route: c.Request.URL.Path, "DeviceNames": s.wg.Cfg.DeviceNames,
Alerts: GetFlashes(c),
Session: currentSession,
Static: s.getStaticData(),
Peers: users,
TotalPeers: len(users),
Device: device,
}) })
} }
@@ -158,7 +163,7 @@ func (s *Server) setNewPeerFormInSession(c *gin.Context) (SessionData, error) {
// If session does not contain a peer form ignore update // If session does not contain a peer form ignore update
// If url contains a formerr parameter reset the form // If url contains a formerr parameter reset the form
if currentSession.FormData == nil || c.Query("formerr") == "" { if currentSession.FormData == nil || c.Query("formerr") == "" {
user, err := s.PrepareNewPeer() user, err := s.PrepareNewPeer(currentSession.DeviceName)
if err != nil { if err != nil {
return currentSession, errors.WithMessage(err, "failed to prepare new peer") return currentSession, errors.WithMessage(err, "failed to prepare new peer")
} }

View File

@@ -6,42 +6,36 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/h44z/wg-portal/internal/common" "github.com/h44z/wg-portal/internal/common"
"github.com/h44z/wg-portal/internal/wireguard"
csrf "github.com/utrack/gin-csrf"
) )
func (s *Server) GetAdminEditInterface(c *gin.Context) { func (s *Server) GetAdminEditInterface(c *gin.Context) {
device := s.peers.GetDevice() currentSession := GetSessionData(c)
users := s.peers.GetAllPeers() device := s.peers.GetDevice(currentSession.DeviceName)
currentSession, err := s.setFormInSession(c, device) currentSession, err := s.setFormInSession(c, device)
if err != nil { if err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "Session error", err.Error()) s.GetHandleError(c, http.StatusInternalServerError, "Session error", err.Error())
return return
} }
c.HTML(http.StatusOK, "admin_edit_interface.html", struct { c.HTML(http.StatusOK, "admin_edit_interface.html", gin.H{
Route string "Route": c.Request.URL.Path,
Alerts []FlashData "Alerts": GetFlashes(c),
Session SessionData "Session": currentSession,
Static StaticData "Static": s.getStaticData(),
Peers []Peer "Device": currentSession.FormData.(wireguard.Device),
Device Device "EditableKeys": s.config.Core.EditableKeys,
EditableKeys bool "DeviceNames": s.wg.Cfg.DeviceNames,
}{ "Csrf": csrf.GetToken(c),
Route: c.Request.URL.Path,
Alerts: GetFlashes(c),
Session: currentSession,
Static: s.getStaticData(),
Peers: users,
Device: currentSession.FormData.(Device),
EditableKeys: s.config.Core.EditableKeys,
}) })
} }
func (s *Server) PostAdminEditInterface(c *gin.Context) { func (s *Server) PostAdminEditInterface(c *gin.Context) {
currentSession := GetSessionData(c) currentSession := GetSessionData(c)
var formDevice Device var formDevice wireguard.Device
if currentSession.FormData != nil { if currentSession.FormData != nil {
formDevice = currentSession.FormData.(Device) formDevice = currentSession.FormData.(wireguard.Device)
} }
if err := c.ShouldBind(&formDevice); err != nil { if err := c.ShouldBind(&formDevice); err != nil {
_ = s.updateFormInSession(c, formDevice) _ = s.updateFormInSession(c, formDevice)
@@ -76,7 +70,7 @@ func (s *Server) PostAdminEditInterface(c *gin.Context) {
} }
// Update WireGuard config file // Update WireGuard config file
err = s.WriteWireGuardConfigFile() err = s.WriteWireGuardConfigFile(currentSession.DeviceName)
if err != nil { if err != nil {
_ = s.updateFormInSession(c, formDevice) _ = s.updateFormInSession(c, formDevice)
SetFlashMessage(c, "Failed to update WireGuard config-file: "+err.Error(), "danger") SetFlashMessage(c, "Failed to update WireGuard config-file: "+err.Error(), "danger")
@@ -86,12 +80,12 @@ func (s *Server) PostAdminEditInterface(c *gin.Context) {
// Update interface IP address // Update interface IP address
if s.config.WG.ManageIPAddresses { if s.config.WG.ManageIPAddresses {
if err := s.wg.SetIPAddress(formDevice.IPs); err != nil { if err := s.wg.SetIPAddress(currentSession.DeviceName, formDevice.IPs); err != nil {
_ = s.updateFormInSession(c, formDevice) _ = s.updateFormInSession(c, formDevice)
SetFlashMessage(c, "Failed to update ip address: "+err.Error(), "danger") SetFlashMessage(c, "Failed to update ip address: "+err.Error(), "danger")
c.Redirect(http.StatusSeeOther, "/admin/device/edit?formerr=update") c.Redirect(http.StatusSeeOther, "/admin/device/edit?formerr=update")
} }
if err := s.wg.SetMTU(formDevice.Mtu); err != nil { if err := s.wg.SetMTU(currentSession.DeviceName, formDevice.Mtu); err != nil {
_ = s.updateFormInSession(c, formDevice) _ = s.updateFormInSession(c, formDevice)
SetFlashMessage(c, "Failed to update MTU: "+err.Error(), "danger") SetFlashMessage(c, "Failed to update MTU: "+err.Error(), "danger")
c.Redirect(http.StatusSeeOther, "/admin/device/edit?formerr=update") c.Redirect(http.StatusSeeOther, "/admin/device/edit?formerr=update")
@@ -106,9 +100,10 @@ func (s *Server) PostAdminEditInterface(c *gin.Context) {
} }
func (s *Server) GetInterfaceConfig(c *gin.Context) { func (s *Server) GetInterfaceConfig(c *gin.Context) {
device := s.peers.GetDevice() currentSession := GetSessionData(c)
users := s.peers.GetActivePeers() device := s.peers.GetDevice(currentSession.DeviceName)
cfg, err := device.GetConfigFile(users) peers := s.peers.GetActivePeers(device.DeviceName)
cfg, err := device.GetConfigFile(peers)
if err != nil { if err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "ConfigFile error", err.Error()) s.GetHandleError(c, http.StatusInternalServerError, "ConfigFile error", err.Error())
return return
@@ -122,13 +117,14 @@ func (s *Server) GetInterfaceConfig(c *gin.Context) {
} }
func (s *Server) GetApplyGlobalConfig(c *gin.Context) { func (s *Server) GetApplyGlobalConfig(c *gin.Context) {
device := s.peers.GetDevice() currentSession := GetSessionData(c)
users := s.peers.GetAllPeers() device := s.peers.GetDevice(currentSession.DeviceName)
peers := s.peers.GetAllPeers(device.DeviceName)
for _, user := range users { for _, peer := range peers {
user.AllowedIPs = device.AllowedIPs peer.AllowedIPs = device.AllowedIPs
user.AllowedIPsStr = device.AllowedIPsStr peer.AllowedIPsStr = device.AllowedIPsStr
if err := s.peers.UpdatePeer(user); err != nil { if err := s.peers.UpdatePeer(peer); err != nil {
SetFlashMessage(c, err.Error(), "danger") SetFlashMessage(c, err.Error(), "danger")
c.Redirect(http.StatusSeeOther, "/admin/device/edit") c.Redirect(http.StatusSeeOther, "/admin/device/edit")
} }

View File

@@ -11,8 +11,10 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/h44z/wg-portal/internal/common" "github.com/h44z/wg-portal/internal/common"
"github.com/h44z/wg-portal/internal/users" "github.com/h44z/wg-portal/internal/users"
"github.com/h44z/wg-portal/internal/wireguard"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/tatsushid/go-fastping" "github.com/tatsushid/go-fastping"
csrf "github.com/utrack/gin-csrf"
) )
type LdapCreateForm struct { type LdapCreateForm struct {
@@ -21,7 +23,6 @@ type LdapCreateForm struct {
} }
func (s *Server) GetAdminEditPeer(c *gin.Context) { func (s *Server) GetAdminEditPeer(c *gin.Context) {
device := s.peers.GetDevice()
peer := s.peers.GetPeerByKey(c.Query("pkey")) peer := s.peers.GetPeerByKey(c.Query("pkey"))
currentSession, err := s.setFormInSession(c, peer) currentSession, err := s.setFormInSession(c, peer)
@@ -30,22 +31,16 @@ func (s *Server) GetAdminEditPeer(c *gin.Context) {
return return
} }
c.HTML(http.StatusOK, "admin_edit_client.html", struct { c.HTML(http.StatusOK, "admin_edit_client.html", gin.H{
Route string "Route": c.Request.URL.Path,
Alerts []FlashData "Alerts": GetFlashes(c),
Session SessionData "Session": currentSession,
Static StaticData "Static": s.getStaticData(),
Peer Peer "Peer": currentSession.FormData.(wireguard.Peer),
Device Device "EditableKeys": s.config.Core.EditableKeys,
EditableKeys bool "Device": s.peers.GetDevice(currentSession.DeviceName),
}{ "DeviceNames": s.wg.Cfg.DeviceNames,
Route: c.Request.URL.Path, "Csrf": csrf.GetToken(c),
Alerts: GetFlashes(c),
Session: currentSession,
Static: s.getStaticData(),
Peer: currentSession.FormData.(Peer),
Device: device,
EditableKeys: s.config.Core.EditableKeys,
}) })
} }
@@ -54,9 +49,9 @@ func (s *Server) PostAdminEditPeer(c *gin.Context) {
urlEncodedKey := url.QueryEscape(c.Query("pkey")) urlEncodedKey := url.QueryEscape(c.Query("pkey"))
currentSession := GetSessionData(c) currentSession := GetSessionData(c)
var formPeer Peer var formPeer wireguard.Peer
if currentSession.FormData != nil { if currentSession.FormData != nil {
formPeer = currentSession.FormData.(Peer) formPeer = currentSession.FormData.(wireguard.Peer)
} }
if err := c.ShouldBind(&formPeer); err != nil { if err := c.ShouldBind(&formPeer); err != nil {
_ = s.updateFormInSession(c, formPeer) _ = s.updateFormInSession(c, formPeer)
@@ -92,37 +87,29 @@ func (s *Server) PostAdminEditPeer(c *gin.Context) {
} }
func (s *Server) GetAdminCreatePeer(c *gin.Context) { func (s *Server) GetAdminCreatePeer(c *gin.Context) {
device := s.peers.GetDevice()
currentSession, err := s.setNewPeerFormInSession(c) currentSession, err := s.setNewPeerFormInSession(c)
if err != nil { if err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "Session error", err.Error()) s.GetHandleError(c, http.StatusInternalServerError, "Session error", err.Error())
return return
} }
c.HTML(http.StatusOK, "admin_edit_client.html", struct { c.HTML(http.StatusOK, "admin_edit_client.html", gin.H{
Route string "Route": c.Request.URL.Path,
Alerts []FlashData "Alerts": GetFlashes(c),
Session SessionData "Session": currentSession,
Static StaticData "Static": s.getStaticData(),
Peer Peer "Peer": currentSession.FormData.(wireguard.Peer),
Device Device "EditableKeys": s.config.Core.EditableKeys,
EditableKeys bool "Device": s.peers.GetDevice(currentSession.DeviceName),
}{ "DeviceNames": s.wg.Cfg.DeviceNames,
Route: c.Request.URL.Path, "Csrf": csrf.GetToken(c),
Alerts: GetFlashes(c),
Session: currentSession,
Static: s.getStaticData(),
Peer: currentSession.FormData.(Peer),
Device: device,
EditableKeys: s.config.Core.EditableKeys,
}) })
} }
func (s *Server) PostAdminCreatePeer(c *gin.Context) { func (s *Server) PostAdminCreatePeer(c *gin.Context) {
currentSession := GetSessionData(c) currentSession := GetSessionData(c)
var formPeer Peer var formPeer wireguard.Peer
if currentSession.FormData != nil { if currentSession.FormData != nil {
formPeer = currentSession.FormData.(Peer) formPeer = currentSession.FormData.(wireguard.Peer)
} }
if err := c.ShouldBind(&formPeer); err != nil { if err := c.ShouldBind(&formPeer); err != nil {
_ = s.updateFormInSession(c, formPeer) _ = s.updateFormInSession(c, formPeer)
@@ -143,7 +130,7 @@ func (s *Server) PostAdminCreatePeer(c *gin.Context) {
formPeer.DeactivatedAt = &now formPeer.DeactivatedAt = &now
} }
if err := s.CreatePeer(formPeer); err != nil { if err := s.CreatePeer(currentSession.DeviceName, formPeer); err != nil {
_ = s.updateFormInSession(c, formPeer) _ = s.updateFormInSession(c, formPeer)
SetFlashMessage(c, "failed to add user: "+err.Error(), "danger") SetFlashMessage(c, "failed to add user: "+err.Error(), "danger")
c.Redirect(http.StatusSeeOther, "/admin/peer/create?formerr=create") c.Redirect(http.StatusSeeOther, "/admin/peer/create?formerr=create")
@@ -161,22 +148,16 @@ func (s *Server) GetAdminCreateLdapPeers(c *gin.Context) {
return return
} }
c.HTML(http.StatusOK, "admin_create_clients.html", struct { c.HTML(http.StatusOK, "admin_create_clients.html", gin.H{
Route string "Route": c.Request.URL.Path,
Alerts []FlashData "Alerts": GetFlashes(c),
Session SessionData "Session": currentSession,
Static StaticData "Static": s.getStaticData(),
Users []users.User "Users": s.users.GetFilteredAndSortedUsers("lastname", "asc", ""),
FormData LdapCreateForm "FormData": currentSession.FormData.(LdapCreateForm),
Device Device "Device": s.peers.GetDevice(currentSession.DeviceName),
}{ "DeviceNames": s.wg.Cfg.DeviceNames,
Route: c.Request.URL.Path, "Csrf": csrf.GetToken(c),
Alerts: GetFlashes(c),
Session: currentSession,
Static: s.getStaticData(),
Users: s.users.GetFilteredAndSortedUsers("lastname", "asc", ""),
FormData: currentSession.FormData.(LdapCreateForm),
Device: s.peers.GetDevice(),
}) })
} }
@@ -207,7 +188,7 @@ func (s *Server) PostAdminCreateLdapPeers(c *gin.Context) {
logrus.Infof("creating %d ldap peers", len(emails)) logrus.Infof("creating %d ldap peers", len(emails))
for i := range emails { for i := range emails {
if err := s.CreatePeerByEmail(emails[i], formData.Identifier, false); err != nil { if err := s.CreatePeerByEmail(currentSession.DeviceName, emails[i], formData.Identifier, false); err != nil {
_ = s.updateFormInSession(c, formData) _ = s.updateFormInSession(c, formData)
SetFlashMessage(c, "failed to add user: "+err.Error(), "danger") SetFlashMessage(c, "failed to add user: "+err.Error(), "danger")
c.Redirect(http.StatusSeeOther, "/admin/peer/createldap?formerr=create") c.Redirect(http.StatusSeeOther, "/admin/peer/createldap?formerr=create")
@@ -220,24 +201,24 @@ func (s *Server) PostAdminCreateLdapPeers(c *gin.Context) {
} }
func (s *Server) GetAdminDeletePeer(c *gin.Context) { func (s *Server) GetAdminDeletePeer(c *gin.Context) {
currentUser := s.peers.GetPeerByKey(c.Query("pkey")) currentPeer := s.peers.GetPeerByKey(c.Query("pkey"))
if err := s.DeletePeer(currentUser); err != nil { if err := s.DeletePeer(currentPeer); err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "Deletion error", err.Error()) s.GetHandleError(c, http.StatusInternalServerError, "Deletion error", err.Error())
return return
} }
SetFlashMessage(c, "user deleted successfully", "success") SetFlashMessage(c, "peer deleted successfully", "success")
c.Redirect(http.StatusSeeOther, "/admin") c.Redirect(http.StatusSeeOther, "/admin")
} }
func (s *Server) GetPeerQRCode(c *gin.Context) { func (s *Server) GetPeerQRCode(c *gin.Context) {
user := s.peers.GetPeerByKey(c.Query("pkey")) peer := s.peers.GetPeerByKey(c.Query("pkey"))
currentSession := GetSessionData(c) currentSession := GetSessionData(c)
if !currentSession.IsAdmin && user.Email != currentSession.Email { if !currentSession.IsAdmin && peer.Email != currentSession.Email {
s.GetHandleError(c, http.StatusUnauthorized, "No permissions", "You don't have permissions to view this resource!") s.GetHandleError(c, http.StatusUnauthorized, "No permissions", "You don't have permissions to view this resource!")
return return
} }
png, err := user.GetQRCode() png, err := peer.GetQRCode()
if err != nil { if err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "QRCode error", err.Error()) s.GetHandleError(c, http.StatusInternalServerError, "QRCode error", err.Error())
return return
@@ -247,38 +228,40 @@ func (s *Server) GetPeerQRCode(c *gin.Context) {
} }
func (s *Server) GetPeerConfig(c *gin.Context) { func (s *Server) GetPeerConfig(c *gin.Context) {
user := s.peers.GetPeerByKey(c.Query("pkey")) peer := s.peers.GetPeerByKey(c.Query("pkey"))
currentSession := GetSessionData(c) currentSession := GetSessionData(c)
if !currentSession.IsAdmin && user.Email != currentSession.Email { if !currentSession.IsAdmin && peer.Email != currentSession.Email {
s.GetHandleError(c, http.StatusUnauthorized, "No permissions", "You don't have permissions to view this resource!") s.GetHandleError(c, http.StatusUnauthorized, "No permissions", "You don't have permissions to view this resource!")
return return
} }
cfg, err := user.GetConfigFile(s.peers.GetDevice()) cfg, err := peer.GetConfigFile(s.peers.GetDevice(currentSession.DeviceName))
if err != nil { if err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "ConfigFile error", err.Error()) s.GetHandleError(c, http.StatusInternalServerError, "ConfigFile error", err.Error())
return return
} }
c.Header("Content-Disposition", "attachment; filename="+user.GetConfigFileName()) c.Header("Content-Disposition", "attachment; filename="+peer.GetConfigFileName())
c.Data(http.StatusOK, "application/config", cfg) c.Data(http.StatusOK, "application/config", cfg)
return return
} }
func (s *Server) GetPeerConfigMail(c *gin.Context) { func (s *Server) GetPeerConfigMail(c *gin.Context) {
user := s.peers.GetPeerByKey(c.Query("pkey")) peer := s.peers.GetPeerByKey(c.Query("pkey"))
currentSession := GetSessionData(c) currentSession := GetSessionData(c)
if !currentSession.IsAdmin && user.Email != currentSession.Email { if !currentSession.IsAdmin && peer.Email != currentSession.Email {
s.GetHandleError(c, http.StatusUnauthorized, "No permissions", "You don't have permissions to view this resource!") s.GetHandleError(c, http.StatusUnauthorized, "No permissions", "You don't have permissions to view this resource!")
return return
} }
cfg, err := user.GetConfigFile(s.peers.GetDevice()) user := s.users.GetUser(peer.Email)
cfg, err := peer.GetConfigFile(s.peers.GetDevice(currentSession.DeviceName))
if err != nil { if err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "ConfigFile error", err.Error()) s.GetHandleError(c, http.StatusInternalServerError, "ConfigFile error", err.Error())
return return
} }
png, err := user.GetQRCode() png, err := peer.GetQRCode()
if err != nil { if err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "QRCode error", err.Error()) s.GetHandleError(c, http.StatusInternalServerError, "QRCode error", err.Error())
return return
@@ -286,11 +269,13 @@ func (s *Server) GetPeerConfigMail(c *gin.Context) {
// Apply mail template // Apply mail template
var tplBuff bytes.Buffer var tplBuff bytes.Buffer
if err := s.mailTpl.Execute(&tplBuff, struct { if err := s.mailTpl.Execute(&tplBuff, struct {
Client Peer Peer wireguard.Peer
User *users.User
QrcodePngName string QrcodePngName string
PortalUrl string PortalUrl string
}{ }{
Client: user, Peer: peer,
User: user,
QrcodePngName: "wireguard-config.png", QrcodePngName: "wireguard-config.png",
PortalUrl: s.config.Core.ExternalUrl, PortalUrl: s.config.Core.ExternalUrl,
}); err != nil { }); err != nil {
@@ -301,7 +286,7 @@ func (s *Server) GetPeerConfigMail(c *gin.Context) {
// Send mail // Send mail
attachments := []common.MailAttachment{ attachments := []common.MailAttachment{
{ {
Name: user.GetConfigFileName(), Name: peer.GetConfigFileName(),
ContentType: "application/config", ContentType: "application/config",
Data: bytes.NewReader(cfg), Data: bytes.NewReader(cfg),
}, },
@@ -314,24 +299,28 @@ func (s *Server) GetPeerConfigMail(c *gin.Context) {
if err := common.SendEmailWithAttachments(s.config.Email, s.config.Core.MailFrom, "", "WireGuard VPN Configuration", if err := common.SendEmailWithAttachments(s.config.Email, s.config.Core.MailFrom, "", "WireGuard VPN Configuration",
"Your mail client does not support HTML. Please find the configuration attached to this mail.", tplBuff.String(), "Your mail client does not support HTML. Please find the configuration attached to this mail.", tplBuff.String(),
[]string{user.Email}, attachments); err != nil { []string{peer.Email}, attachments); err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "Email error", err.Error()) s.GetHandleError(c, http.StatusInternalServerError, "Email error", err.Error())
return return
} }
SetFlashMessage(c, "mail sent successfully", "success") SetFlashMessage(c, "mail sent successfully", "success")
if strings.HasPrefix(c.Request.URL.Path, "/user") {
c.Redirect(http.StatusSeeOther, "/user/profile")
} else {
c.Redirect(http.StatusSeeOther, "/admin") c.Redirect(http.StatusSeeOther, "/admin")
} }
}
func (s *Server) GetPeerStatus(c *gin.Context) { func (s *Server) GetPeerStatus(c *gin.Context) {
user := s.peers.GetPeerByKey(c.Query("pkey")) peer := s.peers.GetPeerByKey(c.Query("pkey"))
currentSession := GetSessionData(c) currentSession := GetSessionData(c)
if !currentSession.IsAdmin && user.Email != currentSession.Email { if !currentSession.IsAdmin && peer.Email != currentSession.Email {
s.GetHandleError(c, http.StatusUnauthorized, "No permissions", "You don't have permissions to view this resource!") s.GetHandleError(c, http.StatusUnauthorized, "No permissions", "You don't have permissions to view this resource!")
return return
} }
if user.Peer == nil { // no peer means disabled if peer.Peer == nil { // no peer means disabled
c.JSON(http.StatusOK, false) c.JSON(http.StatusOK, false)
return return
} }
@@ -339,7 +328,7 @@ func (s *Server) GetPeerStatus(c *gin.Context) {
isOnline := false isOnline := false
ping := make(chan bool) ping := make(chan bool)
defer close(ping) defer close(ping)
for _, cidr := range user.IPs { for _, cidr := range peer.IPs {
ip, _, _ := net.ParseCIDR(cidr) ip, _, _ := net.ParseCIDR(cidr)
var ra *net.IPAddr var ra *net.IPAddr
if common.IsIPv6(ip.String()) { if common.IsIPv6(ip.String()) {

View File

@@ -7,6 +7,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/h44z/wg-portal/internal/users" "github.com/h44z/wg-portal/internal/users"
csrf "github.com/utrack/gin-csrf"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -49,22 +50,15 @@ func (s *Server) GetAdminUsersIndex(c *gin.Context) {
dbUsers := s.users.GetFilteredAndSortedUsersUnscoped(currentSession.SortedBy["users"], currentSession.SortDirection["users"], currentSession.Search["users"]) dbUsers := s.users.GetFilteredAndSortedUsersUnscoped(currentSession.SortedBy["users"], currentSession.SortDirection["users"], currentSession.Search["users"])
c.HTML(http.StatusOK, "admin_user_index.html", struct { c.HTML(http.StatusOK, "admin_user_index.html", gin.H{
Route string "Route": c.Request.URL.Path,
Alerts []FlashData "Alerts": GetFlashes(c),
Session SessionData "Session": currentSession,
Static StaticData "Static": s.getStaticData(),
Users []users.User "Users": dbUsers,
TotalUsers int "TotalUsers": len(s.users.GetUsers()),
Device Device "Device": s.peers.GetDevice(currentSession.DeviceName),
}{ "DeviceNames": s.wg.Cfg.DeviceNames,
Route: c.Request.URL.Path,
Alerts: GetFlashes(c),
Session: currentSession,
Static: s.getStaticData(),
Users: dbUsers,
TotalUsers: len(s.users.GetUsers()),
Device: s.peers.GetDevice(),
}) })
} }
@@ -77,21 +71,16 @@ func (s *Server) GetAdminUsersEdit(c *gin.Context) {
return return
} }
c.HTML(http.StatusOK, "admin_edit_user.html", struct { c.HTML(http.StatusOK, "admin_edit_user.html", gin.H{
Route string "Route": c.Request.URL.Path,
Alerts []FlashData "Alerts": GetFlashes(c),
Session SessionData "Session": currentSession,
Static StaticData "Static": s.getStaticData(),
User users.User "User": currentSession.FormData.(users.User),
Device Device "Device": s.peers.GetDevice(currentSession.DeviceName),
Epoch time.Time "DeviceNames": s.wg.Cfg.DeviceNames,
}{ "Epoch": time.Time{},
Route: c.Request.URL.Path, "Csrf": csrf.GetToken(c),
Alerts: GetFlashes(c),
Session: currentSession,
Static: s.getStaticData(),
User: currentSession.FormData.(users.User),
Device: s.peers.GetDevice(),
}) })
} }
@@ -160,21 +149,16 @@ func (s *Server) GetAdminUsersCreate(c *gin.Context) {
return return
} }
c.HTML(http.StatusOK, "admin_edit_user.html", struct { c.HTML(http.StatusOK, "admin_edit_user.html", gin.H{
Route string "Route": c.Request.URL.Path,
Alerts []FlashData "Alerts": GetFlashes(c),
Session SessionData "Session": currentSession,
Static StaticData "Static": s.getStaticData(),
User users.User "User": currentSession.FormData.(users.User),
Device Device "Device": s.peers.GetDevice(currentSession.DeviceName),
Epoch time.Time "DeviceNames": s.wg.Cfg.DeviceNames,
}{ "Epoch": time.Time{},
Route: c.Request.URL.Path, "Csrf": csrf.GetToken(c),
Alerts: GetFlashes(c),
Session: currentSession,
Static: s.getStaticData(),
User: currentSession.FormData.(users.User),
Device: s.peers.GetDevice(),
}) })
} }
@@ -218,7 +202,7 @@ func (s *Server) PostAdminUsersCreate(c *gin.Context) {
formUser.IsAdmin = c.PostForm("isadmin") == "true" formUser.IsAdmin = c.PostForm("isadmin") == "true"
formUser.Source = users.UserSourceDatabase formUser.Source = users.UserSourceDatabase
if err := s.CreateUser(formUser); err != nil { if err := s.CreateUser(formUser, currentSession.DeviceName); err != nil {
_ = s.updateFormInSession(c, formUser) _ = s.updateFormInSession(c, formUser)
SetFlashMessage(c, "failed to add user: "+err.Error(), "danger") SetFlashMessage(c, "failed to add user: "+err.Error(), "danger")
c.Redirect(http.StatusSeeOther, "/admin/users/create?formerr=create") c.Redirect(http.StatusSeeOther, "/admin/users/create?formerr=create")

View File

@@ -4,14 +4,14 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
wg_portal "github.com/h44z/wg-portal" wgportal "github.com/h44z/wg-portal"
) )
func SetupRoutes(s *Server) { func SetupRoutes(s *Server) {
// Startpage // Startpage
s.server.GET("/", s.GetIndex) s.server.GET("/", s.GetIndex)
s.server.GET("/favicon.ico", func(c *gin.Context) { s.server.GET("/favicon.ico", func(c *gin.Context) {
file, _ := wg_portal.Statics.ReadFile("assets/img/favicon.ico") file, _ := wgportal.Statics.ReadFile("assets/img/favicon.ico")
c.Data( c.Data(
http.StatusOK, http.StatusOK,
"image/x-icon", "image/x-icon",

View File

@@ -11,12 +11,13 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"time" "time"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/memstore" "github.com/gin-contrib/sessions/memstore"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
wg_portal "github.com/h44z/wg-portal" wgportal "github.com/h44z/wg-portal"
ldapprovider "github.com/h44z/wg-portal/internal/authentication/providers/ldap" ldapprovider "github.com/h44z/wg-portal/internal/authentication/providers/ldap"
passwordprovider "github.com/h44z/wg-portal/internal/authentication/providers/password" passwordprovider "github.com/h44z/wg-portal/internal/authentication/providers/password"
"github.com/h44z/wg-portal/internal/common" "github.com/h44z/wg-portal/internal/common"
@@ -25,6 +26,8 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
ginlogrus "github.com/toorop/gin-logrus" ginlogrus "github.com/toorop/gin-logrus"
csrf "github.com/utrack/gin-csrf"
"gorm.io/gorm"
) )
const SessionIdentifier = "wgPortalSession" const SessionIdentifier = "wgPortalSession"
@@ -32,8 +35,8 @@ const SessionIdentifier = "wgPortalSession"
func init() { func init() {
gob.Register(SessionData{}) gob.Register(SessionData{})
gob.Register(FlashData{}) gob.Register(FlashData{})
gob.Register(Peer{}) gob.Register(wireguard.Peer{})
gob.Register(Device{}) gob.Register(wireguard.Device{})
gob.Register(LdapCreateForm{}) gob.Register(LdapCreateForm{})
gob.Register(users.User{}) gob.Register(users.User{})
} }
@@ -44,6 +47,7 @@ type SessionData struct {
Firstname string Firstname string
Lastname string Lastname string
Email string Email string
DeviceName string
SortedBy map[string]string SortedBy map[string]string
SortDirection map[string]string SortDirection map[string]string
@@ -69,14 +73,15 @@ type StaticData struct {
type Server struct { type Server struct {
ctx context.Context ctx context.Context
config *common.Config config *Config
server *gin.Engine server *gin.Engine
mailTpl *template.Template mailTpl *template.Template
auth *AuthManager auth *AuthManager
db *gorm.DB
users *users.Manager users *users.Manager
wg *wireguard.Manager wg *wireguard.Manager
peers *PeerManager peers *wireguard.PeerManager
} }
func (s *Server) Setup(ctx context.Context) error { func (s *Server) Setup(ctx context.Context) error {
@@ -90,9 +95,15 @@ func (s *Server) Setup(ctx context.Context) error {
// Init rand // Init rand
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
s.config = common.NewConfig() s.config = NewConfig()
s.ctx = ctx s.ctx = ctx
// Setup database connection
s.db, err = common.GetDatabaseForConfig(&s.config.Database)
if err != nil {
return errors.WithMessage(err, "database setup failed")
}
// Setup http server // Setup http server
gin.SetMode(gin.DebugMode) gin.SetMode(gin.DebugMode)
gin.DefaultWriter = ioutil.Discard gin.DefaultWriter = ioutil.Discard
@@ -101,27 +112,43 @@ func (s *Server) Setup(ctx context.Context) error {
s.server.Use(ginlogrus.Logger(logrus.StandardLogger())) s.server.Use(ginlogrus.Logger(logrus.StandardLogger()))
} }
s.server.Use(gin.Recovery()) s.server.Use(gin.Recovery())
s.server.Use(sessions.Sessions("authsession", memstore.NewStore([]byte(s.config.Core.SessionSecret))))
s.server.Use(csrf.Middleware(csrf.Options{
Secret: s.config.Core.SessionSecret,
ErrorFunc: func(c *gin.Context) {
c.String(400, "CSRF token mismatch")
c.Abort()
},
}))
s.server.SetFuncMap(template.FuncMap{ s.server.SetFuncMap(template.FuncMap{
"formatBytes": common.ByteCountSI, "formatBytes": common.ByteCountSI,
"urlEncode": url.QueryEscape, "urlEncode": url.QueryEscape,
"startsWith": strings.HasPrefix,
"userForEmail": func(users []users.User, email string) *users.User {
for i := range users {
if users[i].Email == email {
return &users[i]
}
}
return nil
},
}) })
// Setup templates // Setup templates
templates := template.Must(template.New("").Funcs(s.server.FuncMap).ParseFS(wg_portal.Templates, "assets/tpl/*.html")) templates := template.Must(template.New("").Funcs(s.server.FuncMap).ParseFS(wgportal.Templates, "assets/tpl/*.html"))
s.server.SetHTMLTemplate(templates) s.server.SetHTMLTemplate(templates)
s.server.Use(sessions.Sessions("authsession", memstore.NewStore([]byte("secret")))) // TODO: change key?
// Serve static files // Serve static files
s.server.StaticFS("/css", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/css")))) s.server.StaticFS("/css", http.FS(fsMust(fs.Sub(wgportal.Statics, "assets/css"))))
s.server.StaticFS("/js", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/js")))) s.server.StaticFS("/js", http.FS(fsMust(fs.Sub(wgportal.Statics, "assets/js"))))
s.server.StaticFS("/img", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/img")))) s.server.StaticFS("/img", http.FS(fsMust(fs.Sub(wgportal.Statics, "assets/img"))))
s.server.StaticFS("/fonts", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/fonts")))) s.server.StaticFS("/fonts", http.FS(fsMust(fs.Sub(wgportal.Statics, "assets/fonts"))))
// Setup all routes // Setup all routes
SetupRoutes(s) SetupRoutes(s)
// Setup user database (also needed for database authentication) // Setup user database (also needed for database authentication)
s.users, err = users.NewManager(&s.config.Database) s.users, err = users.NewManager(s.db)
if err != nil { if err != nil {
return errors.WithMessage(err, "user-manager initialization failed") return errors.WithMessage(err, "user-manager initialization failed")
} }
@@ -153,18 +180,21 @@ func (s *Server) Setup(ctx context.Context) error {
} }
// Setup peer manager // Setup peer manager
if s.peers, err = NewPeerManager(s.config, s.wg, s.users); err != nil { if s.peers, err = wireguard.NewPeerManager(s.db, s.wg); err != nil {
return errors.WithMessage(err, "unable to setup peer manager") return errors.WithMessage(err, "unable to setup peer manager")
} }
if err = s.peers.InitFromCurrentInterface(); err != nil { if err = s.peers.InitFromPhysicalInterface(); err != nil {
return errors.WithMessage(err, "unable to initialize peer manager") return errors.WithMessagef(err, "unable to initialize peer manager")
}
for _, deviceName := range s.wg.Cfg.DeviceNames {
if err = s.RestoreWireGuardInterface(deviceName); err != nil {
return errors.WithMessagef(err, "unable to restore WireGuard state for %s", deviceName)
} }
if err = s.RestoreWireGuardInterface(); err != nil {
return errors.WithMessage(err, "unable to restore WireGuard state")
} }
// Setup mail template // Setup mail template
s.mailTpl, err = template.New("email.html").ParseFS(wg_portal.Templates, "assets/tpl/email.html") s.mailTpl, err = template.New("email.html").ParseFS(wgportal.Templates, "assets/tpl/email.html")
if err != nil { if err != nil {
return errors.Wrap(err, "unable to pare mail template") return errors.Wrap(err, "unable to pare mail template")
} }
@@ -174,6 +204,8 @@ func (s *Server) Setup(ctx context.Context) error {
} }
func (s *Server) Run() { func (s *Server) Run() {
logrus.Infof("starting web service on %s", s.config.Core.ListeningAddress)
// Start ldap sync // Start ldap sync
if s.config.Core.LdapEnabled { if s.config.Core.LdapEnabled {
go s.SyncLdapWithUserDatabase() go s.SyncLdapWithUserDatabase()
@@ -233,11 +265,12 @@ func GetSessionData(c *gin.Context) SessionData {
} else { } else {
sessionData = SessionData{ sessionData = SessionData{
Search: map[string]string{"peers": "", "userpeers": "", "users": ""}, Search: map[string]string{"peers": "", "userpeers": "", "users": ""},
SortedBy: map[string]string{"peers": "mail", "userpeers": "mail", "users": "email"}, SortedBy: map[string]string{"peers": "handshake", "userpeers": "id", "users": "email"},
SortDirection: map[string]string{"peers": "asc", "userpeers": "asc", "users": "asc"}, SortDirection: map[string]string{"peers": "desc", "userpeers": "asc", "users": "asc"},
Email: "", Email: "",
Firstname: "", Firstname: "",
Lastname: "", Lastname: "",
DeviceName: "",
IsAdmin: false, IsAdmin: false,
LoggedIn: false, LoggedIn: false,
} }

View File

@@ -4,39 +4,42 @@ import (
"crypto/md5" "crypto/md5"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"path"
"syscall" "syscall"
"time" "time"
"github.com/h44z/wg-portal/internal/common" "github.com/h44z/wg-portal/internal/common"
"github.com/h44z/wg-portal/internal/users" "github.com/h44z/wg-portal/internal/users"
"github.com/h44z/wg-portal/internal/wireguard"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"gorm.io/gorm" "gorm.io/gorm"
) )
func (s *Server) PrepareNewPeer() (Peer, error) { // PrepareNewPeer initiates a new peer for the given WireGuard device.
device := s.peers.GetDevice() func (s *Server) PrepareNewPeer(device string) (wireguard.Peer, error) {
dev := s.peers.GetDevice(device)
peer := Peer{} peer := wireguard.Peer{}
peer.IsNew = true peer.IsNew = true
peer.AllowedIPsStr = device.AllowedIPsStr peer.AllowedIPsStr = dev.AllowedIPsStr
peer.IPs = make([]string, len(device.IPs)) peer.IPs = make([]string, len(dev.IPs))
for i := range device.IPs { for i := range dev.IPs {
freeIP, err := s.peers.GetAvailableIp(device.IPs[i]) freeIP, err := s.peers.GetAvailableIp(device, dev.IPs[i])
if err != nil { if err != nil {
return Peer{}, errors.WithMessage(err, "failed to get available IP addresses") return wireguard.Peer{}, errors.WithMessage(err, "failed to get available IP addresses")
} }
peer.IPs[i] = freeIP peer.IPs[i] = freeIP
} }
peer.IPsStr = common.ListToString(peer.IPs) peer.IPsStr = common.ListToString(peer.IPs)
psk, err := wgtypes.GenerateKey() psk, err := wgtypes.GenerateKey()
if err != nil { if err != nil {
return Peer{}, errors.Wrap(err, "failed to generate key") return wireguard.Peer{}, errors.Wrap(err, "failed to generate key")
} }
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
return Peer{}, errors.Wrap(err, "failed to generate private key") return wireguard.Peer{}, errors.Wrap(err, "failed to generate private key")
} }
peer.PresharedKey = psk.String() peer.PresharedKey = psk.String()
peer.PrivateKey = key.String() peer.PrivateKey = key.String()
@@ -46,54 +49,39 @@ func (s *Server) PrepareNewPeer() (Peer, error) {
return peer, nil return peer, nil
} }
func (s *Server) CreatePeerByEmail(email, identifierSuffix string, disabled bool) error { // CreatePeerByEmail creates a new peer for the given email. If no user with the specified email was found, a new one
// will be created.
func (s *Server) CreatePeerByEmail(device, email, identifierSuffix string, disabled bool) error {
user, err := s.users.GetOrCreateUser(email) user, err := s.users.GetOrCreateUser(email)
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed to load/create related user %s", email) return errors.WithMessagef(err, "failed to load/create related user %s", email)
} }
device := s.peers.GetDevice() peer, err := s.PrepareNewPeer(device)
peer := Peer{}
peer.User = user
peer.AllowedIPsStr = device.AllowedIPsStr
peer.IPs = make([]string, len(device.IPs))
for i := range device.IPs {
freeIP, err := s.peers.GetAvailableIp(device.IPs[i])
if err != nil { if err != nil {
return errors.WithMessage(err, "failed to get available IP addresses") return errors.WithMessage(err, "failed to prepare new peer")
} }
peer.IPs[i] = freeIP
}
peer.IPsStr = common.ListToString(peer.IPs)
psk, err := wgtypes.GenerateKey()
if err != nil {
return errors.Wrap(err, "failed to generate key")
}
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return errors.Wrap(err, "failed to generate private key")
}
peer.PresharedKey = psk.String()
peer.PrivateKey = key.String()
peer.PublicKey = key.PublicKey().String()
peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey)))
peer.Email = email peer.Email = email
peer.Identifier = fmt.Sprintf("%s %s (%s)", user.Firstname, user.Lastname, identifierSuffix) peer.Identifier = fmt.Sprintf("%s %s (%s)", user.Firstname, user.Lastname, identifierSuffix)
now := time.Now() now := time.Now()
if disabled { if disabled {
peer.DeactivatedAt = &now peer.DeactivatedAt = &now
} }
return s.CreatePeer(peer) return s.CreatePeer(device, peer)
} }
func (s *Server) CreatePeer(peer Peer) error { // CreatePeer creates the new peer in the database. If the peer has no assigned ip addresses, a new one will be assigned
device := s.peers.GetDevice() // automatically. Also, if the private key is empty, a new key-pair will be generated.
peer.AllowedIPsStr = device.AllowedIPsStr // This function also configures the new peer on the physical WireGuard interface if the peer is not deactivated.
func (s *Server) CreatePeer(device string, peer wireguard.Peer) error {
dev := s.peers.GetDevice(device)
peer.AllowedIPsStr = dev.AllowedIPsStr
if peer.IPs == nil || len(peer.IPs) == 0 { if peer.IPs == nil || len(peer.IPs) == 0 {
peer.IPs = make([]string, len(device.IPs)) peer.IPs = make([]string, len(dev.IPs))
for i := range device.IPs { for i := range dev.IPs {
freeIP, err := s.peers.GetAvailableIp(device.IPs[i]) freeIP, err := s.peers.GetAvailableIp(device, dev.IPs[i])
if err != nil { if err != nil {
return errors.WithMessage(err, "failed to get available IP addresses") return errors.WithMessage(err, "failed to get available IP addresses")
} }
@@ -114,11 +102,12 @@ func (s *Server) CreatePeer(peer Peer) error {
peer.PrivateKey = key.String() peer.PrivateKey = key.String()
peer.PublicKey = key.PublicKey().String() peer.PublicKey = key.PublicKey().String()
} }
peer.DeviceName = dev.DeviceName
peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey))) peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey)))
// Create WireGuard interface // Create WireGuard interface
if peer.DeactivatedAt == nil { if peer.DeactivatedAt == nil {
if err := s.wg.AddPeer(peer.GetConfig()); err != nil { if err := s.wg.AddPeer(device, peer.GetConfig()); err != nil {
return errors.WithMessage(err, "failed to add WireGuard peer") return errors.WithMessage(err, "failed to add WireGuard peer")
} }
} }
@@ -128,21 +117,22 @@ func (s *Server) CreatePeer(peer Peer) error {
return errors.WithMessage(err, "failed to create peer") return errors.WithMessage(err, "failed to create peer")
} }
return s.WriteWireGuardConfigFile() return s.WriteWireGuardConfigFile(device)
} }
func (s *Server) UpdatePeer(peer Peer, updateTime time.Time) error { // UpdatePeer updates the physical WireGuard interface and the database.
func (s *Server) UpdatePeer(peer wireguard.Peer, updateTime time.Time) error {
currentPeer := s.peers.GetPeerByKey(peer.PublicKey) currentPeer := s.peers.GetPeerByKey(peer.PublicKey)
// Update WireGuard device // Update WireGuard device
var err error var err error
switch { switch {
case peer.DeactivatedAt == &updateTime: case peer.DeactivatedAt == &updateTime:
err = s.wg.RemovePeer(peer.PublicKey) err = s.wg.RemovePeer(peer.DeviceName, peer.PublicKey)
case peer.DeactivatedAt == nil && currentPeer.Peer != nil: case peer.DeactivatedAt == nil && currentPeer.Peer != nil:
err = s.wg.UpdatePeer(peer.GetConfig()) err = s.wg.UpdatePeer(peer.DeviceName, peer.GetConfig())
case peer.DeactivatedAt == nil && currentPeer.Peer == nil: case peer.DeactivatedAt == nil && currentPeer.Peer == nil:
err = s.wg.AddPeer(peer.GetConfig()) err = s.wg.AddPeer(peer.DeviceName, peer.GetConfig())
} }
if err != nil { if err != nil {
return errors.WithMessage(err, "failed to update WireGuard peer") return errors.WithMessage(err, "failed to update WireGuard peer")
@@ -153,12 +143,13 @@ func (s *Server) UpdatePeer(peer Peer, updateTime time.Time) error {
return errors.WithMessage(err, "failed to update peer") return errors.WithMessage(err, "failed to update peer")
} }
return s.WriteWireGuardConfigFile() return s.WriteWireGuardConfigFile(peer.DeviceName)
} }
func (s *Server) DeletePeer(peer Peer) error { // DeletePeer removes the peer from the physical WireGuard interface and the database.
func (s *Server) DeletePeer(peer wireguard.Peer) error {
// Delete WireGuard peer // Delete WireGuard peer
if err := s.wg.RemovePeer(peer.PublicKey); err != nil { if err := s.wg.RemovePeer(peer.DeviceName, peer.PublicKey); err != nil {
return errors.WithMessage(err, "failed to remove WireGuard peer") return errors.WithMessage(err, "failed to remove WireGuard peer")
} }
@@ -167,15 +158,16 @@ func (s *Server) DeletePeer(peer Peer) error {
return errors.WithMessage(err, "failed to remove peer") return errors.WithMessage(err, "failed to remove peer")
} }
return s.WriteWireGuardConfigFile() return s.WriteWireGuardConfigFile(peer.DeviceName)
} }
func (s *Server) RestoreWireGuardInterface() error { // RestoreWireGuardInterface restores the state of the physical WireGuard interface from the database.
activePeers := s.peers.GetActivePeers() func (s *Server) RestoreWireGuardInterface(device string) error {
activePeers := s.peers.GetActivePeers(device)
for i := range activePeers { for i := range activePeers {
if activePeers[i].Peer == nil { if activePeers[i].Peer == nil {
if err := s.wg.AddPeer(activePeers[i].GetConfig()); err != nil { if err := s.wg.AddPeer(device, activePeers[i].GetConfig()); err != nil {
return errors.WithMessage(err, "failed to add WireGuard peer") return errors.WithMessage(err, "failed to add WireGuard peer")
} }
} }
@@ -184,26 +176,29 @@ func (s *Server) RestoreWireGuardInterface() error {
return nil return nil
} }
func (s *Server) WriteWireGuardConfigFile() error { // WriteWireGuardConfigFile writes the configuration file for the physical WireGuard interface.
if s.config.WG.WireGuardConfig == "" { func (s *Server) WriteWireGuardConfigFile(device string) error {
if s.config.WG.ConfigDirectoryPath == "" {
return nil // writing disabled return nil // writing disabled
} }
if err := syscall.Access(s.config.WG.WireGuardConfig, syscall.O_RDWR); err != nil { if err := syscall.Access(s.config.WG.ConfigDirectoryPath, syscall.O_RDWR); err != nil {
return errors.Wrap(err, "failed to check WireGuard config access rights") return errors.Wrap(err, "failed to check WireGuard config access rights")
} }
device := s.peers.GetDevice() dev := s.peers.GetDevice(device)
cfg, err := device.GetConfigFile(s.peers.GetActivePeers()) cfg, err := dev.GetConfigFile(s.peers.GetActivePeers(device))
if err != nil { if err != nil {
return errors.WithMessage(err, "failed to get config file") return errors.WithMessage(err, "failed to get config file")
} }
if err := ioutil.WriteFile(s.config.WG.WireGuardConfig, cfg, 0644); err != nil { filePath := path.Join(s.config.WG.ConfigDirectoryPath, dev.DeviceName+".conf")
if err := ioutil.WriteFile(filePath, cfg, 0644); err != nil {
return errors.Wrap(err, "failed to write WireGuard config file") return errors.Wrap(err, "failed to write WireGuard config file")
} }
return nil return nil
} }
func (s *Server) CreateUser(user users.User) error { // CreateUser creates the user in the database and optionally adds a default WireGuard peer for the user.
func (s *Server) CreateUser(user users.User, device string) error {
if user.Email == "" { if user.Email == "" {
return errors.New("cannot create user with empty email address") return errors.New("cannot create user with empty email address")
} }
@@ -220,9 +215,11 @@ func (s *Server) CreateUser(user users.User) error {
} }
// Check if user already has a peer setup, if not, create one // Check if user already has a peer setup, if not, create one
return s.CreateUserDefaultPeer(user.Email) return s.CreateUserDefaultPeer(user.Email, device)
} }
// UpdateUser updates the user in the database. If the user is marked as deleted, it will get remove from the database.
// Also, if the user is re-enabled, all it's linked WireGuard peers will be activated again.
func (s *Server) UpdateUser(user users.User) error { func (s *Server) UpdateUser(user users.User) error {
if user.DeletedAt.Valid { if user.DeletedAt.Valid {
return s.DeleteUser(user) return s.DeleteUser(user)
@@ -249,6 +246,8 @@ func (s *Server) UpdateUser(user users.User) error {
return nil return nil
} }
// DeleteUser removes the user from the database.
// Also, if the user has linked WireGuard peers, they will be deactivated.
func (s *Server) DeleteUser(user users.User) error { func (s *Server) DeleteUser(user users.User) error {
currentUser := s.users.GetUserUnscoped(user.Email) currentUser := s.users.GetUserUnscoped(user.Email)
@@ -271,7 +270,7 @@ func (s *Server) DeleteUser(user users.User) error {
return nil return nil
} }
func (s *Server) CreateUserDefaultPeer(email string) error { func (s *Server) CreateUserDefaultPeer(email, device string) error {
// Check if user is active, if not, quit // Check if user is active, if not, quit
var existingUser *users.User var existingUser *users.User
if existingUser = s.users.GetUser(email); existingUser == nil { if existingUser = s.users.GetUser(email); existingUser == nil {
@@ -282,7 +281,7 @@ func (s *Server) CreateUserDefaultPeer(email string) error {
if s.config.Core.CreateDefaultPeer { if s.config.Core.CreateDefaultPeer {
peers := s.peers.GetPeersByMail(email) peers := s.peers.GetPeersByMail(email)
if len(peers) == 0 { // Create default vpn peer if len(peers) == 0 { // Create default vpn peer
if err := s.CreatePeer(Peer{ if err := s.CreatePeer(device, wireguard.Peer{
Identifier: existingUser.Firstname + " " + existingUser.Lastname + " (Default)", Identifier: existingUser.Firstname + " " + existingUser.Lastname + " (Default)",
Email: existingUser.Email, Email: existingUser.Email,
CreatedBy: existingUser.Email, CreatedBy: existingUser.Email,

View File

@@ -1,17 +0,0 @@
package users
type SupportedDatabase string
const (
SupportedDatabaseMySQL SupportedDatabase = "mysql"
SupportedDatabaseSQLite SupportedDatabase = "sqlite"
)
type Config struct {
Typ SupportedDatabase `yaml:"typ" envconfig:"DATABASE_TYPE"` //mysql or sqlite
Host string `yaml:"host" envconfig:"DATABASE_HOST"`
Port int `yaml:"port" envconfig:"DATABASE_PORT"`
Database string `yaml:"database" envconfig:"DATABASE_NAME"` // On SQLite: the database file-path, otherwise the database name
User string `yaml:"user" envconfig:"DATABASE_USERNAME"`
Password string `yaml:"password" envconfig:"DATABASE_PASSWORD"`
}

View File

@@ -1,9 +1,6 @@
package users package users
import ( import (
"fmt"
"os"
"path/filepath"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@@ -11,69 +8,15 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger"
) )
func GetDatabaseForConfig(cfg *Config) (db *gorm.DB, err error) {
switch cfg.Typ {
case SupportedDatabaseSQLite:
if _, err = os.Stat(filepath.Dir(cfg.Database)); os.IsNotExist(err) {
if err = os.MkdirAll(filepath.Dir(cfg.Database), 0700); err != nil {
return
}
}
db, err = gorm.Open(sqlite.Open(cfg.Database), &gorm.Config{})
if err != nil {
return
}
case SupportedDatabaseMySQL:
connectionString := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database)
db, err = gorm.Open(mysql.Open(connectionString), &gorm.Config{})
if err != nil {
return
}
sqlDB, _ := db.DB()
sqlDB.SetConnMaxLifetime(time.Minute * 5)
sqlDB.SetMaxIdleConns(2)
sqlDB.SetMaxOpenConns(10)
err = sqlDB.Ping() // This DOES open a connection if necessary. This makes sure the database is accessible
if err != nil {
return nil, errors.Wrap(err, "failed to ping mysql authentication database")
}
}
// Enable Logger (logrus)
logCfg := logger.Config{
SlowThreshold: time.Second, // all slower than one second
Colorful: false,
LogLevel: logger.Silent, // default: log nothing
}
if logrus.StandardLogger().GetLevel() == logrus.TraceLevel {
logCfg.LogLevel = logger.Info
logCfg.SlowThreshold = 500 * time.Millisecond // all slower than half a second
}
db.Config.Logger = logger.New(logrus.StandardLogger(), logCfg)
return
}
type Manager struct { type Manager struct {
db *gorm.DB db *gorm.DB
} }
func NewManager(cfg *Config) (*Manager, error) { func NewManager(db *gorm.DB) (*Manager, error) {
m := &Manager{} m := &Manager{db: db}
var err error
m.db, err = GetDatabaseForConfig(cfg)
if err != nil {
return nil, errors.Wrapf(err, "failed to setup user database %s", cfg.Database)
}
// check if old user table exists (from version <= 1.0.2), if so rename it to peers. // check if old user table exists (from version <= 1.0.2), if so rename it to peers.
if m.db.Migrator().HasTable("users") && !m.db.Migrator().HasTable("peers") { if m.db.Migrator().HasTable("users") && !m.db.Migrator().HasTable("peers") {
@@ -84,14 +27,11 @@ func NewManager(cfg *Config) (*Manager, error) {
} }
} }
return m, m.MigrateUserDB() if err := m.db.AutoMigrate(&User{}); err != nil {
return nil, errors.Wrap(err, "failed to migrate user database")
} }
func (m Manager) MigrateUserDB() error { return m, nil
if err := m.db.AutoMigrate(&User{}); err != nil {
return errors.Wrap(err, "failed to migrate user database")
}
return nil
} }
func (m Manager) GetUsers() []User { func (m Manager) GetUsers() []User {

View File

@@ -1,7 +1,17 @@
package wireguard package wireguard
import "github.com/h44z/wg-portal/internal/common"
type Config struct { type Config struct {
DeviceName string `yaml:"device" envconfig:"WG_DEVICE"` DeviceNames []string `yaml:"devices" envconfig:"WG_DEVICES"` // managed devices
WireGuardConfig string `yaml:"configFile" envconfig:"WG_CONFIG_FILE"` // optional, if set, updates will be written to this file DefaultDeviceName string `yaml:"devices" envconfig:"WG_DEFAULT_DEVICE"` // this device is used for auto-created peers, use GetDefaultDeviceName() to access this field
ConfigDirectoryPath string `yaml:"configDirectory" envconfig:"WG_CONFIG_PATH"` // optional, if set, updates will be written to this path, filename: <devicename>.conf
ManageIPAddresses bool `yaml:"manageIPAddresses" envconfig:"MANAGE_IPS"` // handle ip-address setup of interface ManageIPAddresses bool `yaml:"manageIPAddresses" envconfig:"MANAGE_IPS"` // handle ip-address setup of interface
} }
func (c Config) GetDefaultDeviceName() string {
if c.DefaultDeviceName == "" || !common.ListContains(c.DeviceNames, c.DefaultDeviceName) {
return c.DeviceNames[0]
}
return c.DefaultDeviceName
}

View File

@@ -9,6 +9,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// Manager offers a synchronized management interface to the real WireGuard interface.
type Manager struct { type Manager struct {
Cfg *Config Cfg *Config
wg *wgctrl.Client wg *wgctrl.Client
@@ -25,8 +26,8 @@ func (m *Manager) Init() error {
return nil return nil
} }
func (m *Manager) GetDeviceInfo() (*wgtypes.Device, error) { func (m *Manager) GetDeviceInfo(device string) (*wgtypes.Device, error) {
dev, err := m.wg.Device(m.Cfg.DeviceName) dev, err := m.wg.Device(device)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not get WireGuard device") return nil, errors.Wrap(err, "could not get WireGuard device")
} }
@@ -34,11 +35,11 @@ func (m *Manager) GetDeviceInfo() (*wgtypes.Device, error) {
return dev, nil return dev, nil
} }
func (m *Manager) GetPeerList() ([]wgtypes.Peer, error) { func (m *Manager) GetPeerList(device string) ([]wgtypes.Peer, error) {
m.mux.RLock() m.mux.RLock()
defer m.mux.RUnlock() defer m.mux.RUnlock()
dev, err := m.wg.Device(m.Cfg.DeviceName) dev, err := m.wg.Device(device)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not get WireGuard device") return nil, errors.Wrap(err, "could not get WireGuard device")
} }
@@ -46,7 +47,7 @@ func (m *Manager) GetPeerList() ([]wgtypes.Peer, error) {
return dev.Peers, nil return dev.Peers, nil
} }
func (m *Manager) GetPeer(pubKey string) (*wgtypes.Peer, error) { func (m *Manager) GetPeer(device string, pubKey string) (*wgtypes.Peer, error) {
m.mux.RLock() m.mux.RLock()
defer m.mux.RUnlock() defer m.mux.RUnlock()
@@ -55,7 +56,7 @@ func (m *Manager) GetPeer(pubKey string) (*wgtypes.Peer, error) {
return nil, errors.Wrap(err, "invalid public key") return nil, errors.Wrap(err, "invalid public key")
} }
peers, err := m.GetPeerList() peers, err := m.GetPeerList(device)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not get WireGuard peers") return nil, errors.Wrap(err, "could not get WireGuard peers")
} }
@@ -69,11 +70,11 @@ func (m *Manager) GetPeer(pubKey string) (*wgtypes.Peer, error) {
return nil, errors.Errorf("could not find WireGuard peer: %s", pubKey) return nil, errors.Errorf("could not find WireGuard peer: %s", pubKey)
} }
func (m *Manager) AddPeer(cfg wgtypes.PeerConfig) error { func (m *Manager) AddPeer(device string, cfg wgtypes.PeerConfig) error {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
err := m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}}) err := m.wg.ConfigureDevice(device, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}})
if err != nil { if err != nil {
return errors.Wrap(err, "could not configure WireGuard device") return errors.Wrap(err, "could not configure WireGuard device")
} }
@@ -81,12 +82,12 @@ func (m *Manager) AddPeer(cfg wgtypes.PeerConfig) error {
return nil return nil
} }
func (m *Manager) UpdatePeer(cfg wgtypes.PeerConfig) error { func (m *Manager) UpdatePeer(device string, cfg wgtypes.PeerConfig) error {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
cfg.UpdateOnly = true cfg.UpdateOnly = true
err := m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}}) err := m.wg.ConfigureDevice(device, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}})
if err != nil { if err != nil {
return errors.Wrap(err, "could not configure WireGuard device") return errors.Wrap(err, "could not configure WireGuard device")
} }
@@ -94,7 +95,7 @@ func (m *Manager) UpdatePeer(cfg wgtypes.PeerConfig) error {
return nil return nil
} }
func (m *Manager) RemovePeer(pubKey string) error { func (m *Manager) RemovePeer(device string, pubKey string) error {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
@@ -108,7 +109,7 @@ func (m *Manager) RemovePeer(pubKey string) error {
Remove: true, Remove: true,
} }
err = m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{peer}}) err = m.wg.ConfigureDevice(device, wgtypes.Config{Peers: []wgtypes.PeerConfig{peer}})
if err != nil { if err != nil {
return errors.Wrap(err, "could not configure WireGuard device") return errors.Wrap(err, "could not configure WireGuard device")
} }
@@ -116,6 +117,6 @@ func (m *Manager) RemovePeer(pubKey string) error {
return nil return nil
} }
func (m *Manager) UpdateDevice(name string, cfg wgtypes.Config) error { func (m *Manager) UpdateDevice(device string, cfg wgtypes.Config) error {
return m.wg.ConfigureDevice(name, cfg) return m.wg.ConfigureDevice(device, cfg)
} }

View File

@@ -11,10 +11,10 @@ import (
const DefaultMTU = 1420 const DefaultMTU = 1420
func (m *Manager) GetIPAddress() ([]string, error) { func (m *Manager) GetIPAddress(device string) ([]string, error) {
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName) wgInterface, err := tenus.NewLinkFrom(device)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName) return nil, errors.Wrapf(err, "could not retrieve WireGuard interface %s", device)
} }
// Get golang net.interface // Get golang net.interface
@@ -52,14 +52,14 @@ func (m *Manager) GetIPAddress() ([]string, error) {
return ipAddresses, nil return ipAddresses, nil
} }
func (m *Manager) SetIPAddress(cidrs []string) error { func (m *Manager) SetIPAddress(device string, cidrs []string) error {
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName) wgInterface, err := tenus.NewLinkFrom(device)
if err != nil { if err != nil {
return errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName) return errors.Wrapf(err, "could not retrieve WireGuard interface %s", device)
} }
// First remove existing IP addresses // First remove existing IP addresses
existingIPs, err := m.GetIPAddress() existingIPs, err := m.GetIPAddress(device)
if err != nil { if err != nil {
return errors.Wrap(err, "could not retrieve IP addresses") return errors.Wrap(err, "could not retrieve IP addresses")
} }
@@ -89,10 +89,10 @@ func (m *Manager) SetIPAddress(cidrs []string) error {
return nil return nil
} }
func (m *Manager) GetMTU() (int, error) { func (m *Manager) GetMTU(device string) (int, error) {
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName) wgInterface, err := tenus.NewLinkFrom(device)
if err != nil { if err != nil {
return 0, errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName) return 0, errors.Wrapf(err, "could not retrieve WireGuard interface %s", device)
} }
// Get golang net.interface // Get golang net.interface
@@ -104,10 +104,10 @@ func (m *Manager) GetMTU() (int, error) {
return iface.MTU, nil return iface.MTU, nil
} }
func (m *Manager) SetMTU(mtu int) error { func (m *Manager) SetMTU(device string, mtu int) error {
wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName) wgInterface, err := tenus.NewLinkFrom(device)
if err != nil { if err != nil {
return errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName) return errors.Wrapf(err, "could not retrieve WireGuard interface %s", device)
} }
if mtu == 0 { if mtu == 0 {
@@ -115,7 +115,7 @@ func (m *Manager) SetMTU(mtu int) error {
} }
if err := wgInterface.SetLinkMTU(mtu); err != nil { if err := wgInterface.SetLinkMTU(mtu); err != nil {
return errors.Wrapf(err, "could not set MTU on interface %s", m.Cfg.DeviceName) return errors.Wrapf(err, "could not set MTU on interface %s", device)
} }
return nil return nil

View File

@@ -1,4 +1,4 @@
package server package wireguard
import ( import (
"bytes" "bytes"
@@ -15,8 +15,6 @@ import (
"github.com/gin-gonic/gin/binding" "github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"github.com/h44z/wg-portal/internal/common" "github.com/h44z/wg-portal/internal/common"
"github.com/h44z/wg-portal/internal/users"
"github.com/h44z/wg-portal/internal/wireguard"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/skip2/go-qrcode" "github.com/skip2/go-qrcode"
@@ -66,7 +64,6 @@ func init() {
type Peer struct { type Peer struct {
Peer *wgtypes.Peer `gorm:"-"` // WireGuard peer Peer *wgtypes.Peer `gorm:"-"` // WireGuard peer
User *users.User `gorm:"-"` // user reference for the peer
Config string `gorm:"-"` Config string `gorm:"-"`
UID string `form:"uid" binding:"alphanum"` // uid for html identification UID string `form:"uid" binding:"alphanum"` // uid for html identification
@@ -85,6 +82,7 @@ type Peer struct {
IPs []string `gorm:"-"` // The IPs of the client IPs []string `gorm:"-"` // The IPs of the client
PrivateKey string `form:"privkey" binding:"omitempty,base64"` PrivateKey string `form:"privkey" binding:"omitempty,base64"`
PublicKey string `gorm:"primaryKey" form:"pubkey" binding:"required,base64"` PublicKey string `gorm:"primaryKey" form:"pubkey" binding:"required,base64"`
DeviceName string `gorm:"index"`
DeactivatedAt *time.Time DeactivatedAt *time.Time
CreatedBy string CreatedBy string
@@ -122,7 +120,7 @@ func (p Peer) GetConfig() wgtypes.PeerConfig {
} }
func (p Peer) GetConfigFile(device Device) ([]byte, error) { func (p Peer) GetConfigFile(device Device) ([]byte, error) {
tpl, err := template.New("client").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(wireguard.ClientCfgTpl) tpl, err := template.New("client").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(ClientCfgTpl)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to parse client template") return nil, errors.Wrap(err, "failed to parse client template")
} }
@@ -245,7 +243,7 @@ func (d Device) GetConfig() wgtypes.Config {
} }
func (d Device) GetConfigFile(peers []Peer) ([]byte, error) { func (d Device) GetConfigFile(peers []Peer) ([]byte, error) {
tpl, err := template.New("server").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(wireguard.DeviceCfgTpl) tpl, err := template.New("server").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(DeviceCfgTpl)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to parse server template") return nil, errors.Wrap(err, "failed to parse server template")
} }
@@ -272,62 +270,70 @@ func (d Device) GetConfigFile(peers []Peer) ([]byte, error) {
type PeerManager struct { type PeerManager struct {
db *gorm.DB db *gorm.DB
wg *wireguard.Manager wg *Manager
users *users.Manager
} }
func NewPeerManager(cfg *common.Config, wg *wireguard.Manager, userDB *users.Manager) (*PeerManager, error) { func NewPeerManager(db *gorm.DB, wg *Manager) (*PeerManager, error) {
um := &PeerManager{wg: wg, users: userDB} pm := &PeerManager{db: db, wg: wg}
var err error
um.db, err = users.GetDatabaseForConfig(&cfg.Database)
if err != nil {
return nil, errors.WithMessage(err, "failed to open peer database")
}
err = um.db.AutoMigrate(&Peer{}, &Device{}) if err := pm.db.AutoMigrate(&Peer{}, &Device{}); err != nil {
if err != nil {
return nil, errors.WithMessage(err, "failed to migrate peer database") return nil, errors.WithMessage(err, "failed to migrate peer database")
} }
return um, nil // check if peers without device name exist (from version <= 1.0.3), if so assign them to the default device.
peers := make([]Peer, 0)
pm.db.Find(&peers)
for i := range peers {
if peers[i].DeviceName == "" {
peers[i].DeviceName = wg.Cfg.GetDefaultDeviceName()
pm.db.Save(&peers[i])
}
} }
func (u *PeerManager) InitFromCurrentInterface() error { return pm, nil
peers, err := u.wg.GetPeerList()
if err != nil {
return errors.Wrapf(err, "failed to get peer list")
} }
device, err := u.wg.GetDeviceInfo()
// InitFromPhysicalInterface read all WireGuard peers from the WireGuard interface configuration. If a peer does not
// exist in the local database, it gets created.
func (m *PeerManager) InitFromPhysicalInterface() error {
for _, deviceName := range m.wg.Cfg.DeviceNames {
peers, err := m.wg.GetPeerList(deviceName)
if err != nil { if err != nil {
return errors.Wrapf(err, "failed to get device info") return errors.Wrapf(err, "failed to get peer list for device %s", deviceName)
}
device, err := m.wg.GetDeviceInfo(deviceName)
if err != nil {
return errors.Wrapf(err, "failed to get device info for device %s", deviceName)
} }
var ipAddresses []string var ipAddresses []string
var mtu int var mtu int
if u.wg.Cfg.ManageIPAddresses { if m.wg.Cfg.ManageIPAddresses {
if ipAddresses, err = u.wg.GetIPAddress(); err != nil { if ipAddresses, err = m.wg.GetIPAddress(deviceName); err != nil {
return errors.Wrapf(err, "failed to get ip address") return errors.Wrapf(err, "failed to get ip address for device %s", deviceName)
} }
if mtu, err = u.wg.GetMTU(); err != nil { if mtu, err = m.wg.GetMTU(deviceName); err != nil {
return errors.Wrapf(err, "failed to get MTU") return errors.Wrapf(err, "failed to get MTU for device %s", deviceName)
} }
} }
// Check if entries already exist in database, if not create them // Check if entries already exist in database, if not create them
for _, peer := range peers { for _, peer := range peers {
if err := u.validateOrCreatePeer(peer); err != nil { if err := m.validateOrCreatePeer(deviceName, peer); err != nil {
return errors.WithMessagef(err, "failed to validate peer %s", peer.PublicKey) return errors.WithMessagef(err, "failed to validate peer %s for device %s", peer.PublicKey, deviceName)
} }
} }
if err := u.validateOrCreateDevice(*device, ipAddresses, mtu); err != nil { if err := m.validateOrCreateDevice(*device, ipAddresses, mtu); err != nil {
return errors.WithMessagef(err, "failed to validate device %s", device.Name) return errors.WithMessagef(err, "failed to validate device %s", device.Name)
} }
}
return nil return nil
} }
func (u *PeerManager) validateOrCreatePeer(wgPeer wgtypes.Peer) error { // validateOrCreatePeer checks if the given WireGuard peer already exists in the database, if not, the peer entry will be created
func (m *PeerManager) validateOrCreatePeer(device string, wgPeer wgtypes.Peer) error {
peer := Peer{} peer := Peer{}
u.db.Where("public_key = ?", wgPeer.PublicKey.String()).FirstOrInit(&peer) m.db.Where("public_key = ?", wgPeer.PublicKey.String()).FirstOrInit(&peer)
if peer.PublicKey == "" { // peer not found, create if peer.PublicKey == "" { // peer not found, create
peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(wgPeer.PublicKey.String()))) peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(wgPeer.PublicKey.String())))
@@ -347,8 +353,9 @@ func (u *PeerManager) validateOrCreatePeer(wgPeer wgtypes.Peer) error {
} }
peer.AllowedIPsStr = strings.Join(peer.AllowedIPs, ", ") peer.AllowedIPsStr = strings.Join(peer.AllowedIPs, ", ")
peer.IPsStr = strings.Join(peer.IPs, ", ") peer.IPsStr = strings.Join(peer.IPs, ", ")
peer.DeviceName = device
res := u.db.Create(&peer) res := m.db.Create(&peer)
if res.Error != nil { if res.Error != nil {
return errors.Wrapf(res.Error, "failed to create autodetected peer %s", peer.PublicKey) return errors.Wrapf(res.Error, "failed to create autodetected peer %s", peer.PublicKey)
} }
@@ -357,9 +364,10 @@ func (u *PeerManager) validateOrCreatePeer(wgPeer wgtypes.Peer) error {
return nil return nil
} }
func (u *PeerManager) validateOrCreateDevice(dev wgtypes.Device, ipAddresses []string, mtu int) error { // validateOrCreateDevice checks if the given WireGuard device already exists in the database, if not, the peer entry will be created
func (m *PeerManager) validateOrCreateDevice(dev wgtypes.Device, ipAddresses []string, mtu int) error {
device := Device{} device := Device{}
u.db.Where("device_name = ?", dev.Name).FirstOrInit(&device) m.db.Where("device_name = ?", dev.Name).FirstOrInit(&device)
if device.PublicKey == "" { // device not found, create if device.PublicKey == "" { // device not found, create
device.PublicKey = dev.PublicKey.String() device.PublicKey = dev.PublicKey.String()
@@ -369,12 +377,12 @@ func (u *PeerManager) validateOrCreateDevice(dev wgtypes.Device, ipAddresses []s
device.Mtu = 0 device.Mtu = 0
device.PersistentKeepalive = 16 // Default device.PersistentKeepalive = 16 // Default
device.IPsStr = strings.Join(ipAddresses, ", ") device.IPsStr = strings.Join(ipAddresses, ", ")
if mtu == wireguard.DefaultMTU { if mtu == DefaultMTU {
mtu = 0 mtu = 0
} }
device.Mtu = mtu device.Mtu = mtu
res := u.db.Create(&device) res := m.db.Create(&device)
if res.Error != nil { if res.Error != nil {
return errors.Wrapf(res.Error, "failed to create autodetected device") return errors.Wrapf(res.Error, "failed to create autodetected device")
} }
@@ -383,21 +391,22 @@ func (u *PeerManager) validateOrCreateDevice(dev wgtypes.Device, ipAddresses []s
return nil return nil
} }
func (u *PeerManager) populatePeerData(peer *Peer) { // populatePeerData enriches the peer struct with WireGuard live data like last handshake, ...
func (m *PeerManager) populatePeerData(peer *Peer) {
peer.AllowedIPs = strings.Split(peer.AllowedIPsStr, ", ") peer.AllowedIPs = strings.Split(peer.AllowedIPsStr, ", ")
peer.IPs = strings.Split(peer.IPsStr, ", ") peer.IPs = strings.Split(peer.IPsStr, ", ")
// Set config file // Set config file
tmpCfg, _ := peer.GetConfigFile(u.GetDevice()) tmpCfg, _ := peer.GetConfigFile(m.GetDevice(peer.DeviceName))
peer.Config = string(tmpCfg) peer.Config = string(tmpCfg)
// set data from WireGuard interface // set data from WireGuard interface
peer.Peer, _ = u.wg.GetPeer(peer.PublicKey) peer.Peer, _ = m.wg.GetPeer(peer.DeviceName, peer.PublicKey)
peer.LastHandshake = "never" peer.LastHandshake = "never"
peer.LastHandshakeTime = "Never connected, or user is disabled." peer.LastHandshakeTime = "Never connected, or user is disabled."
if peer.Peer != nil { if peer.Peer != nil {
since := time.Since(peer.Peer.LastHandshakeTime) since := time.Since(peer.Peer.LastHandshakeTime)
sinceSeconds := int(since.Round(time.Second).Seconds()) sinceSeconds := int(since.Round(time.Second).Seconds())
sinceMinutes := int(sinceSeconds / 60) sinceMinutes := sinceSeconds / 60
sinceSeconds -= sinceMinutes * 60 sinceSeconds -= sinceMinutes * 60
if sinceMinutes > 2*10080 { // 2 weeks if sinceMinutes > 2*10080 { // 2 weeks
@@ -410,49 +419,47 @@ func (u *PeerManager) populatePeerData(peer *Peer) {
peer.LastHandshakeTime = peer.Peer.LastHandshakeTime.Format(time.UnixDate) peer.LastHandshakeTime = peer.Peer.LastHandshakeTime.Format(time.UnixDate)
} }
peer.IsOnline = false peer.IsOnline = false
// set user data
peer.User = u.users.GetUser(peer.Email)
} }
func (u *PeerManager) populateDeviceData(device *Device) { // populateDeviceData enriches the device struct with WireGuard live data like interface information
func (m *PeerManager) populateDeviceData(device *Device) {
device.AllowedIPs = strings.Split(device.AllowedIPsStr, ", ") device.AllowedIPs = strings.Split(device.AllowedIPsStr, ", ")
device.IPs = strings.Split(device.IPsStr, ", ") device.IPs = strings.Split(device.IPsStr, ", ")
device.DNS = strings.Split(device.DNSStr, ", ") device.DNS = strings.Split(device.DNSStr, ", ")
// set data from WireGuard interface // set data from WireGuard interface
device.Interface, _ = u.wg.GetDeviceInfo() device.Interface, _ = m.wg.GetDeviceInfo(device.DeviceName)
} }
func (u *PeerManager) GetAllPeers() []Peer { func (m *PeerManager) GetAllPeers(device string) []Peer {
peers := make([]Peer, 0) peers := make([]Peer, 0)
u.db.Find(&peers) m.db.Where("device_name = ?", device).Find(&peers)
for i := range peers { for i := range peers {
u.populatePeerData(&peers[i]) m.populatePeerData(&peers[i])
} }
return peers return peers
} }
func (u *PeerManager) GetActivePeers() []Peer { func (m *PeerManager) GetActivePeers(device string) []Peer {
peers := make([]Peer, 0) peers := make([]Peer, 0)
u.db.Where("deactivated_at IS NULL").Find(&peers) m.db.Where("device_name = ? AND deactivated_at IS NULL", device).Find(&peers)
for i := range peers { for i := range peers {
u.populatePeerData(&peers[i]) m.populatePeerData(&peers[i])
} }
return peers return peers
} }
func (u *PeerManager) GetFilteredAndSortedPeers(sortKey, sortDirection, search string) []Peer { func (m *PeerManager) GetFilteredAndSortedPeers(device, sortKey, sortDirection, search string) []Peer {
peers := make([]Peer, 0) peers := make([]Peer, 0)
u.db.Find(&peers) m.db.Where("device_name = ?", device).Find(&peers)
filteredPeers := make([]Peer, 0, len(peers)) filteredPeers := make([]Peer, 0, len(peers))
for i := range peers { for i := range peers {
u.populatePeerData(&peers[i]) m.populatePeerData(&peers[i])
if search == "" || if search == "" ||
strings.Contains(peers[i].Email, search) || strings.Contains(peers[i].Email, search) ||
@@ -499,12 +506,12 @@ func (u *PeerManager) GetFilteredAndSortedPeers(sortKey, sortDirection, search s
return filteredPeers return filteredPeers
} }
func (u *PeerManager) GetSortedPeersForEmail(sortKey, sortDirection, email string) []Peer { func (m *PeerManager) GetSortedPeersForEmail(sortKey, sortDirection, email string) []Peer {
peers := make([]Peer, 0) peers := make([]Peer, 0)
u.db.Where("email = ?", email).Find(&peers) m.db.Where("email = ?", email).Find(&peers)
for i := range peers { for i := range peers {
u.populatePeerData(&peers[i]) m.populatePeerData(&peers[i])
} }
sort.Slice(peers, func(i, j int) bool { sort.Slice(peers, func(i, j int) bool {
@@ -544,42 +551,42 @@ func (u *PeerManager) GetSortedPeersForEmail(sortKey, sortDirection, email strin
return peers return peers
} }
func (u *PeerManager) GetDevice() Device { func (m *PeerManager) GetDevice(device string) Device {
devices := make([]Device, 0, 1) dev := Device{}
u.db.Find(&devices)
for i := range devices { m.db.Where("device_name = ?", device).First(&dev)
u.populateDeviceData(&devices[i]) m.populateDeviceData(&dev)
return dev
} }
return devices[0] // use first device for now... more to come? func (m *PeerManager) GetPeerByKey(publicKey string) Peer {
}
func (u *PeerManager) GetPeerByKey(publicKey string) Peer {
peer := Peer{} peer := Peer{}
u.db.Where("public_key = ?", publicKey).FirstOrInit(&peer) m.db.Where("public_key = ?", publicKey).FirstOrInit(&peer)
u.populatePeerData(&peer) m.populatePeerData(&peer)
return peer return peer
} }
func (u *PeerManager) GetPeersByMail(mail string) []Peer { func (m *PeerManager) GetPeersByMail(mail string) []Peer {
var peers []Peer var peers []Peer
u.db.Where("email = ?", mail).Find(&peers) m.db.Where("email = ?", mail).Find(&peers)
for i := range peers { for i := range peers {
u.populatePeerData(&peers[i]) m.populatePeerData(&peers[i])
} }
return peers return peers
} }
func (u *PeerManager) CreatePeer(peer Peer) error { // ---- Database helpers -----
func (m *PeerManager) CreatePeer(peer Peer) error {
peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey))) peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey)))
peer.UpdatedAt = time.Now() peer.UpdatedAt = time.Now()
peer.CreatedAt = time.Now() peer.CreatedAt = time.Now()
peer.AllowedIPsStr = strings.Join(peer.AllowedIPs, ", ") peer.AllowedIPsStr = strings.Join(peer.AllowedIPs, ", ")
peer.IPsStr = strings.Join(peer.IPs, ", ") peer.IPsStr = strings.Join(peer.IPs, ", ")
res := u.db.Create(&peer) res := m.db.Create(&peer)
if res.Error != nil { if res.Error != nil {
logrus.Errorf("failed to create peer: %v", res.Error) logrus.Errorf("failed to create peer: %v", res.Error)
return errors.Wrap(res.Error, "failed to create peer") return errors.Wrap(res.Error, "failed to create peer")
@@ -588,12 +595,12 @@ func (u *PeerManager) CreatePeer(peer Peer) error {
return nil return nil
} }
func (u *PeerManager) UpdatePeer(peer Peer) error { func (m *PeerManager) UpdatePeer(peer Peer) error {
peer.UpdatedAt = time.Now() peer.UpdatedAt = time.Now()
peer.AllowedIPsStr = strings.Join(peer.AllowedIPs, ", ") peer.AllowedIPsStr = strings.Join(peer.AllowedIPs, ", ")
peer.IPsStr = strings.Join(peer.IPs, ", ") peer.IPsStr = strings.Join(peer.IPs, ", ")
res := u.db.Save(&peer) res := m.db.Save(&peer)
if res.Error != nil { if res.Error != nil {
logrus.Errorf("failed to update peer: %v", res.Error) logrus.Errorf("failed to update peer: %v", res.Error)
return errors.Wrap(res.Error, "failed to update peer") return errors.Wrap(res.Error, "failed to update peer")
@@ -602,8 +609,8 @@ func (u *PeerManager) UpdatePeer(peer Peer) error {
return nil return nil
} }
func (u *PeerManager) DeletePeer(peer Peer) error { func (m *PeerManager) DeletePeer(peer Peer) error {
res := u.db.Delete(&peer) res := m.db.Delete(&peer)
if res.Error != nil { if res.Error != nil {
logrus.Errorf("failed to delete peer: %v", res.Error) logrus.Errorf("failed to delete peer: %v", res.Error)
return errors.Wrap(res.Error, "failed to delete peer") return errors.Wrap(res.Error, "failed to delete peer")
@@ -612,13 +619,13 @@ func (u *PeerManager) DeletePeer(peer Peer) error {
return nil return nil
} }
func (u *PeerManager) UpdateDevice(device Device) error { func (m *PeerManager) UpdateDevice(device Device) error {
device.UpdatedAt = time.Now() device.UpdatedAt = time.Now()
device.AllowedIPsStr = strings.Join(device.AllowedIPs, ", ") device.AllowedIPsStr = strings.Join(device.AllowedIPs, ", ")
device.IPsStr = strings.Join(device.IPs, ", ") device.IPsStr = strings.Join(device.IPs, ", ")
device.DNSStr = strings.Join(device.DNS, ", ") device.DNSStr = strings.Join(device.DNS, ", ")
res := u.db.Save(&device) res := m.db.Save(&device)
if res.Error != nil { if res.Error != nil {
logrus.Errorf("failed to update device: %v", res.Error) logrus.Errorf("failed to update device: %v", res.Error)
return errors.Wrap(res.Error, "failed to update device") return errors.Wrap(res.Error, "failed to update device")
@@ -627,9 +634,11 @@ func (u *PeerManager) UpdateDevice(device Device) error {
return nil return nil
} }
func (u *PeerManager) GetAllReservedIps() ([]string, error) { // ---- IP helpers ----
func (m *PeerManager) GetAllReservedIps(device string) ([]string, error) {
reservedIps := make([]string, 0) reservedIps := make([]string, 0)
peers := u.GetAllPeers() peers := m.GetAllPeers(device)
for _, user := range peers { for _, user := range peers {
for _, cidr := range user.IPs { for _, cidr := range user.IPs {
if cidr == "" { if cidr == "" {
@@ -643,8 +652,8 @@ func (u *PeerManager) GetAllReservedIps() ([]string, error) {
} }
} }
device := u.GetDevice() dev := m.GetDevice(device)
for _, cidr := range device.IPs { for _, cidr := range dev.IPs {
if cidr == "" { if cidr == "" {
continue continue
} }
@@ -659,8 +668,8 @@ func (u *PeerManager) GetAllReservedIps() ([]string, error) {
return reservedIps, nil return reservedIps, nil
} }
func (u *PeerManager) IsIPReserved(cidr string) bool { func (m *PeerManager) IsIPReserved(device string, cidr string) bool {
reserved, err := u.GetAllReservedIps() reserved, err := m.GetAllReservedIps(device)
if err != nil { if err != nil {
return true // in case something failed, assume the ip is reserved return true // in case something failed, assume the ip is reserved
} }
@@ -688,10 +697,10 @@ func (u *PeerManager) IsIPReserved(cidr string) bool {
} }
// GetAvailableIp search for an available ip in cidr against a list of reserved ips // GetAvailableIp search for an available ip in cidr against a list of reserved ips
func (u *PeerManager) GetAvailableIp(cidr string) (string, error) { func (m *PeerManager) GetAvailableIp(device string, cidr string) (string, error) {
reserved, err := u.GetAllReservedIps() reserved, err := m.GetAllReservedIps(device)
if err != nil { if err != nil {
return "", errors.WithMessage(err, "failed to get all reserved IP addresses") return "", errors.WithMessagef(err, "failed to get all reserved IP addresses for %s", device)
} }
ip, ipnet, err := net.ParseCIDR(cidr) ip, ipnet, err := net.ParseCIDR(cidr)
if err != nil { if err != nil {

View File

@@ -1,6 +1,9 @@
LISTENING_ADDRESS=:8080 LISTENING_ADDRESS=:8080
WG_DEVICES=wg0
WG_DEFAULT_DEVICE=wg0
WG_CONFIG_PATH=/etc/wireguard
EXTERNAL_URL=https://vpn.company.com EXTERNAL_URL=https://vpn.company.com
WEBSITE_TITLE=WireGuard VPN WEBSITE_TITLE=WireGuard VPN
COMPANY_NAME=Your Company Name COMPANY_NAME=Your Company Name
ADMIN_USER=admin ADMIN_USER=admin@wgportal.local
ADMIN_PASS=supersecret ADMIN_PASS=supersecret