diff --git a/README.md b/README.md index 03522f8..fd44869 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,8 @@ of wireproxy by [@juev](https://github.com/juev). ``` usage: wireproxy [-h|--help] [-c|--config ""] [-s|--silent] - [-d|--daemon] [-v|--version] [-n|--configtest] + [-d|--daemon] [-i|--info ""] [-v|--version] + [-n|--configtest] Userspace wireguard client for proxying @@ -48,9 +49,11 @@ Arguments: -c --config Path of configuration file -s --silent Silent mode -d --daemon Make wireproxy run in background + -i --info Specify the address and port for exposing health status -v --version Print version -n --configtest Configtest mode. Only check the configuration file for validity. + ``` # Build instruction @@ -188,6 +191,64 @@ PublicKey = YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY= AllowedIPs = 10.254.254.100/32 # Note there is no Endpoint defined here. ``` +# Health endpoint +Wireproxy supports exposing a health endpoint for monitoring purposes. +The argument `--info/-i` specifies an address and port (e.g. `localhost:9080`), which exposes a HTTP server that provides health status metric of the server. + +Currently two endpoints are implemented: + +`/metrics`: Exposes information of the wireguard daemon, this provides the same information you would get with `wg show`. [This](https://www.wireguard.com/xplatform/#example-dialog) shows an example of what the response would look like. + +`/readyz`: This responds with a json which shows the last time a pong is received from an IP specified with `CheckAlive`. When `CheckAlive` is set, a ping is sent out to addresses in `CheckAlive` per `CheckAliveInterval` seconds (defaults to 5) via wireguard. If a pong has not been received from one of the addresses within the last `CheckAliveInterval` seconds (+2 seconds for some leeway to account for latency), then it would respond with a 503, otherwise a 200. + +For example: +``` +[Interface] +PrivateKey = censored +Address = 10.2.0.2/32 +DNS = 10.2.0.1 +CheckAlive = 1.1.1.1, 3.3.3.3 +CheckAliveInterval = 3 + +[Peer] +PublicKey = censored +AllowedIPs = 0.0.0.0/0 +Endpoint = 149.34.244.174:51820 + +[Socks5] +BindAddress = 127.0.0.1:25344 +``` +`/readyz` would respond with +``` +< HTTP/1.1 503 Service Unavailable +< Date: Thu, 11 Apr 2024 00:54:59 GMT +< Content-Length: 35 +< Content-Type: text/plain; charset=utf-8 +< +{"1.1.1.1":1712796899,"3.3.3.3":0} +``` + +And for: +``` +[Interface] +PrivateKey = censored +Address = 10.2.0.2/32 +DNS = 10.2.0.1 +CheckAlive = 1.1.1.1 +``` +`/readyz` would respond with +``` +< HTTP/1.1 200 OK +< Date: Thu, 11 Apr 2024 00:56:21 GMT +< Content-Length: 23 +< Content-Type: text/plain; charset=utf-8 +< +{"1.1.1.1":1712796979} +``` + +If nothing is set for `CheckAlive`, an empty JSON object with 200 will be the response. + +The peer which the ICMP ping packet is routed to depends on the `AllowedIPs` set for each peers. # Stargazers over time [![Stargazers over time](https://starchart.cc/octeep/wireproxy.svg)](https://starchart.cc/octeep/wireproxy) diff --git a/cmd/wireproxy/main.go b/cmd/wireproxy/main.go index 2bbe4a2..9b10dfb 100644 --- a/cmd/wireproxy/main.go +++ b/cmd/wireproxy/main.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "net/http" "os" "os/exec" "os/signal" @@ -78,6 +79,7 @@ func main() { config := parser.String("c", "config", &argparse.Options{Help: "Path of configuration file"}) silent := parser.Flag("s", "silent", &argparse.Options{Help: "Silent mode"}) daemon := parser.Flag("d", "daemon", &argparse.Options{Help: "Make wireproxy run in background"}) + info := parser.String("i", "info", &argparse.Options{Help: "Specify the address and port for exposing health status"}) printVerison := parser.Flag("v", "version", &argparse.Options{Help: "Print version"}) configTest := parser.Flag("n", "configtest", &argparse.Options{Help: "Configtest mode. Only check the configuration file for validity."}) @@ -140,13 +142,24 @@ func main() { // no file access is allowed from now on, only networking pledgeOrPanic("stdio inet dns") - tnet, err := wireproxy.StartWireguard(conf.Device, logLevel) + tun, err := wireproxy.StartWireguard(conf.Device, logLevel) if err != nil { log.Fatal(err) } for _, spawner := range conf.Routines { - go spawner.SpawnRoutine(tnet) + go spawner.SpawnRoutine(tun) + } + + tun.StartPingIPs() + + if *info != "" { + go func() { + err := http.ListenAndServe(*info, tun) + if err != nil { + panic(err) + } + }() } <-ctx.Done() diff --git a/config.go b/config.go index 4f363ec..76593cf 100644 --- a/config.go +++ b/config.go @@ -22,12 +22,14 @@ type PeerConfig struct { // DeviceConfig contains the information to initiate a wireguard connection type DeviceConfig struct { - SecretKey string - Endpoint []netip.Addr - Peers []PeerConfig - DNS []netip.Addr - MTU int - ListenPort *int + SecretKey string + Endpoint []netip.Addr + Peers []PeerConfig + DNS []netip.Addr + MTU int + ListenPort *int + CheckAlive []netip.Addr + CheckAliveInterval int } type TCPClientTunnelConfig struct { @@ -237,6 +239,25 @@ func ParseInterface(cfg *ini.File, device *DeviceConfig) error { device.ListenPort = &value } + checkAlive, err := parseNetIP(section, "CheckAlive") + if err != nil { + return err + } + device.CheckAlive = checkAlive + + device.CheckAliveInterval = 5 + if sectionKey, err := section.GetKey("CheckAliveInterval"); err == nil { + value, err := sectionKey.Int() + if err != nil { + return err + } + if len(checkAlive) == 0 { + return errors.New("CheckAliveInterval is only valid when CheckAlive is set") + } + + device.CheckAliveInterval = value + } + return nil } diff --git a/routine.go b/routine.go index 13fa944..465e6b1 100644 --- a/routine.go +++ b/routine.go @@ -1,15 +1,27 @@ package wireproxy import ( + "bytes" "context" + srand "crypto/rand" "crypto/subtle" + "encoding/binary" + "encoding/json" "errors" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/device" "io" "log" "math/rand" "net" + "net/http" "os" + "path" "strconv" + "strings" + "time" "github.com/sourcegraph/conc" "github.com/things-go/go-socks5" @@ -32,7 +44,11 @@ type CredentialValidator struct { // VirtualTun stores a reference to netstack network and DNS configuration type VirtualTun struct { Tnet *netstack.Net + Dev *device.Device SystemDNS bool + Conf *DeviceConfig + // PingRecord stores the last time an IP was pinged + PingRecord map[string]uint64 } // RoutineSpawner spawns a routine (e.g. socks5, tcp static routes) after the configuration is parsed @@ -148,16 +164,16 @@ func (config *Socks5Config) SpawnRoutine(vt *VirtualTun) { // SpawnRoutine spawns a http server. func (config *HTTPConfig) SpawnRoutine(vt *VirtualTun) { - http := &HTTPServer{ + server := &HTTPServer{ config: config, dial: vt.Tnet.Dial, auth: CredentialValidator{config.Username, config.Password}, } if config.Username != "" || config.Password != "" { - http.authRequired = true + server.authRequired = true } - if err := http.ListenAndServe("tcp", config.BindAddress); err != nil { + if err := server.ListenAndServe("tcp", config.BindAddress); err != nil { log.Fatal(err) } } @@ -330,3 +346,151 @@ func (conf *TCPServerTunnelConfig) SpawnRoutine(vt *VirtualTun) { go tcpServerForward(vt, raddr, conn) } } + +func (d VirtualTun) ServeHTTP(w http.ResponseWriter, r *http.Request) { + log.Printf("Health metric request: %s\n", r.URL.Path) + switch path.Clean(r.URL.Path) { + case "/readyz": + body, err := json.Marshal(d.PingRecord) + if err != nil { + errorLogger.Printf("Failed to get device metrics: %s\n", err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + + status := http.StatusOK + for _, record := range d.PingRecord { + lastPong := time.Unix(int64(record), 0) + // +2 seconds to account for the time it takes to ping the IP + if time.Since(lastPong) > time.Duration(d.Conf.CheckAliveInterval+2)*time.Second { + status = http.StatusServiceUnavailable + break + } + } + + w.WriteHeader(status) + _, _ = w.Write(body) + _, _ = w.Write([]byte("\n")) + case "/metrics": + get, err := d.Dev.IpcGet() + if err != nil { + errorLogger.Printf("Failed to get device metrics: %s\n", err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + var buf bytes.Buffer + for _, peer := range strings.Split(get, "\n") { + pair := strings.SplitN(peer, "=", 2) + if len(pair) != 2 { + buf.WriteString(peer) + continue + } + if pair[0] == "private_key" || pair[0] == "preshared_key" { + pair[1] = "REDACTED" + } + buf.WriteString(pair[0]) + buf.WriteString("=") + buf.WriteString(pair[1]) + buf.WriteString("\n") + } + + w.WriteHeader(http.StatusOK) + _, _ = w.Write(buf.Bytes()) + default: + w.WriteHeader(http.StatusNotFound) + } +} + +func (d VirtualTun) pingIPs() { + for _, addr := range d.Conf.CheckAlive { + socket, err := d.Tnet.Dial("ping", addr.String()) + if err != nil { + errorLogger.Printf("Failed to ping %s: %s\n", addr, err.Error()) + continue + } + + data := make([]byte, 16) + _, _ = srand.Read(data) + + requestPing := icmp.Echo{ + Seq: rand.Intn(1 << 16), + Data: data, + } + + var icmpBytes []byte + if addr.Is4() { + icmpBytes, _ = (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) + } else if addr.Is6() { + icmpBytes, _ = (&icmp.Message{Type: ipv6.ICMPTypeEchoRequest, Code: 0, Body: &requestPing}).Marshal(nil) + } else { + errorLogger.Printf("Failed to ping %s: invalid address: %s\n", addr, addr.String()) + continue + } + + _ = socket.SetReadDeadline(time.Now().Add(time.Duration(d.Conf.CheckAliveInterval) * time.Second)) + _, err = socket.Write(icmpBytes) + if err != nil { + errorLogger.Printf("Failed to ping %s: %s\n", addr, err.Error()) + continue + } + + addr := addr + go func() { + n, err := socket.Read(icmpBytes[:]) + if err != nil { + errorLogger.Printf("Failed to read ping response from %s: %s\n", addr, err.Error()) + return + } + + replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n]) + if err != nil { + errorLogger.Printf("Failed to parse ping response from %s: %s\n", addr, err.Error()) + return + } + + if addr.Is4() { + replyPing, ok := replyPacket.Body.(*icmp.Echo) + if !ok { + errorLogger.Printf("Failed to parse ping response from %s: invalid reply type: %s\n", addr, replyPacket.Type) + return + } + if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { + errorLogger.Printf("Failed to parse ping response from %s: invalid ping reply: %v\n", addr, replyPing) + return + } + } + + if addr.Is6() { + replyPing, ok := replyPacket.Body.(*icmp.RawBody) + if !ok { + errorLogger.Printf("Failed to parse ping response from %s: invalid reply type: %s\n", addr, replyPacket.Type) + return + } + + seq := binary.BigEndian.Uint16(replyPing.Data[2:4]) + pongBody := replyPing.Data[4:] + if !bytes.Equal(pongBody, requestPing.Data) || int(seq) != requestPing.Seq { + errorLogger.Printf("Failed to parse ping response from %s: invalid ping reply: %v\n", addr, replyPing) + return + } + } + + d.PingRecord[addr.String()] = uint64(time.Now().Unix()) + + defer socket.Close() + }() + } +} + +func (d VirtualTun) StartPingIPs() { + for _, addr := range d.Conf.CheckAlive { + d.PingRecord[addr.String()] = 0 + } + + go func() { + for { + d.pingIPs() + time.Sleep(time.Duration(d.Conf.CheckAliveInterval) * time.Second) + } + }() +} diff --git a/wireguard.go b/wireguard.go index b98bc35..31057ed 100644 --- a/wireguard.go +++ b/wireguard.go @@ -81,7 +81,10 @@ func StartWireguard(conf *DeviceConfig, logLevel int) (*VirtualTun, error) { } return &VirtualTun{ - Tnet: tnet, - SystemDNS: len(setting.dns) == 0, + Tnet: tnet, + Dev: dev, + Conf: conf, + SystemDNS: len(setting.dns) == 0, + PingRecord: make(map[string]uint64), }, nil }