This commit is contained in:
octeep 2022-03-28 17:25:51 +01:00 committed by octeep
parent e663f3d412
commit 9f0fe5d20d

619
main.go
View file

@ -1,440 +1,439 @@
package main package main
import ( import (
"bufio" "bufio"
"context" "context"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"log" "log"
"math/rand" "math/rand"
"net" "net"
"os" "os"
"strconv" "strconv"
"strings" "strings"
"github.com/armon/go-socks5" "github.com/armon/go-socks5"
"golang.zx2c4.com/go118/netip" "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
) )
type ConfigSection struct { type ConfigSection struct {
name string name string
entries map[string]string entries map[string]string
} }
type DeviceSetting struct { type DeviceSetting struct {
ipcRequest string ipcRequest string
dns []netip.Addr dns []netip.Addr
deviceAddr *netip.Addr deviceAddr *netip.Addr
} }
type NetstackDNSResolver struct { type NetstackDNSResolver struct {
tnet *netstack.Net tnet *netstack.Net
} }
type Configuration []ConfigSection type Configuration []ConfigSection
type CredentialValidator struct { type CredentialValidator struct {
username string username string
password string password string
} }
func (d NetstackDNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { func (d NetstackDNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
addrs, err := d.tnet.LookupContextHost(ctx, name) addrs, err := d.tnet.LookupContextHost(ctx, name)
if err != nil { if err != nil {
return ctx, nil, err return ctx, nil, err
} }
size := len(addrs) size := len(addrs)
if size == 0 { if size == 0 {
return ctx, nil, errors.New("no address found for: " + name) return ctx, nil, errors.New("no address found for: " + name)
} }
addr := addrs[rand.Intn(size)] addr := addrs[rand.Intn(size)]
ip := net.ParseIP(addr) ip := net.ParseIP(addr)
if ip == nil { if ip == nil {
return ctx, nil, errors.New("invalid address: " + addr) return ctx, nil, errors.New("invalid address: " + addr)
} }
return ctx, ip, err return ctx, ip, err
} }
func configRoot(config Configuration) map[string]string { func configRoot(config Configuration) map[string]string {
for _, section := range config { for _, section := range config {
if section.name == "ROOT" { if section.name == "ROOT" {
return section.entries return section.entries
} }
} }
return nil return nil
} }
func readConfig(path string) (Configuration, error) { func readConfig(path string) (Configuration, error) {
file, err := os.Open(path) file, err := os.Open(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer file.Close() defer file.Close()
scanner := bufio.NewScanner(file) scanner := bufio.NewScanner(file)
section := ConfigSection{name: "ROOT", entries: map[string]string{}} section := ConfigSection{name: "ROOT", entries: map[string]string{}}
sections := []ConfigSection{} sections := []ConfigSection{}
lineNo := 0 lineNo := 0
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
lineNo += 1 lineNo += 1
if hashIndex := strings.Index(line, "#"); hashIndex >= 0 { if hashIndex := strings.Index(line, "#"); hashIndex >= 0 {
line = line[:hashIndex] line = line[:hashIndex]
} }
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
if line == "" { if line == "" {
continue continue
} }
if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") {
sections = append(sections, section) sections = append(sections, section)
section = ConfigSection{name: strings.ToLower(line), entries: map[string]string{}} section = ConfigSection{name: strings.ToLower(line), entries: map[string]string{}}
continue continue
} }
entry := strings.SplitN(line, "=", 2) entry := strings.SplitN(line, "=", 2)
if len(entry) != 2 { if len(entry) != 2 {
return nil, errors.New(fmt.Sprintf("invalid syntax at line %d: %s", lineNo, line)) return nil, errors.New(fmt.Sprintf("invalid syntax at line %d: %s", lineNo, line))
} }
key := strings.TrimSpace(entry[0]) key := strings.TrimSpace(entry[0])
key = strings.ToLower(key) key = strings.ToLower(key)
value := strings.TrimSpace(entry[1]) value := strings.TrimSpace(entry[1])
if _, dup := section.entries[key]; dup { if _, dup := section.entries[key]; dup {
return nil, errors.New(fmt.Sprintf("duplicate key line %d: %s", lineNo, line)) return nil, errors.New(fmt.Sprintf("duplicate key line %d: %s", lineNo, line))
} }
section.entries[key] = value section.entries[key] = value
} }
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
return nil, err return nil, err
} }
sections = append(sections, section) sections = append(sections, section)
return sections, nil return sections, nil
} }
func parseBase64Key(key string) (string, error) { func parseBase64Key(key string) (string, error) {
decoded, err := base64.StdEncoding.DecodeString(key) decoded, err := base64.StdEncoding.DecodeString(key)
if err != nil { if err != nil {
return "", errors.New("invalid base64 string") return "", errors.New("invalid base64 string")
} }
if len(decoded) != 32 { if len(decoded) != 32 {
return "", errors.New("key should be 32 bytes") return "", errors.New("key should be 32 bytes")
} }
return hex.EncodeToString(decoded), nil return hex.EncodeToString(decoded), nil
} }
func resolveIP(ip string) (*net.IPAddr, error) { func resolveIP(ip string) (*net.IPAddr, error) {
return net.ResolveIPAddr("ip", ip) return net.ResolveIPAddr("ip", ip)
} }
func resolveIPPAndPort(addr string) (string, error) { func resolveIPPAndPort(addr string) (string, error) {
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return "", err return "", err
} }
ip, err := resolveIP(host) ip, err := resolveIP(host)
if err != nil { if err != nil {
return "", err return "", err
} }
return net.JoinHostPort(ip.String(), port), nil return net.JoinHostPort(ip.String(), port), nil
} }
func parseIPs(s string) ([]netip.Addr, error) { func parseIPs(s string) ([]netip.Addr, error) {
ips := []netip.Addr{} ips := []netip.Addr{}
for _, str := range strings.Split(s, ",") { for _, str := range strings.Split(s, ",") {
str = strings.TrimSpace(str) str = strings.TrimSpace(str)
ip, err := netip.ParseAddr(str) ip, err := netip.ParseAddr(str)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ips = append(ips, ip) ips = append(ips, ip)
} }
return ips, nil return ips, nil
} }
func createIPCRequest(conf Configuration) (*DeviceSetting, error) { func createIPCRequest(conf Configuration) (*DeviceSetting, error) {
root := configRoot(conf) root := configRoot(conf)
peerPK, err := parseBase64Key(root["peerpublickey"]) peerPK, err := parseBase64Key(root["peerpublickey"])
if err != nil { if err != nil {
return nil, err return nil, err
} }
selfSK, err := parseBase64Key(root["selfsecretkey"]) selfSK, err := parseBase64Key(root["selfsecretkey"])
if err != nil { if err != nil {
return nil, err return nil, err
} }
peerEndpoint, err := resolveIPPAndPort(root["peerendpoint"]) peerEndpoint, err := resolveIPPAndPort(root["peerendpoint"])
if err != nil { if err != nil {
return nil, err return nil, err
} }
selfEndpoint, err := netip.ParseAddr(root["selfendpoint"]) selfEndpoint, err := netip.ParseAddr(root["selfendpoint"])
if err != nil { if err != nil {
return nil, err return nil, err
} }
dns, err := parseIPs(root["dns"]) dns, err := parseIPs(root["dns"])
if err != nil { if err != nil {
return nil, err return nil, err
} }
keepAlive := int64(0) keepAlive := int64(0)
if keepAliveOpt, ok := root["keepalive"]; ok { if keepAliveOpt, ok := root["keepalive"]; ok {
keepAlive, err = strconv.ParseInt(keepAliveOpt, 10, 0) keepAlive, err = strconv.ParseInt(keepAliveOpt, 10, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if keepAlive < 0 { if keepAlive < 0 {
keepAlive = 0 keepAlive = 0
} }
} }
preSharedKey := "0000000000000000000000000000000000000000000000000000000000000000" preSharedKey := "0000000000000000000000000000000000000000000000000000000000000000"
if pskOpt, ok := root["presharedkey"]; ok { if pskOpt, ok := root["presharedkey"]; ok {
preSharedKey, err = parseBase64Key(pskOpt) preSharedKey, err = parseBase64Key(pskOpt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
request := fmt.Sprintf(`private_key=%s request := fmt.Sprintf(`private_key=%s
public_key=%s public_key=%s
endpoint=%s endpoint=%s
persistent_keepalive_interval=%d persistent_keepalive_interval=%d
preshared_key=%s preshared_key=%s
allowed_ip=0.0.0.0/0`, selfSK, peerPK, peerEndpoint, keepAlive, preSharedKey) allowed_ip=0.0.0.0/0`, selfSK, peerPK, peerEndpoint, keepAlive, preSharedKey)
setting := &DeviceSetting{ ipcRequest: request, dns: dns, deviceAddr: &selfEndpoint } setting := &DeviceSetting{ipcRequest: request, dns: dns, deviceAddr: &selfEndpoint}
return setting, nil return setting, nil
} }
func socks5Routine(config map[string]string) (func(*netstack.Net), error) { func socks5Routine(config map[string]string) (func(*netstack.Net), error) {
bindAddr, ok := config["bindaddress"] bindAddr, ok := config["bindaddress"]
if !ok { if !ok {
return nil, errors.New("missing bind address") return nil, errors.New("missing bind address")
} }
routine := func(tnet *netstack.Net) { routine := func(tnet *netstack.Net) {
conf := &socks5.Config{Dial: tnet.DialContext, Resolver: NetstackDNSResolver{tnet: tnet}} conf := &socks5.Config{Dial: tnet.DialContext, Resolver: NetstackDNSResolver{tnet: tnet}}
if username, ok := config["username"]; ok { if username, ok := config["username"]; ok {
validator := CredentialValidator{username: username} validator := CredentialValidator{username: username}
password, ok := config["password"] password, ok := config["password"]
if ok { if ok {
validator.password = password validator.password = password
} }
conf.Credentials = validator conf.Credentials = validator
} }
server, err := socks5.New(conf) server, err := socks5.New(conf)
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)
} }
if err := server.ListenAndServe("tcp", bindAddr); err != nil { if err := server.ListenAndServe("tcp", bindAddr); err != nil {
log.Panic(err) log.Panic(err)
} }
} }
return routine, nil return routine, nil
} }
func (c CredentialValidator) Valid(username, password string) bool { func (c CredentialValidator) Valid(username, password string) bool {
return c.username == username && c.password == password return c.username == username && c.password == password
} }
func connForward(bufSize int, from, to net.Conn) { func connForward(bufSize int, from, to net.Conn) {
buf := make([]byte, bufSize) buf := make([]byte, bufSize)
for { for {
size, err := from.Read(buf) size, err := from.Read(buf)
if err != nil { if err != nil {
to.Close() to.Close()
return return
} }
_, err = to.Write(buf[:size]) _, err = to.Write(buf[:size])
if err != nil { if err != nil {
to.Close() to.Close()
return return
} }
} }
} }
func tcpClientForward(tnet *netstack.Net, target string, conn net.Conn) { func tcpClientForward(tnet *netstack.Net, target string, conn net.Conn) {
sconn, err := tnet.Dial("tcp", target) sconn, err := tnet.Dial("tcp", target)
if err != nil { if err != nil {
fmt.Printf("[ERROR] TCP Client Tunnel to %s: %s\n", target, err.Error()) fmt.Printf("[ERROR] TCP Client Tunnel to %s: %s\n", target, err.Error())
return return
} }
go connForward(1024, sconn, conn) go connForward(1024, sconn, conn)
go connForward(1024, conn, sconn) go connForward(1024, conn, sconn)
} }
func tcpClientRoutine(config map[string]string) (func(*netstack.Net), error) { func tcpClientRoutine(config map[string]string) (func(*netstack.Net), error) {
bindAddr, ok := config["bindaddress"] bindAddr, ok := config["bindaddress"]
if !ok { if !ok {
return nil, errors.New("missing bind address") return nil, errors.New("missing bind address")
} }
bindTCPAddr, err := net.ResolveTCPAddr("tcp", bindAddr) bindTCPAddr, err := net.ResolveTCPAddr("tcp", bindAddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
target, ok := config["target"] target, ok := config["target"]
if !ok { if !ok {
return nil, errors.New("missing target") return nil, errors.New("missing target")
} }
routine := func(tnet *netstack.Net) { routine := func(tnet *netstack.Net) {
server, err := net.ListenTCP("tcp", bindTCPAddr) server, err := net.ListenTCP("tcp", bindTCPAddr)
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)
} }
for { for {
conn, err := server.Accept() conn, err := server.Accept()
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)
} }
go tcpClientForward(tnet, target, conn) go tcpClientForward(tnet, target, conn)
} }
} }
return routine, nil return routine, nil
} }
func tcpServerForward(target string, conn net.Conn) { func tcpServerForward(target string, conn net.Conn) {
sconn, err := net.Dial("tcp", target) sconn, err := net.Dial("tcp", target)
if err != nil { if err != nil {
fmt.Printf("[ERROR] TCP Server Tunnel to %s: %s\n", target, err.Error()) fmt.Printf("[ERROR] TCP Server Tunnel to %s: %s\n", target, err.Error())
return return
} }
go connForward(1024, sconn, conn) go connForward(1024, sconn, conn)
go connForward(1024, conn, sconn) go connForward(1024, conn, sconn)
} }
func tcpServerRoutine(config map[string]string) (func(*netstack.Net), error) { func tcpServerRoutine(config map[string]string) (func(*netstack.Net), error) {
listenPort, err := strconv.ParseInt(config["listenport"], 10, 0) listenPort, err := strconv.ParseInt(config["listenport"], 10, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if listenPort < 1 || listenPort > 65535 { if listenPort < 1 || listenPort > 65535 {
return nil, errors.New("listen port out of bound") return nil, errors.New("listen port out of bound")
} }
addr := &net.TCPAddr{Port : int(listenPort)} addr := &net.TCPAddr{Port: int(listenPort)}
target, ok := config["target"] target, ok := config["target"]
if !ok { if !ok {
return nil, errors.New("missing target") return nil, errors.New("missing target")
} }
routine := func(tnet *netstack.Net) { routine := func(tnet *netstack.Net) {
server, err := tnet.ListenTCP(addr) server, err := tnet.ListenTCP(addr)
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)
} }
for { for {
conn, err := server.Accept() conn, err := server.Accept()
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)
} }
go tcpServerForward(target, conn) go tcpServerForward(target, conn)
} }
} }
return routine, nil return routine, nil
} }
func startWireguard(setting *DeviceSetting) (*netstack.Net, error) { func startWireguard(setting *DeviceSetting) (*netstack.Net, error) {
tun, tnet, err := netstack.CreateNetTUN([]netip.Addr{*(setting.deviceAddr)}, setting.dns, 1420) tun, tnet, err := netstack.CreateNetTUN([]netip.Addr{*(setting.deviceAddr)}, setting.dns, 1420)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
dev.IpcSet(setting.ipcRequest) dev.IpcSet(setting.ipcRequest)
err = dev.Up() err = dev.Up()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return tnet, nil return tnet, nil
} }
func main() { func main() {
if len(os.Args) != 2 { if len(os.Args) != 2 {
fmt.Println("Usage: wireproxy [config file path]") fmt.Println("Usage: wireproxy [config file path]")
return return
} }
conf, err := readConfig(os.Args[1]) conf, err := readConfig(os.Args[1])
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)
} }
setting, err := createIPCRequest(conf) setting, err := createIPCRequest(conf)
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)
} }
routines := [](func(*netstack.Net)){} routines := [](func(*netstack.Net)){}
var routine func(*netstack.Net) var routine func(*netstack.Net)
for _, section := range conf { for _, section := range conf {
switch section.name { switch section.name {
case "[socks5]": case "[socks5]":
routine, err = socks5Routine(section.entries) routine, err = socks5Routine(section.entries)
case "[tcpclienttunnel]": case "[tcpclienttunnel]":
routine, err = tcpClientRoutine(section.entries) routine, err = tcpClientRoutine(section.entries)
case "[tcpservertunnel]": case "[tcpservertunnel]":
routine, err = tcpServerRoutine(section.entries) routine, err = tcpServerRoutine(section.entries)
case "ROOT": case "ROOT":
continue continue
default: default:
log.Panic(errors.New(fmt.Sprintf("unsupported proxy: %s", section.name))) log.Panic(errors.New(fmt.Sprintf("unsupported proxy: %s", section.name)))
} }
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)
} }
routines = append(routines, routine) routines = append(routines, routine)
} }
tnet, err := startWireguard(setting) tnet, err := startWireguard(setting)
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)
} }
for _, netRoutine := range routines { for _, netRoutine := range routines {
go netRoutine(tnet) go netRoutine(tnet)
} }
select{} // sleep eternally select {} // sleep eternally
} }