diff --git a/internal/route/reverse_proxy.go b/internal/route/reverse_proxy.go index 95b4f457..2efdac5f 100755 --- a/internal/route/reverse_proxy.go +++ b/internal/route/reverse_proxy.go @@ -123,6 +123,24 @@ func (r *ReverseProxyRoute) ReverseProxy() *reverseproxy.ReverseProxy { return r.rp } +func (r *ReverseProxyRoute) isSyntheticLoadBalancerRoute() bool { + return r.loadBalancer != nil && r.rp == nil +} + +func (r *ReverseProxyRoute) Key() string { + if r.isSyntheticLoadBalancerRoute() { + return r.Alias + } + return r.Route.Key() +} + +func (r *ReverseProxyRoute) ShouldExclude() bool { + if r.isSyntheticLoadBalancerRoute() { + return false + } + return r.Route.ShouldExclude() +} + // Start implements task.TaskStarter. func (r *ReverseProxyRoute) Start(parent task.Parent) error { r.task = parent.Subtask("http."+r.Name(), false) @@ -206,7 +224,7 @@ func (r *ReverseProxyRoute) addToLoadBalancer(parent task.Parent, ep entrypoint. linked = l.(*ReverseProxyRoute) // it must be a reverse proxy route lb = linked.loadBalancer lb.UpdateConfigIfNeeded(cfg) - if linked.Homepage.Name == "" { + if linked.Homepage == nil || linked.Homepage.Name == "" { linked.Homepage = r.Homepage } } else { diff --git a/internal/route/reverse_proxy_test.go b/internal/route/reverse_proxy_test.go index bd579c83..10841242 100644 --- a/internal/route/reverse_proxy_test.go +++ b/internal/route/reverse_proxy_test.go @@ -1,16 +1,165 @@ package route import ( + "fmt" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + entrypoint "github.com/yusing/godoxy/internal/entrypoint/types" + "github.com/yusing/godoxy/internal/homepage" route "github.com/yusing/godoxy/internal/route/types" "github.com/yusing/godoxy/internal/types" + "github.com/yusing/goutils/task" ) +type testPool[T interface{ Key() string }] struct { + mu sync.RWMutex + items map[string]T +} + +func newTestPool[T interface{ Key() string }]() *testPool[T] { + return &testPool[T]{items: make(map[string]T)} +} + +func (p *testPool[T]) Get(alias string) (T, bool) { + p.mu.RLock() + defer p.mu.RUnlock() + v, ok := p.items[alias] + return v, ok +} + +func (p *testPool[T]) Iter(yield func(alias string, r T) bool) { + p.mu.RLock() + defer p.mu.RUnlock() + for alias, r := range p.items { + if !yield(alias, r) { + return + } + } +} + +func (p *testPool[T]) Size() int { + p.mu.RLock() + defer p.mu.RUnlock() + return len(p.items) +} + +func (p *testPool[T]) Add(r T) { + p.mu.Lock() + defer p.mu.Unlock() + p.items[r.Key()] = r +} + +func (p *testPool[T]) Del(r T) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.items, r.Key()) +} + +type testEntrypoint struct { + httpRoutes *testPool[types.HTTPRoute] + streamRoutes *testPool[types.StreamRoute] + excludedRoutes *testPool[types.Route] +} + +func newTestEntrypoint() *testEntrypoint { + return &testEntrypoint{ + httpRoutes: newTestPool[types.HTTPRoute](), + streamRoutes: newTestPool[types.StreamRoute](), + excludedRoutes: newTestPool[types.Route](), + } +} + +func (ep *testEntrypoint) SupportProxyProtocol() bool { return false } +func (ep *testEntrypoint) DisablePoolsLog(bool) {} + +func (ep *testEntrypoint) GetRoute(alias string) (types.Route, bool) { + if r, ok := ep.httpRoutes.Get(alias); ok { + return r, true + } + if r, ok := ep.streamRoutes.Get(alias); ok { + return r, true + } + if r, ok := ep.excludedRoutes.Get(alias); ok { + return r, true + } + return nil, false +} + +func (ep *testEntrypoint) StartAddRoute(r types.Route) error { + if r.ShouldExclude() { + ep.excludedRoutes.Add(r) + return nil + } + switch rt := r.(type) { + case types.HTTPRoute: + ep.httpRoutes.Add(rt) + return nil + case types.StreamRoute: + ep.streamRoutes.Add(rt) + return nil + default: + return fmt.Errorf("unknown route type: %T", r) + } +} + +func (ep *testEntrypoint) IterRoutes(yield func(r types.Route) bool) { + ep.httpRoutes.Iter(func(_ string, r types.HTTPRoute) bool { + return yield(r) + }) + ep.streamRoutes.Iter(func(_ string, r types.StreamRoute) bool { + return yield(r) + }) + ep.excludedRoutes.Iter(func(_ string, r types.Route) bool { + return yield(r) + }) +} + +func (ep *testEntrypoint) NumRoutes() int { + return ep.httpRoutes.Size() + ep.streamRoutes.Size() + ep.excludedRoutes.Size() +} + +func (ep *testEntrypoint) RoutesByProvider() map[string][]types.Route { + return map[string][]types.Route{} +} + +func (ep *testEntrypoint) HTTPRoutes() entrypoint.PoolLike[types.HTTPRoute] { + return ep.httpRoutes +} + +func (ep *testEntrypoint) StreamRoutes() entrypoint.PoolLike[types.StreamRoute] { + return ep.streamRoutes +} + +func (ep *testEntrypoint) ExcludedRoutes() entrypoint.RWPoolLike[types.Route] { + return ep.excludedRoutes +} + +func (ep *testEntrypoint) GetHealthInfo() map[string]types.HealthInfo { + return nil +} + +func (ep *testEntrypoint) GetHealthInfoWithoutDetail() map[string]types.HealthInfoWithoutDetail { + return nil +} + +func (ep *testEntrypoint) GetHealthInfoSimple() map[string]types.HealthStatus { + return nil +} + func TestReverseProxyRoute(t *testing.T) { t.Run("LinkToLoadBalancer", func(t *testing.T) { + testTask := task.GetTestTask(t) + entrypoint.SetCtx(testTask, newTestEntrypoint()) + cfg := Route{ Alias: "test", Scheme: route.SchemeHTTP, @@ -36,4 +185,75 @@ func TestReverseProxyRoute(t *testing.T) { require.NoError(t, err) assert.NotNil(t, r2) }) + t.Run("LoadBalancerRoute", func(t *testing.T) { + testTask := task.GetTestTask(t) + entrypoint.SetCtx(testTask, newTestEntrypoint()) + + newServer := func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + } + + srv1 := newServer() + t.Cleanup(srv1.Close) + srv2 := newServer() + t.Cleanup(srv2.Close) + srv3 := newServer() + t.Cleanup(srv3.Close) + + makeRoute := func(alias string, target *httptest.Server) *Route { + t.Helper() + + targetURL, err := url.Parse(target.URL) + require.NoError(t, err) + + host, portStr, err := net.SplitHostPort(targetURL.Host) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + require.NoError(t, err) + + return &Route{ + Alias: alias, + Scheme: route.SchemeHTTP, + Host: host, + Port: Port{Proxy: port}, + Homepage: &homepage.ItemConfig{ + Show: true, + }, + LoadBalance: &types.LoadBalancerConfig{ + Link: "lb-test", + }, + HealthCheck: types.HealthCheckConfig{ + Path: "/", + Interval: 2 * time.Second, + Timeout: time.Second, + UseGet: true, + }, + } + } + + _, err := NewStartedTestRoute(t, makeRoute("lb-1", srv1)) + require.NoError(t, err) + _, err = NewStartedTestRoute(t, makeRoute("lb-2", srv2)) + require.NoError(t, err) + _, err = NewStartedTestRoute(t, makeRoute("lb-3", srv3)) + require.NoError(t, err) + + ep := entrypoint.FromCtx(testTask.Context()) + require.NotNil(t, ep) + + lbRoute, ok := ep.HTTPRoutes().Get("lb-test") + require.True(t, ok) + + lb, ok := lbRoute.(*ReverseProxyRoute) + require.True(t, ok) + require.False(t, lb.ShouldExclude()) + require.NotNil(t, lb.loadBalancer) + require.NotNil(t, lb.HealthMonitor()) + assert.Equal(t, route.SchemeNone, lb.Scheme) + assert.Empty(t, lb.Host) + assert.Zero(t, lb.Port.Proxy) + assert.Equal(t, "3/3 servers are healthy", lb.HealthMonitor().Detail()) + }) }