This commit is contained in:
Dmitry Pankov 2025-02-20 21:57:39 -05:00 committed by GitHub
commit 988155602b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 35 additions and 20 deletions

View file

@ -152,6 +152,10 @@ BindAddress = 127.0.0.1:25345
#Username = ... #Username = ...
# Avoid using spaces in the password field # Avoid using spaces in the password field
#Password = ... #Password = ...
# Specifying certificate and key enables HTTPS
#CertFile = ...
#KeyFile = ...
``` ```
Alternatively, if you already have a wireguard config, you can import it in the Alternatively, if you already have a wireguard config, you can import it in the

View file

@ -57,6 +57,8 @@ type HTTPConfig struct {
BindAddress string BindAddress string
Username string Username string
Password string Password string
CertFile string
KeyFile string
} }
type Configuration struct { type Configuration struct {
@ -432,6 +434,12 @@ func parseHTTPConfig(section *ini.Section) (RoutineSpawner, error) {
password, _ := parseString(section, "Password") password, _ := parseString(section, "Password")
config.Password = password config.Password = password
certFile, _ := parseString(section, "CertFile")
config.CertFile = certFile
keyFile, _ := parseString(section, "KeyFile")
config.KeyFile = keyFile
return config, nil return config, nil
} }

17
http.go
View file

@ -3,6 +3,7 @@ package wireproxy
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"crypto/tls"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io" "io"
@ -21,6 +22,7 @@ type HTTPServer struct {
dial func(network, address string) (net.Conn, error) dial func(network, address string) (net.Conn, error)
authRequired bool authRequired bool
tlsRequired bool
} }
func (s *HTTPServer) authenticate(req *http.Request) (int, error) { func (s *HTTPServer) authenticate(req *http.Request) (int, error) {
@ -145,9 +147,22 @@ func (s *HTTPServer) serve(conn net.Conn) {
}() }()
} }
func (s *HTTPServer) listen(network, addr string) (net.Listener, error) {
if s.tlsRequired {
cert, err := tls.LoadX509KeyPair(s.config.CertFile, s.config.KeyFile)
if err != nil {
return nil, err
}
return tls.Listen(network, addr, &tls.Config{Certificates: []tls.Certificate{cert}})
}
return net.Listen(network, addr)
}
// ListenAndServe is used to create a listener and serve on it // ListenAndServe is used to create a listener and serve on it
func (s *HTTPServer) ListenAndServe(network, addr string) error { func (s *HTTPServer) ListenAndServe(network, addr string) error {
server, err := net.Listen(network, addr) server, err := s.listen(network, addr)
if err != nil { if err != nil {
return fmt.Errorf("listen tcp failed: %w", err) return fmt.Errorf("listen tcp failed: %w", err)
} }

16
net.go
View file

@ -1,16 +0,0 @@
// will delete when upgrading to go 1.18
package wireproxy
import (
"net"
"net/netip"
)
func TCPAddrFromAddrPort(addr netip.AddrPort) *net.TCPAddr {
return &net.TCPAddr{
IP: addr.Addr().AsSlice(),
Zone: addr.Addr().Zone(),
Port: int(addr.Port()),
}
}

View file

@ -174,6 +174,10 @@ func (config *HTTPConfig) SpawnRoutine(vt *VirtualTun) {
server.authRequired = true server.authRequired = true
} }
if config.CertFile != "" && config.KeyFile != "" {
server.tlsRequired = true
}
if err := server.ListenAndServe("tcp", config.BindAddress); err != nil { if err := server.ListenAndServe("tcp", config.BindAddress); err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -206,7 +210,7 @@ func tcpClientForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) {
return return
} }
tcpAddr := TCPAddrFromAddrPort(*target) tcpAddr := net.TCPAddrFromAddrPort(*target)
sconn, err := vt.Tnet.DialTCP(tcpAddr) sconn, err := vt.Tnet.DialTCP(tcpAddr)
if err != nil { if err != nil {
@ -233,7 +237,7 @@ func STDIOTcpForward(vt *VirtualTun, raddr *addressPort) {
return return
} }
tcpAddr := TCPAddrFromAddrPort(*target) tcpAddr := net.TCPAddrFromAddrPort(*target)
sconn, err := vt.Tnet.DialTCP(tcpAddr) sconn, err := vt.Tnet.DialTCP(tcpAddr)
if err != nil { if err != nil {
errorLogger.Printf("TCP Client Tunnel to %s (%s): %s\n", target, tcpAddr, err.Error()) errorLogger.Printf("TCP Client Tunnel to %s (%s): %s\n", target, tcpAddr, err.Error())
@ -283,7 +287,7 @@ func tcpServerForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) {
return return
} }
tcpAddr := TCPAddrFromAddrPort(*target) tcpAddr := net.TCPAddrFromAddrPort(*target)
sconn, err := net.DialTCP("tcp", nil, tcpAddr) sconn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil { if err != nil {