diff --git a/internal/docker/client.go b/internal/docker/client.go index fa898845..20fcd0e2 100644 --- a/internal/docker/client.go +++ b/internal/docker/client.go @@ -1,7 +1,10 @@ package docker import ( + "context" "errors" + "fmt" + "net" "net/http" "sync" "sync/atomic" @@ -9,7 +12,9 @@ import ( "github.com/docker/cli/cli/connhelper" "github.com/docker/docker/client" + "github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/internal/common" + config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" ) @@ -21,11 +26,14 @@ type ( key string refCount uint32 closedOn int64 + + addr string + dial func(ctx context.Context) (net.Conn, error) } ) var ( - clientMap = make(map[string]*SharedClient, 5) + clientMap = make(map[string]*SharedClient, 10) clientMapMu sync.RWMutex clientOptEnvHost = []client.Opt{ @@ -74,10 +82,7 @@ func closeTimedOutClients() { now := time.Now().Unix() for _, c := range clientMap { - if c.closedOn == 0 { - continue - } - if c.refCount == 0 && now-c.closedOn > clientTTLSecs { + if atomic.LoadUint32(&c.refCount) == 0 && now-atomic.LoadInt64(&c.closedOn) > clientTTLSecs { delete(clientMap, c.key) c.Client.Close() logging.Debug().Str("host", c.key).Msg("docker client closed") @@ -85,13 +90,26 @@ func closeTimedOutClients() { } } +func (c *SharedClient) Address() string { + return c.addr +} + +func (c *SharedClient) CheckConnection(ctx context.Context) error { + conn, err := c.dial(ctx) + if err != nil { + return err + } + conn.Close() + return nil +} + // if the client is still referenced, this is no-op. func (c *SharedClient) Close() { atomic.StoreInt64(&c.closedOn, time.Now().Unix()) atomic.AddUint32(&c.refCount, ^uint32(0)) } -// ConnectClient creates a new Docker client connection to the specified host. +// NewClient creates a new Docker client connection to the specified host. // // Returns existing client if available. // @@ -101,7 +119,7 @@ func (c *SharedClient) Close() { // Returns: // - Client: the Docker client connection. // - error: an error if the connection failed. -func ConnectClient(host string) (*SharedClient, error) { +func NewClient(host string) (*SharedClient, error) { clientMapMu.Lock() defer clientMapMu.Unlock() @@ -113,33 +131,49 @@ func ConnectClient(host string) (*SharedClient, error) { // create client var opt []client.Opt + var addr string + var dial func(ctx context.Context) (net.Conn, error) - switch host { - case "": - return nil, errors.New("empty docker host") - case common.DockerHostFromEnv: - opt = clientOptEnvHost - default: - helper, err := connhelper.GetConnectionHelper(host) - if err != nil { - logging.Panic().Err(err).Msg("failed to get connection helper") + if agent.IsDockerHostAgent(host) { + cfg, ok := config.GetInstance().GetAgent(host) + if !ok { + panic(fmt.Errorf("agent %q not found", host)) } - if helper != nil { - httpClient := &http.Client{ - Transport: &http.Transport{ - DialContext: helper.Dialer, - }, + opt = []client.Opt{ + client.WithHost(agent.DockerHost), + client.WithHTTPClient(cfg.NewHTTPClient()), + client.WithAPIVersionNegotiation(), + } + addr = "tcp://" + cfg.Addr + dial = cfg.DialContext + } else { + switch host { + case "": + return nil, errors.New("empty docker host") + case common.DockerHostFromEnv: + opt = clientOptEnvHost + default: + helper, err := connhelper.GetConnectionHelper(host) + if err != nil { + logging.Panic().Err(err).Msg("failed to get connection helper") } - opt = []client.Opt{ - client.WithHTTPClient(httpClient), - client.WithHost(helper.Host), - client.WithAPIVersionNegotiation(), - client.WithDialContext(helper.Dialer), - } - } else { - opt = []client.Opt{ - client.WithHost(host), - client.WithAPIVersionNegotiation(), + if helper != nil { + httpClient := &http.Client{ + Transport: &http.Transport{ + DialContext: helper.Dialer, + }, + } + opt = []client.Opt{ + client.WithHTTPClient(httpClient), + client.WithHost(helper.Host), + client.WithAPIVersionNegotiation(), + client.WithDialContext(helper.Dialer), + } + } else { + opt = []client.Opt{ + client.WithHost(host), + client.WithAPIVersionNegotiation(), + } } } } @@ -153,9 +187,16 @@ func ConnectClient(host string) (*SharedClient, error) { Client: client, key: host, refCount: 1, + addr: addr, + dial: dial, } - defer logging.Debug().Str("host", host).Msg("docker client connected") + // non-agent client + if c.dial == nil { + c.dial = client.Dialer() + } + + defer logging.Debug().Str("host", host).Msg("docker client initialized") clientMap[c.key] = c return c, nil diff --git a/internal/docker/list_containers.go b/internal/docker/list_containers.go index ba9b96e8..517b59d8 100644 --- a/internal/docker/list_containers.go +++ b/internal/docker/list_containers.go @@ -5,7 +5,6 @@ import ( "errors" "time" - "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" ) @@ -22,8 +21,8 @@ var listOptions = container.ListOptions{ All: true, } -func ListContainers(clientHost string) ([]types.Container, error) { - dockerClient, err := ConnectClient(clientHost) +func ListContainers(clientHost string) ([]container.Summary, error) { + dockerClient, err := NewClient(clientHost) if err != nil { return nil, err } diff --git a/internal/route/provider/docker.go b/internal/route/provider/docker.go index 5fd33a10..e08cfc8e 100755 --- a/internal/route/provider/docker.go +++ b/internal/route/provider/docker.go @@ -29,7 +29,7 @@ const ( var ErrAliasRefIndexOutOfRange = gperr.New("index out of range") -func DockerProviderImpl(name, dockerHost string) (ProviderImpl, error) { +func DockerProviderImpl(name, dockerHost string) ProviderImpl { if dockerHost == common.DockerHostFromEnv { dockerHost = common.GetEnvString("DOCKER_HOST", client.DefaultDockerHost) } @@ -37,7 +37,7 @@ func DockerProviderImpl(name, dockerHost string) (ProviderImpl, error) { name, dockerHost, logging.With().Str("type", "docker").Str("name", name).Logger(), - }, nil + } } func (p *DockerProvider) String() string { @@ -61,6 +61,7 @@ func (p *DockerProvider) NewWatcher() watcher.Watcher { } func (p *DockerProvider) loadRoutesImpl() (route.Routes, gperr.Error) { + containers, err := docker.ListContainers(p.dockerHost) if err != nil { return nil, gperr.Wrap(err) } diff --git a/internal/route/provider/docker_test.go b/internal/route/provider/docker_test.go index 67fafa9a..cac066eb 100644 --- a/internal/route/provider/docker_test.go +++ b/internal/route/provider/docker_test.go @@ -39,8 +39,7 @@ func makeRoutes(cont *types.Container, dockerHostIP ...string) route.Routes { } func TestExplicitOnly(t *testing.T) { - p, err := NewDockerProvider("a!", "") - ExpectNoError(t, err) + p := NewDockerProvider("a!", "") ExpectTrue(t, p.IsExplicitOnly()) } @@ -258,16 +257,16 @@ func TestPublicIPLocalhost(t *testing.T) { c := &types.Container{Names: dummyNames, State: "running"} r, ok := makeRoutes(c)["a"] ExpectTrue(t, ok) - ExpectEqual(t, r.Container.PublicIP, "127.0.0.1") - ExpectEqual(t, r.Host, r.Container.PublicIP) + ExpectEqual(t, r.Container.PublicHostname, "127.0.0.1") + ExpectEqual(t, r.Host, r.Container.PublicHostname) } func TestPublicIPRemote(t *testing.T) { c := &types.Container{Names: dummyNames, State: "running"} raw, ok := makeRoutes(c, testIP)["a"] ExpectTrue(t, ok) - ExpectEqual(t, raw.Container.PublicIP, testIP) - ExpectEqual(t, raw.Host, raw.Container.PublicIP) + ExpectEqual(t, raw.Container.PublicHostname, testIP) + ExpectEqual(t, raw.Host, raw.Container.PublicHostname) } func TestPrivateIPLocalhost(t *testing.T) { @@ -283,8 +282,8 @@ func TestPrivateIPLocalhost(t *testing.T) { } r, ok := makeRoutes(c)["a"] ExpectTrue(t, ok) - ExpectEqual(t, r.Container.PrivateIP, testDockerIP) - ExpectEqual(t, r.Host, r.Container.PrivateIP) + ExpectEqual(t, r.Container.PrivateHostname, testDockerIP) + ExpectEqual(t, r.Host, r.Container.PrivateHostname) } func TestPrivateIPRemote(t *testing.T) { @@ -301,9 +300,9 @@ func TestPrivateIPRemote(t *testing.T) { } r, ok := makeRoutes(c, testIP)["a"] ExpectTrue(t, ok) - ExpectEqual(t, r.Container.PrivateIP, "") - ExpectEqual(t, r.Container.PublicIP, testIP) - ExpectEqual(t, r.Host, r.Container.PublicIP) + ExpectEqual(t, r.Container.PrivateHostname, "") + ExpectEqual(t, r.Container.PublicHostname, testIP) + ExpectEqual(t, r.Host, r.Container.PublicHostname) } func TestStreamDefaultValues(t *testing.T) { diff --git a/internal/route/provider/event_handler.go b/internal/route/provider/event_handler.go index 522df4ad..2f2b9395 100644 --- a/internal/route/provider/event_handler.go +++ b/internal/route/provider/event_handler.go @@ -6,6 +6,7 @@ import ( "github.com/yusing/go-proxy/internal/route/provider/types" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/watcher" + eventsPkg "github.com/yusing/go-proxy/internal/watcher/events" ) type EventHandler struct { @@ -29,32 +30,21 @@ func (p *Provider) newEventHandler() *EventHandler { func (handler *EventHandler) Handle(parent task.Parent, events []watcher.Event) { oldRoutes := handler.provider.routes - newRoutes, err := handler.provider.loadRoutes() - if err != nil { - handler.errs.Add(err) - if len(newRoutes) == 0 { - return + + isForceReload := false + for _, event := range events { + if event.Action == eventsPkg.ActionForceReload { + isForceReload = true + break } } - if common.IsDebug { - eventsLog := E.NewBuilder("events") - for _, event := range events { - eventsLog.Addf("event %s, actor: name=%s, id=%s", event.Action, event.ActorName, event.ActorID) + newRoutes, err := handler.provider.loadRoutes() + if err != nil { + handler.errs.Add(err) + if len(newRoutes) == 0 && !isForceReload { + return } - E.LogDebug(eventsLog.About(), eventsLog.Error(), handler.provider.Logger()) - - oldRoutesLog := E.NewBuilder("old routes") - for k := range oldRoutes { - oldRoutesLog.Adds(k) - } - E.LogDebug(oldRoutesLog.About(), oldRoutesLog.Error(), handler.provider.Logger()) - - newRoutesLog := E.NewBuilder("new routes") - for k := range newRoutes { - newRoutesLog.Adds(k) - } - E.LogDebug(newRoutesLog.About(), newRoutesLog.Error(), handler.provider.Logger()) } for k, oldr := range oldRoutes { @@ -84,7 +74,7 @@ func (handler *EventHandler) matchAny(events []watcher.Event, route *route.Route func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool { switch handler.provider.GetType() { - case types.ProviderTypeDocker: + case types.ProviderTypeDocker, types.ProviderTypeAgent: return route.Container.ContainerID == event.ActorID || route.Container.ContainerName == event.ActorName case types.ProviderTypeFile: diff --git a/internal/route/provider/provider.go b/internal/route/provider/provider.go index e41ad129..eed8ff25 100644 --- a/internal/route/provider/provider.go +++ b/internal/route/provider/provider.go @@ -58,16 +58,13 @@ func NewFileProvider(filename string) (p *Provider, err error) { return } -func NewDockerProvider(name string, dockerHost string) (p *Provider, err error) { - if name == "" { - return nil, ErrEmptyProviderName - } +func NewDockerProvider(name string, dockerHost string) *Provider { + p := newProvider(types.ProviderTypeDocker) + p.ProviderImpl = DockerProviderImpl(name, dockerHost) + p.watcher = p.NewWatcher() + return p +} - p = newProvider(types.ProviderTypeDocker) - p.ProviderImpl, err = DockerProviderImpl(name, dockerHost) - if err != nil { - return nil, err - } p.watcher = p.NewWatcher() return } @@ -151,6 +148,7 @@ func (p *Provider) loadRoutes() (routes route.Routes, err gperr.Error) { } if r.ShouldExclude() { delete(routes, alias) + continue } } return routes, errs.Error() diff --git a/internal/watcher/docker_watcher.go b/internal/watcher/docker_watcher.go index ae6b6532..4cae2186 100644 --- a/internal/watcher/docker_watcher.go +++ b/internal/watcher/docker_watcher.go @@ -2,12 +2,15 @@ package watcher import ( "context" + "errors" "time" docker_events "github.com/docker/docker/api/types/events" "github.com/docker/docker/api/types/filters" + "github.com/docker/docker/client" "github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/watcher/events" ) @@ -41,72 +44,110 @@ var ( )} dockerWatcherRetryInterval = 3 * time.Second + + reloadTrigger = Event{ + Type: events.EventTypeDocker, + Action: events.ActionForceReload, + ActorAttributes: map[string]string{}, + ActorName: "", + ActorID: "", + } ) func DockerFilterContainerNameID(nameOrID string) filters.KeyValuePair { return filters.Arg("container", nameOrID) } -func NewDockerWatcher(host string) DockerWatcher { - return DockerWatcher{host: host} +func NewDockerWatcher(host string) *DockerWatcher { + return &DockerWatcher{host: host} } func (w *DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan gperr.Error) { return w.EventsWithOptions(ctx, optionsDefault) } +func (w DockerWatcher) parseError(err error) gperr.Error { + if errors.Is(err, context.DeadlineExceeded) { + return gperr.New("docker client connection timeout") + } + if client.IsErrConnectionFailed(err) { + return gperr.New("docker client connection failure") + } + return gperr.Wrap(err) +} + +func (w *DockerWatcher) checkConnection(ctx context.Context) bool { + ctx, cancel := context.WithTimeout(ctx, dockerWatcherRetryInterval) + defer cancel() + err := w.client.CheckConnection(ctx) + if err != nil { + logging.Debug().Err(err).Msg("docker watcher: connection failed") + return false + } + return true +} + +func (w *DockerWatcher) handleEvent(event docker_events.Message, ch chan<- Event) { + action, ok := events.DockerEventMap[event.Action] + if !ok { + return + } + ch <- Event{ + Type: events.EventTypeDocker, + ActorID: event.Actor.ID, + ActorAttributes: event.Actor.Attributes, // labels + ActorName: event.Actor.Attributes["name"], + Action: action, + } +} + func (w *DockerWatcher) EventsWithOptions(ctx context.Context, options DockerListOptions) (<-chan Event, <-chan gperr.Error) { eventCh := make(chan Event) errCh := make(chan gperr.Error) go func() { + var err error + w.client, err = docker.NewClient(w.host) + if err != nil { + errCh <- gperr.Wrap(err, "docker watcher: failed to initialize client") + return + } + defer func() { - defer close(eventCh) - defer close(errCh) + close(eventCh) + close(errCh) w.client.Close() }() - client, err := docker.ConnectClient(w.host) - if err != nil { - errCh <- E.From(err) - return - } - w.client = client - cEventCh, cErrCh := w.client.Events(ctx, options) - + defer logging.Debug().Str("host", w.client.Address()).Msg("docker watcher closed") for { select { case <-ctx.Done(): - if err := E.From(ctx.Err()); err != nil && !err.Is(context.Canceled) { - errCh <- err - } return case msg := <-cEventCh: - action, ok := events.DockerEventMap[msg.Action] - if !ok { - continue - } - event := Event{ - Type: events.EventTypeDocker, - ActorID: msg.Actor.ID, - ActorAttributes: msg.Actor.Attributes, // labels - ActorName: msg.Actor.Attributes["name"], - Action: action, - } - eventCh <- event + w.handleEvent(msg, eventCh) case err := <-cErrCh: if err == nil { continue } - errCh <- E.From(err) - select { - case <-ctx.Done(): - return - default: - time.Sleep(dockerWatcherRetryInterval) - cEventCh, cErrCh = w.client.Events(ctx, options) + errCh <- w.parseError(err) + // release the error because reopening event channel may block + err = nil + // trigger reload (clear routes) + eventCh <- reloadTrigger + for !w.checkConnection(ctx) { + select { + case <-ctx.Done(): + return + case <-time.After(dockerWatcherRetryInterval): + continue + } } + // connection successful, trigger reload (reload routes) + eventCh <- reloadTrigger + // reopen event channel + cEventCh, cErrCh = w.client.Events(ctx, options) } } }()