add proxy stats endpoint

This commit is contained in:
VastBlast 2024-12-25 00:54:22 -05:00
parent 3e6e5a61f0
commit df77f59a2f
2 changed files with 153 additions and 37 deletions

51
http.go
View file

@ -50,6 +50,7 @@ func (s *HTTPServer) authenticate(req *http.Request) (int, error) {
return http.StatusUnauthorized, fmt.Errorf("username and password not matching")
}
// handleConn sets up tunneling for CONNECT requests.
func (s *HTTPServer) handleConn(req *http.Request, conn net.Conn) (peer net.Conn, err error) {
addr := req.Host
if !strings.Contains(addr, ":") {
@ -59,18 +60,19 @@ func (s *HTTPServer) handleConn(req *http.Request, conn net.Conn) (peer net.Conn
peer, err = s.dial("tcp", addr)
if err != nil {
return peer, fmt.Errorf("tun tcp dial failed: %w", err)
return nil, fmt.Errorf("tun tcp dial failed: %w", err)
}
_, err = conn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n"))
if err != nil {
_ = peer.Close()
peer = nil
return nil, err
}
return
return peer, nil
}
// handle handles standard HTTP methods.
func (s *HTTPServer) handle(req *http.Request) (peer net.Conn, err error) {
addr := req.Host
if !strings.Contains(addr, ":") {
@ -80,35 +82,37 @@ func (s *HTTPServer) handle(req *http.Request) (peer net.Conn, err error) {
peer, err = s.dial("tcp", addr)
if err != nil {
return peer, fmt.Errorf("tun tcp dial failed: %w", err)
return nil, fmt.Errorf("tun tcp dial failed: %w", err)
}
err = req.Write(peer)
if err != nil {
_ = peer.Close()
peer = nil
return peer, fmt.Errorf("conn write failed: %w", err)
return nil, fmt.Errorf("conn write failed: %w", err)
}
return
return peer, nil
}
// serve handles one connection from the listener.
func (s *HTTPServer) serve(conn net.Conn) {
var rd = bufio.NewReader(conn)
req, err := http.ReadRequest(rd)
if err != nil {
log.Printf("read request failed: %s\n", err)
conn.Close() // ensure StatsConn closes
return
}
code, err := s.authenticate(req)
if err != nil {
code, authErr := s.authenticate(req)
if authErr != nil {
resp := responseWith(req, code)
if code == http.StatusProxyAuthRequired {
resp.Header.Set("Proxy-Authenticate", "Basic realm=\"Proxy\"")
}
_ = resp.Write(conn)
log.Println(err)
log.Println(authErr)
conn.Close() // ensure StatsConn closes
return
}
@ -121,46 +125,41 @@ func (s *HTTPServer) serve(conn net.Conn) {
default:
_ = responseWith(req, http.StatusMethodNotAllowed).Write(conn)
log.Printf("unsupported protocol: %s\n", req.Method)
conn.Close() // ensure StatsConn closes
return
}
if err != nil {
log.Printf("dial proxy failed: %s\n", err)
conn.Close() // ensure StatsConn closes
return
}
if peer == nil {
log.Println("dial proxy failed: peer nil")
conn.Close() // ensure StatsConn closes
return
}
go func() {
wg := conc.NewWaitGroup()
wg.Go(func() {
_, err = io.Copy(conn, peer)
_ = conn.Close()
_, _ = io.Copy(conn, peer)
conn.Close()
})
wg.Go(func() {
_, err = io.Copy(peer, conn)
_, _ = io.Copy(peer, conn)
_ = peer.Close()
})
wg.Wait()
}()
}
// ListenAndServe is used to create a listener and serve on it
func (s *HTTPServer) ListenAndServe(network, addr string) error {
server, err := net.Listen(network, addr)
if err != nil {
return fmt.Errorf("listen tcp failed: %w", err)
}
defer func(server net.Listener) {
_ = server.Close()
}(server)
// Serve runs an accept loop on the given listener.
func (s *HTTPServer) Serve(listener net.Listener) error {
for {
conn, err := server.Accept()
conn, err := listener.Accept()
if err != nil {
return fmt.Errorf("accept request failed: %w", err)
}
go func(conn net.Conn) {
s.serve(conn)
}(conn)
go s.serve(conn)
}
}

View file

@ -21,6 +21,7 @@ import (
"path"
"strconv"
"strings"
"sync"
"time"
"github.com/sourcegraph/conc"
@ -35,6 +36,59 @@ import (
// errorLogger is the logger to print error message
var errorLogger = log.New(os.Stderr, "ERROR: ", log.LstdFlags)
type ProxyStats struct {
ProxyType string `json:"type"`
BindAddress string `json:"bind_address"`
ActiveConnections int `json:"active_connections"`
LastConnectionTime int64 `json:"last_connection_time"`
TotalConnections int `json:"total_connections"`
mu sync.Mutex
}
func (ps *ProxyStats) IncConnection() {
ps.mu.Lock()
defer ps.mu.Unlock()
ps.ActiveConnections++
ps.TotalConnections++
ps.LastConnectionTime = time.Now().Unix()
}
func (ps *ProxyStats) DecConnection() {
ps.mu.Lock()
defer ps.mu.Unlock()
if ps.ActiveConnections > 0 {
ps.ActiveConnections--
}
}
type StatsListener struct {
net.Listener
stats *ProxyStats
}
func (sl *StatsListener) Accept() (net.Conn, error) {
c, err := sl.Listener.Accept()
if err == nil {
sl.stats.IncConnection()
c = &StatsConn{
Conn: c,
stats: sl.stats,
}
}
return c, err
}
type StatsConn struct {
net.Conn
stats *ProxyStats
}
func (sc *StatsConn) Close() error {
sc.stats.DecConnection()
return sc.Conn.Close()
}
// CredentialValidator stores the authentication data of a socks5 proxy
type CredentialValidator struct {
username string
@ -49,6 +103,8 @@ type VirtualTun struct {
Conf *DeviceConfig
// PingRecord stores the last time an IP was pinged
PingRecord map[string]uint64
mu sync.Mutex
ProxyList []*ProxyStats
}
// RoutineSpawner spawns a routine (e.g. socks5, tcp static routes) after the configuration is parsed
@ -63,7 +119,7 @@ type addressPort struct {
// LookupAddr lookups a hostname.
// DNS traffic may or may not be routed depending on VirtualTun's setting
func (d VirtualTun) LookupAddr(ctx context.Context, name string) ([]string, error) {
func (d *VirtualTun) LookupAddr(ctx context.Context, name string) ([]string, error) {
if d.SystemDNS {
return net.DefaultResolver.LookupHost(ctx, name)
}
@ -72,7 +128,7 @@ func (d VirtualTun) LookupAddr(ctx context.Context, name string) ([]string, erro
// ResolveAddrWithContext resolves a hostname and returns an AddrPort.
// DNS traffic may or may not be routed depending on VirtualTun's setting
func (d VirtualTun) ResolveAddrWithContext(ctx context.Context, name string) (*netip.Addr, error) {
func (d *VirtualTun) ResolveAddrWithContext(ctx context.Context, name string) (*netip.Addr, error) {
addrs, err := d.LookupAddr(ctx, name)
if err != nil {
return nil, err
@ -104,7 +160,7 @@ func (d VirtualTun) ResolveAddrWithContext(ctx context.Context, name string) (*n
// Resolve resolves a hostname and returns an IP.
// DNS traffic may or may not be routed depending on VirtualTun's setting
func (d VirtualTun) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
func (d *VirtualTun) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
addr, err := d.ResolveAddrWithContext(ctx, name)
if err != nil {
return nil, nil, err
@ -127,7 +183,7 @@ func parseAddressPort(endpoint string) (*addressPort, error) {
return &addressPort{address: name, port: uint16(port)}, nil
}
func (d VirtualTun) resolveToAddrPort(endpoint *addressPort) (*netip.AddrPort, error) {
func (d *VirtualTun) resolveToAddrPort(endpoint *addressPort) (*netip.AddrPort, error) {
addr, err := d.ResolveAddrWithContext(context.Background(), endpoint.address)
if err != nil {
return nil, err
@ -139,6 +195,12 @@ func (d VirtualTun) resolveToAddrPort(endpoint *addressPort) (*netip.AddrPort, e
// SpawnRoutine spawns a socks5 server.
func (config *Socks5Config) SpawnRoutine(vt *VirtualTun) {
stats := &ProxyStats{
ProxyType: "socks5",
BindAddress: config.BindAddress,
}
vt.RegisterProxyStats(stats)
var authMethods []socks5.Authenticator
if username := config.Username; username != "" {
authMethods = append(authMethods, socks5.UserPassAuthenticator{
@ -157,13 +219,27 @@ func (config *Socks5Config) SpawnRoutine(vt *VirtualTun) {
server := socks5.NewServer(options...)
if err := server.ListenAndServe("tcp", config.BindAddress); err != nil {
ln, err := net.Listen("tcp", config.BindAddress)
if err != nil {
log.Fatal(err)
}
ln = &StatsListener{Listener: ln, stats: stats}
go func() {
if err := server.Serve(ln); err != nil {
log.Fatal(err)
}
}()
}
// SpawnRoutine spawns a http server.
func (config *HTTPConfig) SpawnRoutine(vt *VirtualTun) {
stats := &ProxyStats{
ProxyType: "http",
BindAddress: config.BindAddress,
}
vt.RegisterProxyStats(stats)
server := &HTTPServer{
config: config,
dial: vt.Tnet.Dial,
@ -173,9 +249,17 @@ func (config *HTTPConfig) SpawnRoutine(vt *VirtualTun) {
server.authRequired = true
}
if err := server.ListenAndServe("tcp", config.BindAddress); err != nil {
ln, err := net.Listen("tcp", config.BindAddress)
if err != nil {
log.Fatal(err)
}
ln = &StatsListener{Listener: ln, stats: stats}
go func() {
if err := server.Serve(ln); err != nil {
log.Fatal(err)
}
}()
}
// Valid checks the authentication data in CredentialValidator and compare them
@ -198,7 +282,7 @@ func connForward(from io.ReadWriteCloser, to io.ReadWriteCloser) {
func tcpClientForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) {
target, err := vt.resolveToAddrPort(raddr)
if err != nil {
errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error())
errorLogger.Printf("TCP Server Tunnel to %s: %s\n", raddr.address, err.Error())
return
}
@ -297,7 +381,7 @@ func (conf *STDIOTunnelConfig) SpawnRoutine(vt *VirtualTun) {
func tcpServerForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) {
target, err := vt.resolveToAddrPort(raddr)
if err != nil {
errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error())
errorLogger.Printf("TCP Server Tunnel to %s: %s\n", raddr.address, err.Error())
return
}
@ -347,7 +431,8 @@ func (conf *TCPServerTunnelConfig) SpawnRoutine(vt *VirtualTun) {
}
}
func (d VirtualTun) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// ServeHTTP is used for health/metrics requests.
func (d *VirtualTun) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Printf("Health metric request: %s\n", r.URL.Path)
switch path.Clean(r.URL.Path) {
case "/readyz":
@ -396,12 +481,36 @@ func (d VirtualTun) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(buf.Bytes())
case "/stats":
// Return statistics about all running proxies
d.mu.Lock()
snapshot := make([]ProxyStats, len(d.ProxyList))
for i, ps := range d.ProxyList {
ps.mu.Lock()
snapshot[i] = ProxyStats{
ProxyType: ps.ProxyType,
BindAddress: ps.BindAddress,
ActiveConnections: ps.ActiveConnections,
LastConnectionTime: ps.LastConnectionTime,
TotalConnections: ps.TotalConnections,
}
ps.mu.Unlock()
}
d.mu.Unlock()
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(snapshot); err != nil {
errorLogger.Printf("Failed to encode /stats: %s", err)
}
default:
w.WriteHeader(http.StatusNotFound)
}
}
func (d VirtualTun) pingIPs() {
// pingIPs pings the IP addresses configured in CheckAlive
func (d *VirtualTun) pingIPs() {
for _, addr := range d.Conf.CheckAlive {
socket, err := d.Tnet.Dial("ping", addr.String())
if err != nil {
@ -482,7 +591,8 @@ func (d VirtualTun) pingIPs() {
}
}
func (d VirtualTun) StartPingIPs() {
// StartPingIPs starts a goroutine that periodically pings the IP addresses in CheckAlive
func (d *VirtualTun) StartPingIPs() {
for _, addr := range d.Conf.CheckAlive {
d.PingRecord[addr.String()] = 0
}
@ -494,3 +604,10 @@ func (d VirtualTun) StartPingIPs() {
}
}()
}
// RegisterProxyStats is used to store the newly created proxy stats object
func (d *VirtualTun) RegisterProxyStats(ps *ProxyStats) {
d.mu.Lock()
defer d.mu.Unlock()
d.ProxyList = append(d.ProxyList, ps)
}