mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-23 16:58:31 +02:00
refactor(loadbalancer): implement sticky sessions and improve algorithm separation
- Refactor load balancer interface to separate server selection (ChooseServer) from request handling - Add cookie-based sticky session support with configurable max-age and secure cookie handling - Integrate idlewatcher requests with automatic sticky session assignment - Improve algorithm implementations: * Replace fnv with xxhash3 for better performance in IP hash and server keys * Add proper bounds checking and error handling in all algorithms * Separate concerns between server selection and request processing - Add Sticky and StickyMaxAge fields to LoadBalancerConfig - Create dedicated sticky.go for session management utilities
This commit is contained in:
@@ -1,11 +1,11 @@
|
|||||||
package loadbalancer
|
package loadbalancer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"hash/fnv"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/bytedance/gopkg/util/xxhash3"
|
||||||
"github.com/yusing/godoxy/internal/net/gphttp/middleware"
|
"github.com/yusing/godoxy/internal/net/gphttp/middleware"
|
||||||
"github.com/yusing/godoxy/internal/types"
|
"github.com/yusing/godoxy/internal/types"
|
||||||
gperr "github.com/yusing/goutils/errs"
|
gperr "github.com/yusing/goutils/errs"
|
||||||
@@ -19,6 +19,9 @@ type ipHash struct {
|
|||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ impl = (*ipHash)(nil)
|
||||||
|
var _ customServeHTTP = (*ipHash)(nil)
|
||||||
|
|
||||||
func (lb *LoadBalancer) newIPHash() impl {
|
func (lb *LoadBalancer) newIPHash() impl {
|
||||||
impl := &ipHash{LoadBalancer: lb}
|
impl := &ipHash{LoadBalancer: lb}
|
||||||
if len(lb.Options) == 0 {
|
if len(lb.Options) == 0 {
|
||||||
@@ -62,31 +65,27 @@ func (impl *ipHash) OnRemoveServer(srv types.LoadBalancerServer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (impl *ipHash) ServeHTTP(_ types.LoadBalancerServers, rw http.ResponseWriter, r *http.Request) {
|
func (impl *ipHash) ServeHTTP(_ types.LoadBalancerServers, rw http.ResponseWriter, r *http.Request) {
|
||||||
if impl.realIP != nil {
|
srv := impl.ChooseServer(impl.pool, r)
|
||||||
impl.realIP.ModifyRequest(impl.serveHTTP, rw, r)
|
|
||||||
} else {
|
|
||||||
impl.serveHTTP(rw, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (impl *ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
|
|
||||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
|
||||||
impl.l.Err(err).Msg("invalid remote address " + r.RemoteAddr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
idx := hashIP(ip) % uint32(len(impl.pool))
|
|
||||||
|
|
||||||
srv := impl.pool[idx]
|
|
||||||
if srv == nil || srv.Status().Bad() {
|
if srv == nil || srv.Status().Bad() {
|
||||||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if impl.realIP != nil {
|
||||||
|
impl.realIP.ModifyRequest(srv.ServeHTTP, rw, r)
|
||||||
|
} else {
|
||||||
|
srv.ServeHTTP(rw, r)
|
||||||
}
|
}
|
||||||
srv.ServeHTTP(rw, r)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func hashIP(ip string) uint32 {
|
func (impl *ipHash) ChooseServer(_ types.LoadBalancerServers, r *http.Request) types.LoadBalancerServer {
|
||||||
h := fnv.New32a()
|
if len(impl.pool) == 0 {
|
||||||
h.Write([]byte(ip))
|
return nil
|
||||||
return h.Sum32()
|
}
|
||||||
|
|
||||||
|
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
ip = r.RemoteAddr
|
||||||
|
}
|
||||||
|
return impl.pool[xxhash3.HashString(ip)%uint64(len(impl.pool))]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ type leastConn struct {
|
|||||||
nConn *xsync.Map[types.LoadBalancerServer, *atomic.Int64]
|
nConn *xsync.Map[types.LoadBalancerServer, *atomic.Int64]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ impl = (*leastConn)(nil)
|
||||||
|
var _ customServeHTTP = (*leastConn)(nil)
|
||||||
|
|
||||||
func (lb *LoadBalancer) newLeastConn() impl {
|
func (lb *LoadBalancer) newLeastConn() impl {
|
||||||
return &leastConn{
|
return &leastConn{
|
||||||
LoadBalancer: lb,
|
LoadBalancer: lb,
|
||||||
@@ -29,18 +32,39 @@ func (impl *leastConn) OnRemoveServer(srv types.LoadBalancerServer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (impl *leastConn) ServeHTTP(srvs types.LoadBalancerServers, rw http.ResponseWriter, r *http.Request) {
|
func (impl *leastConn) ServeHTTP(srvs types.LoadBalancerServers, rw http.ResponseWriter, r *http.Request) {
|
||||||
srv := srvs[0]
|
srv := impl.ChooseServer(srvs, r)
|
||||||
|
if srv == nil {
|
||||||
|
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
minConn, ok := impl.nConn.Load(srv)
|
minConn, ok := impl.nConn.Load(srv)
|
||||||
if !ok {
|
if !ok {
|
||||||
impl.l.Error().Msgf("[BUG] server %s not found", srv.Name())
|
impl.l.Error().Msgf("[BUG] server %s not found", srv.Name())
|
||||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
minConn.Add(1)
|
||||||
|
srv.ServeHTTP(rw, r)
|
||||||
|
minConn.Add(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (impl *leastConn) ChooseServer(srvs types.LoadBalancerServers, r *http.Request) types.LoadBalancerServer {
|
||||||
|
if len(srvs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := srvs[0]
|
||||||
|
minConn, ok := impl.nConn.Load(srv)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 1; i < len(srvs); i++ {
|
for i := 1; i < len(srvs); i++ {
|
||||||
nConn, ok := impl.nConn.Load(srvs[i])
|
nConn, ok := impl.nConn.Load(srvs[i])
|
||||||
if !ok {
|
if !ok {
|
||||||
impl.l.Error().Msgf("[BUG] server %s not found", srv.Name())
|
continue
|
||||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
|
||||||
}
|
}
|
||||||
if nConn.Load() < minConn.Load() {
|
if nConn.Load() < minConn.Load() {
|
||||||
minConn = nConn
|
minConn = nConn
|
||||||
@@ -48,7 +72,5 @@ func (impl *leastConn) ServeHTTP(srvs types.LoadBalancerServers, rw http.Respons
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
minConn.Add(1)
|
return srv
|
||||||
srv.ServeHTTP(rw, r)
|
|
||||||
minConn.Add(-1)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,9 +20,12 @@ import (
|
|||||||
// TODO: support weighted mode.
|
// TODO: support weighted mode.
|
||||||
type (
|
type (
|
||||||
impl interface {
|
impl interface {
|
||||||
ServeHTTP(srvs types.LoadBalancerServers, rw http.ResponseWriter, r *http.Request)
|
|
||||||
OnAddServer(srv types.LoadBalancerServer)
|
OnAddServer(srv types.LoadBalancerServer)
|
||||||
OnRemoveServer(srv types.LoadBalancerServer)
|
OnRemoveServer(srv types.LoadBalancerServer)
|
||||||
|
ChooseServer(srvs types.LoadBalancerServers, r *http.Request) types.LoadBalancerServer
|
||||||
|
}
|
||||||
|
customServeHTTP interface {
|
||||||
|
ServeHTTP(srvs types.LoadBalancerServers, rw http.ResponseWriter, r *http.Request)
|
||||||
}
|
}
|
||||||
|
|
||||||
LoadBalancer struct {
|
LoadBalancer struct {
|
||||||
@@ -235,7 +238,33 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
|||||||
gperr.LogWarn("failed to wake some servers", err, &lb.l)
|
gperr.LogWarn("failed to wake some servers", err, &lb.l)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
lb.impl.ServeHTTP(srvs, rw, r)
|
|
||||||
|
// Check for idlewatcher requests or sticky sessions
|
||||||
|
if lb.Sticky || isIdlewatcherRequest(r) {
|
||||||
|
if selectedServer := getStickyServer(r, srvs); selectedServer != nil {
|
||||||
|
selectedServer.ServeHTTP(rw, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// No sticky session, choose a server and set cookie
|
||||||
|
selectedServer := lb.impl.ChooseServer(srvs, r)
|
||||||
|
if selectedServer != nil {
|
||||||
|
setStickyCookie(rw, r, selectedServer, lb.StickyMaxAge)
|
||||||
|
selectedServer.ServeHTTP(rw, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if customServeHTTP, ok := lb.impl.(customServeHTTP); ok {
|
||||||
|
customServeHTTP.ServeHTTP(srvs, rw, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
selectedServer := lb.ChooseServer(srvs, r)
|
||||||
|
if selectedServer == nil {
|
||||||
|
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
selectedServer.ServeHTTP(rw, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements health.HealthMonitor.
|
// MarshalJSON implements health.HealthMonitor.
|
||||||
@@ -324,3 +353,22 @@ func (lb *LoadBalancer) availServers() []types.LoadBalancerServer {
|
|||||||
}
|
}
|
||||||
return avail
|
return avail
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isIdlewatcherRequest checks if this is an idlewatcher-related request
|
||||||
|
func isIdlewatcherRequest(r *http.Request) bool {
|
||||||
|
// Check for explicit idlewatcher paths
|
||||||
|
if r.URL.Path == idlewatcher.WakeEventsPath ||
|
||||||
|
r.URL.Path == idlewatcher.FavIconPath ||
|
||||||
|
r.URL.Path == idlewatcher.LoadingPageCSSPath ||
|
||||||
|
r.URL.Path == idlewatcher.LoadingPageJSPath {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this is a page refresh after idlewatcher wake up
|
||||||
|
// by looking for the sticky session cookie
|
||||||
|
if _, err := r.Cookie("godoxy_lb_sticky"); err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,14 +11,19 @@ type roundRobin struct {
|
|||||||
index atomic.Uint32
|
index atomic.Uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ impl = (*roundRobin)(nil)
|
||||||
|
|
||||||
func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} }
|
func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} }
|
||||||
func (lb *roundRobin) OnAddServer(srv types.LoadBalancerServer) {}
|
func (lb *roundRobin) OnAddServer(srv types.LoadBalancerServer) {}
|
||||||
func (lb *roundRobin) OnRemoveServer(srv types.LoadBalancerServer) {}
|
func (lb *roundRobin) OnRemoveServer(srv types.LoadBalancerServer) {}
|
||||||
|
|
||||||
func (lb *roundRobin) ServeHTTP(srvs types.LoadBalancerServers, rw http.ResponseWriter, r *http.Request) {
|
func (lb *roundRobin) ChooseServer(srvs types.LoadBalancerServers, r *http.Request) types.LoadBalancerServer {
|
||||||
|
if len(srvs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
index := lb.index.Add(1) % uint32(len(srvs))
|
index := lb.index.Add(1) % uint32(len(srvs))
|
||||||
srvs[index].ServeHTTP(rw, r)
|
|
||||||
if lb.index.Load() >= 2*uint32(len(srvs)) {
|
if lb.index.Load() >= 2*uint32(len(srvs)) {
|
||||||
lb.index.Store(0)
|
lb.index.Store(0)
|
||||||
}
|
}
|
||||||
|
return srvs[index]
|
||||||
}
|
}
|
||||||
|
|||||||
50
internal/net/gphttp/loadbalancer/sticky.go
Normal file
50
internal/net/gphttp/loadbalancer/sticky.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package loadbalancer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/bytedance/gopkg/util/xxhash3"
|
||||||
|
"github.com/yusing/godoxy/internal/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func hashServerKey(key string) string {
|
||||||
|
h := xxhash3.HashString(key)
|
||||||
|
as8bytes := *(*[8]byte)(unsafe.Pointer(&h))
|
||||||
|
return hex.EncodeToString(as8bytes[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// getStickyServer extracts the sticky session cookie and returns the corresponding server
|
||||||
|
func getStickyServer(r *http.Request, srvs []types.LoadBalancerServer) types.LoadBalancerServer {
|
||||||
|
cookie, err := r.Cookie("godoxy_lb_sticky")
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
serverKeyHash := cookie.Value
|
||||||
|
for _, srv := range srvs {
|
||||||
|
if hashServerKey(srv.Key()) == serverKeyHash {
|
||||||
|
return srv
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setStickyCookie sets a cookie to maintain sticky session with a specific server
|
||||||
|
func setStickyCookie(rw http.ResponseWriter, r *http.Request, srv types.LoadBalancerServer, maxAge time.Duration) {
|
||||||
|
http.SetCookie(rw, &http.Cookie{
|
||||||
|
Name: "godoxy_lb_sticky",
|
||||||
|
Value: hashServerKey(srv.Key()),
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: int(maxAge.Seconds()),
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: isSecure(r),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSecure(r *http.Request) bool {
|
||||||
|
return r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package types
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||||
strutils "github.com/yusing/goutils/strings"
|
strutils "github.com/yusing/goutils/strings"
|
||||||
@@ -9,10 +10,12 @@ import (
|
|||||||
|
|
||||||
type (
|
type (
|
||||||
LoadBalancerConfig struct {
|
LoadBalancerConfig struct {
|
||||||
Link string `json:"link"`
|
Link string `json:"link"`
|
||||||
Mode LoadBalancerMode `json:"mode"`
|
Mode LoadBalancerMode `json:"mode"`
|
||||||
Weight int `json:"weight"`
|
Weight int `json:"weight"`
|
||||||
Options map[string]any `json:"options,omitempty"`
|
Sticky bool `json:"sticky"`
|
||||||
|
StickyMaxAge time.Duration `json:"sticky_max_age"`
|
||||||
|
Options map[string]any `json:"options,omitempty"`
|
||||||
} // @name LoadBalancerConfig
|
} // @name LoadBalancerConfig
|
||||||
LoadBalancerMode string // @name LoadBalancerMode
|
LoadBalancerMode string // @name LoadBalancerMode
|
||||||
LoadBalancerServer interface {
|
LoadBalancerServer interface {
|
||||||
@@ -35,6 +38,8 @@ const (
|
|||||||
LoadbalanceModeIPHash LoadBalancerMode = "iphash"
|
LoadbalanceModeIPHash LoadBalancerMode = "iphash"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const StickyMaxAgeDefault = 1 * time.Hour
|
||||||
|
|
||||||
func (mode *LoadBalancerMode) ValidateUpdate() bool {
|
func (mode *LoadBalancerMode) ValidateUpdate() bool {
|
||||||
switch strutils.ToLowerNoSnake(string(*mode)) {
|
switch strutils.ToLowerNoSnake(string(*mode)) {
|
||||||
case "":
|
case "":
|
||||||
|
|||||||
Reference in New Issue
Block a user