implement ICMP ping

This commit is contained in:
pufferffish 2024-04-11 01:46:08 +01:00
parent 4cf68c94dd
commit efc7e62704
4 changed files with 129 additions and 10 deletions

View file

@ -151,6 +151,8 @@ func main() {
go spawner.SpawnRoutine(tun) go spawner.SpawnRoutine(tun)
} }
tun.StartPingIPs()
if *info != "" { if *info != "" {
go func() { go func() {
err := http.ListenAndServe(*info, tun) err := http.ListenAndServe(*info, tun)

View file

@ -22,12 +22,14 @@ type PeerConfig struct {
// DeviceConfig contains the information to initiate a wireguard connection // DeviceConfig contains the information to initiate a wireguard connection
type DeviceConfig struct { type DeviceConfig struct {
SecretKey string SecretKey string
Endpoint []netip.Addr Endpoint []netip.Addr
Peers []PeerConfig Peers []PeerConfig
DNS []netip.Addr DNS []netip.Addr
MTU int MTU int
ListenPort *int ListenPort *int
CheckAlive []netip.Addr
CheckAliveInterval int
} }
type TCPClientTunnelConfig struct { type TCPClientTunnelConfig struct {
@ -237,6 +239,25 @@ func ParseInterface(cfg *ini.File, device *DeviceConfig) error {
device.ListenPort = &value device.ListenPort = &value
} }
checkAlive, err := parseNetIP(section, "CheckAlive")
if err != nil {
return err
}
device.CheckAlive = checkAlive
device.CheckAliveInterval = 5
if sectionKey, err := section.GetKey("CheckAliveInterval"); err == nil {
value, err := sectionKey.Int()
if err != nil {
return err
}
if len(checkAlive) == 0 {
return errors.New("CheckAliveInterval is only valid when CheckAlive is set")
}
device.CheckAliveInterval = value
}
return nil return nil
} }

View file

@ -4,7 +4,10 @@ import (
"bytes" "bytes"
"context" "context"
"crypto/subtle" "crypto/subtle"
"encoding/json"
"errors" "errors"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"io" "io"
"log" "log"
@ -15,6 +18,7 @@ import (
"path" "path"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/sourcegraph/conc" "github.com/sourcegraph/conc"
"github.com/things-go/go-socks5" "github.com/things-go/go-socks5"
@ -39,6 +43,9 @@ type VirtualTun struct {
Tnet *netstack.Net Tnet *netstack.Net
Dev *device.Device Dev *device.Device
SystemDNS bool SystemDNS bool
Conf *DeviceConfig
// PingRecord stores the last time an IP was pinged
PingRecord map[string]uint64
} }
// RoutineSpawner spawns a routine (e.g. socks5, tcp static routes) after the configuration is parsed // RoutineSpawner spawns a routine (e.g. socks5, tcp static routes) after the configuration is parsed
@ -341,7 +348,26 @@ func (d VirtualTun) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Printf("Health metric request: %s\n", r.URL.Path) log.Printf("Health metric request: %s\n", r.URL.Path)
switch path.Clean(r.URL.Path) { switch path.Clean(r.URL.Path) {
case "/readyz": case "/readyz":
w.WriteHeader(http.StatusOK) body, err := json.Marshal(d.PingRecord)
if err != nil {
errorLogger.Printf("Failed to get device metrics: %s\n", err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}
status := http.StatusOK
for _, record := range d.PingRecord {
lastPong := time.Unix(int64(record), 0)
// +2 seconds to account for the time it takes to ping the IP
if time.Since(lastPong) > time.Duration(d.Conf.CheckAliveInterval+2)*time.Second {
status = http.StatusServiceUnavailable
break
}
}
w.WriteHeader(status)
_, _ = w.Write(body)
_, _ = w.Write([]byte("\n"))
case "/metrics": case "/metrics":
get, err := d.Dev.IpcGet() get, err := d.Dev.IpcGet()
if err != nil { if err != nil {
@ -371,3 +397,71 @@ func (d VirtualTun) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
} }
} }
func (d VirtualTun) pingIPs() {
for _, addr := range d.Conf.CheckAlive {
socket, err := d.Tnet.Dial("ping", addr.String())
if err != nil {
errorLogger.Printf("Failed to ping %s: %s\n", addr, err.Error())
continue
}
data := make([]byte, 16)
rand.Read(data)
requestPing := icmp.Echo{
Seq: rand.Intn(1 << 16),
Data: data,
}
icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
_ = socket.SetReadDeadline(time.Now().Add(time.Duration(d.Conf.CheckAliveInterval) * time.Second))
_, err = socket.Write(icmpBytes)
if err != nil {
errorLogger.Printf("Failed to ping %s: %s\n", addr, err.Error())
continue
}
addr := addr
go func() {
n, err := socket.Read(icmpBytes[:])
if err != nil {
errorLogger.Printf("Failed to read ping response from %s: %s\n", addr, err.Error())
return
}
replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
if err != nil {
errorLogger.Printf("Failed to parse ping response from %s: %s\n", addr, err.Error())
return
}
replyPing, ok := replyPacket.Body.(*icmp.Echo)
if !ok {
errorLogger.Printf("Failed to parse ping response from %s: invalid reply type: %v\n", addr, replyPacket)
return
}
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
errorLogger.Printf("Failed to parse ping response from %s: invalid ping reply: %v\n", addr, replyPing)
return
}
d.PingRecord[addr.String()] = uint64(time.Now().Unix())
defer socket.Close()
}()
}
}
func (d VirtualTun) StartPingIPs() {
for _, addr := range d.Conf.CheckAlive {
d.PingRecord[addr.String()] = 0
}
go func() {
for {
d.pingIPs()
time.Sleep(time.Duration(d.Conf.CheckAliveInterval) * time.Second)
}
}()
}

View file

@ -81,8 +81,10 @@ func StartWireguard(conf *DeviceConfig, logLevel int) (*VirtualTun, error) {
} }
return &VirtualTun{ return &VirtualTun{
Tnet: tnet, Tnet: tnet,
Dev: dev, Dev: dev,
SystemDNS: len(setting.dns) == 0, Conf: conf,
SystemDNS: len(setting.dns) == 0,
PingRecord: make(map[string]uint64),
}, nil }, nil
} }