refactor and organize code

This commit is contained in:
yusing
2025-02-15 05:44:47 +08:00
parent 1af6dd9cf8
commit 18d258aaa2
169 changed files with 1020 additions and 755 deletions

View File

@@ -0,0 +1,91 @@
package loadbalancer
import (
"hash/fnv"
"net"
"net/http"
"sync"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/gphttp/middleware"
)
type ipHash struct {
*LoadBalancer
realIP *middleware.Middleware
pool Servers
mu sync.Mutex
}
func (lb *LoadBalancer) newIPHash() impl {
impl := &ipHash{LoadBalancer: lb}
if len(lb.Options) == 0 {
return impl
}
var err gperr.Error
impl.realIP, err = middleware.RealIP.New(lb.Options)
if err != nil {
gperr.LogError("invalid real_ip options, ignoring", err, &impl.l)
}
return impl
}
func (impl *ipHash) OnAddServer(srv Server) {
impl.mu.Lock()
defer impl.mu.Unlock()
for i, s := range impl.pool {
if s == srv {
return
}
if s == nil {
impl.pool[i] = srv
return
}
}
impl.pool = append(impl.pool, srv)
}
func (impl *ipHash) OnRemoveServer(srv Server) {
impl.mu.Lock()
defer impl.mu.Unlock()
for i, s := range impl.pool {
if s == srv {
impl.pool[i] = nil
return
}
}
}
func (impl *ipHash) ServeHTTP(_ Servers, rw http.ResponseWriter, r *http.Request) {
if impl.realIP != nil {
impl.realIP.ModifyRequest(impl.serveHTTP, rw, r)
} else {
impl.serveHTTP(rw, r)
}
}
func (impl *ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(rw, "Internal error", http.StatusInternalServerError)
impl.l.Err(err).Msg("invalid remote address " + r.RemoteAddr)
return
}
idx := hashIP(ip) % uint32(len(impl.pool))
srv := impl.pool[idx]
if srv == nil || srv.Status().Bad() {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
}
srv.ServeHTTP(rw, r)
}
func hashIP(ip string) uint32 {
h := fnv.New32a()
h.Write([]byte(ip))
return h.Sum32()
}

View File

@@ -0,0 +1,53 @@
package loadbalancer
import (
"net/http"
"sync/atomic"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type leastConn struct {
*LoadBalancer
nConn F.Map[Server, *atomic.Int64]
}
func (lb *LoadBalancer) newLeastConn() impl {
return &leastConn{
LoadBalancer: lb,
nConn: F.NewMapOf[Server, *atomic.Int64](),
}
}
func (impl *leastConn) OnAddServer(srv Server) {
impl.nConn.Store(srv, new(atomic.Int64))
}
func (impl *leastConn) OnRemoveServer(srv Server) {
impl.nConn.Delete(srv)
}
func (impl *leastConn) ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request) {
srv := srvs[0]
minConn, ok := impl.nConn.Load(srv)
if !ok {
impl.l.Error().Msgf("[BUG] server %s not found", srv.Name())
http.Error(rw, "Internal error", http.StatusInternalServerError)
}
for i := 1; i < len(srvs); i++ {
nConn, ok := impl.nConn.Load(srvs[i])
if !ok {
impl.l.Error().Msgf("[BUG] server %s not found", srv.Name())
http.Error(rw, "Internal error", http.StatusInternalServerError)
}
if nConn.Load() < minConn.Load() {
minConn = nConn
srv = srvs[i]
}
}
minConn.Add(1)
srv.ServeHTTP(rw, r)
minConn.Add(-1)
}

View File

@@ -0,0 +1,314 @@
package loadbalancer
import (
"net/http"
"sync"
"time"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
"github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/health"
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
)
// TODO: stats of each server.
// TODO: support weighted mode.
type (
impl interface {
ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request)
OnAddServer(srv Server)
OnRemoveServer(srv Server)
}
LoadBalancer struct {
impl
*Config
task *task.Task
pool Pool
poolMu sync.Mutex
sumWeight Weight
startTime time.Time
l zerolog.Logger
}
)
const maxWeight Weight = 100
func New(cfg *Config) *LoadBalancer {
lb := &LoadBalancer{
Config: new(Config),
pool: types.NewServerPool(),
l: logging.With().Str("name", cfg.Link).Logger(),
}
lb.UpdateConfigIfNeeded(cfg)
return lb
}
// Start implements task.TaskStarter.
func (lb *LoadBalancer) Start(parent task.Parent) gperr.Error {
lb.startTime = time.Now()
lb.task = parent.Subtask("loadbalancer."+lb.Link, false)
parent.OnCancel("lb_remove_route", func() {
routes.DeleteHTTPRoute(lb.Link)
})
lb.task.OnFinished("cleanup", func() {
if lb.impl != nil {
lb.pool.RangeAll(func(k string, v Server) {
lb.impl.OnRemoveServer(v)
})
}
})
return nil
}
// Task implements task.TaskStarter.
func (lb *LoadBalancer) Task() *task.Task {
return lb.task
}
// Finish implements task.TaskFinisher.
func (lb *LoadBalancer) Finish(reason any) {
lb.task.Finish(reason)
}
func (lb *LoadBalancer) updateImpl() {
switch lb.Mode {
case types.ModeUnset, types.ModeRoundRobin:
lb.impl = lb.newRoundRobin()
case types.ModeLeastConn:
lb.impl = lb.newLeastConn()
case types.ModeIPHash:
lb.impl = lb.newIPHash()
default: // should happen in test only
lb.impl = lb.newRoundRobin()
}
lb.pool.RangeAll(func(_ string, srv Server) {
lb.impl.OnAddServer(srv)
})
}
func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
if cfg != nil {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
lb.Link = cfg.Link
if lb.Mode == types.ModeUnset && cfg.Mode != types.ModeUnset {
lb.Mode = cfg.Mode
if !lb.Mode.ValidateUpdate() {
lb.l.Error().Msgf("invalid mode %q, fallback to %q", cfg.Mode, lb.Mode)
}
lb.updateImpl()
}
if len(lb.Options) == 0 && len(cfg.Options) > 0 {
lb.Options = cfg.Options
}
}
if lb.impl == nil {
lb.updateImpl()
}
}
func (lb *LoadBalancer) AddServer(srv Server) {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
if lb.pool.Has(srv.Name()) {
old, _ := lb.pool.Load(srv.Name())
lb.sumWeight -= old.Weight()
lb.impl.OnRemoveServer(old)
}
lb.pool.Store(srv.Name(), srv)
lb.sumWeight += srv.Weight()
lb.rebalance()
lb.impl.OnAddServer(srv)
lb.l.Debug().
Str("action", "add").
Str("server", srv.Name()).
Msgf("%d servers available", lb.pool.Size())
}
func (lb *LoadBalancer) RemoveServer(srv Server) {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
if !lb.pool.Has(srv.Name()) {
return
}
lb.pool.Delete(srv.Name())
lb.sumWeight -= srv.Weight()
lb.rebalance()
lb.impl.OnRemoveServer(srv)
lb.l.Debug().
Str("action", "remove").
Str("server", srv.Name()).
Msgf("%d servers left", lb.pool.Size())
if lb.pool.Size() == 0 {
lb.task.Finish("no server left")
return
}
}
func (lb *LoadBalancer) rebalance() {
if lb.sumWeight == maxWeight {
return
}
poolSize := lb.pool.Size()
if poolSize == 0 {
return
}
if lb.sumWeight == 0 { // distribute evenly
weightEach := maxWeight / Weight(poolSize)
remainder := maxWeight % Weight(poolSize)
lb.pool.RangeAll(func(_ string, s Server) {
w := weightEach
lb.sumWeight += weightEach
if remainder > 0 {
w++
remainder--
}
s.SetWeight(w)
})
return
}
// scale evenly
scaleFactor := float64(maxWeight) / float64(lb.sumWeight)
lb.sumWeight = 0
lb.pool.RangeAll(func(_ string, s Server) {
s.SetWeight(Weight(float64(s.Weight()) * scaleFactor))
lb.sumWeight += s.Weight()
})
delta := maxWeight - lb.sumWeight
if delta == 0 {
return
}
lb.pool.Range(func(_ string, s Server) bool {
if delta == 0 {
return false
}
if delta > 0 {
s.SetWeight(s.Weight() + 1)
lb.sumWeight++
delta--
} else {
s.SetWeight(s.Weight() - 1)
lb.sumWeight--
delta++
}
return true
})
}
func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
srvs := lb.availServers()
if len(srvs) == 0 {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return
}
if r.Header.Get(httpheaders.HeaderGoDoxyCheckRedirect) != "" {
// wake all servers
for _, srv := range srvs {
if err := srv.TryWake(); err != nil {
lb.l.Warn().Err(err).
Str("server", srv.Name()).
Msg("failed to wake server")
}
}
}
lb.impl.ServeHTTP(srvs, rw, r)
}
// MarshalJSON implements health.HealthMonitor.
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
extra := make(map[string]any)
lb.pool.RangeAll(func(k string, v Server) {
extra[v.Name()] = v
})
return (&monitor.JSONRepresentation{
Name: lb.Name(),
Status: lb.Status(),
Started: lb.startTime,
Uptime: lb.Uptime(),
Extra: map[string]any{
"config": lb.Config,
"pool": extra,
},
}).MarshalJSON()
}
// Name implements health.HealthMonitor.
func (lb *LoadBalancer) Name() string {
return lb.Link
}
// Status implements health.HealthMonitor.
func (lb *LoadBalancer) Status() health.Status {
if lb.pool.Size() == 0 {
return health.StatusUnknown
}
isHealthy := true
lb.pool.Range(func(_ string, srv Server) bool {
if srv.Status().Bad() {
isHealthy = false
return false
}
return true
})
if !isHealthy {
return health.StatusUnhealthy
}
return health.StatusHealthy
}
// Uptime implements health.HealthMonitor.
func (lb *LoadBalancer) Uptime() time.Duration {
return time.Since(lb.startTime)
}
// Latency implements health.HealthMonitor.
func (lb *LoadBalancer) Latency() time.Duration {
var sum time.Duration
lb.pool.RangeAll(func(_ string, srv Server) {
sum += srv.Latency()
})
return sum
}
// String implements health.HealthMonitor.
func (lb *LoadBalancer) String() string {
return lb.Name()
}
func (lb *LoadBalancer) availServers() []Server {
avail := make([]Server, 0, lb.pool.Size())
lb.pool.RangeAll(func(_ string, srv Server) {
if srv.Status().Good() {
avail = append(avail, srv)
}
})
return avail
}

View File

@@ -0,0 +1,44 @@
package loadbalancer
import (
"testing"
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRebalance(t *testing.T) {
t.Parallel()
t.Run("zero", func(t *testing.T) {
lb := New(new(types.Config))
for range 10 {
lb.AddServer(types.TestNewServer(0))
}
lb.rebalance()
ExpectEqual(t, lb.sumWeight, maxWeight)
})
t.Run("less", func(t *testing.T) {
lb := New(new(types.Config))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .3))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
lb.rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight)
})
t.Run("more", func(t *testing.T) {
lb := New(new(types.Config))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .3))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .4))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .3))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
lb.rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight)
})
}

View File

@@ -0,0 +1,22 @@
package loadbalancer
import (
"net/http"
"sync/atomic"
)
type roundRobin struct {
index atomic.Uint32
}
func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} }
func (lb *roundRobin) OnAddServer(srv Server) {}
func (lb *roundRobin) OnRemoveServer(srv Server) {}
func (lb *roundRobin) ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request) {
index := lb.index.Add(1) % uint32(len(srvs))
srvs[index].ServeHTTP(rw, r)
if lb.index.Load() >= 2*uint32(len(srvs)) {
lb.index.Store(0)
}
}

View File

@@ -0,0 +1,14 @@
package loadbalancer
import (
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
)
type (
Server = types.Server
Servers = []types.Server
Pool = types.Pool
Weight = types.Weight
Config = types.Config
Mode = types.Mode
)

View File

@@ -0,0 +1,8 @@
package types
type Config struct {
Link string `json:"link"`
Mode Mode `json:"mode"`
Weight Weight `json:"weight"`
Options map[string]any `json:"options,omitempty"`
}

View File

@@ -0,0 +1,32 @@
package types
import (
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Mode string
const (
ModeUnset Mode = ""
ModeRoundRobin Mode = "roundrobin"
ModeLeastConn Mode = "leastconn"
ModeIPHash Mode = "iphash"
)
func (mode *Mode) ValidateUpdate() bool {
switch strutils.ToLowerNoSnake(string(*mode)) {
case "":
return true
case string(ModeRoundRobin):
*mode = ModeRoundRobin
return true
case string(ModeLeastConn):
*mode = ModeLeastConn
return true
case string(ModeIPHash):
*mode = ModeIPHash
return true
}
*mode = ModeRoundRobin
return false
}

View File

@@ -0,0 +1,86 @@
package types
import (
"net/http"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
net "github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type (
server struct {
_ U.NoCopy
name string
url *net.URL
weight Weight
http.Handler `json:"-"`
health.HealthMonitor
}
Server interface {
http.Handler
health.HealthMonitor
Name() string
URL() *net.URL
Weight() Weight
SetWeight(weight Weight)
TryWake() error
}
Pool = F.Map[string, Server]
)
var NewServerPool = F.NewMap[Pool]
func NewServer(name string, url *net.URL, weight Weight, handler http.Handler, healthMon health.HealthMonitor) Server {
srv := &server{
name: name,
url: url,
weight: weight,
Handler: handler,
HealthMonitor: healthMon,
}
return srv
}
func TestNewServer[T ~int | ~float32 | ~float64](weight T) Server {
srv := &server{
weight: Weight(weight),
}
return srv
}
func (srv *server) Name() string {
return srv.name
}
func (srv *server) URL() *net.URL {
return srv.url
}
func (srv *server) Weight() Weight {
return srv.weight
}
func (srv *server) SetWeight(weight Weight) {
srv.weight = weight
}
func (srv *server) String() string {
return srv.name
}
func (srv *server) TryWake() error {
waker, ok := srv.Handler.(idlewatcher.Waker)
if ok {
if err := waker.Wake(); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,3 @@
package types
type Weight uint16