diff --git a/README.md b/README.md index 5c774e0..c3945b3 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,10 @@ BindAddress = 127.0.0.1:25345 #Username = ... # Avoid using spaces in the password field #Password = ... + +# Specifying certificate and key enables HTTPS +#CertFile = ... +#KeyFile = ... ``` Alternatively, if you already have a wireguard config, you can import it in the diff --git a/config.go b/config.go index 1f6e4e4..cce2520 100644 --- a/config.go +++ b/config.go @@ -57,6 +57,8 @@ type HTTPConfig struct { BindAddress string Username string Password string + CertFile string + KeyFile string } type Configuration struct { @@ -432,6 +434,12 @@ func parseHTTPConfig(section *ini.Section) (RoutineSpawner, error) { password, _ := parseString(section, "Password") config.Password = password + certFile, _ := parseString(section, "CertFile") + config.CertFile = certFile + + keyFile, _ := parseString(section, "KeyFile") + config.KeyFile = keyFile + return config, nil } diff --git a/http.go b/http.go index 88a7ef4..a010d28 100644 --- a/http.go +++ b/http.go @@ -3,6 +3,7 @@ package wireproxy import ( "bufio" "bytes" + "crypto/tls" "encoding/base64" "fmt" "io" @@ -21,6 +22,7 @@ type HTTPServer struct { dial func(network, address string) (net.Conn, error) authRequired bool + tlsRequired bool } func (s *HTTPServer) authenticate(req *http.Request) (int, error) { @@ -145,9 +147,22 @@ func (s *HTTPServer) serve(conn net.Conn) { }() } +func (s *HTTPServer) listen(network, addr string) (net.Listener, error) { + if s.tlsRequired { + cert, err := tls.LoadX509KeyPair(s.config.CertFile, s.config.KeyFile) + if err != nil { + return nil, err + } + + return tls.Listen(network, addr, &tls.Config{Certificates: []tls.Certificate{cert}}) + } + + return net.Listen(network, addr) +} + // ListenAndServe is used to create a listener and serve on it func (s *HTTPServer) ListenAndServe(network, addr string) error { - server, err := net.Listen(network, addr) + server, err := s.listen(network, addr) if err != nil { return fmt.Errorf("listen tcp failed: %w", err) } diff --git a/net.go b/net.go deleted file mode 100644 index 3c89f39..0000000 --- a/net.go +++ /dev/null @@ -1,16 +0,0 @@ -// will delete when upgrading to go 1.18 - -package wireproxy - -import ( - "net" - "net/netip" -) - -func TCPAddrFromAddrPort(addr netip.AddrPort) *net.TCPAddr { - return &net.TCPAddr{ - IP: addr.Addr().AsSlice(), - Zone: addr.Addr().Zone(), - Port: int(addr.Port()), - } -} diff --git a/routine.go b/routine.go index edfc793..8414d8c 100644 --- a/routine.go +++ b/routine.go @@ -174,6 +174,10 @@ func (config *HTTPConfig) SpawnRoutine(vt *VirtualTun) { server.authRequired = true } + if config.CertFile != "" && config.KeyFile != "" { + server.tlsRequired = true + } + if err := server.ListenAndServe("tcp", config.BindAddress); err != nil { log.Fatal(err) } @@ -206,7 +210,7 @@ func tcpClientForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) { return } - tcpAddr := TCPAddrFromAddrPort(*target) + tcpAddr := net.TCPAddrFromAddrPort(*target) sconn, err := vt.Tnet.DialTCP(tcpAddr) if err != nil { @@ -233,7 +237,7 @@ func STDIOTcpForward(vt *VirtualTun, raddr *addressPort) { return } - tcpAddr := TCPAddrFromAddrPort(*target) + tcpAddr := net.TCPAddrFromAddrPort(*target) sconn, err := vt.Tnet.DialTCP(tcpAddr) if err != nil { errorLogger.Printf("TCP Client Tunnel to %s (%s): %s\n", target, tcpAddr, err.Error()) @@ -283,7 +287,7 @@ func tcpServerForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) { return } - tcpAddr := TCPAddrFromAddrPort(*target) + tcpAddr := net.TCPAddrFromAddrPort(*target) sconn, err := net.DialTCP("tcp", nil, tcpAddr) if err != nil {