diff --git a/http.go b/http.go index ebaa822..b5b834b 100644 --- a/http.go +++ b/http.go @@ -50,6 +50,7 @@ func (s *HTTPServer) authenticate(req *http.Request) (int, error) { return http.StatusUnauthorized, fmt.Errorf("username and password not matching") } +// handleConn sets up tunneling for CONNECT requests. func (s *HTTPServer) handleConn(req *http.Request, conn net.Conn) (peer net.Conn, err error) { addr := req.Host if !strings.Contains(addr, ":") { @@ -59,18 +60,19 @@ func (s *HTTPServer) handleConn(req *http.Request, conn net.Conn) (peer net.Conn peer, err = s.dial("tcp", addr) if err != nil { - return peer, fmt.Errorf("tun tcp dial failed: %w", err) + return nil, fmt.Errorf("tun tcp dial failed: %w", err) } _, err = conn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")) if err != nil { _ = peer.Close() - peer = nil + return nil, err } - return + return peer, nil } +// handle handles standard HTTP methods. func (s *HTTPServer) handle(req *http.Request) (peer net.Conn, err error) { addr := req.Host if !strings.Contains(addr, ":") { @@ -80,35 +82,37 @@ func (s *HTTPServer) handle(req *http.Request) (peer net.Conn, err error) { peer, err = s.dial("tcp", addr) if err != nil { - return peer, fmt.Errorf("tun tcp dial failed: %w", err) + return nil, fmt.Errorf("tun tcp dial failed: %w", err) } err = req.Write(peer) if err != nil { _ = peer.Close() - peer = nil - return peer, fmt.Errorf("conn write failed: %w", err) + return nil, fmt.Errorf("conn write failed: %w", err) } - return + return peer, nil } +// serve handles one connection from the listener. func (s *HTTPServer) serve(conn net.Conn) { var rd = bufio.NewReader(conn) req, err := http.ReadRequest(rd) if err != nil { log.Printf("read request failed: %s\n", err) + conn.Close() // ensure StatsConn closes return } - code, err := s.authenticate(req) - if err != nil { + code, authErr := s.authenticate(req) + if authErr != nil { resp := responseWith(req, code) if code == http.StatusProxyAuthRequired { resp.Header.Set("Proxy-Authenticate", "Basic realm=\"Proxy\"") } _ = resp.Write(conn) - log.Println(err) + log.Println(authErr) + conn.Close() // ensure StatsConn closes return } @@ -121,46 +125,41 @@ func (s *HTTPServer) serve(conn net.Conn) { default: _ = responseWith(req, http.StatusMethodNotAllowed).Write(conn) log.Printf("unsupported protocol: %s\n", req.Method) + conn.Close() // ensure StatsConn closes return } + if err != nil { log.Printf("dial proxy failed: %s\n", err) + conn.Close() // ensure StatsConn closes return } if peer == nil { log.Println("dial proxy failed: peer nil") + conn.Close() // ensure StatsConn closes return } go func() { wg := conc.NewWaitGroup() wg.Go(func() { - _, err = io.Copy(conn, peer) - _ = conn.Close() + _, _ = io.Copy(conn, peer) + conn.Close() }) wg.Go(func() { - _, err = io.Copy(peer, conn) + _, _ = io.Copy(peer, conn) _ = peer.Close() }) wg.Wait() }() } -// ListenAndServe is used to create a listener and serve on it -func (s *HTTPServer) ListenAndServe(network, addr string) error { - server, err := net.Listen(network, addr) - if err != nil { - return fmt.Errorf("listen tcp failed: %w", err) - } - defer func(server net.Listener) { - _ = server.Close() - }(server) +// Serve runs an accept loop on the given listener. +func (s *HTTPServer) Serve(listener net.Listener) error { for { - conn, err := server.Accept() + conn, err := listener.Accept() if err != nil { return fmt.Errorf("accept request failed: %w", err) } - go func(conn net.Conn) { - s.serve(conn) - }(conn) + go s.serve(conn) } } diff --git a/routine.go b/routine.go index 465e6b1..a8da299 100644 --- a/routine.go +++ b/routine.go @@ -21,6 +21,7 @@ import ( "path" "strconv" "strings" + "sync" "time" "github.com/sourcegraph/conc" @@ -35,6 +36,59 @@ import ( // errorLogger is the logger to print error message var errorLogger = log.New(os.Stderr, "ERROR: ", log.LstdFlags) +type ProxyStats struct { + ProxyType string `json:"type"` + BindAddress string `json:"bind_address"` + ActiveConnections int `json:"active_connections"` + LastConnectionTime int64 `json:"last_connection_time"` + TotalConnections int `json:"total_connections"` + + mu sync.Mutex +} + +func (ps *ProxyStats) IncConnection() { + ps.mu.Lock() + defer ps.mu.Unlock() + ps.ActiveConnections++ + ps.TotalConnections++ + ps.LastConnectionTime = time.Now().Unix() +} + +func (ps *ProxyStats) DecConnection() { + ps.mu.Lock() + defer ps.mu.Unlock() + if ps.ActiveConnections > 0 { + ps.ActiveConnections-- + } +} + +type StatsListener struct { + net.Listener + stats *ProxyStats +} + +func (sl *StatsListener) Accept() (net.Conn, error) { + c, err := sl.Listener.Accept() + if err == nil { + sl.stats.IncConnection() + c = &StatsConn{ + Conn: c, + stats: sl.stats, + } + } + return c, err +} + +type StatsConn struct { + net.Conn + stats *ProxyStats +} + +func (sc *StatsConn) Close() error { + sc.stats.DecConnection() + return sc.Conn.Close() +} + // CredentialValidator stores the authentication data of a socks5 proxy type CredentialValidator struct { username string @@ -49,6 +103,8 @@ type VirtualTun struct { Conf *DeviceConfig // PingRecord stores the last time an IP was pinged PingRecord map[string]uint64 + mu sync.Mutex + ProxyList []*ProxyStats } // RoutineSpawner spawns a routine (e.g. socks5, tcp static routes) after the configuration is parsed @@ -63,7 +119,7 @@ type addressPort struct { // LookupAddr lookups a hostname. // DNS traffic may or may not be routed depending on VirtualTun's setting -func (d VirtualTun) LookupAddr(ctx context.Context, name string) ([]string, error) { +func (d *VirtualTun) LookupAddr(ctx context.Context, name string) ([]string, error) { if d.SystemDNS { return net.DefaultResolver.LookupHost(ctx, name) } @@ -72,7 +128,7 @@ func (d VirtualTun) LookupAddr(ctx context.Context, name string) ([]string, erro // ResolveAddrWithContext resolves a hostname and returns an AddrPort. // DNS traffic may or may not be routed depending on VirtualTun's setting -func (d VirtualTun) ResolveAddrWithContext(ctx context.Context, name string) (*netip.Addr, error) { +func (d *VirtualTun) ResolveAddrWithContext(ctx context.Context, name string) (*netip.Addr, error) { addrs, err := d.LookupAddr(ctx, name) if err != nil { return nil, err @@ -104,7 +160,7 @@ func (d VirtualTun) ResolveAddrWithContext(ctx context.Context, name string) (*n // Resolve resolves a hostname and returns an IP. // DNS traffic may or may not be routed depending on VirtualTun's setting -func (d VirtualTun) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { +func (d *VirtualTun) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { addr, err := d.ResolveAddrWithContext(ctx, name) if err != nil { return nil, nil, err @@ -127,7 +183,7 @@ func parseAddressPort(endpoint string) (*addressPort, error) { return &addressPort{address: name, port: uint16(port)}, nil } -func (d VirtualTun) resolveToAddrPort(endpoint *addressPort) (*netip.AddrPort, error) { +func (d *VirtualTun) resolveToAddrPort(endpoint *addressPort) (*netip.AddrPort, error) { addr, err := d.ResolveAddrWithContext(context.Background(), endpoint.address) if err != nil { return nil, err @@ -139,6 +195,12 @@ func (d VirtualTun) resolveToAddrPort(endpoint *addressPort) (*netip.AddrPort, e // SpawnRoutine spawns a socks5 server. func (config *Socks5Config) SpawnRoutine(vt *VirtualTun) { + stats := &ProxyStats{ + ProxyType: "socks5", + BindAddress: config.BindAddress, + } + vt.RegisterProxyStats(stats) + var authMethods []socks5.Authenticator if username := config.Username; username != "" { authMethods = append(authMethods, socks5.UserPassAuthenticator{ @@ -157,13 +219,27 @@ func (config *Socks5Config) SpawnRoutine(vt *VirtualTun) { server := socks5.NewServer(options...) - if err := server.ListenAndServe("tcp", config.BindAddress); err != nil { + ln, err := net.Listen("tcp", config.BindAddress) + if err != nil { log.Fatal(err) } + ln = &StatsListener{Listener: ln, stats: stats} + + go func() { + if err := server.Serve(ln); err != nil { + log.Fatal(err) + } + }() } // SpawnRoutine spawns a http server. func (config *HTTPConfig) SpawnRoutine(vt *VirtualTun) { + stats := &ProxyStats{ + ProxyType: "http", + BindAddress: config.BindAddress, + } + vt.RegisterProxyStats(stats) + server := &HTTPServer{ config: config, dial: vt.Tnet.Dial, @@ -173,9 +249,17 @@ func (config *HTTPConfig) SpawnRoutine(vt *VirtualTun) { server.authRequired = true } - if err := server.ListenAndServe("tcp", config.BindAddress); err != nil { + ln, err := net.Listen("tcp", config.BindAddress) + if err != nil { log.Fatal(err) } + ln = &StatsListener{Listener: ln, stats: stats} + + go func() { + if err := server.Serve(ln); err != nil { + log.Fatal(err) + } + }() } // Valid checks the authentication data in CredentialValidator and compare them @@ -198,7 +282,7 @@ func connForward(from io.ReadWriteCloser, to io.ReadWriteCloser) { func tcpClientForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) { target, err := vt.resolveToAddrPort(raddr) if err != nil { - errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error()) + errorLogger.Printf("TCP Server Tunnel to %s: %s\n", raddr.address, err.Error()) return } @@ -297,7 +381,7 @@ func (conf *STDIOTunnelConfig) SpawnRoutine(vt *VirtualTun) { func tcpServerForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) { target, err := vt.resolveToAddrPort(raddr) if err != nil { - errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error()) + errorLogger.Printf("TCP Server Tunnel to %s: %s\n", raddr.address, err.Error()) return } @@ -347,7 +431,8 @@ func (conf *TCPServerTunnelConfig) SpawnRoutine(vt *VirtualTun) { } } -func (d VirtualTun) ServeHTTP(w http.ResponseWriter, r *http.Request) { +// ServeHTTP is used for health/metrics requests. +func (d *VirtualTun) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Printf("Health metric request: %s\n", r.URL.Path) switch path.Clean(r.URL.Path) { case "/readyz": @@ -396,12 +481,36 @@ func (d VirtualTun) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write(buf.Bytes()) + + case "/stats": + // Return statistics about all running proxies + d.mu.Lock() + snapshot := make([]ProxyStats, len(d.ProxyList)) + for i, ps := range d.ProxyList { + ps.mu.Lock() + snapshot[i] = ProxyStats{ + ProxyType: ps.ProxyType, + BindAddress: ps.BindAddress, + ActiveConnections: ps.ActiveConnections, + LastConnectionTime: ps.LastConnectionTime, + TotalConnections: ps.TotalConnections, + } + ps.mu.Unlock() + } + d.mu.Unlock() + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(snapshot); err != nil { + errorLogger.Printf("Failed to encode /stats: %s", err) + } + default: w.WriteHeader(http.StatusNotFound) } } -func (d VirtualTun) pingIPs() { +// pingIPs pings the IP addresses configured in CheckAlive +func (d *VirtualTun) pingIPs() { for _, addr := range d.Conf.CheckAlive { socket, err := d.Tnet.Dial("ping", addr.String()) if err != nil { @@ -482,7 +591,8 @@ func (d VirtualTun) pingIPs() { } } -func (d VirtualTun) StartPingIPs() { +// StartPingIPs starts a goroutine that periodically pings the IP addresses in CheckAlive +func (d *VirtualTun) StartPingIPs() { for _, addr := range d.Conf.CheckAlive { d.PingRecord[addr.String()] = 0 } @@ -494,3 +604,10 @@ func (d VirtualTun) StartPingIPs() { } }() } + +// RegisterProxyStats is used to store the newly created proxy stats object +func (d *VirtualTun) RegisterProxyStats(ps *ProxyStats) { + d.mu.Lock() + defer d.mu.Unlock() + d.ProxyList = append(d.ProxyList, ps) +}