wireproxy/udp_proxy.go
2025-01-01 01:04:19 -05:00

160 lines
4.4 KiB
Go

package wireproxy
import (
"fmt"
"log"
"net"
"sync"
"time"
)
// udpSession represents a UDP forwarding session, keyed by the local source address.
// remoteConn is the UDP connection to the remote endpoint (on the WireGuard side).
type udpSession struct {
remoteConn net.Conn
lastActive time.Time
closeChan chan struct{}
inactivityDur time.Duration
}
// SpawnRoutine implements the RoutineSpawner interface.
// It starts listening on config.BindAddress, handling each unique source (client) address
// with its own udpSession. If InactivityTimeout > 0, sessions automatically close after inactivity
func (conf *UDPProxyTunnelConfig) SpawnRoutine(vt *VirtualTun) {
addr, err := net.ResolveUDPAddr("udp", conf.BindAddress)
if err != nil {
log.Fatalf("UDPProxyTunnelConfig: could not resolve bind address %s: %v", conf.BindAddress, err)
}
listener, err := net.ListenUDP("udp", addr)
if err != nil {
log.Fatalf("UDPProxyTunnelConfig: could not listen on %s: %v", conf.BindAddress, err)
}
log.Printf("UDPProxyTunnel listening on %s, forwarding to %s", conf.BindAddress, conf.Target)
inactivityDur := time.Duration(conf.InactivityTimeout) * time.Second
sessions := make(map[string]*udpSession)
var sessionMu sync.Mutex
// Periodically clean up expired sessions if inactivity timeout is enabled
if conf.InactivityTimeout > 0 {
go func() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
sessionMu.Lock()
for key, sess := range sessions {
if now.Sub(sess.lastActive) >= inactivityDur {
log.Printf("UDPProxyTunnel: closing inactive session for %s", key)
close(sess.closeChan)
delete(sessions, key)
}
}
sessionMu.Unlock()
}
}()
}
// Create or get a UDP session based on the local source address
getOrCreateSession := func(srcAddr string) (*udpSession, error) {
sessionMu.Lock()
defer sessionMu.Unlock()
// return if session already exists
if s, ok := sessions[srcAddr]; ok {
s.lastActive = time.Now()
return s, nil
}
// Create a new session
remoteConn, err := vt.Tnet.Dial("udp", conf.Target)
if err != nil {
return nil, fmt.Errorf("UDPProxyTunnel: could not Dial(%s): %w", conf.Target, err)
}
s := &udpSession{
remoteConn: remoteConn,
lastActive: time.Now(),
closeChan: make(chan struct{}),
inactivityDur: inactivityDur,
}
sessions[srcAddr] = s
// Spin up a goroutine to handle traffic from remote -> local
go conf.handleRemoteToLocal(listener, srcAddr, s)
return s, nil
}
// Main loop to read from local client and forward to remote
go func() {
buf := make([]byte, 64*1024) // typical max UDP size
for {
n, src, err := listener.ReadFromUDP(buf)
if err != nil {
log.Printf("UDPProxyTunnel: error reading from UDP: %v", err)
continue
}
srcKey := src.String() // identify session by the local client's IP:port
s, err := getOrCreateSession(srcKey)
if err != nil {
errorLogger.Printf("UDPProxyTunnel: getOrCreateSession failed for %s: %v", srcKey, err)
continue
}
s.lastActive = time.Now()
_, err = s.remoteConn.Write(buf[:n])
if err != nil {
errorLogger.Printf("UDPProxyTunnel: could not write to remote (%s): %v", conf.Target, err)
}
}
}()
}
// handles data from the remote WireGuard side back to the local client
// this function blocks until the session is closed
func (conf *UDPProxyTunnelConfig) handleRemoteToLocal(listener *net.UDPConn, srcAddr string, s *udpSession) {
defer func() {
_ = s.remoteConn.Close()
}()
buf := make([]byte, 64*1024)
for {
select {
case <-s.closeChan:
return
default:
}
_ = s.remoteConn.SetReadDeadline(time.Now().Add(5 * time.Second))
n, err := s.remoteConn.Read(buf)
if err != nil {
// If a timeout or temporary error, continue to see if the session is closed
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
select {
case <-s.closeChan:
return
default:
continue
}
}
errorLogger.Printf("UDPProxyTunnel: read error from remote: %v", err)
return
}
s.lastActive = time.Now()
dstUDPAddr, err := net.ResolveUDPAddr("udp", srcAddr)
if err != nil {
errorLogger.Printf("UDPProxyTunnel: cannot resolve local address %s: %v", srcAddr, err)
return
}
_, err = listener.WriteToUDP(buf[:n], dstUDPAddr)
if err != nil {
errorLogger.Printf("UDPProxyTunnel: cannot write to local %s: %v", srcAddr, err)
return
}
}
}