From 6a557b1f069b01dce5cc6779a5d48100abf5ce4b Mon Sep 17 00:00:00 2001 From: octeep Date: Tue, 29 Mar 2022 00:19:29 +0100 Subject: [PATCH] code refactor --- cmd/main.go | 210 +-------------------------------------------------- config.go | 23 ++++-- go.mod | 4 +- routine.go | 131 ++++++++++++++++++++++++++++++++ wireguard.go | 49 ++++++++++++ 5 files changed, 200 insertions(+), 217 deletions(-) create mode 100644 routine.go create mode 100644 wireguard.go diff --git a/cmd/main.go b/cmd/main.go index 5dce308..692aa86 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,182 +1,13 @@ package main import ( - "context" - "errors" "fmt" - "io" "log" - "math/rand" - "net" "os" - "github.com/armon/go-socks5" "github.com/octeep/wireproxy" - - "golang.zx2c4.com/go118/netip" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun/netstack" ) -type DeviceSetting struct { - ipcRequest string - dns []netip.Addr - deviceAddr *netip.Addr - mtu int -} - -type NetstackDNSResolver struct { - tnet *netstack.Net -} - -type CredentialValidator struct { - username string - password string -} - -func (d NetstackDNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { - addrs, err := d.tnet.LookupContextHost(ctx, name) - if err != nil { - return ctx, nil, err - } - - size := len(addrs) - if size == 0 { - return ctx, nil, errors.New("no address found for: " + name) - } - - addr := addrs[rand.Intn(size)] - ip := net.ParseIP(addr) - if ip == nil { - return ctx, nil, errors.New("invalid address: " + addr) - } - - return ctx, ip, err -} - -func createIPCRequest(conf *wireproxy.DeviceConfig) (*DeviceSetting, error) { - request := fmt.Sprintf(`private_key=%s -public_key=%s -endpoint=%s -persistent_keepalive_interval=%d -preshared_key=%s -allowed_ip=0.0.0.0/0`, conf.SelfSecretKey, conf.PeerPublicKey, conf.PeerEndpoint, conf.KeepAlive, conf.PreSharedKey) - - setting := &DeviceSetting{ipcRequest: request, dns: conf.DNS, deviceAddr: conf.SelfEndpoint, mtu: conf.MTU} - return setting, nil -} - -func socks5Routine(config *wireproxy.Socks5Config) (func(*netstack.Net), error) { - routine := func(tnet *netstack.Net) { - conf := &socks5.Config{Dial: tnet.DialContext, Resolver: NetstackDNSResolver{tnet: tnet}} - if username := config.Username; username != "" { - validator := CredentialValidator{username: username} - validator.password = config.Password - conf.Credentials = validator - } - server, err := socks5.New(conf) - if err != nil { - log.Panic(err) - } - - if err := server.ListenAndServe("tcp", config.BindAddress); err != nil { - log.Panic(err) - } - } - - return routine, nil -} - -func (c CredentialValidator) Valid(username, password string) bool { - return c.username == username && c.password == password -} - -func connForward(bufSize int, from, to net.Conn) { - buf := make([]byte, bufSize) - _, err := io.CopyBuffer(to, from, buf) - if err != nil { - to.Close() - return - } -} - -func tcpClientForward(tnet *netstack.Net, target string, conn net.Conn) { - sconn, err := tnet.Dial("tcp", target) - if err != nil { - fmt.Printf("[ERROR] TCP Client Tunnel to %s: %s\n", target, err.Error()) - return - } - - go connForward(1024, sconn, conn) - go connForward(1024, conn, sconn) -} - -func tcpClientRoutine(conf *wireproxy.TCPClientTunnelConfig) (func(*netstack.Net), error) { - routine := func(tnet *netstack.Net) { - server, err := net.ListenTCP("tcp", conf.BindAddress) - if err != nil { - log.Panic(err) - } - - for { - conn, err := server.Accept() - if err != nil { - log.Panic(err) - } - go tcpClientForward(tnet, conf.Target, conn) - } - } - - return routine, nil -} - -func tcpServerForward(target string, conn net.Conn) { - sconn, err := net.Dial("tcp", target) - if err != nil { - fmt.Printf("[ERROR] TCP Server Tunnel to %s: %s\n", target, err.Error()) - return - } - - go connForward(1024, sconn, conn) - go connForward(1024, conn, sconn) -} - -func tcpServerRoutine(conf *wireproxy.TCPServerTunnelConfig) (func(*netstack.Net), error) { - routine := func(tnet *netstack.Net) { - addr := &net.TCPAddr{Port: conf.ListenPort} - server, err := tnet.ListenTCP(addr) - if err != nil { - log.Panic(err) - } - - for { - conn, err := server.Accept() - if err != nil { - log.Panic(err) - } - go tcpServerForward(conf.Target, conn) - } - } - - return routine, nil -} - -func startWireguard(setting *DeviceSetting) (*netstack.Net, error) { - tun, tnet, err := netstack.CreateNetTUN([]netip.Addr{*(setting.deviceAddr)}, setting.dns, setting.mtu) - if err != nil { - return nil, err - } - dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) - dev.IpcSet(setting.ipcRequest) - err = dev.Up() - if err != nil { - return nil, err - } - - return tnet, nil -} - func main() { if len(os.Args) != 2 { fmt.Println("Usage: wireproxy [config file path]") @@ -188,48 +19,13 @@ func main() { log.Panic(err) } - setting, err := createIPCRequest(conf.Device) + tnet, err := wireproxy.StartWireguard(conf.Device) if err != nil { log.Panic(err) } - routines := [](func(*netstack.Net)){} - var routine func(*netstack.Net) - - for _, config := range conf.TCPClientTunnels { - routine, err = tcpClientRoutine(&config) - if err != nil { - log.Panic(err) - } - - routines = append(routines, routine) - } - - for _, config := range conf.TCPServerTunnels { - routine, err = tcpServerRoutine(&config) - if err != nil { - log.Panic(err) - } - - routines = append(routines, routine) - } - - for _, config := range conf.Socks5Proxies { - routine, err = socks5Routine(&config) - if err != nil { - log.Panic(err) - } - - routines = append(routines, routine) - } - - tnet, err := startWireguard(setting) - if err != nil { - log.Panic(err) - } - - for _, netRoutine := range routines { - go netRoutine(tnet) + for _, spawner := range conf.Routines { + go spawner.SpawnRoutine(tnet) } select {} // sleep eternally diff --git a/config.go b/config.go index a7586b8..1420042 100644 --- a/config.go +++ b/config.go @@ -40,10 +40,8 @@ type Socks5Config struct { } type Configuration struct { - Device *DeviceConfig - TCPClientTunnels []TCPClientTunnelConfig - TCPServerTunnels []TCPServerTunnelConfig - Socks5Proxies []Socks5Config + Device *DeviceConfig + Routines []RoutineSpawner } func parseString(section *ini.Section, keyName string) (string, error) { @@ -329,25 +327,34 @@ func ParseConfig(path string) (*Configuration, error) { return nil, err } + routinesSpawners := []RoutineSpawner{} + tcpClientTunnels, err := ParseTCPClientTunnelConfig(cfg) if err != nil { return nil, err } + for _, i := range tcpClientTunnels { + routinesSpawners = append(routinesSpawners, &i) + } tcpServerTunnels, err := ParseTCPServerTunnelConfig(cfg) if err != nil { return nil, err } + for _, i := range tcpServerTunnels { + routinesSpawners = append(routinesSpawners, &i) + } socks5Proxies, err := ParseSocks5Config(cfg) if err != nil { return nil, err } + for _, i := range socks5Proxies { + routinesSpawners = append(routinesSpawners, &i) + } return &Configuration{ - Device: device, - TCPClientTunnels: tcpClientTunnels, - TCPServerTunnels: tcpServerTunnels, - Socks5Proxies: socks5Proxies, + Device: device, + Routines: routinesSpawners, }, nil } diff --git a/go.mod b/go.mod index 2c8519a..5846f23 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,12 @@ module github.com/octeep/wireproxy go 1.17 require ( - github.com/go-ini/ini v1.66.4 github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 + github.com/go-ini/ini v1.66.4 golang.org/x/net v0.0.0-20220225172249-27dd8689420f golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d golang.zx2c4.com/wireguard v0.0.0-20220202223031-3b95c81cc178 + golang.zx2c4.com/wireguard/tun/netstack v0.0.0-20220310012736-ae6bc4dd64e1 gvisor.dev/gvisor v0.0.0-20211020211948-f76a604701b6 ) @@ -17,5 +18,4 @@ require ( golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect - golang.zx2c4.com/wireguard/tun/netstack v0.0.0-20220310012736-ae6bc4dd64e1 // indirect ) diff --git a/routine.go b/routine.go new file mode 100644 index 0000000..56beee5 --- /dev/null +++ b/routine.go @@ -0,0 +1,131 @@ +package wireproxy + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "math/rand" + "net" + + "github.com/armon/go-socks5" + + "golang.zx2c4.com/wireguard/tun/netstack" +) + +type CredentialValidator struct { + username string + password string +} + +type RoutineSpawner interface { + SpawnRoutine(*netstack.Net) +} + +type NetstackDNSResolver struct { + tnet *netstack.Net +} + +func (d NetstackDNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { + addrs, err := d.tnet.LookupContextHost(ctx, name) + if err != nil { + return ctx, nil, err + } + + size := len(addrs) + if size == 0 { + return ctx, nil, errors.New("no address found for: " + name) + } + + addr := addrs[rand.Intn(size)] + ip := net.ParseIP(addr) + if ip == nil { + return ctx, nil, errors.New("invalid address: " + addr) + } + + return ctx, ip, err +} + +func (config *Socks5Config) SpawnRoutine(tnet *netstack.Net) { + conf := &socks5.Config{Dial: tnet.DialContext, Resolver: NetstackDNSResolver{tnet: tnet}} + if username := config.Username; username != "" { + validator := CredentialValidator{username: username} + validator.password = config.Password + conf.Credentials = validator + } + server, err := socks5.New(conf) + if err != nil { + log.Panic(err) + } + + if err := server.ListenAndServe("tcp", config.BindAddress); err != nil { + log.Panic(err) + } +} + +func (c CredentialValidator) Valid(username, password string) bool { + return c.username == username && c.password == password +} + +func connForward(bufSize int, from, to net.Conn) { + buf := make([]byte, bufSize) + _, err := io.CopyBuffer(to, from, buf) + if err != nil { + to.Close() + return + } +} + +func tcpClientForward(tnet *netstack.Net, target string, conn net.Conn) { + sconn, err := tnet.Dial("tcp", target) + if err != nil { + fmt.Printf("[ERROR] TCP Client Tunnel to %s: %s\n", target, err.Error()) + return + } + + go connForward(1024, sconn, conn) + go connForward(1024, conn, sconn) +} + +func (conf *TCPClientTunnelConfig) SpawnRoutine(tnet *netstack.Net) { + server, err := net.ListenTCP("tcp", conf.BindAddress) + if err != nil { + log.Panic(err) + } + + for { + conn, err := server.Accept() + if err != nil { + log.Panic(err) + } + go tcpClientForward(tnet, conf.Target, conn) + } +} + +func tcpServerForward(target string, conn net.Conn) { + sconn, err := net.Dial("tcp", target) + if err != nil { + fmt.Printf("[ERROR] TCP Server Tunnel to %s: %s\n", target, err.Error()) + return + } + + go connForward(1024, sconn, conn) + go connForward(1024, conn, sconn) +} + +func (conf *TCPServerTunnelConfig) SpawnRoutine(tnet *netstack.Net) { + addr := &net.TCPAddr{Port: conf.ListenPort} + server, err := tnet.ListenTCP(addr) + if err != nil { + log.Panic(err) + } + + for { + conn, err := server.Accept() + if err != nil { + log.Panic(err) + } + go tcpServerForward(conf.Target, conn) + } +} diff --git a/wireguard.go b/wireguard.go new file mode 100644 index 0000000..938f1c0 --- /dev/null +++ b/wireguard.go @@ -0,0 +1,49 @@ +package wireproxy + +import ( + "fmt" + + "golang.zx2c4.com/go118/netip" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" +) + +type DeviceSetting struct { + ipcRequest string + dns []netip.Addr + deviceAddr *netip.Addr + mtu int +} + +func createIPCRequest(conf *DeviceConfig) (*DeviceSetting, error) { + request := fmt.Sprintf(`private_key=%s +public_key=%s +endpoint=%s +persistent_keepalive_interval=%d +preshared_key=%s +allowed_ip=0.0.0.0/0`, conf.SelfSecretKey, conf.PeerPublicKey, conf.PeerEndpoint, conf.KeepAlive, conf.PreSharedKey) + + setting := &DeviceSetting{ipcRequest: request, dns: conf.DNS, deviceAddr: conf.SelfEndpoint, mtu: conf.MTU} + return setting, nil +} + +func StartWireguard(conf *DeviceConfig) (*netstack.Net, error) { + setting, err := createIPCRequest(conf) + if err != nil { + return nil, err + } + + tun, tnet, err := netstack.CreateNetTUN([]netip.Addr{*(setting.deviceAddr)}, setting.dns, setting.mtu) + if err != nil { + return nil, err + } + dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) + dev.IpcSet(setting.ipcRequest) + err = dev.Up() + if err != nil { + return nil, err + } + + return tnet, nil +}