refactored some stuff, added healthcheck support, fixed 'include file' reload not showing in log

This commit is contained in:
yusing
2024-10-12 13:56:38 +08:00
parent 64e30f59e8
commit d47b672aa5
41 changed files with 783 additions and 421 deletions

View File

@@ -4,15 +4,40 @@ import (
"hash/fnv"
"net"
"net/http"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware"
)
type ipHash struct{ *LoadBalancer }
type ipHash struct {
*LoadBalancer
realIP *middleware.Middleware
}
func (lb *LoadBalancer) newIPHash() impl { return &ipHash{lb} }
func (lb *LoadBalancer) newIPHash() impl {
impl := &ipHash{LoadBalancer: lb}
if len(lb.Options) == 0 {
return impl
}
var err E.NestedError
impl.realIP, err = middleware.NewRealIP(lb.Options)
if err != nil {
logger.Errorf("loadbalancer %s invalid real_ip options: %s, ignoring", lb.Link, err)
}
return impl
}
func (ipHash) OnAddServer(srv *Server) {}
func (ipHash) OnRemoveServer(srv *Server) {}
func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) {
if impl.realIP != nil {
impl.realIP.ModifyRequest(impl.serveHTTP, rw, r)
} else {
impl.serveHTTP(rw, r)
}
}
func (impl ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(rw, "Internal error", http.StatusInternalServerError)
@@ -20,7 +45,7 @@ func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request)
return
}
idx := hashIP(ip) % uint32(len(impl.pool))
if !impl.pool[idx].available.Load() {
if !impl.pool[idx].IsHealthy() {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
}
impl.pool[idx].handler.ServeHTTP(rw, r)

View File

@@ -1,13 +1,12 @@
package loadbalancer
import (
"context"
"net/http"
"sync"
"time"
"github.com/go-acme/lego/v4/log"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware"
)
// TODO: stats of each server.
@@ -19,20 +18,17 @@ type (
OnRemoveServer(srv *Server)
}
Config struct {
Link string
Mode Mode
Weight weightType
Link string `json:"link" yaml:"link"`
Mode Mode `json:"mode" yaml:"mode"`
Weight weightType `json:"weight" yaml:"weight"`
Options middleware.OptionsRaw `json:"options,omitempty" yaml:"options,omitempty"`
}
LoadBalancer struct {
impl
Config
pool servers
poolMu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
done chan struct{}
poolMu sync.Mutex
sumWeight weightType
}
@@ -73,8 +69,8 @@ func (lb *LoadBalancer) AddServer(srv *Server) {
}
func (lb *LoadBalancer) RemoveServer(srv *Server) {
lb.poolMu.RLock()
defer lb.poolMu.RUnlock()
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
lb.impl.OnRemoveServer(srv)
@@ -85,7 +81,7 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) {
}
}
if lb.IsEmpty() {
lb.Stop()
lb.pool = nil
return
}
@@ -171,54 +167,12 @@ func (lb *LoadBalancer) Start() {
if lb.sumWeight != 0 {
log.Warnf("weighted mode not supported yet")
}
lb.done = make(chan struct{}, 1)
lb.ctx, lb.cancel = context.WithCancel(context.Background())
updateAll := func() {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
var wg sync.WaitGroup
wg.Add(len(lb.pool))
for _, s := range lb.pool {
go func(s *Server) {
defer wg.Done()
s.checkUpdateAvail(lb.ctx)
}(s)
}
wg.Wait()
}
logger.Debugf("loadbalancer %s started", lb.Link)
go func() {
defer lb.cancel()
defer close(lb.done)
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
updateAll()
for {
select {
case <-lb.ctx.Done():
return
case <-ticker.C:
updateAll()
}
}
}()
}
func (lb *LoadBalancer) Stop() {
if lb.cancel == nil {
return
}
lb.cancel()
<-lb.done
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
lb.pool = nil
logger.Debugf("loadbalancer %s stopped", lb.Link)
@@ -228,9 +182,9 @@ func (lb *LoadBalancer) availServers() servers {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
avail := servers{}
avail := make(servers, 0, len(lb.pool))
for _, s := range lb.pool {
if s.available.Load() {
if s.IsHealthy() {
avail = append(avail, s)
}
}

View File

@@ -1,67 +1,42 @@
package loadbalancer
import (
"context"
"net/http"
"sync/atomic"
"time"
"github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type (
Server struct {
Name string
URL types.URL
Weight weightType
handler http.Handler
_ U.NoCopy
pinger *http.Client
available atomic.Bool
Name string
URL types.URL
Weight weightType
handler http.Handler
healthMon health.HealthMonitor
}
servers []*Server
)
func NewServer(name string, url types.URL, weight weightType, handler http.Handler) *Server {
func NewServer(name string, url types.URL, weight weightType, handler http.Handler, healthMon health.HealthMonitor) *Server {
srv := &Server{
Name: name,
URL: url,
Weight: weight,
handler: handler,
pinger: &http.Client{Timeout: 3 * time.Second},
Name: name,
URL: url,
Weight: weight,
handler: handler,
healthMon: healthMon,
}
srv.available.Store(true)
return srv
}
func (srv *Server) checkUpdateAvail(ctx context.Context) {
req, err := http.NewRequestWithContext(
ctx,
http.MethodHead,
srv.URL.String(),
nil,
)
if err != nil {
logger.Error("failed to create request: ", err)
srv.available.Store(false)
}
resp, err := srv.pinger.Do(req)
if err == nil && resp.StatusCode != http.StatusServiceUnavailable {
if !srv.available.Swap(true) {
logger.Infof("server %s is up", srv.Name)
}
} else if err != nil {
if srv.available.Swap(false) {
logger.Warnf("server %s is down: %s", srv.Name, err)
}
} else {
if srv.available.Swap(false) {
logger.Warnf("server %s is down: status %s", srv.Name, resp.Status)
}
}
}
func (srv *Server) String() string {
return srv.Name
}
func (srv *Server) IsHealthy() bool {
return srv.healthMon.IsHealthy()
}

View File

@@ -30,6 +30,8 @@ type (
Options any
Middleware struct {
_ U.NoCopy
name string
before BeforeFunc // runs before ReverseProxy.ServeHTTP
@@ -77,30 +79,37 @@ func (m *Middleware) MarshalJSON() ([]byte, error) {
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
if len(optsRaw) != 0 && m.withOptions != nil {
if mWithOpt, err := m.withOptions(optsRaw); err != nil {
return nil, err
} else {
return mWithOpt, nil
}
return m.withOptions(optsRaw)
}
// WithOptionsClone is called only once
// set withOptions and labelParser will not be used after that
return &Middleware{
m.name,
m.before,
m.modifyResponse,
nil,
m.impl,
m.parent,
m.children,
false,
name: m.name,
before: m.before,
modifyResponse: m.modifyResponse,
impl: m.impl,
parent: m.parent,
children: m.children,
}, nil
}
// TODO: check conflict or duplicates
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (res E.NestedError) {
middlewares := make([]*Middleware, 0, len(middlewaresMap))
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w ResponseWriter, r *Request) {
if m.before != nil {
m.before(next, w, r)
}
}
func (m *Middleware) ModifyResponse(resp *Response) error {
if m.modifyResponse != nil {
return m.modifyResponse(resp)
}
return nil
}
// TODO: check conflict or duplicates.
func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Middleware, res E.NestedError) {
middlewares = make([]*Middleware, 0, len(middlewaresMap))
invalidM := E.NewBuilder("invalid middlewares")
invalidOpts := E.NewBuilder("invalid options")
@@ -124,10 +133,15 @@ func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[strin
middlewares = append(middlewares, m)
}
if invalidM.HasError() {
return
}
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.NestedError) {
var middlewares []*Middleware
middlewares, err = createMiddlewares(middlewaresMap)
if err != nil {
return
}
patchReverseProxy(rpName, rp, middlewares)
return
}

View File

@@ -114,7 +114,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
} else {
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
}
rp := gphttp.NewReverseProxy(types.NewURL(proxyURL), &rr)
rp := gphttp.NewReverseProxy("test", types.NewURL(proxyURL), &rr)
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
if setOptErr != nil {
return nil, setOptErr

View File

@@ -86,7 +86,8 @@ type ReverseProxy struct {
ServeHTTP http.HandlerFunc
TargetURL types.URL
TargetName string
TargetURL types.URL
}
func singleJoiningSlash(a, b string) string {
@@ -144,11 +145,11 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
// }
//
func NewReverseProxy(target types.URL, transport http.RoundTripper) *ReverseProxy {
func NewReverseProxy(name string, target types.URL, transport http.RoundTripper) *ReverseProxy {
if transport == nil {
panic("nil transport")
}
rp := &ReverseProxy{Transport: transport, TargetURL: target}
rp := &ReverseProxy{Transport: transport, TargetName: name, TargetURL: target}
rp.ServeHTTP = rp.serveHTTP
return rp
}
@@ -194,9 +195,9 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err
switch {
case errors.Is(err, context.Canceled),
errors.Is(err, io.EOF):
logger.Debugf("http proxy to %s error: %s", r.URL.String(), err)
logger.Debugf("http proxy to %s(%s) error: %s", p.TargetName, r.URL.String(), err)
default:
logger.Errorf("http proxy to %s error: %s", r.URL.String(), err)
logger.Errorf("http proxy to %s(%s) error: %s", p.TargetName, r.URL.String(), err)
}
if writeHeader {
rw.WriteHeader(http.StatusBadGateway)

View File

@@ -9,20 +9,21 @@ import (
type CIDR net.IPNet
func (*CIDR) ConvertFrom(val any) (any, E.NestedError) {
cidr, ok := val.(string)
func (cidr *CIDR) ConvertFrom(val any) E.NestedError {
cidrStr, ok := val.(string)
if !ok {
return nil, E.TypeMismatch[string](val)
return E.TypeMismatch[string](val)
}
if !strings.Contains(cidr, "/") {
cidr += "/32" // single IP
if !strings.Contains(cidrStr, "/") {
cidrStr += "/32" // single IP
}
_, ipnet, err := net.ParseCIDR(cidr)
_, ipnet, err := net.ParseCIDR(cidrStr)
if err != nil {
return nil, E.Invalid("CIDR", cidr)
return E.Invalid("CIDR", cidr)
}
return (*CIDR)(ipnet), nil
*cidr = CIDR(*ipnet)
return nil
}
func (cidr *CIDR) Contains(ip net.IP) bool {

View File

@@ -1,10 +1,22 @@
package types
import "net/url"
import (
urlPkg "net/url"
)
type URL struct{ *url.URL }
type URL struct {
*urlPkg.URL
}
func NewURL(url *url.URL) URL {
func ParseURL(url string) (URL, error) {
u, err := urlPkg.Parse(url)
if err != nil {
return URL{}, err
}
return URL{URL: u}, nil
}
func NewURL(url *urlPkg.URL) URL {
return URL{url}
}
@@ -19,6 +31,10 @@ func (u URL) MarshalText() (text []byte, err error) {
return []byte(u.String()), nil
}
func (u URL) Equals(other URL) bool {
func (u URL) Equals(other *URL) bool {
return u.URL == other.URL || u.String() == other.String()
}
func (u URL) JoinPath(path string) URL {
return URL{u.URL.JoinPath(path)}
}