wireproxy/routine.go
2022-05-20 12:25:48 +01:00

240 lines
6 KiB
Go

package wireproxy
import (
"context"
"crypto/subtle"
"errors"
"io"
"log"
"math/rand"
"net"
"os"
"strconv"
"github.com/armon/go-socks5"
"golang.zx2c4.com/wireguard/tun/netstack"
"net/netip"
)
// errorLogger is the logger to print error message
var errorLogger = log.New(os.Stderr, "ERROR: ", log.LstdFlags)
// CredentialValidator stores the authentication data of a socks5 proxy
type CredentialValidator struct {
username string
password string
}
// VirtualTun stores a reference to netstack network and DNS configuration
type VirtualTun struct {
tnet *netstack.Net
systemDNS bool
}
// RoutineSpawner spawns a routine (e.g. socks5, tcp static routes) after the configuration is parsed
type RoutineSpawner interface {
SpawnRoutine(vt *VirtualTun)
}
type addressPort struct {
address string
port uint16
}
// LookupAddr lookups a hostname.
// DNS traffic may or may not be routed depending on VirtualTun's setting
func (d VirtualTun) LookupAddr(ctx context.Context, name string) ([]string, error) {
if d.systemDNS {
return net.DefaultResolver.LookupHost(ctx, name)
} else {
return d.tnet.LookupContextHost(ctx, name)
}
}
// ResolveAddrWithContext resolves a hostname and returns an AddrPort.
// DNS traffic may or may not be routed depending on VirtualTun's setting
func (d VirtualTun) ResolveAddrWithContext(ctx context.Context, name string) (*netip.Addr, error) {
addrs, err := d.LookupAddr(ctx, name)
if err != nil {
return nil, err
}
size := len(addrs)
if size == 0 {
return nil, errors.New("no address found for: " + name)
}
rand.Shuffle(size, func(i, j int) {
addrs[i], addrs[j] = addrs[j], addrs[i]
})
var addr netip.Addr
for _, saddr := range addrs {
addr, err = netip.ParseAddr(saddr)
if err == nil {
break
}
}
if err != nil {
return nil, err
}
return &addr, nil
}
// Resolve resolves a hostname and returns an IP.
// DNS traffic may or may not be routed depending on VirtualTun's setting
func (d VirtualTun) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
addr, err := d.ResolveAddrWithContext(ctx, name)
if err != nil {
return nil, nil, err
}
return ctx, addr.AsSlice(), nil
}
func parseAddressPort(endpoint string) (*addressPort, error) {
name, sport, err := net.SplitHostPort(endpoint)
if err != nil {
return nil, err
}
port, err := strconv.Atoi(sport)
if err != nil || port < 0 || port > 65535 {
return nil, &net.OpError{Op: "dial", Err: errors.New("port must be numeric")}
}
return &addressPort{address: name, port: uint16(port)}, nil
}
func (d VirtualTun) resolveToAddrPort(endpoint *addressPort) (*netip.AddrPort, error) {
addr, err := d.ResolveAddrWithContext(context.Background(), endpoint.address)
if err != nil {
return nil, err
}
addrPort := netip.AddrPortFrom(*addr, endpoint.port)
return &addrPort, nil
}
// SpawnRoutine spawns a socks5 server.
func (config *Socks5Config) SpawnRoutine(vt *VirtualTun) {
conf := &socks5.Config{Dial: vt.tnet.DialContext, Resolver: vt}
if username := config.Username; username != "" {
validator := CredentialValidator{username: username}
validator.password = config.Password
conf.Credentials = validator
}
server, err := socks5.New(conf)
if err != nil {
log.Fatal(err)
}
if err := server.ListenAndServe("tcp", config.BindAddress); err != nil {
log.Fatal(err)
}
}
// Valid checks the authentication data in CredentialValidator and compare them
// to username and password in constant time.
func (c CredentialValidator) Valid(username, password string) bool {
u := subtle.ConstantTimeCompare([]byte(c.username), []byte(username))
p := subtle.ConstantTimeCompare([]byte(c.password), []byte(password))
return u&p == 1
}
// connForward copy data from `from` to `to`, then close both stream.
func connForward(bufSize int, from io.ReadWriteCloser, to io.ReadWriteCloser) {
buf := make([]byte, bufSize)
_, err := io.CopyBuffer(to, from, buf)
if err != nil {
errorLogger.Printf("Cannot forward traffic: %s\n", err.Error())
}
_ = from.Close()
_ = to.Close()
}
// tcpClientForward starts a new connection via wireguard and forward traffic from `conn`
func tcpClientForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) {
target, err := vt.resolveToAddrPort(raddr)
if err != nil {
errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error())
return
}
tcpAddr := TCPAddrFromAddrPort(*target)
sconn, err := vt.tnet.DialTCP(tcpAddr)
if err != nil {
errorLogger.Printf("TCP Client Tunnel to %s: %s\n", target, err.Error())
return
}
go connForward(1024, sconn, conn)
go connForward(1024, conn, sconn)
}
// SpawnRoutine spawns a local TCP server which acts as a proxy to the specified target
func (conf *TCPClientTunnelConfig) SpawnRoutine(vt *VirtualTun) {
raddr, err := parseAddressPort(conf.Target)
if err != nil {
log.Fatal(err)
}
server, err := net.ListenTCP("tcp", conf.BindAddress)
if err != nil {
log.Fatal(err)
}
for {
conn, err := server.Accept()
if err != nil {
log.Fatal(err)
}
go tcpClientForward(vt, raddr, conn)
}
}
// tcpServerForward starts a new connection locally and forward traffic from `conn`
func tcpServerForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) {
target, err := vt.resolveToAddrPort(raddr)
if err != nil {
errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error())
return
}
tcpAddr := TCPAddrFromAddrPort(*target)
sconn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil {
errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error())
return
}
go connForward(1024, sconn, conn)
go connForward(1024, conn, sconn)
}
// SpawnRoutine spawns a TCP server on wireguard which acts as a proxy to the specified target
func (conf *TCPServerTunnelConfig) SpawnRoutine(vt *VirtualTun) {
raddr, err := parseAddressPort(conf.Target)
if err != nil {
log.Fatal(err)
}
addr := &net.TCPAddr{Port: conf.ListenPort}
server, err := vt.tnet.ListenTCP(addr)
if err != nil {
log.Fatal(err)
}
for {
conn, err := server.Accept()
if err != nil {
log.Fatal(err)
}
go tcpServerForward(vt, raddr, conn)
}
}