diff --git a/routine.go b/routine.go index f10eee1..465e6b1 100644 --- a/routine.go +++ b/routine.go @@ -5,10 +5,12 @@ import ( "context" srand "crypto/rand" "crypto/subtle" + "encoding/binary" "encoding/json" "errors" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" "golang.zx2c4.com/wireguard/device" "io" "log" @@ -415,7 +417,16 @@ func (d VirtualTun) pingIPs() { Data: data, } - icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) + var icmpBytes []byte + if addr.Is4() { + icmpBytes, _ = (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) + } else if addr.Is6() { + icmpBytes, _ = (&icmp.Message{Type: ipv6.ICMPTypeEchoRequest, Code: 0, Body: &requestPing}).Marshal(nil) + } else { + errorLogger.Printf("Failed to ping %s: invalid address: %s\n", addr, addr.String()) + continue + } + _ = socket.SetReadDeadline(time.Now().Add(time.Duration(d.Conf.CheckAliveInterval) * time.Second)) _, err = socket.Write(icmpBytes) if err != nil { @@ -437,14 +448,31 @@ func (d VirtualTun) pingIPs() { return } - replyPing, ok := replyPacket.Body.(*icmp.Echo) - if !ok { - errorLogger.Printf("Failed to parse ping response from %s: invalid reply type: %v\n", addr, replyPacket) - return + if addr.Is4() { + replyPing, ok := replyPacket.Body.(*icmp.Echo) + if !ok { + errorLogger.Printf("Failed to parse ping response from %s: invalid reply type: %s\n", addr, replyPacket.Type) + return + } + if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { + errorLogger.Printf("Failed to parse ping response from %s: invalid ping reply: %v\n", addr, replyPing) + return + } } - if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { - errorLogger.Printf("Failed to parse ping response from %s: invalid ping reply: %v\n", addr, replyPing) - return + + if addr.Is6() { + replyPing, ok := replyPacket.Body.(*icmp.RawBody) + if !ok { + errorLogger.Printf("Failed to parse ping response from %s: invalid reply type: %s\n", addr, replyPacket.Type) + return + } + + seq := binary.BigEndian.Uint16(replyPing.Data[2:4]) + pongBody := replyPing.Data[4:] + if !bytes.Equal(pongBody, requestPing.Data) || int(seq) != requestPing.Seq { + errorLogger.Printf("Failed to parse ping response from %s: invalid ping reply: %v\n", addr, replyPing) + return + } } d.PingRecord[addr.String()] = uint64(time.Now().Unix())