feat(socket-proxy): implement Docker socket proxy and related configurations

- Updated Dockerfile and Makefile for socket-proxy build.
- Modified go.mod to include necessary dependencies.
- Updated CI workflows for socket-proxy integration.
- Better module isolation
- Code refactor
This commit is contained in:
yusing
2025-05-10 09:47:03 +08:00
parent 4ddfb48b9d
commit 8fe94d6d14
38 changed files with 658 additions and 523 deletions

View File

@@ -12,6 +12,7 @@ import (
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/logging/memlogger"
"github.com/yusing/go-proxy/internal/metrics/uptime"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/utils/strutils"
"github.com/yusing/go-proxy/pkg"
@@ -45,7 +46,7 @@ func (mux ServeMux) HandleFunc(methods, endpoint string, h any, requireAuth ...b
origHandler := handler
handler = func(w http.ResponseWriter, r *http.Request) {
if httpheaders.IsWebsocket(r.Header) {
httpheaders.SetWebsocketAllowedDomains(r.Header, matchDomains)
gpwebsocket.SetWebsocketAllowedDomains(r.Header, matchDomains)
}
origHandler(w, r)
}

View File

@@ -3,8 +3,7 @@ package common
import (
"crypto/rand"
"encoding/base64"
"github.com/rs/zerolog/log"
"log"
)
func decodeJWTKey(key string) []byte {
@@ -13,7 +12,7 @@ func decodeJWTKey(key string) []byte {
}
bytes, err := base64.StdEncoding.DecodeString(key)
if err != nil {
log.Fatal().Str("key", key).Err(err).Msg("failed to decode secret")
log.Fatalf("failed to decode secret: %s", err)
}
return bytes
}
@@ -22,7 +21,7 @@ func RandomJWTKey() []byte {
key := make([]byte, 32)
_, err := rand.Read(key)
if err != nil {
log.Fatal().Err(err).Msg("failed to generate random jwt key")
log.Fatalf("failed to generate random jwt key: %s", err)
}
return key
}

View File

@@ -2,13 +2,13 @@ package common
import (
"fmt"
"log"
"net"
"os"
"strconv"
"strings"
"time"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@@ -78,7 +78,7 @@ func GetEnv[T any](key string, defaultValue T, parser func(string) (T, error)) T
if err == nil {
return parsed
}
log.Fatal().Err(err).Msgf("env %s: invalid %T value: %s", key, parsed, value)
log.Fatalf("env %s: invalid %T value: %s", key, parsed, value)
return defaultValue
}
@@ -105,7 +105,7 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host string, portInt in
}
host, port, err := net.SplitHostPort(addr)
if err != nil {
log.Fatal().Msgf("env %s: invalid address: %s", key, addr)
log.Fatalf("env %s: invalid address: %s", key, addr)
}
if host == "" {
host = "localhost"
@@ -113,7 +113,7 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host string, portInt in
fullURL = fmt.Sprintf("%s://%s:%s", scheme, host, port)
portInt, err = strconv.Atoi(port)
if err != nil {
log.Fatal().Msgf("env %s: invalid port: %s", key, port)
log.Fatalf("env %s: invalid port: %s", key, port)
}
return
}

View File

@@ -40,7 +40,7 @@ func (cfg *Config) VerifyNewAgent(host string, ca agent.PEMPair, client agent.PE
var agentCfg agent.AgentConfig
agentCfg.Addr = host
err := agentCfg.StartWithCerts(cfg.Task(), ca.Cert, client.Cert, client.Key)
err := agentCfg.StartWithCerts(cfg.Task().Context(), ca.Cert, client.Cert, client.Key)
if err != nil {
return 0, gperr.Wrap(err, "failed to start agent")
}

View File

@@ -328,8 +328,8 @@ func (cfg *Config) loadRouteProviders(providers *config.Providers) gperr.Error {
removeAllAgents()
for _, agent := range providers.Agents {
if err := agent.Start(cfg.task); err != nil {
errs.Add(err.Subject(agent.String()))
if err := agent.Start(cfg.task.Context()); err != nil {
errs.Add(gperr.PrependSubject(agent.String(), err))
continue
}
addAgent(agent)

View File

@@ -6,7 +6,7 @@ replace github.com/yusing/go-proxy => ../..
require (
github.com/go-acme/lego/v4 v4.23.1
github.com/yusing/go-proxy v0.12.3
github.com/yusing/go-proxy v0.0.0-00010101000000-000000000000
)
require (
@@ -146,13 +146,13 @@ require (
github.com/spf13/viper v1.20.1 // indirect
github.com/stretchr/testify v1.10.0 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1160 // indirect
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1161 // indirect
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/dnspod v1.0.1136 // indirect
github.com/tjfoc/gmsm v1.4.1 // indirect
github.com/transip/gotransip/v6 v6.26.0 // indirect
github.com/ultradns/ultradns-go-sdk v1.8.0-20241010134910-243eeec // indirect
github.com/vinyldns/go-vinyldns v0.9.16 // indirect
github.com/volcengine/volc-sdk-golang v1.0.206 // indirect
github.com/volcengine/volc-sdk-golang v1.0.207 // indirect
github.com/vultr/govultr/v3 v3.20.0 // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect

View File

@@ -1519,8 +1519,8 @@ github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNG
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1136/go.mod h1:r5r4xbfxSaeR04b166HGsBa/R4U3SueirEUpXGuw+Q0=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1160 h1:jKVzMJy52E0zGbabQiZ7KaaYJwwwWblZAKgkt0Mex5E=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1160/go.mod h1:r5r4xbfxSaeR04b166HGsBa/R4U3SueirEUpXGuw+Q0=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1161 h1:S4dJSWhOtaPjp0/GO/yhzUC6DfZvpWhrnsEKaLxr73c=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1161/go.mod h1:r5r4xbfxSaeR04b166HGsBa/R4U3SueirEUpXGuw+Q0=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/dnspod v1.0.1136 h1:kMIdSU5IvpOROh27ToVQ3hlm6ym3lCRs9tnGCOBoZqk=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/dnspod v1.0.1136/go.mod h1:FpyIz3mymKaExVs6Fz27kxDBS42jqZn7vbACtxdeEH4=
github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho=
@@ -1538,8 +1538,8 @@ github.com/ultradns/ultradns-go-sdk v1.8.0-20241010134910-243eeec/go.mod h1:BZr7
github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI=
github.com/vinyldns/go-vinyldns v0.9.16 h1:GZJStDkcCk1F1AcRc64LuuMh+ENL8pHA0CVd4ulRMcQ=
github.com/vinyldns/go-vinyldns v0.9.16/go.mod h1:5qIJOdmzAnatKjurI+Tl4uTus7GJKJxb+zitufjHs3Q=
github.com/volcengine/volc-sdk-golang v1.0.206 h1:7NG8FCpvu9wbx+Z4I/p3tcTS2zdBqTZtJXgydunGy6g=
github.com/volcengine/volc-sdk-golang v1.0.206/go.mod h1:stZX+EPgv1vF4nZwOlEe8iGcriUPRBKX8zA19gXycOQ=
github.com/volcengine/volc-sdk-golang v1.0.207 h1:1OJ/nC92dF1URRoyO1AHSghCob12NT1PAA/GoK8uU18=
github.com/volcengine/volc-sdk-golang v1.0.207/go.mod h1:stZX+EPgv1vF4nZwOlEe8iGcriUPRBKX8zA19gXycOQ=
github.com/vultr/govultr/v3 v3.20.0 h1:O+Om6gXpN6ehwAIIKq5DyGuekpyHaoRlwrxTb44bDzA=
github.com/vultr/govultr/v3 v3.20.0/go.mod h1:q34Wd76upKmf+vxFMgaNMH3A8BbsPBmSYZUGC8oZa5w=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=

View File

@@ -3,8 +3,8 @@ package accesslog_test
import (
"testing"
"github.com/yusing/go-proxy/internal/docker"
. "github.com/yusing/go-proxy/internal/logging/accesslog"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/utils"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)

View File

@@ -9,9 +9,8 @@ import (
"time"
"github.com/coder/websocket"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/puzpuzpuz/xsync/v3"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type logEntryRange struct {
@@ -22,8 +21,8 @@ type memLogger struct {
*bytes.Buffer
sync.RWMutex
notifyLock sync.RWMutex
connChans F.Map[chan *logEntryRange, struct{}]
listeners F.Map[chan []byte, struct{}]
connChans *xsync.MapOf[chan *logEntryRange, struct{}]
listeners *xsync.MapOf[chan []byte, struct{}]
}
type MemLogger io.Writer
@@ -40,8 +39,8 @@ const (
var memLoggerInstance = &memLogger{
Buffer: bytes.NewBuffer(make([]byte, maxMemLogSize)),
connChans: F.NewMapOf[chan *logEntryRange, struct{}](),
listeners: F.NewMapOf[chan []byte, struct{}](),
connChans: xsync.NewMapOf[chan *logEntryRange, struct{}](),
listeners: xsync.NewMapOf[chan []byte, struct{}](),
}
func GetMemLogger() MemLogger {
@@ -136,7 +135,7 @@ func (m *memLogger) Write(p []byte) (n int, err error) {
func (m *memLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) {
conn, err := gpwebsocket.Initiate(w, r)
if err != nil {
gphttp.ServerError(w, r, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -153,7 +152,7 @@ func (m *memLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}()
if err := m.wsInitial(r.Context(), conn); err != nil {
gphttp.ServerError(w, r, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

View File

@@ -6,11 +6,7 @@ import (
"time"
"github.com/coder/websocket"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
)
func warnNoMatchDomains() {
@@ -19,13 +15,25 @@ func warnNoMatchDomains() {
var warnNoMatchDomainOnce sync.Once
const (
HeaderXGoDoxyWebsocketAllowedDomains = "X-GoDoxy-Websocket-Allowed-Domains"
)
func WebsocketAllowedDomains(h http.Header) []string {
return h[HeaderXGoDoxyWebsocketAllowedDomains]
}
func SetWebsocketAllowedDomains(h http.Header, domains []string) {
h[HeaderXGoDoxyWebsocketAllowedDomains] = domains
}
func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
var originPats []string
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
allowedDomains := httpheaders.WebsocketAllowedDomains(r.Header)
if len(allowedDomains) == 0 || common.IsDebug {
allowedDomains := WebsocketAllowedDomains(r.Header)
if len(allowedDomains) == 0 {
warnNoMatchDomainOnce.Do(warnNoMatchDomains)
originPats = []string{"*"}
} else {
@@ -47,14 +55,14 @@ func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) {
conn, err := Initiate(w, r)
if err != nil {
gphttp.ServerError(w, r, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
//nolint:errcheck
defer conn.CloseNow()
if err := do(conn); err != nil {
gphttp.ServerError(w, r, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -67,7 +75,7 @@ func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do
return
case <-ticker.C:
if err := do(conn); err != nil {
gphttp.ServerError(w, r, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
@@ -79,7 +87,7 @@ func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do
// It logs an error if the message is not written successfully.
func WriteText(r *http.Request, conn *websocket.Conn, msg string) bool {
if err := conn.Write(r.Context(), websocket.MessageText, []byte(msg)); err != nil {
gperr.LogError("failed to write text message", err)
logging.Err(err).Msg("failed to write text message")
return false
}
return true

View File

@@ -4,18 +4,6 @@ import (
"net/http"
)
const (
HeaderXGoDoxyWebsocketAllowedDomains = "X-GoDoxy-Websocket-Allowed-Domains"
)
func WebsocketAllowedDomains(h http.Header) []string {
return h[HeaderXGoDoxyWebsocketAllowedDomains]
}
func SetWebsocketAllowedDomains(h http.Header, domains []string) {
h[HeaderXGoDoxyWebsocketAllowedDomains] = domains
}
func IsWebsocket(h http.Header) bool {
return UpgradeType(h) == "websocket"
}

View File

@@ -4,12 +4,21 @@ import (
"testing"
"time"
"github.com/stretchr/testify/require"
. "github.com/yusing/go-proxy/internal/utils/strutils"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func mustParseTime(t *testing.T, layout, value string) time.Time {
t.Helper()
time, err := time.Parse(layout, value)
if err != nil {
t.Fatalf("failed to parse time: %s", err)
}
return time
}
func TestFormatTime(t *testing.T) {
now := expect.Must(time.Parse(time.RFC3339, "2021-06-15T12:30:30Z"))
now := mustParseTime(t, time.RFC3339, "2021-06-15T12:30:30Z")
tests := []struct {
name string
@@ -84,9 +93,9 @@ func TestFormatTime(t *testing.T) {
result := FormatTimeWithReference(tt.time, now)
if tt.expectedLength > 0 {
expect.Equal(t, len(result), tt.expectedLength, result)
require.Equal(t, tt.expectedLength, len(result), result)
} else {
expect.Equal(t, result, tt.expected)
require.Equal(t, tt.expected, result)
}
})
}
@@ -173,7 +182,7 @@ func TestFormatDuration(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := FormatDuration(tt.duration)
expect.Equal(t, result, tt.expected)
require.Equal(t, tt.expected, result)
})
}
}
@@ -203,7 +212,7 @@ func TestFormatLastSeen(t *testing.T) {
result := FormatLastSeen(tt.time)
if tt.name == "zero time" {
expect.Equal(t, result, tt.expected)
require.Equal(t, tt.expected, result)
} else {
// Just make sure it's not "never", the actual formatting is tested in TestFormatTime
if result == "never" {
@@ -290,7 +299,7 @@ func TestFormatByteSize(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := FormatByteSize(tt.size)
expect.Equal(t, result, tt.expected)
require.Equal(t, tt.expected, result)
})
}
}

View File

@@ -4,8 +4,8 @@ import (
"strings"
"testing"
"github.com/stretchr/testify/require"
. "github.com/yusing/go-proxy/internal/utils/strutils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
var alphaNumeric = func() string {
@@ -31,8 +31,8 @@ func TestSplit(t *testing.T) {
for sep, rsep := range tests {
t.Run(sep, func(t *testing.T) {
expected := strings.Split(alphaNumeric, sep)
ExpectEqual(t, SplitRune(alphaNumeric, rsep), expected)
ExpectEqual(t, JoinRune(expected, rsep), alphaNumeric)
require.Equal(t, expected, SplitRune(alphaNumeric, rsep))
require.Equal(t, alphaNumeric, JoinRune(expected, rsep))
})
}
}