diff --git a/http.go b/http.go index 1d224b9..92f1a47 100644 --- a/http.go +++ b/http.go @@ -10,6 +10,8 @@ import ( "net" "net/http" "strings" + + "github.com/sourcegraph/conc" ) const proxyAuthHeaderKey = "Proxy-Authorization" @@ -91,17 +93,19 @@ func (s *HTTPServer) handle(req *http.Request) (peer net.Conn, err error) { return } -func (s *HTTPServer) serve(conn net.Conn) error { +func (s *HTTPServer) serve(conn net.Conn) { var rd = bufio.NewReader(conn) req, err := http.ReadRequest(rd) if err != nil { - return fmt.Errorf("read request failed: %w", err) + log.Printf("read request failed: %s\n", err) + return } code, err := s.authenticate(req) if err != nil { _ = responseWith(req, code).Write(conn) - return err + log.Println(err) + return } var peer net.Conn @@ -112,23 +116,29 @@ func (s *HTTPServer) serve(conn net.Conn) error { peer, err = s.handle(req) default: _ = responseWith(req, http.StatusMethodNotAllowed).Write(conn) - return fmt.Errorf("unsupported protocol: %s", req.Method) + log.Printf("unsupported protocol: %s\n", req.Method) + return } if err != nil { - return fmt.Errorf("dial proxy failed: %w", err) + log.Printf("dial proxy failed: %s\n", err) + return } if peer == nil { - return fmt.Errorf("dial proxy failed: peer nil") + log.Println("dial proxy failed: peer nil") + return } - defer peer.Close() - go func() { - _, _ = io.Copy(conn, peer) + wg := conc.NewWaitGroup() + wg.Go(func() { + _, err = io.Copy(conn, peer) + }) + wg.Go(func() { + _, err = io.Copy(peer, conn) + }) + wg.Wait() + _ = peer.Close() + _ = conn.Close() }() - - _, err = io.Copy(peer, conn) - - return err } // ListenAndServe is used to create a listener and serve on it @@ -137,20 +147,16 @@ func (s *HTTPServer) ListenAndServe(network, addr string) error { if err != nil { return fmt.Errorf("listen tcp failed: %w", err) } - defer server.Close() - var conn net.Conn + defer func(server net.Listener) { + _ = server.Close() + }(server) for { - conn, err = server.Accept() + conn, err := server.Accept() if err != nil { return fmt.Errorf("accept request failed: %w", err) } go func(conn net.Conn) { - err = s.serve(conn) - if err != nil { - log.Println(err) - } - _ = conn.Close() - conn = nil + s.serve(conn) }(conn) } } diff --git a/routine.go b/routine.go index 5044812..13fa944 100644 --- a/routine.go +++ b/routine.go @@ -170,7 +170,7 @@ func (c CredentialValidator) Valid(username, password string) bool { return u&p == 1 } -// connForward copy data from `from` to `to`, then close both stream. +// connForward copy data from `from` to `to` func connForward(from io.ReadWriteCloser, to io.ReadWriteCloser) { _, err := io.Copy(to, from) if err != nil { @@ -195,15 +195,18 @@ func tcpClientForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) { } go func() { - gr := conc.NewWaitGroup() - gr.Go(func() { + wg := conc.NewWaitGroup() + wg.Go(func() { connForward(sconn, conn) }) - gr.Go(func() { + wg.Go(func() { connForward(conn, sconn) }) - gr.Wait() + wg.Wait() _ = sconn.Close() + _ = conn.Close() + sconn = nil + conn = nil }() } @@ -230,14 +233,16 @@ func STDIOTcpForward(vt *VirtualTun, raddr *addressPort) { } go func() { - gr := conc.NewWaitGroup() - gr.Go(func() { + wg := conc.NewWaitGroup() + wg.Go(func() { connForward(os.Stdin, sconn) }) - gr.Go(func() { + wg.Go(func() { connForward(sconn, stdout) }) - gr.Wait() + wg.Wait() + _ = sconn.Close() + sconn = nil }() } @@ -253,9 +258,8 @@ func (conf *TCPClientTunnelConfig) SpawnRoutine(vt *VirtualTun) { log.Fatal(err) } - var conn net.Conn for { - conn, err = server.Accept() + conn, err := server.Accept() if err != nil { log.Fatal(err) } @@ -299,6 +303,9 @@ func tcpServerForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) { }) gr.Wait() _ = sconn.Close() + _ = conn.Close() + sconn = nil + conn = nil }() } @@ -315,9 +322,8 @@ func (conf *TCPServerTunnelConfig) SpawnRoutine(vt *VirtualTun) { log.Fatal(err) } - var conn net.Conn for { - conn, err = server.Accept() + conn, err := server.Accept() if err != nil { log.Fatal(err) }