mirror of
https://github.com/whyvl/wireproxy.git
synced 2025-04-29 19:01:42 +02:00
Merge 8ae43d6103
into 9dad356bee
This commit is contained in:
commit
0c662b2c4b
2 changed files with 199 additions and 0 deletions
39
config.go
39
config.go
|
@ -33,6 +33,12 @@ type DeviceConfig struct {
|
||||||
CheckAliveInterval int
|
CheckAliveInterval int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UDPProxyTunnelConfig struct {
|
||||||
|
BindAddress string
|
||||||
|
Target string
|
||||||
|
InactivityTimeout int
|
||||||
|
}
|
||||||
|
|
||||||
type TCPClientTunnelConfig struct {
|
type TCPClientTunnelConfig struct {
|
||||||
BindAddress *net.TCPAddr
|
BindAddress *net.TCPAddr
|
||||||
Target string
|
Target string
|
||||||
|
@ -435,6 +441,34 @@ func parseHTTPConfig(section *ini.Section) (RoutineSpawner, error) {
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseUDPProxyTunnelConfig(section *ini.Section) (RoutineSpawner, error) {
|
||||||
|
config := &UDPProxyTunnelConfig{}
|
||||||
|
|
||||||
|
bindAddress, err := parseString(section, "BindAddress")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
config.BindAddress = bindAddress
|
||||||
|
|
||||||
|
target, err := parseString(section, "Target")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
config.Target = target
|
||||||
|
|
||||||
|
inactivityTimeout := 0
|
||||||
|
if sectionKey, err := section.GetKey("InactivityTimeout"); err == nil {
|
||||||
|
timeoutVal, err := sectionKey.Int()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
inactivityTimeout = timeoutVal
|
||||||
|
}
|
||||||
|
config.InactivityTimeout = inactivityTimeout
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Takes a function that parses an individual section into a config, and apply it on all
|
// Takes a function that parses an individual section into a config, and apply it on all
|
||||||
// specified sections
|
// specified sections
|
||||||
func parseRoutinesConfig(routines *[]RoutineSpawner, cfg *ini.File, sectionName string, f func(*ini.Section) (RoutineSpawner, error)) error {
|
func parseRoutinesConfig(routines *[]RoutineSpawner, cfg *ini.File, sectionName string, f func(*ini.Section) (RoutineSpawner, error)) error {
|
||||||
|
@ -519,6 +553,11 @@ func ParseConfig(path string) (*Configuration, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = parseRoutinesConfig(&routinesSpawners, cfg, "UDPProxyTunnel", parseUDPProxyTunnelConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return &Configuration{
|
return &Configuration{
|
||||||
Device: device,
|
Device: device,
|
||||||
Routines: routinesSpawners,
|
Routines: routinesSpawners,
|
||||||
|
|
160
udp_proxy.go
Normal file
160
udp_proxy.go
Normal file
|
@ -0,0 +1,160 @@
|
||||||
|
package wireproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// udpSession represents a UDP forwarding session, keyed by the local source address.
|
||||||
|
// remoteConn is the UDP connection to the remote endpoint (on the WireGuard side).
|
||||||
|
type udpSession struct {
|
||||||
|
remoteConn net.Conn
|
||||||
|
lastActive time.Time
|
||||||
|
closeChan chan struct{}
|
||||||
|
inactivityDur time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// SpawnRoutine implements the RoutineSpawner interface.
|
||||||
|
// It starts listening on config.BindAddress, handling each unique source (client) address
|
||||||
|
// with its own udpSession. If InactivityTimeout > 0, sessions automatically close after inactivity
|
||||||
|
func (conf *UDPProxyTunnelConfig) SpawnRoutine(vt *VirtualTun) {
|
||||||
|
addr, err := net.ResolveUDPAddr("udp", conf.BindAddress)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("UDPProxyTunnelConfig: could not resolve bind address %s: %v", conf.BindAddress, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := net.ListenUDP("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("UDPProxyTunnelConfig: could not listen on %s: %v", conf.BindAddress, err)
|
||||||
|
}
|
||||||
|
log.Printf("UDPProxyTunnel listening on %s, forwarding to %s", conf.BindAddress, conf.Target)
|
||||||
|
|
||||||
|
inactivityDur := time.Duration(conf.InactivityTimeout) * time.Second
|
||||||
|
sessions := make(map[string]*udpSession)
|
||||||
|
var sessionMu sync.Mutex
|
||||||
|
|
||||||
|
// Periodically clean up expired sessions if inactivity timeout is enabled
|
||||||
|
if conf.InactivityTimeout > 0 {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(10 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for range ticker.C {
|
||||||
|
now := time.Now()
|
||||||
|
sessionMu.Lock()
|
||||||
|
for key, sess := range sessions {
|
||||||
|
if now.Sub(sess.lastActive) >= inactivityDur {
|
||||||
|
log.Printf("UDPProxyTunnel: closing inactive session for %s", key)
|
||||||
|
close(sess.closeChan)
|
||||||
|
delete(sessions, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sessionMu.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create or get a UDP session based on the local source address
|
||||||
|
getOrCreateSession := func(srcAddr string) (*udpSession, error) {
|
||||||
|
sessionMu.Lock()
|
||||||
|
defer sessionMu.Unlock()
|
||||||
|
|
||||||
|
// return if session already exists
|
||||||
|
if s, ok := sessions[srcAddr]; ok {
|
||||||
|
s.lastActive = time.Now()
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new session
|
||||||
|
remoteConn, err := vt.Tnet.Dial("udp", conf.Target)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("UDPProxyTunnel: could not Dial(%s): %w", conf.Target, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &udpSession{
|
||||||
|
remoteConn: remoteConn,
|
||||||
|
lastActive: time.Now(),
|
||||||
|
closeChan: make(chan struct{}),
|
||||||
|
inactivityDur: inactivityDur,
|
||||||
|
}
|
||||||
|
sessions[srcAddr] = s
|
||||||
|
|
||||||
|
// Spin up a goroutine to handle traffic from remote -> local
|
||||||
|
go conf.handleRemoteToLocal(listener, srcAddr, s)
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main loop to read from local client and forward to remote
|
||||||
|
go func() {
|
||||||
|
buf := make([]byte, 64*1024) // typical max UDP size
|
||||||
|
for {
|
||||||
|
n, src, err := listener.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("UDPProxyTunnel: error reading from UDP: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
srcKey := src.String() // identify session by the local client's IP:port
|
||||||
|
s, err := getOrCreateSession(srcKey)
|
||||||
|
if err != nil {
|
||||||
|
errorLogger.Printf("UDPProxyTunnel: getOrCreateSession failed for %s: %v", srcKey, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
s.lastActive = time.Now()
|
||||||
|
_, err = s.remoteConn.Write(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
errorLogger.Printf("UDPProxyTunnel: could not write to remote (%s): %v", conf.Target, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// handles data from the remote WireGuard side back to the local client
|
||||||
|
// this function blocks until the session is closed
|
||||||
|
func (conf *UDPProxyTunnelConfig) handleRemoteToLocal(listener *net.UDPConn, srcAddr string, s *udpSession) {
|
||||||
|
defer func() {
|
||||||
|
_ = s.remoteConn.Close()
|
||||||
|
}()
|
||||||
|
buf := make([]byte, 64*1024)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.closeChan:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = s.remoteConn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
n, err := s.remoteConn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
// If a timeout or temporary error, continue to see if the session is closed
|
||||||
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||||
|
select {
|
||||||
|
case <-s.closeChan:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errorLogger.Printf("UDPProxyTunnel: read error from remote: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.lastActive = time.Now()
|
||||||
|
|
||||||
|
dstUDPAddr, err := net.ResolveUDPAddr("udp", srcAddr)
|
||||||
|
if err != nil {
|
||||||
|
errorLogger.Printf("UDPProxyTunnel: cannot resolve local address %s: %v", srcAddr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = listener.WriteToUDP(buf[:n], dstUDPAddr)
|
||||||
|
if err != nil {
|
||||||
|
errorLogger.Printf("UDPProxyTunnel: cannot write to local %s: %v", srcAddr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue