implement ICMP ping

This commit is contained in:
pufferffish 2024-04-11 01:46:08 +01:00
parent 4cf68c94dd
commit efc7e62704
4 changed files with 129 additions and 10 deletions

View file

@ -4,7 +4,10 @@ import (
"bytes"
"context"
"crypto/subtle"
"encoding/json"
"errors"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.zx2c4.com/wireguard/device"
"io"
"log"
@ -15,6 +18,7 @@ import (
"path"
"strconv"
"strings"
"time"
"github.com/sourcegraph/conc"
"github.com/things-go/go-socks5"
@ -39,6 +43,9 @@ type VirtualTun struct {
Tnet *netstack.Net
Dev *device.Device
SystemDNS bool
Conf *DeviceConfig
// PingRecord stores the last time an IP was pinged
PingRecord map[string]uint64
}
// RoutineSpawner spawns a routine (e.g. socks5, tcp static routes) after the configuration is parsed
@ -341,7 +348,26 @@ func (d VirtualTun) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Printf("Health metric request: %s\n", r.URL.Path)
switch path.Clean(r.URL.Path) {
case "/readyz":
w.WriteHeader(http.StatusOK)
body, err := json.Marshal(d.PingRecord)
if err != nil {
errorLogger.Printf("Failed to get device metrics: %s\n", err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}
status := http.StatusOK
for _, record := range d.PingRecord {
lastPong := time.Unix(int64(record), 0)
// +2 seconds to account for the time it takes to ping the IP
if time.Since(lastPong) > time.Duration(d.Conf.CheckAliveInterval+2)*time.Second {
status = http.StatusServiceUnavailable
break
}
}
w.WriteHeader(status)
_, _ = w.Write(body)
_, _ = w.Write([]byte("\n"))
case "/metrics":
get, err := d.Dev.IpcGet()
if err != nil {
@ -371,3 +397,71 @@ func (d VirtualTun) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}
}
func (d VirtualTun) pingIPs() {
for _, addr := range d.Conf.CheckAlive {
socket, err := d.Tnet.Dial("ping", addr.String())
if err != nil {
errorLogger.Printf("Failed to ping %s: %s\n", addr, err.Error())
continue
}
data := make([]byte, 16)
rand.Read(data)
requestPing := icmp.Echo{
Seq: rand.Intn(1 << 16),
Data: data,
}
icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
_ = socket.SetReadDeadline(time.Now().Add(time.Duration(d.Conf.CheckAliveInterval) * time.Second))
_, err = socket.Write(icmpBytes)
if err != nil {
errorLogger.Printf("Failed to ping %s: %s\n", addr, err.Error())
continue
}
addr := addr
go func() {
n, err := socket.Read(icmpBytes[:])
if err != nil {
errorLogger.Printf("Failed to read ping response from %s: %s\n", addr, err.Error())
return
}
replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
if err != nil {
errorLogger.Printf("Failed to parse ping response from %s: %s\n", addr, err.Error())
return
}
replyPing, ok := replyPacket.Body.(*icmp.Echo)
if !ok {
errorLogger.Printf("Failed to parse ping response from %s: invalid reply type: %v\n", addr, replyPacket)
return
}
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
errorLogger.Printf("Failed to parse ping response from %s: invalid ping reply: %v\n", addr, replyPing)
return
}
d.PingRecord[addr.String()] = uint64(time.Now().Unix())
defer socket.Close()
}()
}
}
func (d VirtualTun) StartPingIPs() {
for _, addr := range d.Conf.CheckAlive {
d.PingRecord[addr.String()] = 0
}
go func() {
for {
d.pingIPs()
time.Sleep(time.Duration(d.Conf.CheckAliveInterval) * time.Second)
}
}()
}