diff --git a/config.go b/config.go index 1430a70..91bb4d2 100644 --- a/config.go +++ b/config.go @@ -12,15 +12,20 @@ import ( "net/netip" ) +type PeerConfig struct { + PublicKey string + PreSharedKey string + Endpoint string + KeepAlive int + AllowedIPs []netip.Addr +} + // DeviceConfig contains the information to initiate a wireguard connection type DeviceConfig struct { SelfSecretKey string SelfEndpoint []netip.Addr - PeerPublicKey string - PeerEndpoint string + Peers []PeerConfig DNS []netip.Addr - KeepAlive int - PreSharedKey string MTU int } @@ -144,6 +149,25 @@ func parseCIDRNetIP(section *ini.Section, keyName string) ([]netip.Addr, error) return ips, nil } +func parseAllowedIPs(section *ini.Section) ([]netip.Addr, error) { + key := section.Key("AllowedIPs") + if key == nil { + return []netip.Addr{}, nil + } + + var ips []netip.Addr + for _, str := range key.StringsWithShadows(",") { + prefix, err := netip.ParsePrefix(str) + if err != nil { + return nil, err + } + + addr := prefix.Addr() + ips = append(ips, addr) + } + return ips, nil +} + func resolveIP(ip string) (*net.IPAddr, error) { return net.ResolveIPAddr("ip", ip) } @@ -199,46 +223,55 @@ func ParseInterface(cfg *ini.File, device *DeviceConfig) error { return nil } -// ParsePeer parses the [Peer] section and extract the information into `device` -func ParsePeer(cfg *ini.File, device *DeviceConfig) error { +// ParsePeer parses the [Peer] section and extract the information into `peers` +func ParsePeers(cfg *ini.File, peers *[]PeerConfig) error { sections, err := cfg.SectionsByName("Peer") - if len(sections) != 1 || err != nil { - return errors.New("one and only one [Peer] is expected") + if len(sections) < 1 || err != nil { + return errors.New("at least one [Peer] is expected") } - section := sections[0] - decoded, err := parseBase64KeyToHex(section, "PublicKey") - if err != nil { - return err - } - device.PeerPublicKey = decoded + for _, section := range sections { + peer := PeerConfig{} - if sectionKey, err := section.GetKey("PreSharedKey"); err == nil { - value, err := encodeBase64ToHex(sectionKey.String()) + decoded, err := parseBase64KeyToHex(section, "PublicKey") if err != nil { return err } - device.PreSharedKey = value - } + peer.PublicKey = decoded - decoded, err = parseString(section, "Endpoint") - if err != nil { - return err - } - decoded, err = resolveIPPAndPort(decoded) - if err != nil { - return err - } - device.PeerEndpoint = decoded + if sectionKey, err := section.GetKey("PreSharedKey"); err == nil { + value, err := encodeBase64ToHex(sectionKey.String()) + if err != nil { + return err + } + peer.PreSharedKey = value + } - if sectionKey, err := section.GetKey("PersistentKeepalive"); err == nil { - value, err := sectionKey.Int() + decoded, err = parseString(section, "Endpoint") if err != nil { return err } - device.KeepAlive = value - } + decoded, err = resolveIPPAndPort(decoded) + if err != nil { + return err + } + peer.Endpoint = decoded + if sectionKey, err := section.GetKey("PersistentKeepalive"); err == nil { + value, err := sectionKey.Int() + if err != nil { + return err + } + peer.KeepAlive = value + } + + peer.AllowedIPs, err = parseAllowedIPs(section) + if err != nil { + return err + } + + *peers = append(*peers, peer) + } return nil } @@ -318,8 +351,9 @@ func parseRoutinesConfig(routines *[]RoutineSpawner, cfg *ini.File, sectionName // ParseConfig takes the path of a configuration file and parses it into Configuration func ParseConfig(path string) (*Configuration, error) { iniOpt := ini.LoadOptions{ - Insensitive: true, - AllowShadows: true, + Insensitive: true, + AllowShadows: true, + AllowNonUniqueSections: true, } cfg, err := ini.LoadSources(iniOpt, path) @@ -328,9 +362,7 @@ func ParseConfig(path string) (*Configuration, error) { } device := &DeviceConfig{ - PreSharedKey: "0000000000000000000000000000000000000000000000000000000000000000", - KeepAlive: 0, - MTU: 1420, + MTU: 1420, } root := cfg.Section("") @@ -348,7 +380,7 @@ func ParseConfig(path string) (*Configuration, error) { return nil, err } - err = ParsePeer(wgCfg, device) + err = ParsePeers(wgCfg, &device.Peers) if err != nil { return nil, err } diff --git a/go.mod b/go.mod index 3078d49..e4538df 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/octeep/wireproxy go 1.18 require ( + github.com/MakeNowJust/heredoc/v2 v2.0.1 github.com/akamensky/argparse v1.3.1 github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 github.com/go-ini/ini v1.66.4 diff --git a/go.sum b/go.sum index 600a730..d51823f 100644 --- a/go.sum +++ b/go.sum @@ -56,6 +56,8 @@ github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbt github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/MakeNowJust/heredoc/v2 v2.0.1 h1:rlCHh70XXXv7toz95ajQWOWQnN4WNLt0TdpZYIR/J6A= +github.com/MakeNowJust/heredoc/v2 v2.0.1/go.mod h1:6/2Abh5s+hc3g9nbWLe9ObDIOhaRrqsyY9MWy+4JdRM= github.com/Microsoft/go-winio v0.4.11/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA= github.com/Microsoft/go-winio v0.4.14/go.mod h1:qXqCSQ3Xa7+6tgxaGTIe4Kpcdsi+P8jBhyzoq1bpyYA= github.com/Microsoft/go-winio v0.4.16-0.20201130162521-d1ffc52c7331/go.mod h1:XB6nPKklQyQ7GC9LdcBEcBl8PF76WugXOPRXwdLnMv0= diff --git a/wireguard.go b/wireguard.go index 493db02..e3d6e3e 100644 --- a/wireguard.go +++ b/wireguard.go @@ -1,12 +1,15 @@ package wireproxy import ( + "bytes" "fmt" + "net/netip" + + "github.com/MakeNowJust/heredoc/v2" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" - "net/netip" ) // DeviceSetting contains the parameters for setting up a tun interface @@ -19,15 +22,33 @@ type DeviceSetting struct { // serialize the config into an IPC request and DeviceSetting func createIPCRequest(conf *DeviceConfig) (*DeviceSetting, error) { - request := fmt.Sprintf(`private_key=%s -public_key=%s -endpoint=%s -persistent_keepalive_interval=%d -preshared_key=%s -allowed_ip=0.0.0.0/0 -allowed_ip=::0/0`, conf.SelfSecretKey, conf.PeerPublicKey, conf.PeerEndpoint, conf.KeepAlive, conf.PreSharedKey) + var request bytes.Buffer - setting := &DeviceSetting{ipcRequest: request, dns: conf.DNS, deviceAddr: conf.SelfEndpoint, mtu: conf.MTU} + request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SelfSecretKey)) + + for _, peer := range conf.Peers { + request.WriteString(fmt.Sprintf(heredoc.Doc(` + public_key=%s + endpoint=%s + persistent_keepalive_interval=%d + preshared_key=%s + `), + peer.PublicKey, peer.Endpoint, peer.KeepAlive, peer.PreSharedKey, + )) + + if len(peer.AllowedIPs) > 0 { + for _, ip := range peer.AllowedIPs { + request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip)) + } + } else { + request.WriteString(heredoc.Doc(` + allowed_ip=0.0.0.0/0 + allowed_ip=::0/0 + `)) + } + } + + setting := &DeviceSetting{ipcRequest: request.String(), dns: conf.DNS, deviceAddr: conf.SelfEndpoint, mtu: conf.MTU} return setting, nil } @@ -53,6 +74,9 @@ func StartWireguard(conf *DeviceConfig) (*VirtualTun, error) { return nil, err } + // Make sure an initial handshake happpens + dev.SendKeepalivesToPeersWithCurrentKeypair() + return &VirtualTun{ tnet: tnet, systemDNS: len(setting.dns) == 0,