From efc7e62704e18d98c4ebe8e3bed58b864969d8d7 Mon Sep 17 00:00:00 2001 From: pufferffish Date: Thu, 11 Apr 2024 01:46:08 +0100 Subject: [PATCH] implement ICMP ping --- cmd/wireproxy/main.go | 2 + config.go | 33 ++++++++++++--- routine.go | 96 ++++++++++++++++++++++++++++++++++++++++++- wireguard.go | 8 ++-- 4 files changed, 129 insertions(+), 10 deletions(-) diff --git a/cmd/wireproxy/main.go b/cmd/wireproxy/main.go index 9c8a7ca..9b10dfb 100644 --- a/cmd/wireproxy/main.go +++ b/cmd/wireproxy/main.go @@ -151,6 +151,8 @@ func main() { go spawner.SpawnRoutine(tun) } + tun.StartPingIPs() + if *info != "" { go func() { err := http.ListenAndServe(*info, tun) diff --git a/config.go b/config.go index 4f363ec..76593cf 100644 --- a/config.go +++ b/config.go @@ -22,12 +22,14 @@ type PeerConfig struct { // DeviceConfig contains the information to initiate a wireguard connection type DeviceConfig struct { - SecretKey string - Endpoint []netip.Addr - Peers []PeerConfig - DNS []netip.Addr - MTU int - ListenPort *int + SecretKey string + Endpoint []netip.Addr + Peers []PeerConfig + DNS []netip.Addr + MTU int + ListenPort *int + CheckAlive []netip.Addr + CheckAliveInterval int } type TCPClientTunnelConfig struct { @@ -237,6 +239,25 @@ func ParseInterface(cfg *ini.File, device *DeviceConfig) error { device.ListenPort = &value } + checkAlive, err := parseNetIP(section, "CheckAlive") + if err != nil { + return err + } + device.CheckAlive = checkAlive + + device.CheckAliveInterval = 5 + if sectionKey, err := section.GetKey("CheckAliveInterval"); err == nil { + value, err := sectionKey.Int() + if err != nil { + return err + } + if len(checkAlive) == 0 { + return errors.New("CheckAliveInterval is only valid when CheckAlive is set") + } + + device.CheckAliveInterval = value + } + return nil } diff --git a/routine.go b/routine.go index 6f7429f..68e0e20 100644 --- a/routine.go +++ b/routine.go @@ -4,7 +4,10 @@ import ( "bytes" "context" "crypto/subtle" + "encoding/json" "errors" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" "golang.zx2c4.com/wireguard/device" "io" "log" @@ -15,6 +18,7 @@ import ( "path" "strconv" "strings" + "time" "github.com/sourcegraph/conc" "github.com/things-go/go-socks5" @@ -39,6 +43,9 @@ type VirtualTun struct { Tnet *netstack.Net Dev *device.Device SystemDNS bool + Conf *DeviceConfig + // PingRecord stores the last time an IP was pinged + PingRecord map[string]uint64 } // RoutineSpawner spawns a routine (e.g. socks5, tcp static routes) after the configuration is parsed @@ -341,7 +348,26 @@ func (d VirtualTun) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Printf("Health metric request: %s\n", r.URL.Path) switch path.Clean(r.URL.Path) { case "/readyz": - w.WriteHeader(http.StatusOK) + body, err := json.Marshal(d.PingRecord) + if err != nil { + errorLogger.Printf("Failed to get device metrics: %s\n", err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + + status := http.StatusOK + for _, record := range d.PingRecord { + lastPong := time.Unix(int64(record), 0) + // +2 seconds to account for the time it takes to ping the IP + if time.Since(lastPong) > time.Duration(d.Conf.CheckAliveInterval+2)*time.Second { + status = http.StatusServiceUnavailable + break + } + } + + w.WriteHeader(status) + _, _ = w.Write(body) + _, _ = w.Write([]byte("\n")) case "/metrics": get, err := d.Dev.IpcGet() if err != nil { @@ -371,3 +397,71 @@ func (d VirtualTun) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) } } + +func (d VirtualTun) pingIPs() { + for _, addr := range d.Conf.CheckAlive { + socket, err := d.Tnet.Dial("ping", addr.String()) + if err != nil { + errorLogger.Printf("Failed to ping %s: %s\n", addr, err.Error()) + continue + } + + data := make([]byte, 16) + rand.Read(data) + + requestPing := icmp.Echo{ + Seq: rand.Intn(1 << 16), + Data: data, + } + + icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) + _ = socket.SetReadDeadline(time.Now().Add(time.Duration(d.Conf.CheckAliveInterval) * time.Second)) + _, err = socket.Write(icmpBytes) + if err != nil { + errorLogger.Printf("Failed to ping %s: %s\n", addr, err.Error()) + continue + } + + addr := addr + go func() { + n, err := socket.Read(icmpBytes[:]) + if err != nil { + errorLogger.Printf("Failed to read ping response from %s: %s\n", addr, err.Error()) + return + } + + replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n]) + if err != nil { + errorLogger.Printf("Failed to parse ping response from %s: %s\n", addr, err.Error()) + return + } + + replyPing, ok := replyPacket.Body.(*icmp.Echo) + if !ok { + errorLogger.Printf("Failed to parse ping response from %s: invalid reply type: %v\n", addr, replyPacket) + return + } + if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { + errorLogger.Printf("Failed to parse ping response from %s: invalid ping reply: %v\n", addr, replyPing) + return + } + + d.PingRecord[addr.String()] = uint64(time.Now().Unix()) + + defer socket.Close() + }() + } +} + +func (d VirtualTun) StartPingIPs() { + for _, addr := range d.Conf.CheckAlive { + d.PingRecord[addr.String()] = 0 + } + + go func() { + for { + d.pingIPs() + time.Sleep(time.Duration(d.Conf.CheckAliveInterval) * time.Second) + } + }() +} diff --git a/wireguard.go b/wireguard.go index bd0b9c1..31057ed 100644 --- a/wireguard.go +++ b/wireguard.go @@ -81,8 +81,10 @@ func StartWireguard(conf *DeviceConfig, logLevel int) (*VirtualTun, error) { } return &VirtualTun{ - Tnet: tnet, - Dev: dev, - SystemDNS: len(setting.dns) == 0, + Tnet: tnet, + Dev: dev, + Conf: conf, + SystemDNS: len(setting.dns) == 0, + PingRecord: make(map[string]uint64), }, nil }