diff --git a/go.mod b/go.mod index 3c2e14e..ab52783 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( require ( github.com/google/btree v1.1.2 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect golang.org/x/crypto v0.19.0 // indirect golang.org/x/net v0.21.0 // indirect golang.org/x/sys v0.17.0 // indirect diff --git a/go.sum b/go.sum index bc70e9b..949c441 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/things-go/go-socks5 v0.0.5 h1:qvKaGcBkfDrUL33SchHN93srAmYGzb4CxSM2DPYufe8= diff --git a/routine.go b/routine.go index d7af0ce..5044812 100644 --- a/routine.go +++ b/routine.go @@ -11,6 +11,7 @@ import ( "os" "strconv" + "github.com/sourcegraph/conc" "github.com/things-go/go-socks5" "github.com/things-go/go-socks5/bufferpool" @@ -175,8 +176,6 @@ func connForward(from io.ReadWriteCloser, to io.ReadWriteCloser) { if err != nil { errorLogger.Printf("Cannot forward traffic: %s\n", err.Error()) } - _ = from.Close() - _ = to.Close() } // tcpClientForward starts a new connection via wireguard and forward traffic from `conn` @@ -195,8 +194,17 @@ func tcpClientForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) { return } - go connForward(sconn, conn) - go connForward(conn, sconn) + go func() { + gr := conc.NewWaitGroup() + gr.Go(func() { + connForward(sconn, conn) + }) + gr.Go(func() { + connForward(conn, sconn) + }) + gr.Wait() + _ = sconn.Close() + }() } // STDIOTcpForward starts a new connection via wireguard and forward traffic from `conn` @@ -221,8 +229,16 @@ func STDIOTcpForward(vt *VirtualTun, raddr *addressPort) { return } - go connForward(os.Stdin, sconn) - go connForward(sconn, stdout) + go func() { + gr := conc.NewWaitGroup() + gr.Go(func() { + connForward(os.Stdin, sconn) + }) + gr.Go(func() { + connForward(sconn, stdout) + }) + gr.Wait() + }() } // SpawnRoutine spawns a local TCP server which acts as a proxy to the specified target @@ -273,8 +289,17 @@ func tcpServerForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) { return } - go connForward(sconn, conn) - go connForward(conn, sconn) + go func() { + gr := conc.NewWaitGroup() + gr.Go(func() { + connForward(sconn, conn) + }) + gr.Go(func() { + connForward(conn, sconn) + }) + gr.Wait() + _ = sconn.Close() + }() } // SpawnRoutine spawns a TCP server on wireguard which acts as a proxy to the specified target