mirror of
https://github.com/yusing/godoxy.git
synced 2026-03-26 11:01:07 +01:00
feat(socket-proxy): implement Docker socket proxy and related configurations
- Updated Dockerfile and Makefile for socket-proxy build. - Modified go.mod to include necessary dependencies. - Updated CI workflows for socket-proxy integration. - Better module isolation - Code refactor
This commit is contained in:
@@ -5,6 +5,8 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -14,10 +16,7 @@ import (
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/agent/pkg/certs"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/pkg"
|
||||
)
|
||||
|
||||
@@ -80,7 +79,7 @@ func (cfg *AgentConfig) Parse(addr string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte) error {
|
||||
func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte) error {
|
||||
clientCert, err := tls.X509KeyPair(crt, key)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -90,7 +89,7 @@ func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte)
|
||||
caCertPool := x509.NewCertPool()
|
||||
ok := caCertPool.AppendCertsFromPEM(ca)
|
||||
if !ok {
|
||||
return gperr.New("invalid ca certificate")
|
||||
return errors.New("invalid ca certificate")
|
||||
}
|
||||
|
||||
cfg.tlsConfig = &tls.Config{
|
||||
@@ -102,7 +101,7 @@ func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte)
|
||||
// create transport and http client
|
||||
cfg.httpClient = cfg.NewHTTPClient()
|
||||
|
||||
ctx, cancel := context.WithTimeout(parent.Context(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// get agent name
|
||||
@@ -131,23 +130,23 @@ func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *AgentConfig) Start(parent task.Parent) gperr.Error {
|
||||
func (cfg *AgentConfig) Start(ctx context.Context) error {
|
||||
filepath, ok := certs.AgentCertsFilepath(cfg.Addr)
|
||||
if !ok {
|
||||
return gperr.New("invalid agent host").Subject(cfg.Addr)
|
||||
return fmt.Errorf("invalid agent host: %s", cfg.Addr)
|
||||
}
|
||||
|
||||
certData, err := os.ReadFile(filepath)
|
||||
if err != nil {
|
||||
return gperr.Wrap(err, "failed to read agent certs")
|
||||
return fmt.Errorf("failed to read agent certs: %w", err)
|
||||
}
|
||||
|
||||
ca, crt, key, err := certs.ExtractCert(certData)
|
||||
if err != nil {
|
||||
return gperr.Wrap(err, "failed to extract agent certs")
|
||||
return fmt.Errorf("failed to extract agent certs: %w", err)
|
||||
}
|
||||
|
||||
return gperr.Wrap(cfg.StartWithCerts(parent, ca, crt, key))
|
||||
return cfg.StartWithCerts(ctx, ca, crt, key)
|
||||
}
|
||||
|
||||
func (cfg *AgentConfig) NewHTTPClient() *http.Client {
|
||||
@@ -171,8 +170,10 @@ func (cfg *AgentConfig) Transport() *http.Transport {
|
||||
}
|
||||
}
|
||||
|
||||
var dialer = &net.Dialer{Timeout: 5 * time.Second}
|
||||
|
||||
func (cfg *AgentConfig) DialContext(ctx context.Context) (net.Conn, error) {
|
||||
return gphttp.DefaultDialer.DialContext(ctx, "tcp", cfg.Addr)
|
||||
return dialer.DialContext(ctx, "tcp", cfg.Addr)
|
||||
}
|
||||
|
||||
func (cfg *AgentConfig) Name() string {
|
||||
|
||||
@@ -8,59 +8,59 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewAgent(t *testing.T) {
|
||||
ca, srv, client, err := NewAgent()
|
||||
ExpectNoError(t, err)
|
||||
ExpectTrue(t, ca != nil)
|
||||
ExpectTrue(t, srv != nil)
|
||||
ExpectTrue(t, client != nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, ca)
|
||||
require.NotNil(t, srv)
|
||||
require.NotNil(t, client)
|
||||
}
|
||||
|
||||
func TestPEMPair(t *testing.T) {
|
||||
ca, srv, client, err := NewAgent()
|
||||
ExpectNoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
for i, p := range []*PEMPair{ca, srv, client} {
|
||||
t.Run(fmt.Sprintf("load-%d", i), func(t *testing.T) {
|
||||
var pp PEMPair
|
||||
err := pp.Load(p.String())
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, p.Cert, pp.Cert)
|
||||
ExpectEqual(t, p.Key, pp.Key)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, p.Cert, pp.Cert)
|
||||
require.Equal(t, p.Key, pp.Key)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPEMPairToTLSCert(t *testing.T) {
|
||||
ca, srv, client, err := NewAgent()
|
||||
ExpectNoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
for i, p := range []*PEMPair{ca, srv, client} {
|
||||
t.Run(fmt.Sprintf("toTLSCert-%d", i), func(t *testing.T) {
|
||||
cert, err := p.ToTLSCert()
|
||||
ExpectNoError(t, err)
|
||||
ExpectTrue(t, cert != nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cert)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerClient(t *testing.T) {
|
||||
ca, srv, client, err := NewAgent()
|
||||
ExpectNoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
srvTLS, err := srv.ToTLSCert()
|
||||
ExpectNoError(t, err)
|
||||
ExpectTrue(t, srvTLS != nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, srvTLS)
|
||||
|
||||
clientTLS, err := client.ToTLSCert()
|
||||
ExpectNoError(t, err)
|
||||
ExpectTrue(t, clientTLS != nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, clientTLS)
|
||||
|
||||
caPool := x509.NewCertPool()
|
||||
ExpectTrue(t, caPool.AppendCertsFromPEM(ca.Cert))
|
||||
require.True(t, caPool.AppendCertsFromPEM(ca.Cert))
|
||||
|
||||
srvTLSConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{*srvTLS},
|
||||
@@ -86,6 +86,6 @@ func TestServerClient(t *testing.T) {
|
||||
}
|
||||
|
||||
resp, err := httpClient.Get(server.URL)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, resp.StatusCode, http.StatusOK)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
60
agent/pkg/env/env.go
vendored
60
agent/pkg/env/env.go
vendored
@@ -20,35 +20,6 @@ var (
|
||||
AgentSkipClientCertCheck bool
|
||||
AgentCACert string
|
||||
AgentSSLCert string
|
||||
|
||||
DockerSocketAddr string
|
||||
DockerPost bool
|
||||
DockerRestarts bool
|
||||
DockerStart bool
|
||||
DockerStop bool
|
||||
DockerAuth bool
|
||||
DockerBuild bool
|
||||
DockerCommit bool
|
||||
DockerConfigs bool
|
||||
DockerContainers bool
|
||||
DockerDistribution bool
|
||||
DockerEvents bool
|
||||
DockerExec bool
|
||||
DockerGrpc bool
|
||||
DockerImages bool
|
||||
DockerInfo bool
|
||||
DockerNetworks bool
|
||||
DockerNodes bool
|
||||
DockerPing bool
|
||||
DockerPlugins bool
|
||||
DockerSecrets bool
|
||||
DockerServices bool
|
||||
DockerSession bool
|
||||
DockerSwarm bool
|
||||
DockerSystem bool
|
||||
DockerTasks bool
|
||||
DockerVersion bool
|
||||
DockerVolumes bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -62,35 +33,4 @@ func Load() {
|
||||
|
||||
AgentCACert = common.GetEnvString("AGENT_CA_CERT", "")
|
||||
AgentSSLCert = common.GetEnvString("AGENT_SSL_CERT", "")
|
||||
|
||||
// docker socket proxy
|
||||
DockerSocketAddr = common.GetEnvString("DOCKER_SOCKET_ADDR", "127.0.0.1:2375")
|
||||
|
||||
DockerPost = common.GetEnvBool("POST", false)
|
||||
DockerRestarts = common.GetEnvBool("ALLOW_RESTARTS", false)
|
||||
DockerStart = common.GetEnvBool("ALLOW_START", false)
|
||||
DockerStop = common.GetEnvBool("ALLOW_STOP", false)
|
||||
DockerAuth = common.GetEnvBool("AUTH", false)
|
||||
DockerBuild = common.GetEnvBool("BUILD", false)
|
||||
DockerCommit = common.GetEnvBool("COMMIT", false)
|
||||
DockerConfigs = common.GetEnvBool("CONFIGS", false)
|
||||
DockerContainers = common.GetEnvBool("CONTAINERS", false)
|
||||
DockerDistribution = common.GetEnvBool("DISTRIBUTION", false)
|
||||
DockerEvents = common.GetEnvBool("EVENTS", true)
|
||||
DockerExec = common.GetEnvBool("EXEC", false)
|
||||
DockerGrpc = common.GetEnvBool("GRPC", false)
|
||||
DockerImages = common.GetEnvBool("IMAGES", false)
|
||||
DockerInfo = common.GetEnvBool("INFO", false)
|
||||
DockerNetworks = common.GetEnvBool("NETWORKS", false)
|
||||
DockerNodes = common.GetEnvBool("NODES", false)
|
||||
DockerPing = common.GetEnvBool("PING", true)
|
||||
DockerPlugins = common.GetEnvBool("PLUGINS", false)
|
||||
DockerSecrets = common.GetEnvBool("SECRETS", false)
|
||||
DockerServices = common.GetEnvBool("SERVICES", false)
|
||||
DockerSession = common.GetEnvBool("SESSION", false)
|
||||
DockerSwarm = common.GetEnvBool("SWARM", false)
|
||||
DockerSystem = common.GetEnvBool("SYSTEM", false)
|
||||
DockerTasks = common.GetEnvBool("TASKS", false)
|
||||
DockerVersion = common.GetEnvBool("VERSION", true)
|
||||
DockerVolumes = common.GetEnvBool("VOLUMES", false)
|
||||
}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
|
||||
)
|
||||
@@ -73,5 +73,7 @@ func CheckHealth(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
gphttp.RespondJSON(w, r, result)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
|
||||
@@ -1,414 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/agent/pkg/env"
|
||||
)
|
||||
|
||||
func TestNewDockerHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
envSetup func()
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "GET _ping allowed by default",
|
||||
method: http.MethodGet,
|
||||
path: "/_ping",
|
||||
envSetup: func() {},
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "GET version allowed by default",
|
||||
method: http.MethodGet,
|
||||
path: "/version",
|
||||
envSetup: func() {},
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "GET containers allowed when enabled",
|
||||
method: http.MethodGet,
|
||||
path: "/containers",
|
||||
envSetup: func() {
|
||||
env.DockerContainers = true
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "GET containers not allowed when disabled",
|
||||
method: http.MethodGet,
|
||||
path: "/containers",
|
||||
envSetup: func() {
|
||||
env.DockerContainers = false
|
||||
},
|
||||
wantStatusCode: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "POST not allowed by default",
|
||||
method: http.MethodPost,
|
||||
path: "/_ping",
|
||||
envSetup: func() {
|
||||
env.DockerPost = false
|
||||
},
|
||||
wantStatusCode: http.StatusMethodNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "POST allowed when enabled",
|
||||
method: http.MethodPost,
|
||||
path: "/_ping",
|
||||
envSetup: func() {
|
||||
env.DockerPost = true
|
||||
env.DockerPing = true
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Container restart not allowed when disabled",
|
||||
method: http.MethodPost,
|
||||
path: "/containers/test-container/restart",
|
||||
envSetup: func() {
|
||||
env.DockerPost = true
|
||||
env.DockerContainers = true
|
||||
env.DockerRestarts = false
|
||||
},
|
||||
wantStatusCode: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "Container restart allowed when enabled",
|
||||
method: http.MethodPost,
|
||||
path: "/containers/test-container/restart",
|
||||
envSetup: func() {
|
||||
env.DockerPost = true
|
||||
env.DockerContainers = true
|
||||
env.DockerRestarts = true
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Container start not allowed when disabled",
|
||||
method: http.MethodPost,
|
||||
path: "/containers/test-container/start",
|
||||
envSetup: func() {
|
||||
env.DockerPost = true
|
||||
env.DockerContainers = true
|
||||
env.DockerStart = false
|
||||
},
|
||||
wantStatusCode: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "Container start allowed when enabled",
|
||||
method: http.MethodPost,
|
||||
path: "/containers/test-container/start",
|
||||
envSetup: func() {
|
||||
env.DockerPost = true
|
||||
env.DockerContainers = true
|
||||
env.DockerStart = true
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Container stop not allowed when disabled",
|
||||
method: http.MethodPost,
|
||||
path: "/containers/test-container/stop",
|
||||
envSetup: func() {
|
||||
env.DockerPost = true
|
||||
env.DockerContainers = true
|
||||
env.DockerStop = false
|
||||
},
|
||||
wantStatusCode: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "Container stop allowed when enabled",
|
||||
method: http.MethodPost,
|
||||
path: "/containers/test-container/stop",
|
||||
envSetup: func() {
|
||||
env.DockerPost = true
|
||||
env.DockerContainers = true
|
||||
env.DockerStop = true
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Versioned API paths work",
|
||||
method: http.MethodGet,
|
||||
path: "/v1.41/version",
|
||||
envSetup: func() {
|
||||
env.DockerVersion = true
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "PUT method not allowed",
|
||||
method: http.MethodPut,
|
||||
path: "/version",
|
||||
envSetup: func() {
|
||||
env.DockerVersion = true
|
||||
},
|
||||
wantStatusCode: http.StatusMethodNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "DELETE method not allowed",
|
||||
method: http.MethodDelete,
|
||||
path: "/version",
|
||||
envSetup: func() {
|
||||
env.DockerVersion = true
|
||||
},
|
||||
wantStatusCode: http.StatusMethodNotAllowed,
|
||||
},
|
||||
}
|
||||
|
||||
// Save original env values to restore after tests
|
||||
originalContainers := env.DockerContainers
|
||||
originalRestarts := env.DockerRestarts
|
||||
originalStart := env.DockerStart
|
||||
originalStop := env.DockerStop
|
||||
originalPost := env.DockerPost
|
||||
originalPing := env.DockerPing
|
||||
originalVersion := env.DockerVersion
|
||||
|
||||
defer func() {
|
||||
// Restore original values
|
||||
env.DockerContainers = originalContainers
|
||||
env.DockerRestarts = originalRestarts
|
||||
env.DockerStart = originalStart
|
||||
env.DockerStop = originalStop
|
||||
env.DockerPost = originalPost
|
||||
env.DockerPing = originalPing
|
||||
env.DockerVersion = originalVersion
|
||||
}()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Setup environment for this test
|
||||
tt.envSetup()
|
||||
|
||||
// Create test handler that will record the response for verification
|
||||
dockerHandler := NewDockerHandler()
|
||||
|
||||
// Test server to capture the response
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
// Create request
|
||||
req, err := http.NewRequest(tt.method, tt.path, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
|
||||
// Process the request
|
||||
dockerHandler.ServeHTTP(recorder, req)
|
||||
|
||||
// Check response
|
||||
if recorder.Code != tt.wantStatusCode {
|
||||
t.Errorf("Expected status code %d, got %d",
|
||||
tt.wantStatusCode, recorder.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// This test focuses on checking that all the path prefix handling works correctly
|
||||
func TestNewDockerHandler_PathHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
envVarName string
|
||||
envVarValue bool
|
||||
method string
|
||||
wantAllowed bool
|
||||
}{
|
||||
{"Container path", "/containers/json", "DockerContainers", true, http.MethodGet, true},
|
||||
{"Container path disabled", "/containers/json", "DockerContainers", false, http.MethodGet, false},
|
||||
|
||||
{"Auth path", "/auth", "DockerAuth", true, http.MethodGet, true},
|
||||
{"Auth path disabled", "/auth", "DockerAuth", false, http.MethodGet, false},
|
||||
|
||||
{"Build path", "/build", "DockerBuild", true, http.MethodGet, true},
|
||||
{"Build path disabled", "/build", "DockerBuild", false, http.MethodGet, false},
|
||||
|
||||
{"Commit path", "/commit", "DockerCommit", true, http.MethodGet, true},
|
||||
{"Commit path disabled", "/commit", "DockerCommit", false, http.MethodGet, false},
|
||||
|
||||
{"Configs path", "/configs", "DockerConfigs", true, http.MethodGet, true},
|
||||
{"Configs path disabled", "/configs", "DockerConfigs", false, http.MethodGet, false},
|
||||
|
||||
{"Distribution path", "/distribution", "DockerDistribution", true, http.MethodGet, true},
|
||||
{"Distribution path disabled", "/distribution", "DockerDistribution", false, http.MethodGet, false},
|
||||
|
||||
{"Events path", "/events", "DockerEvents", true, http.MethodGet, true},
|
||||
{"Events path disabled", "/events", "DockerEvents", false, http.MethodGet, false},
|
||||
|
||||
{"Exec path", "/exec", "DockerExec", true, http.MethodGet, true},
|
||||
{"Exec path disabled", "/exec", "DockerExec", false, http.MethodGet, false},
|
||||
|
||||
{"Grpc path", "/grpc", "DockerGrpc", true, http.MethodGet, true},
|
||||
{"Grpc path disabled", "/grpc", "DockerGrpc", false, http.MethodGet, false},
|
||||
|
||||
{"Images path", "/images", "DockerImages", true, http.MethodGet, true},
|
||||
{"Images path disabled", "/images", "DockerImages", false, http.MethodGet, false},
|
||||
|
||||
{"Info path", "/info", "DockerInfo", true, http.MethodGet, true},
|
||||
{"Info path disabled", "/info", "DockerInfo", false, http.MethodGet, false},
|
||||
|
||||
{"Networks path", "/networks", "DockerNetworks", true, http.MethodGet, true},
|
||||
{"Networks path disabled", "/networks", "DockerNetworks", false, http.MethodGet, false},
|
||||
|
||||
{"Nodes path", "/nodes", "DockerNodes", true, http.MethodGet, true},
|
||||
{"Nodes path disabled", "/nodes", "DockerNodes", false, http.MethodGet, false},
|
||||
|
||||
{"Plugins path", "/plugins", "DockerPlugins", true, http.MethodGet, true},
|
||||
{"Plugins path disabled", "/plugins", "DockerPlugins", false, http.MethodGet, false},
|
||||
|
||||
{"Secrets path", "/secrets", "DockerSecrets", true, http.MethodGet, true},
|
||||
{"Secrets path disabled", "/secrets", "DockerSecrets", false, http.MethodGet, false},
|
||||
|
||||
{"Services path", "/services", "DockerServices", true, http.MethodGet, true},
|
||||
{"Services path disabled", "/services", "DockerServices", false, http.MethodGet, false},
|
||||
|
||||
{"Session path", "/session", "DockerSession", true, http.MethodGet, true},
|
||||
{"Session path disabled", "/session", "DockerSession", false, http.MethodGet, false},
|
||||
|
||||
{"Swarm path", "/swarm", "DockerSwarm", true, http.MethodGet, true},
|
||||
{"Swarm path disabled", "/swarm", "DockerSwarm", false, http.MethodGet, false},
|
||||
|
||||
{"System path", "/system", "DockerSystem", true, http.MethodGet, true},
|
||||
{"System path disabled", "/system", "DockerSystem", false, http.MethodGet, false},
|
||||
|
||||
{"Tasks path", "/tasks", "DockerTasks", true, http.MethodGet, true},
|
||||
{"Tasks path disabled", "/tasks", "DockerTasks", false, http.MethodGet, false},
|
||||
|
||||
{"Volumes path", "/volumes", "DockerVolumes", true, http.MethodGet, true},
|
||||
{"Volumes path disabled", "/volumes", "DockerVolumes", false, http.MethodGet, false},
|
||||
|
||||
// Test versioned paths
|
||||
{"Versioned auth", "/v1.41/auth", "DockerAuth", true, http.MethodGet, true},
|
||||
{"Versioned auth disabled", "/v1.41/auth", "DockerAuth", false, http.MethodGet, false},
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// Restore original env values
|
||||
env.Load()
|
||||
}()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset all Docker* env vars to false for this test
|
||||
env.Load()
|
||||
|
||||
// Enable POST for all these tests
|
||||
env.DockerPost = true
|
||||
|
||||
// Set the specific env var for this test
|
||||
switch tt.envVarName {
|
||||
case "DockerContainers":
|
||||
env.DockerContainers = tt.envVarValue
|
||||
case "DockerRestarts":
|
||||
env.DockerRestarts = tt.envVarValue
|
||||
case "DockerStart":
|
||||
env.DockerStart = tt.envVarValue
|
||||
case "DockerStop":
|
||||
env.DockerStop = tt.envVarValue
|
||||
case "DockerAuth":
|
||||
env.DockerAuth = tt.envVarValue
|
||||
case "DockerBuild":
|
||||
env.DockerBuild = tt.envVarValue
|
||||
case "DockerCommit":
|
||||
env.DockerCommit = tt.envVarValue
|
||||
case "DockerConfigs":
|
||||
env.DockerConfigs = tt.envVarValue
|
||||
case "DockerDistribution":
|
||||
env.DockerDistribution = tt.envVarValue
|
||||
case "DockerEvents":
|
||||
env.DockerEvents = tt.envVarValue
|
||||
case "DockerExec":
|
||||
env.DockerExec = tt.envVarValue
|
||||
case "DockerGrpc":
|
||||
env.DockerGrpc = tt.envVarValue
|
||||
case "DockerImages":
|
||||
env.DockerImages = tt.envVarValue
|
||||
case "DockerInfo":
|
||||
env.DockerInfo = tt.envVarValue
|
||||
case "DockerNetworks":
|
||||
env.DockerNetworks = tt.envVarValue
|
||||
case "DockerNodes":
|
||||
env.DockerNodes = tt.envVarValue
|
||||
case "DockerPlugins":
|
||||
env.DockerPlugins = tt.envVarValue
|
||||
case "DockerSecrets":
|
||||
env.DockerSecrets = tt.envVarValue
|
||||
case "DockerServices":
|
||||
env.DockerServices = tt.envVarValue
|
||||
case "DockerSession":
|
||||
env.DockerSession = tt.envVarValue
|
||||
case "DockerSwarm":
|
||||
env.DockerSwarm = tt.envVarValue
|
||||
case "DockerSystem":
|
||||
env.DockerSystem = tt.envVarValue
|
||||
case "DockerTasks":
|
||||
env.DockerTasks = tt.envVarValue
|
||||
case "DockerVolumes":
|
||||
env.DockerVolumes = tt.envVarValue
|
||||
default:
|
||||
t.Fatalf("Unknown env var: %s", tt.envVarName)
|
||||
}
|
||||
|
||||
// Create test handler
|
||||
dockerHandler := NewDockerHandler()
|
||||
|
||||
// Test server to capture the response
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
// Create request
|
||||
req, err := http.NewRequest(tt.method, tt.path, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
|
||||
// Process the request
|
||||
dockerHandler.ServeHTTP(recorder, req)
|
||||
|
||||
// Check if the status indicates if the path is allowed or not
|
||||
isAllowed := recorder.Code != http.StatusForbidden
|
||||
if isAllowed != tt.wantAllowed {
|
||||
t.Errorf("Path %s with env %s=%v: got allowed=%v, want allowed=%v (status=%d)",
|
||||
tt.path, tt.envVarName, tt.envVarValue, isAllowed, tt.wantAllowed, recorder.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewDockerHandlerWithMockDocker mocks the Docker API to test the actual HTTP handler behavior
|
||||
// This is a more comprehensive test that verifies the full request/response chain
|
||||
func TestNewDockerHandlerWithMockDocker(t *testing.T) {
|
||||
// Set up environment
|
||||
env.DockerContainers = true
|
||||
env.DockerPost = true
|
||||
|
||||
// Create the handler
|
||||
handler := NewDockerHandler()
|
||||
|
||||
// Test a valid request
|
||||
req, _ := http.NewRequest(http.MethodGet, "/containers", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Errorf("Expected status OK for /containers, got %d", recorder.Code)
|
||||
}
|
||||
|
||||
// Test a disallowed path
|
||||
env.DockerContainers = false
|
||||
handler = NewDockerHandler() // recreate with new env
|
||||
|
||||
req, _ = http.NewRequest(http.MethodGet, "/containers", nil)
|
||||
recorder = httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusForbidden {
|
||||
t.Errorf("Expected status Forbidden for /containers when disabled, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/docker"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
)
|
||||
|
||||
func serviceUnavailable(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "docker socket is not available", http.StatusServiceUnavailable)
|
||||
}
|
||||
|
||||
func mockDockerSocketHandler() http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("mock docker response"))
|
||||
})
|
||||
}
|
||||
|
||||
func DockerSocketHandler() http.HandlerFunc {
|
||||
dockerClient, err := docker.NewClient(common.DockerHostFromEnv)
|
||||
if err != nil {
|
||||
logging.Warn().Err(err).Msg("failed to connect to docker client")
|
||||
return serviceUnavailable
|
||||
}
|
||||
rp := reverseproxy.NewReverseProxy("docker", types.NewURL(&url.URL{
|
||||
Scheme: "http",
|
||||
Host: client.DummyHost,
|
||||
}), dockerClient.HTTPClient().Transport)
|
||||
|
||||
return rp.ServeHTTP
|
||||
}
|
||||
@@ -2,201 +2,35 @@ package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/yusing/go-proxy/agent/pkg/agent"
|
||||
"github.com/yusing/go-proxy/agent/pkg/env"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/logging/memlogger"
|
||||
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
"github.com/yusing/go-proxy/pkg"
|
||||
socketproxy "github.com/yusing/go-proxy/socketproxy/pkg"
|
||||
)
|
||||
|
||||
type ServeMux struct{ *http.ServeMux }
|
||||
|
||||
func (mux ServeMux) HandleMethods(methods, endpoint string, handler http.HandlerFunc) {
|
||||
for _, m := range strutils.CommaSeperatedList(methods) {
|
||||
mux.ServeMux.HandleFunc(m+" "+agent.APIEndpointBase+endpoint, handler)
|
||||
}
|
||||
func (mux ServeMux) HandleEndpoint(method, endpoint string, handler http.HandlerFunc) {
|
||||
mux.ServeMux.HandleFunc(method+" "+agent.APIEndpointBase+endpoint, handler)
|
||||
}
|
||||
|
||||
func (mux ServeMux) HandleFunc(endpoint string, handler http.HandlerFunc) {
|
||||
mux.ServeMux.HandleFunc(agent.APIEndpointBase+endpoint, handler)
|
||||
}
|
||||
|
||||
type NopWriteCloser struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (NopWriteCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewAgentHandler() http.Handler {
|
||||
mux := ServeMux{http.NewServeMux()}
|
||||
|
||||
mux.HandleFunc(agent.EndpointProxyHTTP+"/{path...}", ProxyHTTP)
|
||||
mux.HandleMethods("GET", agent.EndpointVersion, pkg.GetVersionHTTPHandler())
|
||||
mux.HandleMethods("GET", agent.EndpointName, func(w http.ResponseWriter, r *http.Request) {
|
||||
mux.HandleEndpoint("GET", agent.EndpointVersion, pkg.GetVersionHTTPHandler())
|
||||
mux.HandleEndpoint("GET", agent.EndpointName, func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, env.AgentName)
|
||||
})
|
||||
mux.HandleMethods("GET", agent.EndpointHealth, CheckHealth)
|
||||
mux.HandleMethods("GET", agent.EndpointLogs, memlogger.HandlerFunc())
|
||||
mux.HandleMethods("GET", agent.EndpointSystemInfo, systeminfo.Poller.ServeHTTP)
|
||||
mux.ServeMux.HandleFunc("/", DockerSocketHandler())
|
||||
mux.HandleEndpoint("GET", agent.EndpointHealth, CheckHealth)
|
||||
mux.HandleEndpoint("GET", agent.EndpointSystemInfo, systeminfo.Poller.ServeHTTP)
|
||||
mux.ServeMux.HandleFunc("/", socketproxy.DockerSocketHandler())
|
||||
return mux
|
||||
}
|
||||
|
||||
func endpointNotAllowed(w http.ResponseWriter, _ *http.Request) {
|
||||
http.Error(w, "Endpoint not allowed", http.StatusForbidden)
|
||||
}
|
||||
|
||||
// ref: https://github.com/Tecnativa/docker-socket-proxy/blob/master/haproxy.cfg
|
||||
func NewDockerHandler() http.Handler {
|
||||
r := mux.NewRouter()
|
||||
var socketHandler http.HandlerFunc
|
||||
if common.IsTest {
|
||||
socketHandler = mockDockerSocketHandler()
|
||||
} else {
|
||||
socketHandler = DockerSocketHandler()
|
||||
}
|
||||
|
||||
const apiVersionPrefix = `/{version:(?:v[\d\.]+)?}`
|
||||
const containerPath = "/containers/{id:[a-zA-Z0-9_.-]+}"
|
||||
|
||||
allowedPaths := []string{}
|
||||
deniedPaths := []string{}
|
||||
|
||||
if env.DockerContainers {
|
||||
allowedPaths = append(allowedPaths, "/containers")
|
||||
if !env.DockerRestarts {
|
||||
deniedPaths = append(deniedPaths, containerPath+"/stop")
|
||||
deniedPaths = append(deniedPaths, containerPath+"/restart")
|
||||
deniedPaths = append(deniedPaths, containerPath+"/kill")
|
||||
}
|
||||
if !env.DockerStart {
|
||||
deniedPaths = append(deniedPaths, containerPath+"/start")
|
||||
}
|
||||
if !env.DockerStop && env.DockerRestarts {
|
||||
deniedPaths = append(deniedPaths, containerPath+"/stop")
|
||||
}
|
||||
}
|
||||
if env.DockerAuth {
|
||||
allowedPaths = append(allowedPaths, "/auth")
|
||||
}
|
||||
if env.DockerBuild {
|
||||
allowedPaths = append(allowedPaths, "/build")
|
||||
}
|
||||
if env.DockerCommit {
|
||||
allowedPaths = append(allowedPaths, "/commit")
|
||||
}
|
||||
if env.DockerConfigs {
|
||||
allowedPaths = append(allowedPaths, "/configs")
|
||||
}
|
||||
if env.DockerDistribution {
|
||||
allowedPaths = append(allowedPaths, "/distribution")
|
||||
}
|
||||
if env.DockerEvents {
|
||||
allowedPaths = append(allowedPaths, "/events")
|
||||
}
|
||||
if env.DockerExec {
|
||||
allowedPaths = append(allowedPaths, "/exec")
|
||||
}
|
||||
if env.DockerGrpc {
|
||||
allowedPaths = append(allowedPaths, "/grpc")
|
||||
}
|
||||
if env.DockerImages {
|
||||
allowedPaths = append(allowedPaths, "/images")
|
||||
}
|
||||
if env.DockerInfo {
|
||||
allowedPaths = append(allowedPaths, "/info")
|
||||
}
|
||||
if env.DockerNetworks {
|
||||
allowedPaths = append(allowedPaths, "/networks")
|
||||
}
|
||||
if env.DockerNodes {
|
||||
allowedPaths = append(allowedPaths, "/nodes")
|
||||
}
|
||||
if env.DockerPing {
|
||||
allowedPaths = append(allowedPaths, "/_ping")
|
||||
}
|
||||
if env.DockerPlugins {
|
||||
allowedPaths = append(allowedPaths, "/plugins")
|
||||
}
|
||||
if env.DockerSecrets {
|
||||
allowedPaths = append(allowedPaths, "/secrets")
|
||||
}
|
||||
if env.DockerServices {
|
||||
allowedPaths = append(allowedPaths, "/services")
|
||||
}
|
||||
if env.DockerSession {
|
||||
allowedPaths = append(allowedPaths, "/session")
|
||||
}
|
||||
if env.DockerSwarm {
|
||||
allowedPaths = append(allowedPaths, "/swarm")
|
||||
}
|
||||
if env.DockerSystem {
|
||||
allowedPaths = append(allowedPaths, "/system")
|
||||
}
|
||||
if env.DockerTasks {
|
||||
allowedPaths = append(allowedPaths, "/tasks")
|
||||
}
|
||||
if env.DockerVersion {
|
||||
allowedPaths = append(allowedPaths, "/version")
|
||||
}
|
||||
if env.DockerVolumes {
|
||||
allowedPaths = append(allowedPaths, "/volumes")
|
||||
}
|
||||
|
||||
// Helper to determine if a path should be treated as a prefix
|
||||
isPrefixPath := func(path string) bool {
|
||||
return strings.Count(path, "/") == 1
|
||||
}
|
||||
|
||||
// 1. Register Denied Paths (specific)
|
||||
for _, path := range deniedPaths {
|
||||
// Handle with version prefix
|
||||
r.HandleFunc(apiVersionPrefix+path, endpointNotAllowed)
|
||||
// Handle without version prefix
|
||||
r.HandleFunc(path, endpointNotAllowed)
|
||||
}
|
||||
|
||||
// 2. Register Allowed Paths
|
||||
for _, p := range allowedPaths {
|
||||
fullPathWithVersion := apiVersionPrefix + p
|
||||
if isPrefixPath(p) {
|
||||
r.PathPrefix(fullPathWithVersion).Handler(socketHandler)
|
||||
r.PathPrefix(p).Handler(socketHandler)
|
||||
} else {
|
||||
r.HandleFunc(fullPathWithVersion, socketHandler)
|
||||
r.HandleFunc(p, socketHandler)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Add fallback for any other routes
|
||||
r.PathPrefix("/").HandlerFunc(endpointNotAllowed)
|
||||
|
||||
// HTTP method filtering
|
||||
if !env.DockerPost {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
switch req.Method {
|
||||
case http.MethodGet:
|
||||
r.ServeHTTP(w, req)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
switch req.Method {
|
||||
case http.MethodPost, http.MethodGet:
|
||||
r.ServeHTTP(w, req)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,18 +3,26 @@ package handler
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"net/http/httputil"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/agent/pkg/agent"
|
||||
"github.com/yusing/go-proxy/agent/pkg/agentproxy"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
)
|
||||
|
||||
func NewTransport() *http.Transport {
|
||||
return &http.Transport{
|
||||
MaxIdleConnsPerHost: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ResponseHeaderTimeout: 60 * time.Second,
|
||||
WriteBufferSize: 16 * 1024, // 16KB
|
||||
ReadBufferSize: 16 * 1024, // 16KB
|
||||
}
|
||||
}
|
||||
|
||||
func ProxyHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
host := r.Header.Get(agentproxy.HeaderXProxyHost)
|
||||
isHTTPS, _ := strconv.ParseBool(r.Header.Get(agentproxy.HeaderXProxyHTTPS))
|
||||
@@ -34,11 +42,9 @@ func ProxyHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
var transport *http.Transport
|
||||
transport := NewTransport()
|
||||
if skipTLSVerify {
|
||||
transport = gphttp.NewTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true})
|
||||
} else {
|
||||
transport = gphttp.NewTransport()
|
||||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
}
|
||||
|
||||
if responseHeaderTimeout > 0 {
|
||||
@@ -49,14 +55,13 @@ func ProxyHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
r.URL.Host = ""
|
||||
r.URL.Path = r.URL.Path[agent.HTTPProxyURLPrefixLen:] // strip the {API_BASE}/proxy/http prefix
|
||||
r.RequestURI = r.URL.String()
|
||||
r.URL.Host = host
|
||||
r.URL.Scheme = scheme
|
||||
|
||||
logging.Debug().Msgf("proxy http request: %s %s", r.Method, r.URL.String())
|
||||
|
||||
rp := reverseproxy.NewReverseProxy("agent", types.NewURL(&url.URL{
|
||||
Scheme: scheme,
|
||||
Host: host,
|
||||
}), transport)
|
||||
rp := &httputil.ReverseProxy{
|
||||
Director: func(r *http.Request) {
|
||||
r.URL.Scheme = scheme
|
||||
r.URL.Host = host
|
||||
},
|
||||
Transport: transport,
|
||||
}
|
||||
rp.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user