modules reorganized and code refactor

This commit is contained in:
yusing
2024-11-25 01:40:12 +08:00
parent f3b21e6bd9
commit d723403b6b
46 changed files with 437 additions and 331 deletions

View File

@@ -14,7 +14,7 @@ type ipHash struct {
*LoadBalancer
realIP *middleware.Middleware
pool servers
pool Servers
mu sync.Mutex
}
@@ -26,7 +26,7 @@ func (lb *LoadBalancer) newIPHash() impl {
var err E.Error
impl.realIP, err = middleware.NewRealIP(lb.Options)
if err != nil {
E.LogError("invalid real_ip options, ignoring", err, &impl.Logger)
E.LogError("invalid real_ip options, ignoring", err, &impl.l)
}
return impl
}
@@ -60,7 +60,7 @@ func (impl *ipHash) OnRemoveServer(srv *Server) {
}
}
func (impl *ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) {
func (impl *ipHash) ServeHTTP(_ Servers, rw http.ResponseWriter, r *http.Request) {
if impl.realIP != nil {
impl.realIP.ModifyRequest(impl.serveHTTP, rw, r)
} else {
@@ -72,7 +72,7 @@ func (impl *ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(rw, "Internal error", http.StatusInternalServerError)
impl.Err(err).Msg("invalid remote address " + r.RemoteAddr)
impl.l.Err(err).Msg("invalid remote address " + r.RemoteAddr)
return
}
idx := hashIP(ip) % uint32(len(impl.pool))

View File

@@ -27,18 +27,18 @@ func (impl *leastConn) OnRemoveServer(srv *Server) {
impl.nConn.Delete(srv)
}
func (impl *leastConn) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
func (impl *leastConn) ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request) {
srv := srvs[0]
minConn, ok := impl.nConn.Load(srv)
if !ok {
impl.Error().Msgf("[BUG] server %s not found", srv.Name)
impl.l.Error().Msgf("[BUG] server %s not found", srv.Name)
http.Error(rw, "Internal error", http.StatusInternalServerError)
}
for i := 1; i < len(srvs); i++ {
nConn, ok := impl.nConn.Load(srvs[i])
if !ok {
impl.Error().Msgf("[BUG] server %s not found", srv.Name)
impl.l.Error().Msgf("[BUG] server %s not found", srv.Name)
http.Error(rw, "Internal error", http.StatusInternalServerError)
}
if nConn.Load() < minConn.Load() {

View File

@@ -7,9 +7,8 @@ import (
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/common"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/health"
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
@@ -19,19 +18,12 @@ import (
// TODO: support weighted mode.
type (
impl interface {
ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request)
ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request)
OnAddServer(srv *Server)
OnRemoveServer(srv *Server)
}
Config struct {
Link string `json:"link" yaml:"link"`
Mode Mode `json:"mode" yaml:"mode"`
Weight weightType `json:"weight" yaml:"weight"`
Options middleware.OptionsRaw `json:"options,omitempty" yaml:"options,omitempty"`
}
LoadBalancer struct {
zerolog.Logger
LoadBalancer struct {
impl
*Config
@@ -40,20 +32,20 @@ type (
pool Pool
poolMu sync.Mutex
sumWeight weightType
sumWeight Weight
startTime time.Time
}
weightType uint16
l zerolog.Logger
}
)
const maxWeight weightType = 100
const maxWeight Weight = 100
func New(cfg *Config) *LoadBalancer {
lb := &LoadBalancer{
Logger: logger.With().Str("name", cfg.Link).Logger(),
Config: new(Config),
pool: newPool(),
pool: types.NewServerPool(),
l: logger.With().Str("name", cfg.Link).Logger(),
}
lb.UpdateConfigIfNeeded(cfg)
return lb
@@ -81,11 +73,11 @@ func (lb *LoadBalancer) Finish(reason any) {
func (lb *LoadBalancer) updateImpl() {
switch lb.Mode {
case Unset, RoundRobin:
case types.ModeUnset, types.ModeRoundRobin:
lb.impl = lb.newRoundRobin()
case LeastConn:
case types.ModeLeastConn:
lb.impl = lb.newLeastConn()
case IPHash:
case types.ModeIPHash:
lb.impl = lb.newIPHash()
default: // should happen in test only
lb.impl = lb.newRoundRobin()
@@ -102,10 +94,10 @@ func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
lb.Link = cfg.Link
if lb.Mode == Unset && cfg.Mode != Unset {
if lb.Mode == types.ModeUnset && cfg.Mode != types.ModeUnset {
lb.Mode = cfg.Mode
if !lb.Mode.ValidateUpdate() {
lb.Error().Msgf("invalid mode %q, fallback to %q", cfg.Mode, lb.Mode)
lb.l.Error().Msgf("invalid mode %q, fallback to %q", cfg.Mode, lb.Mode)
}
lb.updateImpl()
}
@@ -135,7 +127,7 @@ func (lb *LoadBalancer) AddServer(srv *Server) {
lb.rebalance()
lb.impl.OnAddServer(srv)
lb.Debug().
lb.l.Debug().
Str("action", "add").
Str("server", srv.Name).
Msgf("%d servers available", lb.pool.Size())
@@ -155,7 +147,7 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) {
lb.rebalance()
lb.impl.OnRemoveServer(srv)
lb.Debug().
lb.l.Debug().
Str("action", "remove").
Str("server", srv.Name).
Msgf("%d servers left", lb.pool.Size())
@@ -174,8 +166,8 @@ func (lb *LoadBalancer) rebalance() {
return
}
if lb.sumWeight == 0 { // distribute evenly
weightEach := maxWeight / weightType(lb.pool.Size())
remainder := maxWeight % weightType(lb.pool.Size())
weightEach := maxWeight / Weight(lb.pool.Size())
remainder := maxWeight % Weight(lb.pool.Size())
lb.pool.RangeAll(func(_ string, s *Server) {
s.Weight = weightEach
lb.sumWeight += weightEach
@@ -192,7 +184,7 @@ func (lb *LoadBalancer) rebalance() {
lb.sumWeight = 0
lb.pool.RangeAll(func(_ string, s *Server) {
s.Weight = weightType(float64(s.Weight) * scaleFactor)
s.Weight = Weight(float64(s.Weight) * scaleFactor)
lb.sumWeight += s.Weight
})
@@ -226,13 +218,7 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
if r.Header.Get(common.HeaderCheckRedirect) != "" {
// wake all servers
for _, srv := range srvs {
// wake only if server implements Waker
waker, ok := srv.handler.(idlewatcher.Waker)
if ok {
if err := waker.Wake(); err != nil {
lb.Err(err).Msgf("failed to wake server %s", srv.Name)
}
}
srv.TryWake()
}
}
lb.impl.ServeHTTP(srvs, rw, r)
@@ -246,7 +232,7 @@ func (lb *LoadBalancer) Uptime() time.Duration {
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
extra := make(map[string]any)
lb.pool.RangeAll(func(k string, v *Server) {
extra[v.Name] = v.healthMon
extra[v.Name] = v.HealthMonitor()
})
return (&monitor.JSONRepresentation{

View File

@@ -3,13 +3,14 @@ package loadbalancer
import (
"testing"
loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRebalance(t *testing.T) {
t.Parallel()
t.Run("zero", func(t *testing.T) {
lb := New(new(Config))
lb := New(new(loadbalance.Config))
for range 10 {
lb.AddServer(&Server{})
}
@@ -17,25 +18,25 @@ func TestRebalance(t *testing.T) {
ExpectEqual(t, lb.sumWeight, maxWeight)
})
t.Run("less", func(t *testing.T) {
lb := New(new(Config))
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb := New(new(loadbalance.Config))
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .1)})
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .1)})
lb.rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight)
})
t.Run("more", func(t *testing.T) {
lb := New(new(Config))
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .4)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb := New(new(loadbalance.Config))
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .1)})
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .4)})
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .1)})
lb.rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight)

View File

@@ -1,32 +0,0 @@
package loadbalancer
import (
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Mode string
const (
Unset Mode = ""
RoundRobin Mode = "roundrobin"
LeastConn Mode = "leastconn"
IPHash Mode = "iphash"
)
func (mode *Mode) ValidateUpdate() bool {
switch strutils.ToLowerNoSnake(string(*mode)) {
case "":
return true
case string(RoundRobin):
*mode = RoundRobin
return true
case string(LeastConn):
*mode = LeastConn
return true
case string(IPHash):
*mode = IPHash
return true
}
*mode = RoundRobin
return false
}

View File

@@ -13,7 +13,7 @@ func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} }
func (lb *roundRobin) OnAddServer(srv *Server) {}
func (lb *roundRobin) OnRemoveServer(srv *Server) {}
func (lb *roundRobin) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
func (lb *roundRobin) ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request) {
index := lb.index.Add(1) % uint32(len(srvs))
srvs[index].ServeHTTP(rw, r)
if lb.index.Load() >= 2*uint32(len(srvs)) {

View File

@@ -0,0 +1,14 @@
package loadbalancer
import (
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
)
type (
Server = types.Server
Servers = types.Servers
Pool = types.Pool
Weight = types.Weight
Config = types.Config
Mode = types.Mode
)

View File

@@ -0,0 +1,8 @@
package types
type Config struct {
Link string `json:"link" yaml:"link"`
Mode Mode `json:"mode" yaml:"mode"`
Weight Weight `json:"weight" yaml:"weight"`
Options map[string]any `json:"options,omitempty" yaml:"options,omitempty"`
}

View File

@@ -0,0 +1,32 @@
package types
import (
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Mode string
const (
ModeUnset Mode = ""
ModeRoundRobin Mode = "roundrobin"
ModeLeastConn Mode = "leastconn"
ModeIPHash Mode = "iphash"
)
func (mode *Mode) ValidateUpdate() bool {
switch strutils.ToLowerNoSnake(string(*mode)) {
case "":
return true
case string(ModeRoundRobin):
*mode = ModeRoundRobin
return true
case string(ModeLeastConn):
*mode = ModeLeastConn
return true
case string(ModeIPHash):
*mode = ModeIPHash
return true
}
*mode = ModeRoundRobin
return false
}

View File

@@ -1,9 +1,10 @@
package loadbalancer
package types
import (
"net/http"
"time"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
"github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
@@ -16,18 +17,18 @@ type (
Name string
URL types.URL
Weight weightType
Weight Weight
handler http.Handler
healthMon health.HealthMonitor
}
servers = []*Server
Servers = []*Server
Pool = F.Map[string, *Server]
)
var newPool = F.NewMap[Pool]
var NewServerPool = F.NewMap[Pool]
func NewServer(name string, url types.URL, weight weightType, handler http.Handler, healthMon health.HealthMonitor) *Server {
func NewServer(name string, url types.URL, weight Weight, handler http.Handler, healthMon health.HealthMonitor) *Server {
srv := &Server{
Name: name,
URL: url,
@@ -53,3 +54,17 @@ func (srv *Server) Status() health.Status {
func (srv *Server) Uptime() time.Duration {
return srv.healthMon.Uptime()
}
func (srv *Server) TryWake() error {
waker, ok := srv.handler.(idlewatcher.Waker)
if ok {
if err := waker.Wake(); err != nil {
return err
}
}
return nil
}
func (srv *Server) HealthMonitor() health.HealthMonitor {
return srv.healthMon
}

View File

@@ -0,0 +1,3 @@
package types
type Weight uint16

View File

@@ -0,0 +1,164 @@
package server
import (
"context"
"crypto/tls"
"errors"
"io"
"log"
"net/http"
"time"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
)
type Server struct {
Name string
CertProvider *autocert.Provider
http *http.Server
https *http.Server
httpStarted bool
httpsStarted bool
startTime time.Time
task task.Task
l zerolog.Logger
}
type Options struct {
Name string
HTTPAddr string
HTTPSAddr string
CertProvider *autocert.Provider
RedirectToHTTPS bool
Handler http.Handler
}
func StartServer(opt Options) (s *Server) {
s = NewServer(opt)
s.Start()
return s
}
func NewServer(opt Options) (s *Server) {
var httpSer, httpsSer *http.Server
var httpHandler http.Handler
logger := logging.With().Str("module", "server").Str("name", opt.Name).Logger()
certAvailable := false
if opt.CertProvider != nil {
_, err := opt.CertProvider.GetCert(nil)
certAvailable = err == nil
}
if certAvailable && opt.RedirectToHTTPS && opt.HTTPSAddr != "" {
httpHandler = redirectToTLSHandler(opt.HTTPSAddr)
} else {
httpHandler = opt.Handler
}
if opt.HTTPAddr != "" {
httpSer = &http.Server{
Addr: opt.HTTPAddr,
Handler: httpHandler,
ErrorLog: log.New(io.Discard, "", 0), // most are tls related
}
}
if certAvailable && opt.HTTPSAddr != "" {
httpsSer = &http.Server{
Addr: opt.HTTPSAddr,
Handler: opt.Handler,
ErrorLog: log.New(io.Discard, "", 0), // most are tls related
TLSConfig: &tls.Config{
GetCertificate: opt.CertProvider.GetCert,
},
}
}
return &Server{
Name: opt.Name,
CertProvider: opt.CertProvider,
http: httpSer,
https: httpsSer,
task: task.GlobalTask(opt.Name + " server"),
l: logger,
}
}
// Start will start the http and https servers.
//
// If both are not set, this does nothing.
//
// Start() is non-blocking.
func (s *Server) Start() {
if s.http == nil && s.https == nil {
return
}
s.startTime = time.Now()
if s.http != nil {
go func() {
s.handleErr("http", s.http.ListenAndServe())
}()
s.httpStarted = true
s.l.Info().Str("addr", s.http.Addr).Msg("server started")
}
if s.https != nil {
go func() {
s.handleErr("https", s.https.ListenAndServeTLS(s.CertProvider.GetCertPath(), s.CertProvider.GetKeyPath()))
}()
s.httpsStarted = true
s.l.Info().Str("addr", s.https.Addr).Msgf("server started")
}
s.task.OnFinished("stop server", s.stop)
}
func (s *Server) stop() {
if s.http == nil && s.https == nil {
return
}
if s.http != nil && s.httpStarted {
s.handleErr("http", s.http.Shutdown(s.task.Context()))
s.httpStarted = false
}
if s.https != nil && s.httpsStarted {
s.handleErr("https", s.https.Shutdown(s.task.Context()))
s.httpsStarted = false
}
}
func (s *Server) Uptime() time.Duration {
return time.Since(s.startTime)
}
func (s *Server) handleErr(scheme string, err error) {
switch {
case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled):
return
default:
s.l.Fatal().Err(err).Str("scheme", scheme).Msg("server error")
}
}
func redirectToTLSHandler(port string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
r.URL.Scheme = "https"
r.URL.Host = r.URL.Hostname() + port
var redirectCode int
if r.Method == http.MethodGet {
redirectCode = http.StatusMovedPermanently
} else {
redirectCode = http.StatusPermanentRedirect
}
http.Redirect(w, r, r.URL.String(), redirectCode)
}
}