mirror of
https://github.com/whyvl/wireproxy.git
synced 2025-04-29 19:01:42 +02:00
613 lines
15 KiB
Go
613 lines
15 KiB
Go
package wireproxy
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
srand "crypto/rand"
|
|
"crypto/subtle"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"errors"
|
|
"golang.org/x/net/icmp"
|
|
"golang.org/x/net/ipv4"
|
|
"golang.org/x/net/ipv6"
|
|
"golang.zx2c4.com/wireguard/device"
|
|
"io"
|
|
"log"
|
|
"math/rand"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sourcegraph/conc"
|
|
"github.com/things-go/go-socks5"
|
|
"github.com/things-go/go-socks5/bufferpool"
|
|
|
|
"net/netip"
|
|
|
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
|
)
|
|
|
|
// errorLogger is the logger to print error message
|
|
var errorLogger = log.New(os.Stderr, "ERROR: ", log.LstdFlags)
|
|
|
|
type ProxyStats struct {
|
|
ProxyType string `json:"type"`
|
|
BindAddress string `json:"bind_address"`
|
|
ActiveConnections int `json:"active_connections"`
|
|
LastConnectionTime int64 `json:"last_connection_time"`
|
|
TotalConnections int `json:"total_connections"`
|
|
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (ps *ProxyStats) IncConnection() {
|
|
ps.mu.Lock()
|
|
defer ps.mu.Unlock()
|
|
ps.ActiveConnections++
|
|
ps.TotalConnections++
|
|
ps.LastConnectionTime = time.Now().Unix()
|
|
}
|
|
|
|
func (ps *ProxyStats) DecConnection() {
|
|
ps.mu.Lock()
|
|
defer ps.mu.Unlock()
|
|
if ps.ActiveConnections > 0 {
|
|
ps.ActiveConnections--
|
|
}
|
|
}
|
|
|
|
type StatsListener struct {
|
|
net.Listener
|
|
stats *ProxyStats
|
|
}
|
|
|
|
func (sl *StatsListener) Accept() (net.Conn, error) {
|
|
c, err := sl.Listener.Accept()
|
|
if err == nil {
|
|
sl.stats.IncConnection()
|
|
c = &StatsConn{
|
|
Conn: c,
|
|
stats: sl.stats,
|
|
}
|
|
}
|
|
return c, err
|
|
}
|
|
|
|
type StatsConn struct {
|
|
net.Conn
|
|
stats *ProxyStats
|
|
}
|
|
|
|
func (sc *StatsConn) Close() error {
|
|
sc.stats.DecConnection()
|
|
return sc.Conn.Close()
|
|
}
|
|
|
|
// CredentialValidator stores the authentication data of a socks5 proxy
|
|
type CredentialValidator struct {
|
|
username string
|
|
password string
|
|
}
|
|
|
|
// VirtualTun stores a reference to netstack network and DNS configuration
|
|
type VirtualTun struct {
|
|
Tnet *netstack.Net
|
|
Dev *device.Device
|
|
SystemDNS bool
|
|
Conf *DeviceConfig
|
|
// PingRecord stores the last time an IP was pinged
|
|
PingRecord map[string]uint64
|
|
mu sync.Mutex
|
|
ProxyList []*ProxyStats
|
|
}
|
|
|
|
// RoutineSpawner spawns a routine (e.g. socks5, tcp static routes) after the configuration is parsed
|
|
type RoutineSpawner interface {
|
|
SpawnRoutine(vt *VirtualTun)
|
|
}
|
|
|
|
type addressPort struct {
|
|
address string
|
|
port uint16
|
|
}
|
|
|
|
// LookupAddr lookups a hostname.
|
|
// DNS traffic may or may not be routed depending on VirtualTun's setting
|
|
func (d *VirtualTun) LookupAddr(ctx context.Context, name string) ([]string, error) {
|
|
if d.SystemDNS {
|
|
return net.DefaultResolver.LookupHost(ctx, name)
|
|
}
|
|
return d.Tnet.LookupContextHost(ctx, name)
|
|
}
|
|
|
|
// ResolveAddrWithContext resolves a hostname and returns an AddrPort.
|
|
// DNS traffic may or may not be routed depending on VirtualTun's setting
|
|
func (d *VirtualTun) ResolveAddrWithContext(ctx context.Context, name string) (*netip.Addr, error) {
|
|
addrs, err := d.LookupAddr(ctx, name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
size := len(addrs)
|
|
if size == 0 {
|
|
return nil, errors.New("no address found for: " + name)
|
|
}
|
|
|
|
rand.Shuffle(size, func(i, j int) {
|
|
addrs[i], addrs[j] = addrs[j], addrs[i]
|
|
})
|
|
|
|
var addr netip.Addr
|
|
for _, saddr := range addrs {
|
|
addr, err = netip.ParseAddr(saddr)
|
|
if err == nil {
|
|
break
|
|
}
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &addr, nil
|
|
}
|
|
|
|
// Resolve resolves a hostname and returns an IP.
|
|
// DNS traffic may or may not be routed depending on VirtualTun's setting
|
|
func (d *VirtualTun) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
|
|
addr, err := d.ResolveAddrWithContext(ctx, name)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return ctx, addr.AsSlice(), nil
|
|
}
|
|
|
|
func parseAddressPort(endpoint string) (*addressPort, error) {
|
|
name, sport, err := net.SplitHostPort(endpoint)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
port, err := strconv.Atoi(sport)
|
|
if err != nil || port < 0 || port > 65535 {
|
|
return nil, &net.OpError{Op: "dial", Err: errors.New("port must be numeric")}
|
|
}
|
|
|
|
return &addressPort{address: name, port: uint16(port)}, nil
|
|
}
|
|
|
|
func (d *VirtualTun) resolveToAddrPort(endpoint *addressPort) (*netip.AddrPort, error) {
|
|
addr, err := d.ResolveAddrWithContext(context.Background(), endpoint.address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
addrPort := netip.AddrPortFrom(*addr, endpoint.port)
|
|
return &addrPort, nil
|
|
}
|
|
|
|
// SpawnRoutine spawns a socks5 server.
|
|
func (config *Socks5Config) SpawnRoutine(vt *VirtualTun) {
|
|
stats := &ProxyStats{
|
|
ProxyType: "socks5",
|
|
BindAddress: config.BindAddress,
|
|
}
|
|
vt.RegisterProxyStats(stats)
|
|
|
|
var authMethods []socks5.Authenticator
|
|
if username := config.Username; username != "" {
|
|
authMethods = append(authMethods, socks5.UserPassAuthenticator{
|
|
Credentials: socks5.StaticCredentials{username: config.Password},
|
|
})
|
|
} else {
|
|
authMethods = append(authMethods, socks5.NoAuthAuthenticator{})
|
|
}
|
|
|
|
options := []socks5.Option{
|
|
socks5.WithDial(vt.Tnet.DialContext),
|
|
socks5.WithResolver(vt),
|
|
socks5.WithAuthMethods(authMethods),
|
|
socks5.WithBufferPool(bufferpool.NewPool(256 * 1024)),
|
|
}
|
|
|
|
server := socks5.NewServer(options...)
|
|
|
|
ln, err := net.Listen("tcp", config.BindAddress)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
ln = &StatsListener{Listener: ln, stats: stats}
|
|
|
|
go func() {
|
|
if err := server.Serve(ln); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
// SpawnRoutine spawns a http server.
|
|
func (config *HTTPConfig) SpawnRoutine(vt *VirtualTun) {
|
|
stats := &ProxyStats{
|
|
ProxyType: "http",
|
|
BindAddress: config.BindAddress,
|
|
}
|
|
vt.RegisterProxyStats(stats)
|
|
|
|
server := &HTTPServer{
|
|
config: config,
|
|
dial: vt.Tnet.Dial,
|
|
auth: CredentialValidator{config.Username, config.Password},
|
|
}
|
|
if config.Username != "" || config.Password != "" {
|
|
server.authRequired = true
|
|
}
|
|
|
|
ln, err := net.Listen("tcp", config.BindAddress)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
ln = &StatsListener{Listener: ln, stats: stats}
|
|
|
|
go func() {
|
|
if err := server.Serve(ln); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Valid checks the authentication data in CredentialValidator and compare them
|
|
// to username and password in constant time.
|
|
func (c CredentialValidator) Valid(username, password string) bool {
|
|
u := subtle.ConstantTimeCompare([]byte(c.username), []byte(username))
|
|
p := subtle.ConstantTimeCompare([]byte(c.password), []byte(password))
|
|
return u&p == 1
|
|
}
|
|
|
|
// connForward copy data from `from` to `to`
|
|
func connForward(from io.ReadWriteCloser, to io.ReadWriteCloser) {
|
|
_, err := io.Copy(to, from)
|
|
if err != nil {
|
|
errorLogger.Printf("Cannot forward traffic: %s\n", err.Error())
|
|
}
|
|
}
|
|
|
|
// tcpClientForward starts a new connection via wireguard and forward traffic from `conn`
|
|
func tcpClientForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) {
|
|
target, err := vt.resolveToAddrPort(raddr)
|
|
if err != nil {
|
|
errorLogger.Printf("TCP Server Tunnel to %s: %s\n", raddr.address, err.Error())
|
|
return
|
|
}
|
|
|
|
tcpAddr := TCPAddrFromAddrPort(*target)
|
|
|
|
sconn, err := vt.Tnet.DialTCP(tcpAddr)
|
|
if err != nil {
|
|
errorLogger.Printf("TCP Client Tunnel to %s: %s\n", target, err.Error())
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
wg := conc.NewWaitGroup()
|
|
wg.Go(func() {
|
|
connForward(sconn, conn)
|
|
})
|
|
wg.Go(func() {
|
|
connForward(conn, sconn)
|
|
})
|
|
wg.Wait()
|
|
_ = sconn.Close()
|
|
_ = conn.Close()
|
|
sconn = nil
|
|
conn = nil
|
|
}()
|
|
}
|
|
|
|
// STDIOTcpForward starts a new connection via wireguard and forward traffic from `conn`
|
|
func STDIOTcpForward(vt *VirtualTun, raddr *addressPort) {
|
|
target, err := vt.resolveToAddrPort(raddr)
|
|
if err != nil {
|
|
errorLogger.Printf("Name resolution error for %s: %s\n", raddr.address, err.Error())
|
|
return
|
|
}
|
|
|
|
// os.Stdout has previously been remapped to stderr, se we can't use it
|
|
stdout, err := os.OpenFile("/dev/stdout", os.O_WRONLY, 0)
|
|
if err != nil {
|
|
errorLogger.Printf("Failed to open /dev/stdout: %s\n", err.Error())
|
|
return
|
|
}
|
|
|
|
tcpAddr := TCPAddrFromAddrPort(*target)
|
|
sconn, err := vt.Tnet.DialTCP(tcpAddr)
|
|
if err != nil {
|
|
errorLogger.Printf("TCP Client Tunnel to %s (%s): %s\n", target, tcpAddr, err.Error())
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
wg := conc.NewWaitGroup()
|
|
wg.Go(func() {
|
|
connForward(os.Stdin, sconn)
|
|
})
|
|
wg.Go(func() {
|
|
connForward(sconn, stdout)
|
|
})
|
|
wg.Wait()
|
|
_ = sconn.Close()
|
|
sconn = nil
|
|
}()
|
|
}
|
|
|
|
// SpawnRoutine spawns a local TCP server which acts as a proxy to the specified target
|
|
func (conf *TCPClientTunnelConfig) SpawnRoutine(vt *VirtualTun) {
|
|
raddr, err := parseAddressPort(conf.Target)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
server, err := net.ListenTCP("tcp", conf.BindAddress)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
for {
|
|
conn, err := server.Accept()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
go tcpClientForward(vt, raddr, conn)
|
|
}
|
|
}
|
|
|
|
// SpawnRoutine connects to the specified target and plumbs it to STDIN / STDOUT
|
|
func (conf *STDIOTunnelConfig) SpawnRoutine(vt *VirtualTun) {
|
|
raddr, err := parseAddressPort(conf.Target)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
go STDIOTcpForward(vt, raddr)
|
|
}
|
|
|
|
// tcpServerForward starts a new connection locally and forward traffic from `conn`
|
|
func tcpServerForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) {
|
|
target, err := vt.resolveToAddrPort(raddr)
|
|
if err != nil {
|
|
errorLogger.Printf("TCP Server Tunnel to %s: %s\n", raddr.address, err.Error())
|
|
return
|
|
}
|
|
|
|
tcpAddr := TCPAddrFromAddrPort(*target)
|
|
|
|
sconn, err := net.DialTCP("tcp", nil, tcpAddr)
|
|
if err != nil {
|
|
errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error())
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
gr := conc.NewWaitGroup()
|
|
gr.Go(func() {
|
|
connForward(sconn, conn)
|
|
})
|
|
gr.Go(func() {
|
|
connForward(conn, sconn)
|
|
})
|
|
gr.Wait()
|
|
_ = sconn.Close()
|
|
_ = conn.Close()
|
|
sconn = nil
|
|
conn = nil
|
|
}()
|
|
}
|
|
|
|
// SpawnRoutine spawns a TCP server on wireguard which acts as a proxy to the specified target
|
|
func (conf *TCPServerTunnelConfig) SpawnRoutine(vt *VirtualTun) {
|
|
raddr, err := parseAddressPort(conf.Target)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
addr := &net.TCPAddr{Port: conf.ListenPort}
|
|
server, err := vt.Tnet.ListenTCP(addr)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
for {
|
|
conn, err := server.Accept()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
go tcpServerForward(vt, raddr, conn)
|
|
}
|
|
}
|
|
|
|
// ServeHTTP is used for health/metrics requests.
|
|
func (d *VirtualTun) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
log.Printf("Health metric request: %s\n", r.URL.Path)
|
|
switch path.Clean(r.URL.Path) {
|
|
case "/readyz":
|
|
body, err := json.Marshal(d.PingRecord)
|
|
if err != nil {
|
|
errorLogger.Printf("Failed to get device metrics: %s\n", err.Error())
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
status := http.StatusOK
|
|
for _, record := range d.PingRecord {
|
|
lastPong := time.Unix(int64(record), 0)
|
|
// +2 seconds to account for the time it takes to ping the IP
|
|
if time.Since(lastPong) > time.Duration(d.Conf.CheckAliveInterval+2)*time.Second {
|
|
status = http.StatusServiceUnavailable
|
|
break
|
|
}
|
|
}
|
|
|
|
w.WriteHeader(status)
|
|
_, _ = w.Write(body)
|
|
_, _ = w.Write([]byte("\n"))
|
|
case "/metrics":
|
|
get, err := d.Dev.IpcGet()
|
|
if err != nil {
|
|
errorLogger.Printf("Failed to get device metrics: %s\n", err.Error())
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
var buf bytes.Buffer
|
|
for _, peer := range strings.Split(get, "\n") {
|
|
pair := strings.SplitN(peer, "=", 2)
|
|
if len(pair) != 2 {
|
|
buf.WriteString(peer)
|
|
continue
|
|
}
|
|
if pair[0] == "private_key" || pair[0] == "preshared_key" {
|
|
pair[1] = "REDACTED"
|
|
}
|
|
buf.WriteString(pair[0])
|
|
buf.WriteString("=")
|
|
buf.WriteString(pair[1])
|
|
buf.WriteString("\n")
|
|
}
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write(buf.Bytes())
|
|
|
|
case "/stats":
|
|
// Return statistics about all running proxies
|
|
d.mu.Lock()
|
|
snapshot := make([]ProxyStats, len(d.ProxyList))
|
|
for i, ps := range d.ProxyList {
|
|
ps.mu.Lock()
|
|
snapshot[i] = ProxyStats{
|
|
ProxyType: ps.ProxyType,
|
|
BindAddress: ps.BindAddress,
|
|
ActiveConnections: ps.ActiveConnections,
|
|
LastConnectionTime: ps.LastConnectionTime,
|
|
TotalConnections: ps.TotalConnections,
|
|
}
|
|
ps.mu.Unlock()
|
|
}
|
|
d.mu.Unlock()
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w).Encode(snapshot); err != nil {
|
|
errorLogger.Printf("Failed to encode /stats: %s", err)
|
|
}
|
|
|
|
default:
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}
|
|
}
|
|
|
|
// pingIPs pings the IP addresses configured in CheckAlive
|
|
func (d *VirtualTun) pingIPs() {
|
|
for _, addr := range d.Conf.CheckAlive {
|
|
socket, err := d.Tnet.Dial("ping", addr.String())
|
|
if err != nil {
|
|
errorLogger.Printf("Failed to ping %s: %s\n", addr, err.Error())
|
|
continue
|
|
}
|
|
|
|
data := make([]byte, 16)
|
|
_, _ = srand.Read(data)
|
|
|
|
requestPing := icmp.Echo{
|
|
Seq: rand.Intn(1 << 16),
|
|
Data: data,
|
|
}
|
|
|
|
var icmpBytes []byte
|
|
if addr.Is4() {
|
|
icmpBytes, _ = (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
|
|
} else if addr.Is6() {
|
|
icmpBytes, _ = (&icmp.Message{Type: ipv6.ICMPTypeEchoRequest, Code: 0, Body: &requestPing}).Marshal(nil)
|
|
} else {
|
|
errorLogger.Printf("Failed to ping %s: invalid address: %s\n", addr, addr.String())
|
|
continue
|
|
}
|
|
|
|
_ = socket.SetReadDeadline(time.Now().Add(time.Duration(d.Conf.CheckAliveInterval) * time.Second))
|
|
_, err = socket.Write(icmpBytes)
|
|
if err != nil {
|
|
errorLogger.Printf("Failed to ping %s: %s\n", addr, err.Error())
|
|
continue
|
|
}
|
|
|
|
addr := addr
|
|
go func() {
|
|
n, err := socket.Read(icmpBytes[:])
|
|
if err != nil {
|
|
errorLogger.Printf("Failed to read ping response from %s: %s\n", addr, err.Error())
|
|
return
|
|
}
|
|
|
|
replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
|
|
if err != nil {
|
|
errorLogger.Printf("Failed to parse ping response from %s: %s\n", addr, err.Error())
|
|
return
|
|
}
|
|
|
|
if addr.Is4() {
|
|
replyPing, ok := replyPacket.Body.(*icmp.Echo)
|
|
if !ok {
|
|
errorLogger.Printf("Failed to parse ping response from %s: invalid reply type: %s\n", addr, replyPacket.Type)
|
|
return
|
|
}
|
|
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
|
|
errorLogger.Printf("Failed to parse ping response from %s: invalid ping reply: %v\n", addr, replyPing)
|
|
return
|
|
}
|
|
}
|
|
|
|
if addr.Is6() {
|
|
replyPing, ok := replyPacket.Body.(*icmp.RawBody)
|
|
if !ok {
|
|
errorLogger.Printf("Failed to parse ping response from %s: invalid reply type: %s\n", addr, replyPacket.Type)
|
|
return
|
|
}
|
|
|
|
seq := binary.BigEndian.Uint16(replyPing.Data[2:4])
|
|
pongBody := replyPing.Data[4:]
|
|
if !bytes.Equal(pongBody, requestPing.Data) || int(seq) != requestPing.Seq {
|
|
errorLogger.Printf("Failed to parse ping response from %s: invalid ping reply: %v\n", addr, replyPing)
|
|
return
|
|
}
|
|
}
|
|
|
|
d.PingRecord[addr.String()] = uint64(time.Now().Unix())
|
|
|
|
defer socket.Close()
|
|
}()
|
|
}
|
|
}
|
|
|
|
// StartPingIPs starts a goroutine that periodically pings the IP addresses in CheckAlive
|
|
func (d *VirtualTun) StartPingIPs() {
|
|
for _, addr := range d.Conf.CheckAlive {
|
|
d.PingRecord[addr.String()] = 0
|
|
}
|
|
|
|
go func() {
|
|
for {
|
|
d.pingIPs()
|
|
time.Sleep(time.Duration(d.Conf.CheckAliveInterval) * time.Second)
|
|
}
|
|
}()
|
|
}
|
|
|
|
// RegisterProxyStats is used to store the newly created proxy stats object
|
|
func (d *VirtualTun) RegisterProxyStats(ps *ProxyStats) {
|
|
d.mu.Lock()
|
|
defer d.mu.Unlock()
|
|
d.ProxyList = append(d.ProxyList, ps)
|
|
}
|