feat(socket-proxy): implement Docker socket proxy and related configurations

- Updated Dockerfile and Makefile for socket-proxy build.
- Modified go.mod to include necessary dependencies.
- Updated CI workflows for socket-proxy integration.
- Better module isolation
- Code refactor
This commit is contained in:
yusing
2025-05-10 09:47:03 +08:00
parent 4ddfb48b9d
commit 8fe94d6d14
38 changed files with 658 additions and 523 deletions

View File

@@ -5,6 +5,8 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
@@ -14,10 +16,7 @@ import (
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/agent/pkg/certs"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/pkg"
)
@@ -80,7 +79,7 @@ func (cfg *AgentConfig) Parse(addr string) error {
return nil
}
func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte) error {
func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte) error {
clientCert, err := tls.X509KeyPair(crt, key)
if err != nil {
return err
@@ -90,7 +89,7 @@ func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte)
caCertPool := x509.NewCertPool()
ok := caCertPool.AppendCertsFromPEM(ca)
if !ok {
return gperr.New("invalid ca certificate")
return errors.New("invalid ca certificate")
}
cfg.tlsConfig = &tls.Config{
@@ -102,7 +101,7 @@ func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte)
// create transport and http client
cfg.httpClient = cfg.NewHTTPClient()
ctx, cancel := context.WithTimeout(parent.Context(), 5*time.Second)
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
// get agent name
@@ -131,23 +130,23 @@ func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte)
return nil
}
func (cfg *AgentConfig) Start(parent task.Parent) gperr.Error {
func (cfg *AgentConfig) Start(ctx context.Context) error {
filepath, ok := certs.AgentCertsFilepath(cfg.Addr)
if !ok {
return gperr.New("invalid agent host").Subject(cfg.Addr)
return fmt.Errorf("invalid agent host: %s", cfg.Addr)
}
certData, err := os.ReadFile(filepath)
if err != nil {
return gperr.Wrap(err, "failed to read agent certs")
return fmt.Errorf("failed to read agent certs: %w", err)
}
ca, crt, key, err := certs.ExtractCert(certData)
if err != nil {
return gperr.Wrap(err, "failed to extract agent certs")
return fmt.Errorf("failed to extract agent certs: %w", err)
}
return gperr.Wrap(cfg.StartWithCerts(parent, ca, crt, key))
return cfg.StartWithCerts(ctx, ca, crt, key)
}
func (cfg *AgentConfig) NewHTTPClient() *http.Client {
@@ -171,8 +170,10 @@ func (cfg *AgentConfig) Transport() *http.Transport {
}
}
var dialer = &net.Dialer{Timeout: 5 * time.Second}
func (cfg *AgentConfig) DialContext(ctx context.Context) (net.Conn, error) {
return gphttp.DefaultDialer.DialContext(ctx, "tcp", cfg.Addr)
return dialer.DialContext(ctx, "tcp", cfg.Addr)
}
func (cfg *AgentConfig) Name() string {

View File

@@ -8,59 +8,59 @@ import (
"net/http/httptest"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
"github.com/stretchr/testify/require"
)
func TestNewAgent(t *testing.T) {
ca, srv, client, err := NewAgent()
ExpectNoError(t, err)
ExpectTrue(t, ca != nil)
ExpectTrue(t, srv != nil)
ExpectTrue(t, client != nil)
require.NoError(t, err)
require.NotNil(t, ca)
require.NotNil(t, srv)
require.NotNil(t, client)
}
func TestPEMPair(t *testing.T) {
ca, srv, client, err := NewAgent()
ExpectNoError(t, err)
require.NoError(t, err)
for i, p := range []*PEMPair{ca, srv, client} {
t.Run(fmt.Sprintf("load-%d", i), func(t *testing.T) {
var pp PEMPair
err := pp.Load(p.String())
ExpectNoError(t, err)
ExpectEqual(t, p.Cert, pp.Cert)
ExpectEqual(t, p.Key, pp.Key)
require.NoError(t, err)
require.Equal(t, p.Cert, pp.Cert)
require.Equal(t, p.Key, pp.Key)
})
}
}
func TestPEMPairToTLSCert(t *testing.T) {
ca, srv, client, err := NewAgent()
ExpectNoError(t, err)
require.NoError(t, err)
for i, p := range []*PEMPair{ca, srv, client} {
t.Run(fmt.Sprintf("toTLSCert-%d", i), func(t *testing.T) {
cert, err := p.ToTLSCert()
ExpectNoError(t, err)
ExpectTrue(t, cert != nil)
require.NoError(t, err)
require.NotNil(t, cert)
})
}
}
func TestServerClient(t *testing.T) {
ca, srv, client, err := NewAgent()
ExpectNoError(t, err)
require.NoError(t, err)
srvTLS, err := srv.ToTLSCert()
ExpectNoError(t, err)
ExpectTrue(t, srvTLS != nil)
require.NoError(t, err)
require.NotNil(t, srvTLS)
clientTLS, err := client.ToTLSCert()
ExpectNoError(t, err)
ExpectTrue(t, clientTLS != nil)
require.NoError(t, err)
require.NotNil(t, clientTLS)
caPool := x509.NewCertPool()
ExpectTrue(t, caPool.AppendCertsFromPEM(ca.Cert))
require.True(t, caPool.AppendCertsFromPEM(ca.Cert))
srvTLSConfig := &tls.Config{
Certificates: []tls.Certificate{*srvTLS},
@@ -86,6 +86,6 @@ func TestServerClient(t *testing.T) {
}
resp, err := httpClient.Get(server.URL)
ExpectNoError(t, err)
ExpectEqual(t, resp.StatusCode, http.StatusOK)
require.NoError(t, err)
require.Equal(t, resp.StatusCode, http.StatusOK)
}

60
agent/pkg/env/env.go vendored
View File

@@ -20,35 +20,6 @@ var (
AgentSkipClientCertCheck bool
AgentCACert string
AgentSSLCert string
DockerSocketAddr string
DockerPost bool
DockerRestarts bool
DockerStart bool
DockerStop bool
DockerAuth bool
DockerBuild bool
DockerCommit bool
DockerConfigs bool
DockerContainers bool
DockerDistribution bool
DockerEvents bool
DockerExec bool
DockerGrpc bool
DockerImages bool
DockerInfo bool
DockerNetworks bool
DockerNodes bool
DockerPing bool
DockerPlugins bool
DockerSecrets bool
DockerServices bool
DockerSession bool
DockerSwarm bool
DockerSystem bool
DockerTasks bool
DockerVersion bool
DockerVolumes bool
)
func init() {
@@ -62,35 +33,4 @@ func Load() {
AgentCACert = common.GetEnvString("AGENT_CA_CERT", "")
AgentSSLCert = common.GetEnvString("AGENT_SSL_CERT", "")
// docker socket proxy
DockerSocketAddr = common.GetEnvString("DOCKER_SOCKET_ADDR", "127.0.0.1:2375")
DockerPost = common.GetEnvBool("POST", false)
DockerRestarts = common.GetEnvBool("ALLOW_RESTARTS", false)
DockerStart = common.GetEnvBool("ALLOW_START", false)
DockerStop = common.GetEnvBool("ALLOW_STOP", false)
DockerAuth = common.GetEnvBool("AUTH", false)
DockerBuild = common.GetEnvBool("BUILD", false)
DockerCommit = common.GetEnvBool("COMMIT", false)
DockerConfigs = common.GetEnvBool("CONFIGS", false)
DockerContainers = common.GetEnvBool("CONTAINERS", false)
DockerDistribution = common.GetEnvBool("DISTRIBUTION", false)
DockerEvents = common.GetEnvBool("EVENTS", true)
DockerExec = common.GetEnvBool("EXEC", false)
DockerGrpc = common.GetEnvBool("GRPC", false)
DockerImages = common.GetEnvBool("IMAGES", false)
DockerInfo = common.GetEnvBool("INFO", false)
DockerNetworks = common.GetEnvBool("NETWORKS", false)
DockerNodes = common.GetEnvBool("NODES", false)
DockerPing = common.GetEnvBool("PING", true)
DockerPlugins = common.GetEnvBool("PLUGINS", false)
DockerSecrets = common.GetEnvBool("SECRETS", false)
DockerServices = common.GetEnvBool("SERVICES", false)
DockerSession = common.GetEnvBool("SESSION", false)
DockerSwarm = common.GetEnvBool("SWARM", false)
DockerSystem = common.GetEnvBool("SYSTEM", false)
DockerTasks = common.GetEnvBool("TASKS", false)
DockerVersion = common.GetEnvBool("VERSION", true)
DockerVolumes = common.GetEnvBool("VOLUMES", false)
}

View File

@@ -1,13 +1,13 @@
package handler
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/watcher/health"
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
)
@@ -73,5 +73,7 @@ func CheckHealth(w http.ResponseWriter, r *http.Request) {
return
}
gphttp.RespondJSON(w, r, result)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(result)
}

View File

@@ -1,414 +0,0 @@
package handler
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/yusing/go-proxy/agent/pkg/env"
)
func TestNewDockerHandler(t *testing.T) {
tests := []struct {
name string
method string
path string
envSetup func()
wantStatusCode int
}{
{
name: "GET _ping allowed by default",
method: http.MethodGet,
path: "/_ping",
envSetup: func() {},
wantStatusCode: http.StatusOK,
},
{
name: "GET version allowed by default",
method: http.MethodGet,
path: "/version",
envSetup: func() {},
wantStatusCode: http.StatusOK,
},
{
name: "GET containers allowed when enabled",
method: http.MethodGet,
path: "/containers",
envSetup: func() {
env.DockerContainers = true
},
wantStatusCode: http.StatusOK,
},
{
name: "GET containers not allowed when disabled",
method: http.MethodGet,
path: "/containers",
envSetup: func() {
env.DockerContainers = false
},
wantStatusCode: http.StatusForbidden,
},
{
name: "POST not allowed by default",
method: http.MethodPost,
path: "/_ping",
envSetup: func() {
env.DockerPost = false
},
wantStatusCode: http.StatusMethodNotAllowed,
},
{
name: "POST allowed when enabled",
method: http.MethodPost,
path: "/_ping",
envSetup: func() {
env.DockerPost = true
env.DockerPing = true
},
wantStatusCode: http.StatusOK,
},
{
name: "Container restart not allowed when disabled",
method: http.MethodPost,
path: "/containers/test-container/restart",
envSetup: func() {
env.DockerPost = true
env.DockerContainers = true
env.DockerRestarts = false
},
wantStatusCode: http.StatusForbidden,
},
{
name: "Container restart allowed when enabled",
method: http.MethodPost,
path: "/containers/test-container/restart",
envSetup: func() {
env.DockerPost = true
env.DockerContainers = true
env.DockerRestarts = true
},
wantStatusCode: http.StatusOK,
},
{
name: "Container start not allowed when disabled",
method: http.MethodPost,
path: "/containers/test-container/start",
envSetup: func() {
env.DockerPost = true
env.DockerContainers = true
env.DockerStart = false
},
wantStatusCode: http.StatusForbidden,
},
{
name: "Container start allowed when enabled",
method: http.MethodPost,
path: "/containers/test-container/start",
envSetup: func() {
env.DockerPost = true
env.DockerContainers = true
env.DockerStart = true
},
wantStatusCode: http.StatusOK,
},
{
name: "Container stop not allowed when disabled",
method: http.MethodPost,
path: "/containers/test-container/stop",
envSetup: func() {
env.DockerPost = true
env.DockerContainers = true
env.DockerStop = false
},
wantStatusCode: http.StatusForbidden,
},
{
name: "Container stop allowed when enabled",
method: http.MethodPost,
path: "/containers/test-container/stop",
envSetup: func() {
env.DockerPost = true
env.DockerContainers = true
env.DockerStop = true
},
wantStatusCode: http.StatusOK,
},
{
name: "Versioned API paths work",
method: http.MethodGet,
path: "/v1.41/version",
envSetup: func() {
env.DockerVersion = true
},
wantStatusCode: http.StatusOK,
},
{
name: "PUT method not allowed",
method: http.MethodPut,
path: "/version",
envSetup: func() {
env.DockerVersion = true
},
wantStatusCode: http.StatusMethodNotAllowed,
},
{
name: "DELETE method not allowed",
method: http.MethodDelete,
path: "/version",
envSetup: func() {
env.DockerVersion = true
},
wantStatusCode: http.StatusMethodNotAllowed,
},
}
// Save original env values to restore after tests
originalContainers := env.DockerContainers
originalRestarts := env.DockerRestarts
originalStart := env.DockerStart
originalStop := env.DockerStop
originalPost := env.DockerPost
originalPing := env.DockerPing
originalVersion := env.DockerVersion
defer func() {
// Restore original values
env.DockerContainers = originalContainers
env.DockerRestarts = originalRestarts
env.DockerStart = originalStart
env.DockerStop = originalStop
env.DockerPost = originalPost
env.DockerPing = originalPing
env.DockerVersion = originalVersion
}()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Setup environment for this test
tt.envSetup()
// Create test handler that will record the response for verification
dockerHandler := NewDockerHandler()
// Test server to capture the response
recorder := httptest.NewRecorder()
// Create request
req, err := http.NewRequest(tt.method, tt.path, nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
// Process the request
dockerHandler.ServeHTTP(recorder, req)
// Check response
if recorder.Code != tt.wantStatusCode {
t.Errorf("Expected status code %d, got %d",
tt.wantStatusCode, recorder.Code)
}
})
}
}
// This test focuses on checking that all the path prefix handling works correctly
func TestNewDockerHandler_PathHandling(t *testing.T) {
tests := []struct {
name string
path string
envVarName string
envVarValue bool
method string
wantAllowed bool
}{
{"Container path", "/containers/json", "DockerContainers", true, http.MethodGet, true},
{"Container path disabled", "/containers/json", "DockerContainers", false, http.MethodGet, false},
{"Auth path", "/auth", "DockerAuth", true, http.MethodGet, true},
{"Auth path disabled", "/auth", "DockerAuth", false, http.MethodGet, false},
{"Build path", "/build", "DockerBuild", true, http.MethodGet, true},
{"Build path disabled", "/build", "DockerBuild", false, http.MethodGet, false},
{"Commit path", "/commit", "DockerCommit", true, http.MethodGet, true},
{"Commit path disabled", "/commit", "DockerCommit", false, http.MethodGet, false},
{"Configs path", "/configs", "DockerConfigs", true, http.MethodGet, true},
{"Configs path disabled", "/configs", "DockerConfigs", false, http.MethodGet, false},
{"Distribution path", "/distribution", "DockerDistribution", true, http.MethodGet, true},
{"Distribution path disabled", "/distribution", "DockerDistribution", false, http.MethodGet, false},
{"Events path", "/events", "DockerEvents", true, http.MethodGet, true},
{"Events path disabled", "/events", "DockerEvents", false, http.MethodGet, false},
{"Exec path", "/exec", "DockerExec", true, http.MethodGet, true},
{"Exec path disabled", "/exec", "DockerExec", false, http.MethodGet, false},
{"Grpc path", "/grpc", "DockerGrpc", true, http.MethodGet, true},
{"Grpc path disabled", "/grpc", "DockerGrpc", false, http.MethodGet, false},
{"Images path", "/images", "DockerImages", true, http.MethodGet, true},
{"Images path disabled", "/images", "DockerImages", false, http.MethodGet, false},
{"Info path", "/info", "DockerInfo", true, http.MethodGet, true},
{"Info path disabled", "/info", "DockerInfo", false, http.MethodGet, false},
{"Networks path", "/networks", "DockerNetworks", true, http.MethodGet, true},
{"Networks path disabled", "/networks", "DockerNetworks", false, http.MethodGet, false},
{"Nodes path", "/nodes", "DockerNodes", true, http.MethodGet, true},
{"Nodes path disabled", "/nodes", "DockerNodes", false, http.MethodGet, false},
{"Plugins path", "/plugins", "DockerPlugins", true, http.MethodGet, true},
{"Plugins path disabled", "/plugins", "DockerPlugins", false, http.MethodGet, false},
{"Secrets path", "/secrets", "DockerSecrets", true, http.MethodGet, true},
{"Secrets path disabled", "/secrets", "DockerSecrets", false, http.MethodGet, false},
{"Services path", "/services", "DockerServices", true, http.MethodGet, true},
{"Services path disabled", "/services", "DockerServices", false, http.MethodGet, false},
{"Session path", "/session", "DockerSession", true, http.MethodGet, true},
{"Session path disabled", "/session", "DockerSession", false, http.MethodGet, false},
{"Swarm path", "/swarm", "DockerSwarm", true, http.MethodGet, true},
{"Swarm path disabled", "/swarm", "DockerSwarm", false, http.MethodGet, false},
{"System path", "/system", "DockerSystem", true, http.MethodGet, true},
{"System path disabled", "/system", "DockerSystem", false, http.MethodGet, false},
{"Tasks path", "/tasks", "DockerTasks", true, http.MethodGet, true},
{"Tasks path disabled", "/tasks", "DockerTasks", false, http.MethodGet, false},
{"Volumes path", "/volumes", "DockerVolumes", true, http.MethodGet, true},
{"Volumes path disabled", "/volumes", "DockerVolumes", false, http.MethodGet, false},
// Test versioned paths
{"Versioned auth", "/v1.41/auth", "DockerAuth", true, http.MethodGet, true},
{"Versioned auth disabled", "/v1.41/auth", "DockerAuth", false, http.MethodGet, false},
}
defer func() {
// Restore original env values
env.Load()
}()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset all Docker* env vars to false for this test
env.Load()
// Enable POST for all these tests
env.DockerPost = true
// Set the specific env var for this test
switch tt.envVarName {
case "DockerContainers":
env.DockerContainers = tt.envVarValue
case "DockerRestarts":
env.DockerRestarts = tt.envVarValue
case "DockerStart":
env.DockerStart = tt.envVarValue
case "DockerStop":
env.DockerStop = tt.envVarValue
case "DockerAuth":
env.DockerAuth = tt.envVarValue
case "DockerBuild":
env.DockerBuild = tt.envVarValue
case "DockerCommit":
env.DockerCommit = tt.envVarValue
case "DockerConfigs":
env.DockerConfigs = tt.envVarValue
case "DockerDistribution":
env.DockerDistribution = tt.envVarValue
case "DockerEvents":
env.DockerEvents = tt.envVarValue
case "DockerExec":
env.DockerExec = tt.envVarValue
case "DockerGrpc":
env.DockerGrpc = tt.envVarValue
case "DockerImages":
env.DockerImages = tt.envVarValue
case "DockerInfo":
env.DockerInfo = tt.envVarValue
case "DockerNetworks":
env.DockerNetworks = tt.envVarValue
case "DockerNodes":
env.DockerNodes = tt.envVarValue
case "DockerPlugins":
env.DockerPlugins = tt.envVarValue
case "DockerSecrets":
env.DockerSecrets = tt.envVarValue
case "DockerServices":
env.DockerServices = tt.envVarValue
case "DockerSession":
env.DockerSession = tt.envVarValue
case "DockerSwarm":
env.DockerSwarm = tt.envVarValue
case "DockerSystem":
env.DockerSystem = tt.envVarValue
case "DockerTasks":
env.DockerTasks = tt.envVarValue
case "DockerVolumes":
env.DockerVolumes = tt.envVarValue
default:
t.Fatalf("Unknown env var: %s", tt.envVarName)
}
// Create test handler
dockerHandler := NewDockerHandler()
// Test server to capture the response
recorder := httptest.NewRecorder()
// Create request
req, err := http.NewRequest(tt.method, tt.path, nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
// Process the request
dockerHandler.ServeHTTP(recorder, req)
// Check if the status indicates if the path is allowed or not
isAllowed := recorder.Code != http.StatusForbidden
if isAllowed != tt.wantAllowed {
t.Errorf("Path %s with env %s=%v: got allowed=%v, want allowed=%v (status=%d)",
tt.path, tt.envVarName, tt.envVarValue, isAllowed, tt.wantAllowed, recorder.Code)
}
})
}
}
// TestNewDockerHandlerWithMockDocker mocks the Docker API to test the actual HTTP handler behavior
// This is a more comprehensive test that verifies the full request/response chain
func TestNewDockerHandlerWithMockDocker(t *testing.T) {
// Set up environment
env.DockerContainers = true
env.DockerPost = true
// Create the handler
handler := NewDockerHandler()
// Test a valid request
req, _ := http.NewRequest(http.MethodGet, "/containers", nil)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Errorf("Expected status OK for /containers, got %d", recorder.Code)
}
// Test a disallowed path
env.DockerContainers = false
handler = NewDockerHandler() // recreate with new env
req, _ = http.NewRequest(http.MethodGet, "/containers", nil)
recorder = httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
if recorder.Code != http.StatusForbidden {
t.Errorf("Expected status Forbidden for /containers when disabled, got %d", recorder.Code)
}
}

View File

@@ -1,38 +0,0 @@
package handler
import (
"net/http"
"net/url"
"github.com/docker/docker/client"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
)
func serviceUnavailable(w http.ResponseWriter, r *http.Request) {
http.Error(w, "docker socket is not available", http.StatusServiceUnavailable)
}
func mockDockerSocketHandler() http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("mock docker response"))
})
}
func DockerSocketHandler() http.HandlerFunc {
dockerClient, err := docker.NewClient(common.DockerHostFromEnv)
if err != nil {
logging.Warn().Err(err).Msg("failed to connect to docker client")
return serviceUnavailable
}
rp := reverseproxy.NewReverseProxy("docker", types.NewURL(&url.URL{
Scheme: "http",
Host: client.DummyHost,
}), dockerClient.HTTPClient().Transport)
return rp.ServeHTTP
}

View File

@@ -2,201 +2,35 @@ package handler
import (
"fmt"
"io"
"net/http"
"strings"
"github.com/gorilla/mux"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/agent/pkg/env"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging/memlogger"
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
"github.com/yusing/go-proxy/internal/utils/strutils"
"github.com/yusing/go-proxy/pkg"
socketproxy "github.com/yusing/go-proxy/socketproxy/pkg"
)
type ServeMux struct{ *http.ServeMux }
func (mux ServeMux) HandleMethods(methods, endpoint string, handler http.HandlerFunc) {
for _, m := range strutils.CommaSeperatedList(methods) {
mux.ServeMux.HandleFunc(m+" "+agent.APIEndpointBase+endpoint, handler)
}
func (mux ServeMux) HandleEndpoint(method, endpoint string, handler http.HandlerFunc) {
mux.ServeMux.HandleFunc(method+" "+agent.APIEndpointBase+endpoint, handler)
}
func (mux ServeMux) HandleFunc(endpoint string, handler http.HandlerFunc) {
mux.ServeMux.HandleFunc(agent.APIEndpointBase+endpoint, handler)
}
type NopWriteCloser struct {
io.Writer
}
func (NopWriteCloser) Close() error {
return nil
}
func NewAgentHandler() http.Handler {
mux := ServeMux{http.NewServeMux()}
mux.HandleFunc(agent.EndpointProxyHTTP+"/{path...}", ProxyHTTP)
mux.HandleMethods("GET", agent.EndpointVersion, pkg.GetVersionHTTPHandler())
mux.HandleMethods("GET", agent.EndpointName, func(w http.ResponseWriter, r *http.Request) {
mux.HandleEndpoint("GET", agent.EndpointVersion, pkg.GetVersionHTTPHandler())
mux.HandleEndpoint("GET", agent.EndpointName, func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, env.AgentName)
})
mux.HandleMethods("GET", agent.EndpointHealth, CheckHealth)
mux.HandleMethods("GET", agent.EndpointLogs, memlogger.HandlerFunc())
mux.HandleMethods("GET", agent.EndpointSystemInfo, systeminfo.Poller.ServeHTTP)
mux.ServeMux.HandleFunc("/", DockerSocketHandler())
mux.HandleEndpoint("GET", agent.EndpointHealth, CheckHealth)
mux.HandleEndpoint("GET", agent.EndpointSystemInfo, systeminfo.Poller.ServeHTTP)
mux.ServeMux.HandleFunc("/", socketproxy.DockerSocketHandler())
return mux
}
func endpointNotAllowed(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "Endpoint not allowed", http.StatusForbidden)
}
// ref: https://github.com/Tecnativa/docker-socket-proxy/blob/master/haproxy.cfg
func NewDockerHandler() http.Handler {
r := mux.NewRouter()
var socketHandler http.HandlerFunc
if common.IsTest {
socketHandler = mockDockerSocketHandler()
} else {
socketHandler = DockerSocketHandler()
}
const apiVersionPrefix = `/{version:(?:v[\d\.]+)?}`
const containerPath = "/containers/{id:[a-zA-Z0-9_.-]+}"
allowedPaths := []string{}
deniedPaths := []string{}
if env.DockerContainers {
allowedPaths = append(allowedPaths, "/containers")
if !env.DockerRestarts {
deniedPaths = append(deniedPaths, containerPath+"/stop")
deniedPaths = append(deniedPaths, containerPath+"/restart")
deniedPaths = append(deniedPaths, containerPath+"/kill")
}
if !env.DockerStart {
deniedPaths = append(deniedPaths, containerPath+"/start")
}
if !env.DockerStop && env.DockerRestarts {
deniedPaths = append(deniedPaths, containerPath+"/stop")
}
}
if env.DockerAuth {
allowedPaths = append(allowedPaths, "/auth")
}
if env.DockerBuild {
allowedPaths = append(allowedPaths, "/build")
}
if env.DockerCommit {
allowedPaths = append(allowedPaths, "/commit")
}
if env.DockerConfigs {
allowedPaths = append(allowedPaths, "/configs")
}
if env.DockerDistribution {
allowedPaths = append(allowedPaths, "/distribution")
}
if env.DockerEvents {
allowedPaths = append(allowedPaths, "/events")
}
if env.DockerExec {
allowedPaths = append(allowedPaths, "/exec")
}
if env.DockerGrpc {
allowedPaths = append(allowedPaths, "/grpc")
}
if env.DockerImages {
allowedPaths = append(allowedPaths, "/images")
}
if env.DockerInfo {
allowedPaths = append(allowedPaths, "/info")
}
if env.DockerNetworks {
allowedPaths = append(allowedPaths, "/networks")
}
if env.DockerNodes {
allowedPaths = append(allowedPaths, "/nodes")
}
if env.DockerPing {
allowedPaths = append(allowedPaths, "/_ping")
}
if env.DockerPlugins {
allowedPaths = append(allowedPaths, "/plugins")
}
if env.DockerSecrets {
allowedPaths = append(allowedPaths, "/secrets")
}
if env.DockerServices {
allowedPaths = append(allowedPaths, "/services")
}
if env.DockerSession {
allowedPaths = append(allowedPaths, "/session")
}
if env.DockerSwarm {
allowedPaths = append(allowedPaths, "/swarm")
}
if env.DockerSystem {
allowedPaths = append(allowedPaths, "/system")
}
if env.DockerTasks {
allowedPaths = append(allowedPaths, "/tasks")
}
if env.DockerVersion {
allowedPaths = append(allowedPaths, "/version")
}
if env.DockerVolumes {
allowedPaths = append(allowedPaths, "/volumes")
}
// Helper to determine if a path should be treated as a prefix
isPrefixPath := func(path string) bool {
return strings.Count(path, "/") == 1
}
// 1. Register Denied Paths (specific)
for _, path := range deniedPaths {
// Handle with version prefix
r.HandleFunc(apiVersionPrefix+path, endpointNotAllowed)
// Handle without version prefix
r.HandleFunc(path, endpointNotAllowed)
}
// 2. Register Allowed Paths
for _, p := range allowedPaths {
fullPathWithVersion := apiVersionPrefix + p
if isPrefixPath(p) {
r.PathPrefix(fullPathWithVersion).Handler(socketHandler)
r.PathPrefix(p).Handler(socketHandler)
} else {
r.HandleFunc(fullPathWithVersion, socketHandler)
r.HandleFunc(p, socketHandler)
}
}
// 3. Add fallback for any other routes
r.PathPrefix("/").HandlerFunc(endpointNotAllowed)
// HTTP method filtering
if !env.DockerPost {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
switch req.Method {
case http.MethodGet:
r.ServeHTTP(w, req)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
switch req.Method {
case http.MethodPost, http.MethodGet:
r.ServeHTTP(w, req)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
}

View File

@@ -3,18 +3,26 @@ package handler
import (
"crypto/tls"
"net/http"
"net/url"
"net/http/httputil"
"strconv"
"time"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/agent/pkg/agentproxy"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
)
func NewTransport() *http.Transport {
return &http.Transport{
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: 60 * time.Second,
WriteBufferSize: 16 * 1024, // 16KB
ReadBufferSize: 16 * 1024, // 16KB
}
}
func ProxyHTTP(w http.ResponseWriter, r *http.Request) {
host := r.Header.Get(agentproxy.HeaderXProxyHost)
isHTTPS, _ := strconv.ParseBool(r.Header.Get(agentproxy.HeaderXProxyHTTPS))
@@ -34,11 +42,9 @@ func ProxyHTTP(w http.ResponseWriter, r *http.Request) {
scheme = "https"
}
var transport *http.Transport
transport := NewTransport()
if skipTLSVerify {
transport = gphttp.NewTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true})
} else {
transport = gphttp.NewTransport()
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
if responseHeaderTimeout > 0 {
@@ -49,14 +55,13 @@ func ProxyHTTP(w http.ResponseWriter, r *http.Request) {
r.URL.Host = ""
r.URL.Path = r.URL.Path[agent.HTTPProxyURLPrefixLen:] // strip the {API_BASE}/proxy/http prefix
r.RequestURI = r.URL.String()
r.URL.Host = host
r.URL.Scheme = scheme
logging.Debug().Msgf("proxy http request: %s %s", r.Method, r.URL.String())
rp := reverseproxy.NewReverseProxy("agent", types.NewURL(&url.URL{
Scheme: scheme,
Host: host,
}), transport)
rp := &httputil.ReverseProxy{
Director: func(r *http.Request) {
r.URL.Scheme = scheme
r.URL.Host = host
},
Transport: transport,
}
rp.ServeHTTP(w, r)
}