diff --git a/README.md b/README.md index 12583f6..9a15f13 100644 --- a/README.md +++ b/README.md @@ -54,9 +54,9 @@ DNS = 10.200.200.1 [Peer] PublicKey = QP+A67Z2UBrMgvNIdHv8gPel5URWNLS4B3ZQ2hQIZlg= -# PresharedKey = UItQuvLsyh50ucXHfjF0bbR4IIpVBd74lwKc8uIPXXs= (optinal) +# PresharedKey = UItQuvLsyh50ucXHfjF0bbR4IIpVBd74lwKc8uIPXXs= (optional) Endpoint = my.ddns.example.com:51820 -# PersistentKeepalive = 25 (optinal) +# PersistentKeepalive = 25 (optional) # TCPClientTunnel is a tunnel listening on your machine, # and it forwards any TCP traffic received to the specified target via wireguard. diff --git a/net.go b/net.go new file mode 100644 index 0000000..946afea --- /dev/null +++ b/net.go @@ -0,0 +1,24 @@ +// will delete when upgrading to go 1.18 + +package wireproxy + +import ( + "golang.zx2c4.com/go118/netip" + "net" +) + +func TCPAddrFromAddrPort(addr netip.AddrPort) *net.TCPAddr { + return &net.TCPAddr{ + IP: addr.Addr().AsSlice(), + Zone: addr.Addr().Zone(), + Port: int(addr.Port()), + } +} + +func UDPAddrFromAddrPort(addr netip.AddrPort) *net.UDPAddr { + return &net.UDPAddr{ + IP: addr.Addr().AsSlice(), + Zone: addr.Addr().Zone(), + Port: int(addr.Port()), + } +} diff --git a/routine.go b/routine.go index 56beee5..7c683f8 100644 --- a/routine.go +++ b/routine.go @@ -8,9 +8,11 @@ import ( "log" "math/rand" "net" + "strconv" "github.com/armon/go-socks5" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/tun/netstack" ) @@ -19,16 +21,59 @@ type CredentialValidator struct { password string } +type VirtualTun struct { + tnet *netstack.Net + systemDNS bool +} + type RoutineSpawner interface { - SpawnRoutine(*netstack.Net) + SpawnRoutine(vt *VirtualTun) } -type NetstackDNSResolver struct { - tnet *netstack.Net +func (d VirtualTun) LookupAddr(ctx context.Context, name string) ([]string, error) { + if d.systemDNS { + return net.DefaultResolver.LookupHost(ctx, name) + } else { + return d.tnet.LookupContextHost(ctx, name) + } } -func (d NetstackDNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { - addrs, err := d.tnet.LookupContextHost(ctx, name) +func (d VirtualTun) ResolveAddrPort(saddr string) (*netip.AddrPort, error) { + name, sport, err := net.SplitHostPort(saddr) + if err != nil { + return nil, err + } + + addrs, err := d.LookupAddr(context.Background(), name) + if err != nil { + return nil, err + } + + size := len(addrs) + if size == 0 { + return nil, errors.New("no address found for: " + name) + } + + addr, err := netip.ParseAddr(addrs[rand.Intn(size)]) + if err != nil { + return nil, err + } + + port, err := strconv.Atoi(sport) + if err != nil || port < 0 || port > 65535 { + return nil, &net.OpError{Op: "dial", Err: errors.New("port must be numeric")} + } + + addrPort := netip.AddrPortFrom(addr, uint16(port)) + return &addrPort, nil +} + +func (d VirtualTun) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { + var addrs []string + var err error + + addrs, err = d.LookupAddr(ctx, name) + if err != nil { return ctx, nil, err } @@ -47,8 +92,8 @@ func (d NetstackDNSResolver) Resolve(ctx context.Context, name string) (context. return ctx, ip, err } -func (config *Socks5Config) SpawnRoutine(tnet *netstack.Net) { - conf := &socks5.Config{Dial: tnet.DialContext, Resolver: NetstackDNSResolver{tnet: tnet}} +func (config *Socks5Config) SpawnRoutine(vt *VirtualTun) { + conf := &socks5.Config{Dial: vt.tnet.DialContext, Resolver: vt} if username := config.Username; username != "" { validator := CredentialValidator{username: username} validator.password = config.Password @@ -68,7 +113,7 @@ func (c CredentialValidator) Valid(username, password string) bool { return c.username == username && c.password == password } -func connForward(bufSize int, from, to net.Conn) { +func connForward(bufSize int, from io.ReadWriteCloser, to io.ReadWriteCloser) { buf := make([]byte, bufSize) _, err := io.CopyBuffer(to, from, buf) if err != nil { @@ -77,8 +122,8 @@ func connForward(bufSize int, from, to net.Conn) { } } -func tcpClientForward(tnet *netstack.Net, target string, conn net.Conn) { - sconn, err := tnet.Dial("tcp", target) +func tcpClientForward(tnet *netstack.Net, target *net.TCPAddr, conn net.Conn) { + sconn, err := tnet.DialTCP(target) if err != nil { fmt.Printf("[ERROR] TCP Client Tunnel to %s: %s\n", target, err.Error()) return @@ -88,7 +133,13 @@ func tcpClientForward(tnet *netstack.Net, target string, conn net.Conn) { go connForward(1024, conn, sconn) } -func (conf *TCPClientTunnelConfig) SpawnRoutine(tnet *netstack.Net) { +func (conf *TCPClientTunnelConfig) SpawnRoutine(vt *VirtualTun) { + raddr, err := vt.ResolveAddrPort(conf.Target) + if err != nil { + log.Panic(err) + } + tcpAddr := TCPAddrFromAddrPort(*raddr) + server, err := net.ListenTCP("tcp", conf.BindAddress) if err != nil { log.Panic(err) @@ -99,12 +150,12 @@ func (conf *TCPClientTunnelConfig) SpawnRoutine(tnet *netstack.Net) { if err != nil { log.Panic(err) } - go tcpClientForward(tnet, conf.Target, conn) + go tcpClientForward(vt.tnet, tcpAddr, conn) } } -func tcpServerForward(target string, conn net.Conn) { - sconn, err := net.Dial("tcp", target) +func tcpServerForward(target *net.TCPAddr, conn net.Conn) { + sconn, err := net.DialTCP("tcp", nil, target) if err != nil { fmt.Printf("[ERROR] TCP Server Tunnel to %s: %s\n", target, err.Error()) return @@ -114,9 +165,15 @@ func tcpServerForward(target string, conn net.Conn) { go connForward(1024, conn, sconn) } -func (conf *TCPServerTunnelConfig) SpawnRoutine(tnet *netstack.Net) { +func (conf *TCPServerTunnelConfig) SpawnRoutine(vt *VirtualTun) { + raddr, err := vt.ResolveAddrPort(conf.Target) + if err != nil { + log.Panic(err) + } + tcpAddr := TCPAddrFromAddrPort(*raddr) + addr := &net.TCPAddr{Port: conf.ListenPort} - server, err := tnet.ListenTCP(addr) + server, err := vt.tnet.ListenTCP(addr) if err != nil { log.Panic(err) } @@ -126,6 +183,6 @@ func (conf *TCPServerTunnelConfig) SpawnRoutine(tnet *netstack.Net) { if err != nil { log.Panic(err) } - go tcpServerForward(conf.Target, conn) + go tcpServerForward(tcpAddr, conn) } } diff --git a/wireguard.go b/wireguard.go index a77fd00..e465da2 100644 --- a/wireguard.go +++ b/wireguard.go @@ -28,7 +28,7 @@ allowed_ip=0.0.0.0/0`, conf.SelfSecretKey, conf.PeerPublicKey, conf.PeerEndpoint return setting, nil } -func StartWireguard(conf *DeviceConfig) (*netstack.Net, error) { +func StartWireguard(conf *DeviceConfig) (*VirtualTun, error) { setting, err := createIPCRequest(conf) if err != nil { return nil, err @@ -49,5 +49,8 @@ func StartWireguard(conf *DeviceConfig) (*netstack.Net, error) { return nil, err } - return tnet, nil + return &VirtualTun{ + tnet: tnet, + systemDNS: len(setting.dns) == 0, + }, nil }