diff --git a/internal/net/gphttp/loadbalancer/ip_hash.go b/internal/net/gphttp/loadbalancer/ip_hash.go index 1abce638..fdea282d 100644 --- a/internal/net/gphttp/loadbalancer/ip_hash.go +++ b/internal/net/gphttp/loadbalancer/ip_hash.go @@ -1,11 +1,11 @@ package loadbalancer import ( - "hash/fnv" "net" "net/http" "sync" + "github.com/bytedance/gopkg/util/xxhash3" "github.com/yusing/godoxy/internal/net/gphttp/middleware" "github.com/yusing/godoxy/internal/types" gperr "github.com/yusing/goutils/errs" @@ -19,6 +19,9 @@ type ipHash struct { mu sync.Mutex } +var _ impl = (*ipHash)(nil) +var _ customServeHTTP = (*ipHash)(nil) + func (lb *LoadBalancer) newIPHash() impl { impl := &ipHash{LoadBalancer: lb} if len(lb.Options) == 0 { @@ -62,31 +65,27 @@ func (impl *ipHash) OnRemoveServer(srv types.LoadBalancerServer) { } func (impl *ipHash) ServeHTTP(_ types.LoadBalancerServers, rw http.ResponseWriter, r *http.Request) { - if impl.realIP != nil { - impl.realIP.ModifyRequest(impl.serveHTTP, rw, r) - } else { - impl.serveHTTP(rw, r) - } -} - -func (impl *ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) { - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - http.Error(rw, "Internal error", http.StatusInternalServerError) - impl.l.Err(err).Msg("invalid remote address " + r.RemoteAddr) - return - } - idx := hashIP(ip) % uint32(len(impl.pool)) - - srv := impl.pool[idx] + srv := impl.ChooseServer(impl.pool, r) if srv == nil || srv.Status().Bad() { http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) + return + } + + if impl.realIP != nil { + impl.realIP.ModifyRequest(srv.ServeHTTP, rw, r) + } else { + srv.ServeHTTP(rw, r) } - srv.ServeHTTP(rw, r) } -func hashIP(ip string) uint32 { - h := fnv.New32a() - h.Write([]byte(ip)) - return h.Sum32() +func (impl *ipHash) ChooseServer(_ types.LoadBalancerServers, r *http.Request) types.LoadBalancerServer { + if len(impl.pool) == 0 { + return nil + } + + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + ip = r.RemoteAddr + } + return impl.pool[xxhash3.HashString(ip)%uint64(len(impl.pool))] } diff --git a/internal/net/gphttp/loadbalancer/least_conn.go b/internal/net/gphttp/loadbalancer/least_conn.go index a810c0d2..0783e177 100644 --- a/internal/net/gphttp/loadbalancer/least_conn.go +++ b/internal/net/gphttp/loadbalancer/least_conn.go @@ -13,6 +13,9 @@ type leastConn struct { nConn *xsync.Map[types.LoadBalancerServer, *atomic.Int64] } +var _ impl = (*leastConn)(nil) +var _ customServeHTTP = (*leastConn)(nil) + func (lb *LoadBalancer) newLeastConn() impl { return &leastConn{ LoadBalancer: lb, @@ -29,18 +32,39 @@ func (impl *leastConn) OnRemoveServer(srv types.LoadBalancerServer) { } func (impl *leastConn) ServeHTTP(srvs types.LoadBalancerServers, rw http.ResponseWriter, r *http.Request) { - srv := srvs[0] + srv := impl.ChooseServer(srvs, r) + if srv == nil { + http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) + return + } + minConn, ok := impl.nConn.Load(srv) if !ok { impl.l.Error().Msgf("[BUG] server %s not found", srv.Name()) http.Error(rw, "Internal error", http.StatusInternalServerError) + return + } + + minConn.Add(1) + srv.ServeHTTP(rw, r) + minConn.Add(-1) +} + +func (impl *leastConn) ChooseServer(srvs types.LoadBalancerServers, r *http.Request) types.LoadBalancerServer { + if len(srvs) == 0 { + return nil + } + + srv := srvs[0] + minConn, ok := impl.nConn.Load(srv) + if !ok { + return nil } for i := 1; i < len(srvs); i++ { nConn, ok := impl.nConn.Load(srvs[i]) if !ok { - impl.l.Error().Msgf("[BUG] server %s not found", srv.Name()) - http.Error(rw, "Internal error", http.StatusInternalServerError) + continue } if nConn.Load() < minConn.Load() { minConn = nConn @@ -48,7 +72,5 @@ func (impl *leastConn) ServeHTTP(srvs types.LoadBalancerServers, rw http.Respons } } - minConn.Add(1) - srv.ServeHTTP(rw, r) - minConn.Add(-1) + return srv } diff --git a/internal/net/gphttp/loadbalancer/loadbalancer.go b/internal/net/gphttp/loadbalancer/loadbalancer.go index fa79782f..b8251cc4 100644 --- a/internal/net/gphttp/loadbalancer/loadbalancer.go +++ b/internal/net/gphttp/loadbalancer/loadbalancer.go @@ -20,9 +20,12 @@ import ( // TODO: support weighted mode. type ( impl interface { - ServeHTTP(srvs types.LoadBalancerServers, rw http.ResponseWriter, r *http.Request) OnAddServer(srv types.LoadBalancerServer) OnRemoveServer(srv types.LoadBalancerServer) + ChooseServer(srvs types.LoadBalancerServers, r *http.Request) types.LoadBalancerServer + } + customServeHTTP interface { + ServeHTTP(srvs types.LoadBalancerServers, rw http.ResponseWriter, r *http.Request) } LoadBalancer struct { @@ -235,7 +238,33 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) { gperr.LogWarn("failed to wake some servers", err, &lb.l) } } - lb.impl.ServeHTTP(srvs, rw, r) + + // Check for idlewatcher requests or sticky sessions + if lb.Sticky || isIdlewatcherRequest(r) { + if selectedServer := getStickyServer(r, srvs); selectedServer != nil { + selectedServer.ServeHTTP(rw, r) + return + } + // No sticky session, choose a server and set cookie + selectedServer := lb.impl.ChooseServer(srvs, r) + if selectedServer != nil { + setStickyCookie(rw, r, selectedServer, lb.StickyMaxAge) + selectedServer.ServeHTTP(rw, r) + return + } + } + + if customServeHTTP, ok := lb.impl.(customServeHTTP); ok { + customServeHTTP.ServeHTTP(srvs, rw, r) + return + } + + selectedServer := lb.ChooseServer(srvs, r) + if selectedServer == nil { + http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) + return + } + selectedServer.ServeHTTP(rw, r) } // MarshalJSON implements health.HealthMonitor. @@ -324,3 +353,22 @@ func (lb *LoadBalancer) availServers() []types.LoadBalancerServer { } return avail } + +// isIdlewatcherRequest checks if this is an idlewatcher-related request +func isIdlewatcherRequest(r *http.Request) bool { + // Check for explicit idlewatcher paths + if r.URL.Path == idlewatcher.WakeEventsPath || + r.URL.Path == idlewatcher.FavIconPath || + r.URL.Path == idlewatcher.LoadingPageCSSPath || + r.URL.Path == idlewatcher.LoadingPageJSPath { + return true + } + + // Check if this is a page refresh after idlewatcher wake up + // by looking for the sticky session cookie + if _, err := r.Cookie("godoxy_lb_sticky"); err == nil { + return true + } + + return false +} diff --git a/internal/net/gphttp/loadbalancer/round_robin.go b/internal/net/gphttp/loadbalancer/round_robin.go index e0602630..31c61c9e 100644 --- a/internal/net/gphttp/loadbalancer/round_robin.go +++ b/internal/net/gphttp/loadbalancer/round_robin.go @@ -11,14 +11,19 @@ type roundRobin struct { index atomic.Uint32 } +var _ impl = (*roundRobin)(nil) + func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} } func (lb *roundRobin) OnAddServer(srv types.LoadBalancerServer) {} func (lb *roundRobin) OnRemoveServer(srv types.LoadBalancerServer) {} -func (lb *roundRobin) ServeHTTP(srvs types.LoadBalancerServers, rw http.ResponseWriter, r *http.Request) { +func (lb *roundRobin) ChooseServer(srvs types.LoadBalancerServers, r *http.Request) types.LoadBalancerServer { + if len(srvs) == 0 { + return nil + } index := lb.index.Add(1) % uint32(len(srvs)) - srvs[index].ServeHTTP(rw, r) if lb.index.Load() >= 2*uint32(len(srvs)) { lb.index.Store(0) } + return srvs[index] } diff --git a/internal/net/gphttp/loadbalancer/sticky.go b/internal/net/gphttp/loadbalancer/sticky.go new file mode 100644 index 00000000..ffe7f486 --- /dev/null +++ b/internal/net/gphttp/loadbalancer/sticky.go @@ -0,0 +1,50 @@ +package loadbalancer + +import ( + "encoding/hex" + "net/http" + "time" + "unsafe" + + "github.com/bytedance/gopkg/util/xxhash3" + "github.com/yusing/godoxy/internal/types" +) + +func hashServerKey(key string) string { + h := xxhash3.HashString(key) + as8bytes := *(*[8]byte)(unsafe.Pointer(&h)) + return hex.EncodeToString(as8bytes[:]) +} + +// getStickyServer extracts the sticky session cookie and returns the corresponding server +func getStickyServer(r *http.Request, srvs []types.LoadBalancerServer) types.LoadBalancerServer { + cookie, err := r.Cookie("godoxy_lb_sticky") + if err != nil { + return nil + } + + serverKeyHash := cookie.Value + for _, srv := range srvs { + if hashServerKey(srv.Key()) == serverKeyHash { + return srv + } + } + return nil +} + +// setStickyCookie sets a cookie to maintain sticky session with a specific server +func setStickyCookie(rw http.ResponseWriter, r *http.Request, srv types.LoadBalancerServer, maxAge time.Duration) { + http.SetCookie(rw, &http.Cookie{ + Name: "godoxy_lb_sticky", + Value: hashServerKey(srv.Key()), + Path: "/", + MaxAge: int(maxAge.Seconds()), + SameSite: http.SameSiteLaxMode, + HttpOnly: true, + Secure: isSecure(r), + }) +} + +func isSecure(r *http.Request) bool { + return r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" +} diff --git a/internal/types/loadbalancer.go b/internal/types/loadbalancer.go index 2dae15a5..c1449f32 100644 --- a/internal/types/loadbalancer.go +++ b/internal/types/loadbalancer.go @@ -2,6 +2,7 @@ package types import ( "net/http" + "time" nettypes "github.com/yusing/godoxy/internal/net/types" strutils "github.com/yusing/goutils/strings" @@ -9,10 +10,12 @@ import ( type ( LoadBalancerConfig struct { - Link string `json:"link"` - Mode LoadBalancerMode `json:"mode"` - Weight int `json:"weight"` - Options map[string]any `json:"options,omitempty"` + Link string `json:"link"` + Mode LoadBalancerMode `json:"mode"` + Weight int `json:"weight"` + Sticky bool `json:"sticky"` + StickyMaxAge time.Duration `json:"sticky_max_age"` + Options map[string]any `json:"options,omitempty"` } // @name LoadBalancerConfig LoadBalancerMode string // @name LoadBalancerMode LoadBalancerServer interface { @@ -35,6 +38,8 @@ const ( LoadbalanceModeIPHash LoadBalancerMode = "iphash" ) +const StickyMaxAgeDefault = 1 * time.Hour + func (mode *LoadBalancerMode) ValidateUpdate() bool { switch strutils.ToLowerNoSnake(string(*mode)) { case "":