From a21bd62350e099686c3840dbd4b172555de0b0d8 Mon Sep 17 00:00:00 2001 From: Dmitry Pankov Date: Wed, 21 Aug 2024 23:39:12 +0300 Subject: [PATCH 1/2] Add support for HTTPS --- README.md | 4 ++++ config.go | 8 ++++++++ http.go | 17 ++++++++++++++++- routine.go | 4 ++++ 4 files changed, 32 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 06da167..b4e392f 100644 --- a/README.md +++ b/README.md @@ -131,6 +131,10 @@ BindAddress = 127.0.0.1:25345 #Username = ... # Avoid using spaces in the password field #Password = ... + +# Specifying certificate and key enables HTTPS +#CertFile = ... +#KeyFile = ... ``` Alternatively, if you already have a wireguard config, you can import it in the diff --git a/config.go b/config.go index b1aba15..811d36b 100644 --- a/config.go +++ b/config.go @@ -57,6 +57,8 @@ type HTTPConfig struct { BindAddress string Username string Password string + CertFile string + KeyFile string } type Configuration struct { @@ -431,6 +433,12 @@ func parseHTTPConfig(section *ini.Section) (RoutineSpawner, error) { password, _ := parseString(section, "Password") config.Password = password + certFile, _ := parseString(section, "CertFile") + config.CertFile = certFile + + keyFile, _ := parseString(section, "KeyFile") + config.KeyFile = keyFile + return config, nil } diff --git a/http.go b/http.go index 9fa7932..71e5668 100644 --- a/http.go +++ b/http.go @@ -3,6 +3,7 @@ package wireproxy import ( "bufio" "bytes" + "crypto/tls" "encoding/base64" "fmt" "io" @@ -23,6 +24,7 @@ type HTTPServer struct { dial func(network, address string) (net.Conn, error) authRequired bool + tlsRequired bool } func (s *HTTPServer) authenticate(req *http.Request) (int, error) { @@ -141,9 +143,22 @@ func (s *HTTPServer) serve(conn net.Conn) { }() } +func (s *HTTPServer) listen(network, addr string) (net.Listener, error) { + if s.tlsRequired { + cert, err := tls.LoadX509KeyPair(s.config.CertFile, s.config.KeyFile) + if err != nil { + return nil, err + } + + return tls.Listen(network, addr, &tls.Config{Certificates: []tls.Certificate{cert}}) + } + + return net.Listen(network, addr) +} + // ListenAndServe is used to create a listener and serve on it func (s *HTTPServer) ListenAndServe(network, addr string) error { - server, err := net.Listen(network, addr) + server, err := s.listen(network, addr) if err != nil { return fmt.Errorf("listen tcp failed: %w", err) } diff --git a/routine.go b/routine.go index 465e6b1..eba9fde 100644 --- a/routine.go +++ b/routine.go @@ -173,6 +173,10 @@ func (config *HTTPConfig) SpawnRoutine(vt *VirtualTun) { server.authRequired = true } + if config.CertFile != "" && config.KeyFile != "" { + server.tlsRequired = true + } + if err := server.ListenAndServe("tcp", config.BindAddress); err != nil { log.Fatal(err) } From 6bd3c7443445a6951a79e7534db357aac743e395 Mon Sep 17 00:00:00 2001 From: Dmitry Pankov Date: Thu, 22 Aug 2024 00:56:46 +0300 Subject: [PATCH 2/2] Replace TCPAddrFromAddrPort with built-in one --- net.go | 16 ---------------- routine.go | 6 +++--- 2 files changed, 3 insertions(+), 19 deletions(-) delete mode 100644 net.go diff --git a/net.go b/net.go deleted file mode 100644 index 3c89f39..0000000 --- a/net.go +++ /dev/null @@ -1,16 +0,0 @@ -// will delete when upgrading to go 1.18 - -package wireproxy - -import ( - "net" - "net/netip" -) - -func TCPAddrFromAddrPort(addr netip.AddrPort) *net.TCPAddr { - return &net.TCPAddr{ - IP: addr.Addr().AsSlice(), - Zone: addr.Addr().Zone(), - Port: int(addr.Port()), - } -} diff --git a/routine.go b/routine.go index eba9fde..e9d448d 100644 --- a/routine.go +++ b/routine.go @@ -206,7 +206,7 @@ func tcpClientForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) { return } - tcpAddr := TCPAddrFromAddrPort(*target) + tcpAddr := net.TCPAddrFromAddrPort(*target) sconn, err := vt.Tnet.DialTCP(tcpAddr) if err != nil { @@ -245,7 +245,7 @@ func STDIOTcpForward(vt *VirtualTun, raddr *addressPort) { return } - tcpAddr := TCPAddrFromAddrPort(*target) + tcpAddr := net.TCPAddrFromAddrPort(*target) sconn, err := vt.Tnet.DialTCP(tcpAddr) if err != nil { errorLogger.Printf("TCP Client Tunnel to %s (%s): %s\n", target, tcpAddr, err.Error()) @@ -305,7 +305,7 @@ func tcpServerForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) { return } - tcpAddr := TCPAddrFromAddrPort(*target) + tcpAddr := net.TCPAddrFromAddrPort(*target) sconn, err := net.DialTCP("tcp", nil, tcpAddr) if err != nil {