diff --git a/config.go b/config.go index 1430a70..a52431c 100644 --- a/config.go +++ b/config.go @@ -12,16 +12,21 @@ import ( "net/netip" ) +type PeerConfig struct { + PublicKey string + PreSharedKey string + Endpoint string + KeepAlive int + AllowedIPs []netip.Prefix +} + // DeviceConfig contains the information to initiate a wireguard connection type DeviceConfig struct { - SelfSecretKey string - SelfEndpoint []netip.Addr - PeerPublicKey string - PeerEndpoint string - DNS []netip.Addr - KeepAlive int - PreSharedKey string - MTU int + SecretKey string + Endpoint []netip.Addr + Peers []PeerConfig + DNS []netip.Addr + MTU int } type TCPClientTunnelConfig struct { @@ -144,6 +149,24 @@ func parseCIDRNetIP(section *ini.Section, keyName string) ([]netip.Addr, error) return ips, nil } +func parseAllowedIPs(section *ini.Section) ([]netip.Prefix, error) { + key := section.Key("AllowedIPs") + if key == nil { + return []netip.Prefix{}, nil + } + + var ips []netip.Prefix + for _, str := range key.StringsWithShadows(",") { + prefix, err := netip.ParsePrefix(str) + if err != nil { + return nil, err + } + + ips = append(ips, prefix) + } + return ips, nil +} + func resolveIP(ip string) (*net.IPAddr, error) { return net.ResolveIPAddr("ip", ip) } @@ -174,13 +197,13 @@ func ParseInterface(cfg *ini.File, device *DeviceConfig) error { return err } - device.SelfEndpoint = address + device.Endpoint = address privKey, err := parseBase64KeyToHex(section, "PrivateKey") if err != nil { return err } - device.SelfSecretKey = privKey + device.SecretKey = privKey dns, err := parseNetIP(section, "DNS") if err != nil { @@ -199,46 +222,58 @@ func ParseInterface(cfg *ini.File, device *DeviceConfig) error { return nil } -// ParsePeer parses the [Peer] section and extract the information into `device` -func ParsePeer(cfg *ini.File, device *DeviceConfig) error { +// ParsePeer parses the [Peer] section and extract the information into `peers` +func ParsePeers(cfg *ini.File, peers *[]PeerConfig) error { sections, err := cfg.SectionsByName("Peer") - if len(sections) != 1 || err != nil { - return errors.New("one and only one [Peer] is expected") + if len(sections) < 1 || err != nil { + return errors.New("at least one [Peer] is expected") } - section := sections[0] - decoded, err := parseBase64KeyToHex(section, "PublicKey") - if err != nil { - return err - } - device.PeerPublicKey = decoded + for _, section := range sections { + peer := PeerConfig{ + PreSharedKey: "0000000000000000000000000000000000000000000000000000000000000000", + KeepAlive: 0, + } - if sectionKey, err := section.GetKey("PreSharedKey"); err == nil { - value, err := encodeBase64ToHex(sectionKey.String()) + decoded, err := parseBase64KeyToHex(section, "PublicKey") if err != nil { return err } - device.PreSharedKey = value - } + peer.PublicKey = decoded - decoded, err = parseString(section, "Endpoint") - if err != nil { - return err - } - decoded, err = resolveIPPAndPort(decoded) - if err != nil { - return err - } - device.PeerEndpoint = decoded + if sectionKey, err := section.GetKey("PreSharedKey"); err == nil { + value, err := encodeBase64ToHex(sectionKey.String()) + if err != nil { + return err + } + peer.PreSharedKey = value + } - if sectionKey, err := section.GetKey("PersistentKeepalive"); err == nil { - value, err := sectionKey.Int() + decoded, err = parseString(section, "Endpoint") if err != nil { return err } - device.KeepAlive = value - } + decoded, err = resolveIPPAndPort(decoded) + if err != nil { + return err + } + peer.Endpoint = decoded + if sectionKey, err := section.GetKey("PersistentKeepalive"); err == nil { + value, err := sectionKey.Int() + if err != nil { + return err + } + peer.KeepAlive = value + } + + peer.AllowedIPs, err = parseAllowedIPs(section) + if err != nil { + return err + } + + *peers = append(*peers, peer) + } return nil } @@ -318,8 +353,9 @@ func parseRoutinesConfig(routines *[]RoutineSpawner, cfg *ini.File, sectionName // ParseConfig takes the path of a configuration file and parses it into Configuration func ParseConfig(path string) (*Configuration, error) { iniOpt := ini.LoadOptions{ - Insensitive: true, - AllowShadows: true, + Insensitive: true, + AllowShadows: true, + AllowNonUniqueSections: true, } cfg, err := ini.LoadSources(iniOpt, path) @@ -328,9 +364,7 @@ func ParseConfig(path string) (*Configuration, error) { } device := &DeviceConfig{ - PreSharedKey: "0000000000000000000000000000000000000000000000000000000000000000", - KeepAlive: 0, - MTU: 1420, + MTU: 1420, } root := cfg.Section("") @@ -348,7 +382,7 @@ func ParseConfig(path string) (*Configuration, error) { return nil, err } - err = ParsePeer(wgCfg, device) + err = ParsePeers(wgCfg, &device.Peers) if err != nil { return nil, err } diff --git a/go.mod b/go.mod index ff57620..4a98f29 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/octeep/wireproxy go 1.18 require ( + github.com/MakeNowJust/heredoc/v2 v2.0.1 github.com/akamensky/argparse v1.3.1 github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 github.com/go-ini/ini v1.66.4 diff --git a/go.sum b/go.sum index 2779173..498545f 100644 --- a/go.sum +++ b/go.sum @@ -56,6 +56,8 @@ github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbt github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/MakeNowJust/heredoc/v2 v2.0.1 h1:rlCHh70XXXv7toz95ajQWOWQnN4WNLt0TdpZYIR/J6A= +github.com/MakeNowJust/heredoc/v2 v2.0.1/go.mod h1:6/2Abh5s+hc3g9nbWLe9ObDIOhaRrqsyY9MWy+4JdRM= github.com/Microsoft/go-winio v0.4.11/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA= github.com/Microsoft/go-winio v0.4.14/go.mod h1:qXqCSQ3Xa7+6tgxaGTIe4Kpcdsi+P8jBhyzoq1bpyYA= github.com/Microsoft/go-winio v0.4.16-0.20201130162521-d1ffc52c7331/go.mod h1:XB6nPKklQyQ7GC9LdcBEcBl8PF76WugXOPRXwdLnMv0= diff --git a/wireguard.go b/wireguard.go index 493db02..3f497e5 100644 --- a/wireguard.go +++ b/wireguard.go @@ -1,12 +1,15 @@ package wireproxy import ( + "bytes" "fmt" + "net/netip" + + "github.com/MakeNowJust/heredoc/v2" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" - "net/netip" ) // DeviceSetting contains the parameters for setting up a tun interface @@ -19,15 +22,33 @@ type DeviceSetting struct { // serialize the config into an IPC request and DeviceSetting func createIPCRequest(conf *DeviceConfig) (*DeviceSetting, error) { - request := fmt.Sprintf(`private_key=%s -public_key=%s -endpoint=%s -persistent_keepalive_interval=%d -preshared_key=%s -allowed_ip=0.0.0.0/0 -allowed_ip=::0/0`, conf.SelfSecretKey, conf.PeerPublicKey, conf.PeerEndpoint, conf.KeepAlive, conf.PreSharedKey) + var request bytes.Buffer - setting := &DeviceSetting{ipcRequest: request, dns: conf.DNS, deviceAddr: conf.SelfEndpoint, mtu: conf.MTU} + request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey)) + + for _, peer := range conf.Peers { + request.WriteString(fmt.Sprintf(heredoc.Doc(` + public_key=%s + endpoint=%s + persistent_keepalive_interval=%d + preshared_key=%s + `), + peer.PublicKey, peer.Endpoint, peer.KeepAlive, peer.PreSharedKey, + )) + + if len(peer.AllowedIPs) > 0 { + for _, ip := range peer.AllowedIPs { + request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip.String())) + } + } else { + request.WriteString(heredoc.Doc(` + allowed_ip=0.0.0.0/0 + allowed_ip=::0/0 + `)) + } + } + + setting := &DeviceSetting{ipcRequest: request.String(), dns: conf.DNS, deviceAddr: conf.Endpoint, mtu: conf.MTU} return setting, nil }