mirror of
https://github.com/yusing/godoxy.git
synced 2026-03-21 16:49:03 +01:00
panel apperance, added experimental tcp/udp proxy support, slight performance improvement for http proxy
This commit is contained in:
@@ -1,14 +1,13 @@
|
||||
package go_proxy
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
mapset "github.com/deckarep/golang-set/v2"
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/client"
|
||||
@@ -17,22 +16,21 @@ import (
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
type ProxyConfig struct {
|
||||
Alias string
|
||||
Scheme string
|
||||
Host string
|
||||
Port string
|
||||
Path string
|
||||
Path string // http proxy only
|
||||
}
|
||||
|
||||
type Route struct {
|
||||
Url *url.URL
|
||||
Path string
|
||||
func NewProxyConfig() ProxyConfig {
|
||||
return ProxyConfig{}
|
||||
}
|
||||
|
||||
var dockerClient *client.Client
|
||||
var subdomainRouteMap map[string]mapset.Set[Route] // subdomain -> path
|
||||
|
||||
func buildContainerCfg(container types.Container) {
|
||||
func buildContainerRoute(container types.Container) {
|
||||
var aliases []string
|
||||
|
||||
container_name := strings.TrimPrefix(container.Names[0], "/")
|
||||
@@ -44,7 +42,7 @@ func buildContainerCfg(container types.Container) {
|
||||
}
|
||||
|
||||
for _, alias := range aliases {
|
||||
config := NewConfig()
|
||||
config := NewProxyConfig()
|
||||
prefix := fmt.Sprintf("proxy.%s.", alias)
|
||||
for label, value := range container.Labels {
|
||||
if strings.HasPrefix(label, prefix) {
|
||||
@@ -76,11 +74,22 @@ func buildContainerCfg(container types.Container) {
|
||||
if config.Scheme == "" {
|
||||
if strings.HasSuffix(config.Port, "443") {
|
||||
config.Scheme = "https"
|
||||
} else {
|
||||
} else if strings.HasPrefix(container.Image, "sha256:") {
|
||||
config.Scheme = "http"
|
||||
} else {
|
||||
imageSplit := strings.Split(container.Image, "/")
|
||||
imageSplit = strings.Split(imageSplit[len(imageSplit)-1], ":")
|
||||
imageName := imageSplit[0]
|
||||
_, isKnownImage := imageNamePortMap[imageName]
|
||||
if isKnownImage {
|
||||
log.Printf("[Build] Known image '%s' detected for %s", imageName, container_name)
|
||||
config.Scheme = "tcp"
|
||||
} else {
|
||||
config.Scheme = "http"
|
||||
}
|
||||
}
|
||||
}
|
||||
if config.Scheme != "http" && config.Scheme != "https" {
|
||||
if !isValidScheme(config.Scheme) {
|
||||
log.Printf("%s: unsupported scheme: %s, using http", container_name, config.Scheme)
|
||||
config.Scheme = "http"
|
||||
}
|
||||
@@ -91,26 +100,39 @@ func buildContainerCfg(container types.Container) {
|
||||
config.Host = "host.docker.internal"
|
||||
}
|
||||
}
|
||||
_, inMap := subdomainRouteMap[alias]
|
||||
if !inMap {
|
||||
subdomainRouteMap[alias] = mapset.NewSet[Route]()
|
||||
}
|
||||
url, err := url.Parse(fmt.Sprintf("%s://%s:%s", config.Scheme, config.Host, config.Port))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
subdomainRouteMap[alias].Add(Route{Url: url, Path: config.Path})
|
||||
config.Alias = alias
|
||||
createProxy(config)
|
||||
}
|
||||
}
|
||||
|
||||
func buildRoutes() {
|
||||
subdomainRouteMap = make(map[string]mapset.Set[Route])
|
||||
initProxyMaps()
|
||||
containerSlice, err := dockerClient.ContainerList(context.Background(), container.ListOptions{})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
for _, container := range containerSlice {
|
||||
buildContainerCfg(container)
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
hostname = "go-proxy"
|
||||
}
|
||||
for _, container := range containerSlice {
|
||||
if container.Names[0] == hostname { // skip self
|
||||
continue
|
||||
}
|
||||
buildContainerRoute(container)
|
||||
}
|
||||
subdomainRouteMap["go-proxy"] = panelRoute
|
||||
}
|
||||
|
||||
func findHTTPRoute(host string, path string) (*HTTPRoute, error) {
|
||||
subdomain := strings.Split(host, ".")[0]
|
||||
routeMap, ok := routes.HTTPRoutes[subdomain]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no matching route for subdomain %s", subdomain)
|
||||
}
|
||||
for _, route := range routeMap {
|
||||
if strings.HasPrefix(path, route.Path) {
|
||||
return &route, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no matching route for path %s for subdomain %s", path, subdomain)
|
||||
}
|
||||
|
||||
32
src/go-proxy/healthcheck.go
Normal file
32
src/go-proxy/healthcheck.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
func healthCheckHttp(targetUrl string) error {
|
||||
// try HEAD first
|
||||
// if HEAD is not allowed, try GET
|
||||
resp, err := healthCheckHttpClient.Head(targetUrl)
|
||||
if resp != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
if err != nil && resp != nil && resp.StatusCode == http.StatusMethodNotAllowed {
|
||||
_, err = healthCheckHttpClient.Get(targetUrl)
|
||||
}
|
||||
if resp != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func healthCheckStream(scheme string, host string) error {
|
||||
conn, err := net.DialTimeout(scheme, host, 5*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
return nil
|
||||
}
|
||||
66
src/go-proxy/http_proxy.go
Executable file
66
src/go-proxy/http_proxy.go
Executable file
@@ -0,0 +1,66 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
type HTTPRoute struct {
|
||||
Url *url.URL
|
||||
Path string
|
||||
Proxy *httputil.ReverseProxy
|
||||
}
|
||||
|
||||
// TODO: default + per proxy
|
||||
var transport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 60 * time.Second,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}).DialContext,
|
||||
MaxIdleConns: 1000,
|
||||
MaxIdleConnsPerHost: 1000,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
}
|
||||
|
||||
func NewHTTPRoute(Url *url.URL, Path string) HTTPRoute {
|
||||
proxy := httputil.NewSingleHostReverseProxy(Url)
|
||||
proxy.Transport = transport
|
||||
return HTTPRoute{Url: Url, Path: Path, Proxy: proxy}
|
||||
}
|
||||
|
||||
func redirectToTLS(w http.ResponseWriter, r *http.Request) {
|
||||
// Redirect to the same host but with HTTPS
|
||||
log.Printf("[Redirect] redirecting to https")
|
||||
var redirectCode int
|
||||
if r.Method == http.MethodGet {
|
||||
redirectCode = http.StatusMovedPermanently
|
||||
} else {
|
||||
redirectCode = http.StatusPermanentRedirect
|
||||
}
|
||||
http.Redirect(w, r, fmt.Sprintf("https://%s%s?%s", r.Host, r.URL.Path, r.URL.RawQuery), redirectCode)
|
||||
}
|
||||
|
||||
func httpProxyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
route, err := findHTTPRoute(r.Host, r.URL.Path)
|
||||
if err != nil {
|
||||
log.Printf("[Request] failed %s %s%s, error: %v",
|
||||
r.Method,
|
||||
r.Host,
|
||||
r.URL.Path,
|
||||
err,
|
||||
)
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
route.Proxy.ServeHTTP(w, r)
|
||||
}
|
||||
81
src/go-proxy/main.go
Executable file → Normal file
81
src/go-proxy/main.go
Executable file → Normal file
@@ -1,46 +1,16 @@
|
||||
package go_proxy
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/docker/docker/api/types/filters"
|
||||
"github.com/docker/docker/client"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
mapset "github.com/deckarep/golang-set/v2"
|
||||
)
|
||||
|
||||
var panelRoute = mapset.NewSet(Route{Url: &url.URL{Scheme: "http", Host: "localhost:81", Path: "/"}, Path: "/"})
|
||||
|
||||
// TODO: default + per proxy
|
||||
var transport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 60 * time.Second,
|
||||
KeepAlive: 60 * time.Second,
|
||||
DualStack: true,
|
||||
}).DialContext,
|
||||
MaxIdleConns: 1000,
|
||||
MaxIdleConnsPerHost: 1000,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
func NewConfig() Config {
|
||||
return Config{Scheme: "", Host: "", Port: "", Path: ""}
|
||||
}
|
||||
|
||||
func main() {
|
||||
var err error
|
||||
runtime.GOMAXPROCS(runtime.NumCPU())
|
||||
@@ -49,6 +19,9 @@ func main() {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
buildRoutes()
|
||||
log.Printf("[Build] built %v reverse proxies", countProxies())
|
||||
|
||||
go func() {
|
||||
filter := filters.NewArgs(
|
||||
filters.Arg("type", "container"),
|
||||
@@ -62,15 +35,12 @@ func main() {
|
||||
// TODO: handle actor only
|
||||
log.Printf("[Event] %s %s caused rebuild", msg.Action, msg.Actor.Attributes["name"])
|
||||
buildRoutes()
|
||||
log.Printf("[Build] rebuilt %v reverse proxies", len(subdomainRouteMap))
|
||||
log.Printf("[Build] rebuilt %v reverse proxies", countProxies())
|
||||
}
|
||||
}()
|
||||
|
||||
buildRoutes()
|
||||
log.Printf("[Build] built %v reverse proxies", len(subdomainRouteMap))
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", handler)
|
||||
mux.HandleFunc("/", httpProxyHandler)
|
||||
|
||||
go func() {
|
||||
log.Println("Starting HTTP server on port 80")
|
||||
@@ -80,8 +50,8 @@ func main() {
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
log.Println("Starting HTTP panel on port 81")
|
||||
err := http.ListenAndServe(":81", http.HandlerFunc(panelHandler))
|
||||
log.Println("Starting HTTPS panel on port 8443")
|
||||
err := http.ListenAndServeTLS(":8443", "/certs/cert.crt", "/certs/priv.key", http.HandlerFunc(panelHandler))
|
||||
if err != nil {
|
||||
log.Fatal("HTTP server error", err)
|
||||
}
|
||||
@@ -92,38 +62,3 @@ func main() {
|
||||
log.Fatal("HTTPS Server error: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func redirectToTLS(w http.ResponseWriter, r *http.Request) {
|
||||
// Redirect to the same host but with HTTPS
|
||||
log.Printf("[Redirect] redirecting to https")
|
||||
var redirectCode int
|
||||
if r.Method == http.MethodGet {
|
||||
redirectCode = http.StatusMovedPermanently
|
||||
} else {
|
||||
redirectCode = http.StatusPermanentRedirect
|
||||
}
|
||||
http.Redirect(w, r, fmt.Sprintf("https://%s%s?%s", r.Host, r.URL.Path, r.URL.RawQuery), redirectCode)
|
||||
}
|
||||
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("[Request] %s %s", r.Method, r.URL.String())
|
||||
subdomain := strings.Split(r.Host, ".")[0]
|
||||
routeMap, ok := subdomainRouteMap[subdomain]
|
||||
if !ok {
|
||||
http.Error(w, fmt.Sprintf("no matching route for subdomain %s", subdomain), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
for route := range routeMap.Iter() {
|
||||
if strings.HasPrefix(r.URL.Path, route.Path) {
|
||||
realPath := strings.TrimPrefix(r.URL.Path, route.Path)
|
||||
origHost := r.Host
|
||||
r.URL.Path = realPath
|
||||
log.Printf("[Route] %s -> %s%s ", origHost, route.Url.String(), route.Path)
|
||||
proxyServer := httputil.NewSingleHostReverseProxy(route.Url)
|
||||
proxyServer.Transport = transport
|
||||
proxyServer.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
http.Error(w, fmt.Sprintf("no matching route for path %s for subdomain %s", r.URL.Path, subdomain), http.StatusNotFound)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package go_proxy
|
||||
package main
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -11,7 +14,13 @@ const templateFile = "/app/templates/panel.html"
|
||||
var healthCheckHttpClient = &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DisableKeepAlives: true,
|
||||
ForceAttemptHTTP2: true,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
KeepAlive: 5 * time.Second,
|
||||
}).DialContext,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -42,7 +51,7 @@ func panelIndex(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = tmpl.Execute(w, subdomainRouteMap)
|
||||
err = tmpl.Execute(w, routes)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -61,20 +70,23 @@ func panelCheckTargetHealth(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// try HEAD first
|
||||
// if HEAD is not allowed, try GET
|
||||
resp, err := healthCheckHttpClient.Head(targetUrl)
|
||||
if err != nil && resp != nil && resp.StatusCode == http.StatusMethodNotAllowed {
|
||||
_, err = healthCheckHttpClient.Get(targetUrl)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
}
|
||||
url, err := url.Parse(targetUrl)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
log.Printf("[Panel] failed to parse %s, error: %v", targetUrl, err)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
scheme := url.Scheme
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if isStreamScheme(scheme) {
|
||||
err = healthCheckStream(scheme, url.Host)
|
||||
} else {
|
||||
err = healthCheckHttp(targetUrl)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
120
src/go-proxy/proxy.go
Normal file
120
src/go-proxy/proxy.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/url"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Routes struct {
|
||||
HTTPRoutes map[string][]HTTPRoute // subdomain/alias -> path
|
||||
StreamRoutes map[string]*StreamRoute // port -> target
|
||||
}
|
||||
|
||||
var routes = Routes{
|
||||
HTTPRoutes: make(map[string][]HTTPRoute),
|
||||
StreamRoutes: make(map[string]*StreamRoute),
|
||||
}
|
||||
var routesMutex = sync.Mutex{}
|
||||
|
||||
var streamSchemes = []string{"tcp", "udp"} // TODO: support "tcp:udp", "udp:tcp"
|
||||
var httpSchemes = []string{"http", "https"}
|
||||
|
||||
var validSchemes = append(streamSchemes, httpSchemes...)
|
||||
|
||||
var lastFreePort int
|
||||
|
||||
|
||||
func isValidScheme(scheme string) bool {
|
||||
for _, v := range validSchemes {
|
||||
if v == scheme {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isStreamScheme(scheme string) bool {
|
||||
for _, v := range streamSchemes {
|
||||
if v == scheme {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func initProxyMaps() {
|
||||
routesMutex.Lock()
|
||||
defer routesMutex.Unlock()
|
||||
|
||||
lastFreePort = 20000
|
||||
oldStreamRoutes := routes.StreamRoutes
|
||||
routes.StreamRoutes = make(map[string]*StreamRoute)
|
||||
routes.HTTPRoutes = make(map[string][]HTTPRoute)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(oldStreamRoutes))
|
||||
defer wg.Wait()
|
||||
|
||||
for _, route := range oldStreamRoutes {
|
||||
go func(r *StreamRoute) {
|
||||
r.Cancel()
|
||||
wg.Done()
|
||||
}(route)
|
||||
}
|
||||
}
|
||||
|
||||
func countProxies() int {
|
||||
return len(routes.HTTPRoutes) + len(routes.StreamRoutes)
|
||||
}
|
||||
|
||||
func createProxy(config ProxyConfig) {
|
||||
if isStreamScheme(config.Scheme) {
|
||||
_, inMap := routes.StreamRoutes[config.Port]
|
||||
if inMap {
|
||||
log.Printf("[Build] Duplicated stream :%s, ignoring", config.Port)
|
||||
return
|
||||
}
|
||||
route, err := NewStreamRoute(config)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
routes.StreamRoutes[config.Port] = route
|
||||
go route.listenStream()
|
||||
} else {
|
||||
_, inMap := routes.HTTPRoutes[config.Alias]
|
||||
if !inMap {
|
||||
routes.HTTPRoutes[config.Alias] = make([]HTTPRoute, 0)
|
||||
}
|
||||
url, err := url.Parse(fmt.Sprintf("%s://%s:%s", config.Scheme, config.Host, config.Port))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
routes.HTTPRoutes[config.Alias] = append(routes.HTTPRoutes[config.Alias], NewHTTPRoute(url, config.Path))
|
||||
}
|
||||
}
|
||||
|
||||
func findFreePort() (int, error) {
|
||||
var portStr string
|
||||
var l net.Listener
|
||||
var err error = nil
|
||||
|
||||
for lastFreePort <= 21000 {
|
||||
portStr = fmt.Sprintf(":%d", lastFreePort)
|
||||
l, err = net.Listen("tcp", portStr)
|
||||
lastFreePort++
|
||||
if err != nil {
|
||||
l.Close()
|
||||
return lastFreePort, nil
|
||||
}
|
||||
}
|
||||
l, err = net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("unable to find free port: %v", err)
|
||||
}
|
||||
// NOTE: may not be after 20000
|
||||
return l.Addr().(*net.TCPAddr).Port, nil
|
||||
}
|
||||
177
src/go-proxy/stream.go
Normal file
177
src/go-proxy/stream.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type StreamRoute struct {
|
||||
Alias string // to show in panel
|
||||
Type string
|
||||
ListeningScheme string
|
||||
ListeningPort string
|
||||
TargetScheme string
|
||||
TargetHost string
|
||||
TargetPort string
|
||||
|
||||
Context context.Context
|
||||
Cancel context.CancelFunc
|
||||
}
|
||||
|
||||
var imageNamePortMap = map[string]string{
|
||||
"postgres": "5432",
|
||||
"mysql": "3306",
|
||||
"mariadb": "3306",
|
||||
"redis": "6379",
|
||||
"mssql": "1433",
|
||||
"memcached": "11211",
|
||||
"rabbitmq": "5672",
|
||||
}
|
||||
var extraNamePortMap = map[string]string{
|
||||
"dns": "53",
|
||||
"ssh": "22",
|
||||
"ftp": "21",
|
||||
"smtp": "25",
|
||||
"pop3": "110",
|
||||
"imap": "143",
|
||||
}
|
||||
var namePortMap = func() map[string]string {
|
||||
m := make(map[string]string)
|
||||
for k, v := range imageNamePortMap {
|
||||
m[k] = v
|
||||
}
|
||||
for k, v := range extraNamePortMap {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}()
|
||||
|
||||
const UDPStreamType = "udp"
|
||||
const TCPStreamType = "tcp"
|
||||
|
||||
func NewStreamRoute(config ProxyConfig) (*StreamRoute, error) {
|
||||
port_split := strings.Split(config.Port, ":")
|
||||
|
||||
var streamType string = TCPStreamType
|
||||
var srcPort string
|
||||
var dstPort string
|
||||
var srcScheme string
|
||||
var dstScheme string
|
||||
var srcUDPAddr *net.UDPAddr = nil
|
||||
var dstUDPAddr *net.UDPAddr = nil
|
||||
|
||||
if len(port_split) != 2 {
|
||||
warnMsg := fmt.Sprintf(`[Build] Invalid stream port %s, `+
|
||||
`should be <listeningPort>:<targetPort>`, config.Port)
|
||||
freePort, err := findFreePort()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s and %s", warnMsg, err)
|
||||
}
|
||||
srcPort = fmt.Sprintf("%d", freePort)
|
||||
dstPort = config.Port
|
||||
fmt.Printf(`%s, assuming %s is targetPort and `+
|
||||
`using free port %s as listeningPort`,
|
||||
warnMsg,
|
||||
srcPort,
|
||||
dstPort,
|
||||
)
|
||||
} else {
|
||||
srcPort = port_split[0]
|
||||
dstPort = port_split[1]
|
||||
}
|
||||
|
||||
port, hasName := namePortMap[dstPort]
|
||||
if hasName {
|
||||
dstPort = port
|
||||
}
|
||||
_, err := strconv.Atoi(dstPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"[Build] Unrecognized stream target port %s, ignoring",
|
||||
dstPort,
|
||||
)
|
||||
}
|
||||
|
||||
scheme_split := strings.Split(config.Scheme, ":")
|
||||
|
||||
if len(scheme_split) == 2 {
|
||||
srcScheme = scheme_split[0]
|
||||
dstScheme = scheme_split[1]
|
||||
} else {
|
||||
srcScheme = config.Scheme
|
||||
dstScheme = config.Scheme
|
||||
}
|
||||
|
||||
if srcScheme == "udp" {
|
||||
streamType = UDPStreamType
|
||||
srcUDPAddr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("0.0.0.0:%s", srcPort))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if dstScheme == "udp" {
|
||||
streamType = UDPStreamType
|
||||
dstUDPAddr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%s", config.Host, dstPort))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
route := StreamRoute{
|
||||
Alias: config.Alias,
|
||||
Type: streamType,
|
||||
ListeningScheme: srcScheme,
|
||||
TargetScheme: dstScheme,
|
||||
TargetHost: config.Host,
|
||||
ListeningPort: srcPort,
|
||||
TargetPort: dstPort,
|
||||
|
||||
Context: ctx,
|
||||
Cancel: cancel,
|
||||
}
|
||||
|
||||
if streamType == UDPStreamType {
|
||||
return (*StreamRoute)(unsafe.Pointer(&UDPRoute{
|
||||
StreamRoute: route,
|
||||
ConnMap: make(map[net.Addr]*net.UDPConn),
|
||||
ConnMapMutex: sync.Mutex{},
|
||||
QueueSize: atomic.Int32{},
|
||||
SourceUDPAddr: srcUDPAddr,
|
||||
TargetUDPAddr: dstUDPAddr,
|
||||
})), nil
|
||||
}
|
||||
return &route, nil
|
||||
}
|
||||
|
||||
func (route *StreamRoute) PrintError(err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
log.Printf("[Stream] %s => %s error: %v", route.ListeningUrl(), route.TargetUrl(), err)
|
||||
}
|
||||
|
||||
func (route *StreamRoute) ListeningUrl() string {
|
||||
return fmt.Sprintf("%s://:%s", route.ListeningScheme, route.ListeningPort)
|
||||
}
|
||||
|
||||
func (route *StreamRoute) TargetUrl() string {
|
||||
return fmt.Sprintf("%s://%s:%s", route.TargetScheme, route.TargetHost, route.TargetPort)
|
||||
}
|
||||
|
||||
func (route *StreamRoute) listenStream() {
|
||||
if route.Type == UDPStreamType {
|
||||
listenUDP((*UDPRoute)(unsafe.Pointer(route)))
|
||||
} else {
|
||||
listenTCP(route)
|
||||
}
|
||||
}
|
||||
74
src/go-proxy/tcp.go
Normal file
74
src/go-proxy/tcp.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const tcpDialTimeout = 5 * time.Second
|
||||
|
||||
func listenTCP(route *StreamRoute) {
|
||||
in, err := net.Listen(
|
||||
route.ListeningScheme,
|
||||
fmt.Sprintf(":%s", route.ListeningPort),
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("[Stream Listen] %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer in.Close()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-route.Context.Done():
|
||||
return
|
||||
default:
|
||||
clientConn, err := in.Accept()
|
||||
if err != nil {
|
||||
log.Printf("[Stream Accept] %v", err)
|
||||
return
|
||||
}
|
||||
go connectTCPPipe(route, clientConn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func connectTCPPipe(route *StreamRoute, clientConn net.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), tcpDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
serverAddr := fmt.Sprintf("%s:%s", route.TargetHost, route.TargetPort)
|
||||
dialer := &net.Dialer{}
|
||||
serverConn, err := dialer.DialContext(ctx, route.TargetScheme, serverAddr)
|
||||
if err != nil {
|
||||
log.Printf("[Stream Dial] %v", err)
|
||||
return
|
||||
}
|
||||
tcpPipe(route, clientConn, serverConn)
|
||||
}
|
||||
|
||||
func tcpPipe(route *StreamRoute, src net.Conn, dest net.Conn) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2) // Number of goroutines
|
||||
defer src.Close()
|
||||
defer dest.Close()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(src, dest)
|
||||
go route.PrintError(err)
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
_, err := io.Copy(dest, src)
|
||||
go route.PrintError(err)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
125
src/go-proxy/udp.go
Normal file
125
src/go-proxy/udp.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const udpBufferSize = 1500
|
||||
const udpMaxQueueSizePerStream = 100
|
||||
const udpListenTimeout = 100 * time.Second
|
||||
const udpConnectionTimeout = 30 * time.Second
|
||||
|
||||
type UDPRoute struct {
|
||||
StreamRoute
|
||||
|
||||
ConnMap map[net.Addr]*net.UDPConn
|
||||
ConnMapMutex sync.Mutex
|
||||
QueueSize atomic.Int32
|
||||
SourceUDPAddr *net.UDPAddr
|
||||
TargetUDPAddr *net.UDPAddr
|
||||
}
|
||||
|
||||
func listenUDP(route *UDPRoute) {
|
||||
source, err := net.ListenUDP(route.ListeningScheme, route.SourceUDPAddr)
|
||||
if err != nil {
|
||||
route.PrintError(err)
|
||||
return
|
||||
}
|
||||
|
||||
target, err := net.DialUDP(route.TargetScheme, nil, route.TargetUDPAddr)
|
||||
if err != nil {
|
||||
route.PrintError(err)
|
||||
return
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
|
||||
defer wg.Wait()
|
||||
defer source.Close()
|
||||
defer target.Close()
|
||||
|
||||
var udpBuffers = [udpMaxQueueSizePerStream][udpBufferSize]byte{}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-route.Context.Done():
|
||||
return
|
||||
default:
|
||||
if route.QueueSize.Load() >= udpMaxQueueSizePerStream {
|
||||
wg.Wait()
|
||||
}
|
||||
go udpLoop(
|
||||
route,
|
||||
source,
|
||||
target,
|
||||
udpBuffers[route.QueueSize.Load()][:],
|
||||
&wg,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func udpLoop(route *UDPRoute, in *net.UDPConn, out *net.UDPConn, buffer []byte, wg *sync.WaitGroup) {
|
||||
wg.Add(1)
|
||||
route.QueueSize.Add(1)
|
||||
defer route.QueueSize.Add(-1)
|
||||
defer wg.Done()
|
||||
|
||||
in.SetReadDeadline(time.Now().Add(udpListenTimeout))
|
||||
|
||||
var nRead int
|
||||
var nWritten int
|
||||
nRead, srcAddr, err := in.ReadFromUDP(buffer)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[Stream] received %d bytes from %s, forwarding to %s", nRead, srcAddr.String(), out.RemoteAddr().String())
|
||||
out.SetWriteDeadline(time.Now().Add(udpConnectionTimeout))
|
||||
nWritten, err = out.Write(buffer[:nRead])
|
||||
if nWritten != nRead {
|
||||
err = io.ErrShortWrite
|
||||
}
|
||||
if err != nil {
|
||||
go route.PrintError(err)
|
||||
return
|
||||
}
|
||||
|
||||
err = udpPipe(route, out, srcAddr, buffer)
|
||||
if err != nil {
|
||||
go route.PrintError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func udpPipe(route *UDPRoute, src *net.UDPConn, destAddr *net.UDPAddr, buffer []byte) error {
|
||||
src.SetReadDeadline(time.Now().Add(udpConnectionTimeout))
|
||||
nRead, err := src.Read(buffer)
|
||||
if err != nil || nRead == 0 {
|
||||
return err
|
||||
}
|
||||
log.Printf("[Stream] received %d bytes from %s, forwarding to %s", nRead, src.RemoteAddr().String(), destAddr.String())
|
||||
dest, ok := route.ConnMap[destAddr]
|
||||
if !ok {
|
||||
dest, err = net.DialUDP(src.LocalAddr().Network(), nil, destAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.ConnMapMutex.Lock()
|
||||
route.ConnMap[destAddr] = dest
|
||||
route.ConnMapMutex.Unlock()
|
||||
}
|
||||
dest.SetWriteDeadline(time.Now().Add(udpConnectionTimeout))
|
||||
nWritten, err := dest.Write(buffer[:nRead])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if nWritten != nRead {
|
||||
return io.ErrShortWrite
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user