Files
godoxy-yusing/internal/net/gphttp/loadbalancer/loadbalancer.go
yusing 424398442b refactor: replace gperr.Builder with gperr.Group for concurrent error handling
- Updated various files to utilize gperr.Group for cleaner concurrency error handling.
- Removed sync.WaitGroup usage, simplifying the code structure.
- Ensured consistent error reporting across different components.
2026-01-06 16:29:35 +08:00

368 lines
8.3 KiB
Go

package loadbalancer
import (
"fmt"
"net/http"
"sync"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
idlewatcher "github.com/yusing/godoxy/internal/idlewatcher/types"
"github.com/yusing/godoxy/internal/types"
gperr "github.com/yusing/goutils/errs"
"github.com/yusing/goutils/pool"
"github.com/yusing/goutils/task"
)
// TODO: stats of each server.
// TODO: support weighted mode.
type (
impl interface {
OnAddServer(srv types.LoadBalancerServer)
OnRemoveServer(srv types.LoadBalancerServer)
ChooseServer(srvs types.LoadBalancerServers, r *http.Request) types.LoadBalancerServer
}
customServeHTTP interface {
ServeHTTP(srvs types.LoadBalancerServers, rw http.ResponseWriter, r *http.Request)
}
LoadBalancer struct {
impl
*types.LoadBalancerConfig
task *task.Task
pool pool.Pool[types.LoadBalancerServer]
poolMu sync.Mutex
sumWeight int
startTime time.Time
l zerolog.Logger
}
)
const maxWeight int = 100
func New(cfg *types.LoadBalancerConfig) *LoadBalancer {
lb := &LoadBalancer{
LoadBalancerConfig: cfg,
pool: pool.New[types.LoadBalancerServer]("loadbalancer." + cfg.Link),
l: log.With().Str("name", cfg.Link).Logger(),
}
lb.UpdateConfigIfNeeded(cfg)
return lb
}
// Start implements task.TaskStarter.
func (lb *LoadBalancer) Start(parent task.Parent) gperr.Error {
lb.startTime = time.Now()
lb.task = parent.Subtask("loadbalancer."+lb.Link, true)
lb.task.OnCancel("cleanup", func() {
if lb.impl != nil {
for _, srv := range lb.pool.Iter {
lb.impl.OnRemoveServer(srv)
}
}
lb.task.Finish(nil)
})
return nil
}
// Task implements task.TaskStarter.
func (lb *LoadBalancer) Task() *task.Task {
return lb.task
}
// Finish implements task.TaskFinisher.
func (lb *LoadBalancer) Finish(reason any) {
lb.task.Finish(reason)
}
func (lb *LoadBalancer) updateImpl() {
switch lb.Mode {
case types.LoadbalanceModeUnset, types.LoadbalanceModeRoundRobin:
lb.impl = lb.newRoundRobin()
case types.LoadbalanceModeLeastConn:
lb.impl = lb.newLeastConn()
case types.LoadbalanceModeIPHash:
lb.impl = lb.newIPHash()
default: // should happen in test only
lb.impl = lb.newRoundRobin()
}
for _, srv := range lb.pool.Iter {
lb.impl.OnAddServer(srv)
}
}
func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *types.LoadBalancerConfig) {
if cfg != nil {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
lb.Link = cfg.Link
if lb.Mode == types.LoadbalanceModeUnset && cfg.Mode != types.LoadbalanceModeUnset {
lb.Mode = cfg.Mode
if !lb.Mode.ValidateUpdate() {
lb.l.Error().Msgf("invalid mode %q, fallback to %q", cfg.Mode, lb.Mode)
}
lb.updateImpl()
}
if len(lb.Options) == 0 && len(cfg.Options) > 0 {
lb.Options = cfg.Options
}
}
if lb.impl == nil {
lb.updateImpl()
}
}
func (lb *LoadBalancer) AddServer(srv types.LoadBalancerServer) {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
if old, ok := lb.pool.Get(srv.Key()); ok { // FIXME: this should be a warning
lb.sumWeight -= old.Weight()
lb.impl.OnRemoveServer(old)
lb.pool.Del(old)
}
lb.pool.Add(srv)
lb.sumWeight += srv.Weight()
lb.rebalance()
lb.impl.OnAddServer(srv)
}
func (lb *LoadBalancer) RemoveServer(srv types.LoadBalancerServer) {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
if _, ok := lb.pool.Get(srv.Key()); !ok {
return
}
lb.pool.Del(srv)
lb.sumWeight -= srv.Weight()
lb.rebalance()
lb.impl.OnRemoveServer(srv)
lb.l.Debug().
Str("action", "remove").
Str("server", srv.Name()).
Msgf("%d servers left", lb.pool.Size())
if lb.pool.Size() == 0 {
lb.task.Finish("no server left")
return
}
}
func (lb *LoadBalancer) rebalance() {
if lb.sumWeight == maxWeight {
return
}
poolSize := lb.pool.Size()
if poolSize == 0 {
return
}
if lb.sumWeight == 0 { // distribute evenly
weightEach := maxWeight / poolSize
remainder := maxWeight % poolSize
for _, srv := range lb.pool.Iter {
w := weightEach
lb.sumWeight += weightEach
if remainder > 0 {
w++
remainder--
}
srv.SetWeight(w)
}
return
}
// scale evenly
scaleFactor := float64(maxWeight) / float64(lb.sumWeight)
lb.sumWeight = 0
for _, srv := range lb.pool.Iter {
srv.SetWeight(int(float64(srv.Weight()) * scaleFactor))
lb.sumWeight += srv.Weight()
}
delta := maxWeight - lb.sumWeight
if delta == 0 {
return
}
for _, srv := range lb.pool.Iter {
if delta == 0 {
break
}
if delta > 0 {
srv.SetWeight(srv.Weight() + 1)
lb.sumWeight++
delta--
} else {
srv.SetWeight(srv.Weight() - 1)
lb.sumWeight--
delta++
}
}
}
func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
srvs := lb.availServers()
if len(srvs) == 0 {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return
}
if r.URL.Path == idlewatcher.WakeEventsPath {
var errs gperr.Group
// wake all servers
for _, srv := range srvs {
errs.Go(func() error {
err := srv.TryWake()
if err != nil {
return fmt.Errorf("failed to wake server %q: %w", srv.Name(), err)
}
return nil
})
}
if err := errs.Wait().Error(); err != nil {
gperr.LogWarn("failed to wake some servers", err, &lb.l)
}
}
// Check for idlewatcher requests or sticky sessions
if lb.Sticky || isIdlewatcherRequest(r) {
if selectedServer := getStickyServer(r, srvs); selectedServer != nil {
selectedServer.ServeHTTP(rw, r)
return
}
// No sticky session, choose a server and set cookie
selectedServer := lb.impl.ChooseServer(srvs, r)
if selectedServer != nil {
setStickyCookie(rw, r, selectedServer, lb.StickyMaxAge)
selectedServer.ServeHTTP(rw, r)
return
}
}
if customServeHTTP, ok := lb.impl.(customServeHTTP); ok {
customServeHTTP.ServeHTTP(srvs, rw, r)
return
}
selectedServer := lb.ChooseServer(srvs, r)
if selectedServer == nil {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return
}
selectedServer.ServeHTTP(rw, r)
}
// MarshalJSON implements health.HealthMonitor.
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
extra := make(map[string]any)
for _, srv := range lb.pool.Iter {
extra[srv.Key()] = srv
}
status, numHealthy := lb.status()
return (&types.HealthJSONRepr{
Name: lb.Name(),
Status: status,
Detail: fmt.Sprintf("%d/%d servers are healthy", numHealthy, lb.pool.Size()),
Started: lb.startTime,
Uptime: lb.Uptime(),
Latency: lb.Latency(),
Extra: &types.HealthExtra{
Config: lb.LoadBalancerConfig,
Pool: extra,
},
}).MarshalJSON()
}
// Name implements health.HealthMonitor.
func (lb *LoadBalancer) Name() string {
return lb.Link
}
// Status implements health.HealthMonitor.
func (lb *LoadBalancer) Status() types.HealthStatus {
status, _ := lb.status()
return status
}
// Detail implements health.HealthMonitor.
func (lb *LoadBalancer) Detail() string {
_, numHealthy := lb.status()
return fmt.Sprintf("%d/%d servers are healthy", numHealthy, lb.pool.Size())
}
func (lb *LoadBalancer) status() (status types.HealthStatus, numHealthy int) {
if lb.pool.Size() == 0 {
return types.StatusUnknown, 0
}
// should be healthy if at least one server is healthy
numHealthy = 0
for _, srv := range lb.pool.Iter {
if srv.Status().Good() {
numHealthy++
}
}
if numHealthy == 0 {
return types.StatusUnhealthy, numHealthy
}
return types.StatusHealthy, numHealthy
}
// Uptime implements health.HealthMonitor.
func (lb *LoadBalancer) Uptime() time.Duration {
return time.Since(lb.startTime)
}
// Latency implements health.HealthMonitor.
func (lb *LoadBalancer) Latency() time.Duration {
var sum time.Duration
for _, srv := range lb.pool.Iter {
sum += srv.Latency()
}
return sum
}
// String implements health.HealthMonitor.
func (lb *LoadBalancer) String() string {
return lb.Name()
}
func (lb *LoadBalancer) availServers() []types.LoadBalancerServer {
avail := make([]types.LoadBalancerServer, 0, lb.pool.Size())
for _, srv := range lb.pool.Iter {
if srv.Status().Good() {
avail = append(avail, srv)
}
}
return avail
}
// isIdlewatcherRequest checks if this is an idlewatcher-related request
func isIdlewatcherRequest(r *http.Request) bool {
// Check for explicit idlewatcher paths
if r.URL.Path == idlewatcher.WakeEventsPath ||
r.URL.Path == idlewatcher.FavIconPath ||
r.URL.Path == idlewatcher.LoadingPageCSSPath ||
r.URL.Path == idlewatcher.LoadingPageJSPath {
return true
}
return false
}