mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-23 08:48:32 +02:00
modules reorganized and code refactor
This commit is contained in:
@@ -14,7 +14,7 @@ type ipHash struct {
|
||||
*LoadBalancer
|
||||
|
||||
realIP *middleware.Middleware
|
||||
pool servers
|
||||
pool Servers
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ func (lb *LoadBalancer) newIPHash() impl {
|
||||
var err E.Error
|
||||
impl.realIP, err = middleware.NewRealIP(lb.Options)
|
||||
if err != nil {
|
||||
E.LogError("invalid real_ip options, ignoring", err, &impl.Logger)
|
||||
E.LogError("invalid real_ip options, ignoring", err, &impl.l)
|
||||
}
|
||||
return impl
|
||||
}
|
||||
@@ -60,7 +60,7 @@ func (impl *ipHash) OnRemoveServer(srv *Server) {
|
||||
}
|
||||
}
|
||||
|
||||
func (impl *ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) {
|
||||
func (impl *ipHash) ServeHTTP(_ Servers, rw http.ResponseWriter, r *http.Request) {
|
||||
if impl.realIP != nil {
|
||||
impl.realIP.ModifyRequest(impl.serveHTTP, rw, r)
|
||||
} else {
|
||||
@@ -72,7 +72,7 @@ func (impl *ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
impl.Err(err).Msg("invalid remote address " + r.RemoteAddr)
|
||||
impl.l.Err(err).Msg("invalid remote address " + r.RemoteAddr)
|
||||
return
|
||||
}
|
||||
idx := hashIP(ip) % uint32(len(impl.pool))
|
||||
|
||||
@@ -27,18 +27,18 @@ func (impl *leastConn) OnRemoveServer(srv *Server) {
|
||||
impl.nConn.Delete(srv)
|
||||
}
|
||||
|
||||
func (impl *leastConn) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
|
||||
func (impl *leastConn) ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request) {
|
||||
srv := srvs[0]
|
||||
minConn, ok := impl.nConn.Load(srv)
|
||||
if !ok {
|
||||
impl.Error().Msgf("[BUG] server %s not found", srv.Name)
|
||||
impl.l.Error().Msgf("[BUG] server %s not found", srv.Name)
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
for i := 1; i < len(srvs); i++ {
|
||||
nConn, ok := impl.nConn.Load(srvs[i])
|
||||
if !ok {
|
||||
impl.Error().Msgf("[BUG] server %s not found", srv.Name)
|
||||
impl.l.Error().Msgf("[BUG] server %s not found", srv.Name)
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
}
|
||||
if nConn.Load() < minConn.Load() {
|
||||
|
||||
@@ -7,9 +7,8 @@ import (
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
|
||||
@@ -19,19 +18,12 @@ import (
|
||||
// TODO: support weighted mode.
|
||||
type (
|
||||
impl interface {
|
||||
ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request)
|
||||
ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request)
|
||||
OnAddServer(srv *Server)
|
||||
OnRemoveServer(srv *Server)
|
||||
}
|
||||
Config struct {
|
||||
Link string `json:"link" yaml:"link"`
|
||||
Mode Mode `json:"mode" yaml:"mode"`
|
||||
Weight weightType `json:"weight" yaml:"weight"`
|
||||
Options middleware.OptionsRaw `json:"options,omitempty" yaml:"options,omitempty"`
|
||||
}
|
||||
LoadBalancer struct {
|
||||
zerolog.Logger
|
||||
|
||||
LoadBalancer struct {
|
||||
impl
|
||||
*Config
|
||||
|
||||
@@ -40,20 +32,20 @@ type (
|
||||
pool Pool
|
||||
poolMu sync.Mutex
|
||||
|
||||
sumWeight weightType
|
||||
sumWeight Weight
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
weightType uint16
|
||||
l zerolog.Logger
|
||||
}
|
||||
)
|
||||
|
||||
const maxWeight weightType = 100
|
||||
const maxWeight Weight = 100
|
||||
|
||||
func New(cfg *Config) *LoadBalancer {
|
||||
lb := &LoadBalancer{
|
||||
Logger: logger.With().Str("name", cfg.Link).Logger(),
|
||||
Config: new(Config),
|
||||
pool: newPool(),
|
||||
pool: types.NewServerPool(),
|
||||
l: logger.With().Str("name", cfg.Link).Logger(),
|
||||
}
|
||||
lb.UpdateConfigIfNeeded(cfg)
|
||||
return lb
|
||||
@@ -81,11 +73,11 @@ func (lb *LoadBalancer) Finish(reason any) {
|
||||
|
||||
func (lb *LoadBalancer) updateImpl() {
|
||||
switch lb.Mode {
|
||||
case Unset, RoundRobin:
|
||||
case types.ModeUnset, types.ModeRoundRobin:
|
||||
lb.impl = lb.newRoundRobin()
|
||||
case LeastConn:
|
||||
case types.ModeLeastConn:
|
||||
lb.impl = lb.newLeastConn()
|
||||
case IPHash:
|
||||
case types.ModeIPHash:
|
||||
lb.impl = lb.newIPHash()
|
||||
default: // should happen in test only
|
||||
lb.impl = lb.newRoundRobin()
|
||||
@@ -102,10 +94,10 @@ func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
|
||||
|
||||
lb.Link = cfg.Link
|
||||
|
||||
if lb.Mode == Unset && cfg.Mode != Unset {
|
||||
if lb.Mode == types.ModeUnset && cfg.Mode != types.ModeUnset {
|
||||
lb.Mode = cfg.Mode
|
||||
if !lb.Mode.ValidateUpdate() {
|
||||
lb.Error().Msgf("invalid mode %q, fallback to %q", cfg.Mode, lb.Mode)
|
||||
lb.l.Error().Msgf("invalid mode %q, fallback to %q", cfg.Mode, lb.Mode)
|
||||
}
|
||||
lb.updateImpl()
|
||||
}
|
||||
@@ -135,7 +127,7 @@ func (lb *LoadBalancer) AddServer(srv *Server) {
|
||||
lb.rebalance()
|
||||
lb.impl.OnAddServer(srv)
|
||||
|
||||
lb.Debug().
|
||||
lb.l.Debug().
|
||||
Str("action", "add").
|
||||
Str("server", srv.Name).
|
||||
Msgf("%d servers available", lb.pool.Size())
|
||||
@@ -155,7 +147,7 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) {
|
||||
lb.rebalance()
|
||||
lb.impl.OnRemoveServer(srv)
|
||||
|
||||
lb.Debug().
|
||||
lb.l.Debug().
|
||||
Str("action", "remove").
|
||||
Str("server", srv.Name).
|
||||
Msgf("%d servers left", lb.pool.Size())
|
||||
@@ -174,8 +166,8 @@ func (lb *LoadBalancer) rebalance() {
|
||||
return
|
||||
}
|
||||
if lb.sumWeight == 0 { // distribute evenly
|
||||
weightEach := maxWeight / weightType(lb.pool.Size())
|
||||
remainder := maxWeight % weightType(lb.pool.Size())
|
||||
weightEach := maxWeight / Weight(lb.pool.Size())
|
||||
remainder := maxWeight % Weight(lb.pool.Size())
|
||||
lb.pool.RangeAll(func(_ string, s *Server) {
|
||||
s.Weight = weightEach
|
||||
lb.sumWeight += weightEach
|
||||
@@ -192,7 +184,7 @@ func (lb *LoadBalancer) rebalance() {
|
||||
lb.sumWeight = 0
|
||||
|
||||
lb.pool.RangeAll(func(_ string, s *Server) {
|
||||
s.Weight = weightType(float64(s.Weight) * scaleFactor)
|
||||
s.Weight = Weight(float64(s.Weight) * scaleFactor)
|
||||
lb.sumWeight += s.Weight
|
||||
})
|
||||
|
||||
@@ -226,13 +218,7 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get(common.HeaderCheckRedirect) != "" {
|
||||
// wake all servers
|
||||
for _, srv := range srvs {
|
||||
// wake only if server implements Waker
|
||||
waker, ok := srv.handler.(idlewatcher.Waker)
|
||||
if ok {
|
||||
if err := waker.Wake(); err != nil {
|
||||
lb.Err(err).Msgf("failed to wake server %s", srv.Name)
|
||||
}
|
||||
}
|
||||
srv.TryWake()
|
||||
}
|
||||
}
|
||||
lb.impl.ServeHTTP(srvs, rw, r)
|
||||
@@ -246,7 +232,7 @@ func (lb *LoadBalancer) Uptime() time.Duration {
|
||||
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
|
||||
extra := make(map[string]any)
|
||||
lb.pool.RangeAll(func(k string, v *Server) {
|
||||
extra[v.Name] = v.healthMon
|
||||
extra[v.Name] = v.HealthMonitor()
|
||||
})
|
||||
|
||||
return (&monitor.JSONRepresentation{
|
||||
|
||||
@@ -3,13 +3,14 @@ package loadbalancer
|
||||
import (
|
||||
"testing"
|
||||
|
||||
loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestRebalance(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("zero", func(t *testing.T) {
|
||||
lb := New(new(Config))
|
||||
lb := New(new(loadbalance.Config))
|
||||
for range 10 {
|
||||
lb.AddServer(&Server{})
|
||||
}
|
||||
@@ -17,25 +18,25 @@ func TestRebalance(t *testing.T) {
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
t.Run("less", func(t *testing.T) {
|
||||
lb := New(new(Config))
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||
lb := New(new(loadbalance.Config))
|
||||
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .1)})
|
||||
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .3)})
|
||||
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .1)})
|
||||
lb.rebalance()
|
||||
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
t.Run("more", func(t *testing.T) {
|
||||
lb := New(new(Config))
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .4)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||
lb := New(new(loadbalance.Config))
|
||||
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .1)})
|
||||
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .3)})
|
||||
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .4)})
|
||||
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .3)})
|
||||
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .1)})
|
||||
lb.rebalance()
|
||||
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type Mode string
|
||||
|
||||
const (
|
||||
Unset Mode = ""
|
||||
RoundRobin Mode = "roundrobin"
|
||||
LeastConn Mode = "leastconn"
|
||||
IPHash Mode = "iphash"
|
||||
)
|
||||
|
||||
func (mode *Mode) ValidateUpdate() bool {
|
||||
switch strutils.ToLowerNoSnake(string(*mode)) {
|
||||
case "":
|
||||
return true
|
||||
case string(RoundRobin):
|
||||
*mode = RoundRobin
|
||||
return true
|
||||
case string(LeastConn):
|
||||
*mode = LeastConn
|
||||
return true
|
||||
case string(IPHash):
|
||||
*mode = IPHash
|
||||
return true
|
||||
}
|
||||
*mode = RoundRobin
|
||||
return false
|
||||
}
|
||||
@@ -13,7 +13,7 @@ func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} }
|
||||
func (lb *roundRobin) OnAddServer(srv *Server) {}
|
||||
func (lb *roundRobin) OnRemoveServer(srv *Server) {}
|
||||
|
||||
func (lb *roundRobin) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
|
||||
func (lb *roundRobin) ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request) {
|
||||
index := lb.index.Add(1) % uint32(len(srvs))
|
||||
srvs[index].ServeHTTP(rw, r)
|
||||
if lb.index.Load() >= 2*uint32(len(srvs)) {
|
||||
|
||||
14
internal/net/http/loadbalancer/types.go
Normal file
14
internal/net/http/loadbalancer/types.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
|
||||
)
|
||||
|
||||
type (
|
||||
Server = types.Server
|
||||
Servers = types.Servers
|
||||
Pool = types.Pool
|
||||
Weight = types.Weight
|
||||
Config = types.Config
|
||||
Mode = types.Mode
|
||||
)
|
||||
8
internal/net/http/loadbalancer/types/config.go
Normal file
8
internal/net/http/loadbalancer/types/config.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package types
|
||||
|
||||
type Config struct {
|
||||
Link string `json:"link" yaml:"link"`
|
||||
Mode Mode `json:"mode" yaml:"mode"`
|
||||
Weight Weight `json:"weight" yaml:"weight"`
|
||||
Options map[string]any `json:"options,omitempty" yaml:"options,omitempty"`
|
||||
}
|
||||
32
internal/net/http/loadbalancer/types/mode.go
Normal file
32
internal/net/http/loadbalancer/types/mode.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type Mode string
|
||||
|
||||
const (
|
||||
ModeUnset Mode = ""
|
||||
ModeRoundRobin Mode = "roundrobin"
|
||||
ModeLeastConn Mode = "leastconn"
|
||||
ModeIPHash Mode = "iphash"
|
||||
)
|
||||
|
||||
func (mode *Mode) ValidateUpdate() bool {
|
||||
switch strutils.ToLowerNoSnake(string(*mode)) {
|
||||
case "":
|
||||
return true
|
||||
case string(ModeRoundRobin):
|
||||
*mode = ModeRoundRobin
|
||||
return true
|
||||
case string(ModeLeastConn):
|
||||
*mode = ModeLeastConn
|
||||
return true
|
||||
case string(ModeIPHash):
|
||||
*mode = ModeIPHash
|
||||
return true
|
||||
}
|
||||
*mode = ModeRoundRobin
|
||||
return false
|
||||
}
|
||||
@@ -1,9 +1,10 @@
|
||||
package loadbalancer
|
||||
package types
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
@@ -16,18 +17,18 @@ type (
|
||||
|
||||
Name string
|
||||
URL types.URL
|
||||
Weight weightType
|
||||
Weight Weight
|
||||
|
||||
handler http.Handler
|
||||
healthMon health.HealthMonitor
|
||||
}
|
||||
servers = []*Server
|
||||
Servers = []*Server
|
||||
Pool = F.Map[string, *Server]
|
||||
)
|
||||
|
||||
var newPool = F.NewMap[Pool]
|
||||
var NewServerPool = F.NewMap[Pool]
|
||||
|
||||
func NewServer(name string, url types.URL, weight weightType, handler http.Handler, healthMon health.HealthMonitor) *Server {
|
||||
func NewServer(name string, url types.URL, weight Weight, handler http.Handler, healthMon health.HealthMonitor) *Server {
|
||||
srv := &Server{
|
||||
Name: name,
|
||||
URL: url,
|
||||
@@ -53,3 +54,17 @@ func (srv *Server) Status() health.Status {
|
||||
func (srv *Server) Uptime() time.Duration {
|
||||
return srv.healthMon.Uptime()
|
||||
}
|
||||
|
||||
func (srv *Server) TryWake() error {
|
||||
waker, ok := srv.handler.(idlewatcher.Waker)
|
||||
if ok {
|
||||
if err := waker.Wake(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (srv *Server) HealthMonitor() health.HealthMonitor {
|
||||
return srv.healthMon
|
||||
}
|
||||
3
internal/net/http/loadbalancer/types/weight.go
Normal file
3
internal/net/http/loadbalancer/types/weight.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package types
|
||||
|
||||
type Weight uint16
|
||||
164
internal/net/http/server/server.go
Normal file
164
internal/net/http/server/server.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/autocert"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
Name string
|
||||
CertProvider *autocert.Provider
|
||||
http *http.Server
|
||||
https *http.Server
|
||||
httpStarted bool
|
||||
httpsStarted bool
|
||||
startTime time.Time
|
||||
|
||||
task task.Task
|
||||
|
||||
l zerolog.Logger
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Name string
|
||||
HTTPAddr string
|
||||
HTTPSAddr string
|
||||
CertProvider *autocert.Provider
|
||||
RedirectToHTTPS bool
|
||||
Handler http.Handler
|
||||
}
|
||||
|
||||
func StartServer(opt Options) (s *Server) {
|
||||
s = NewServer(opt)
|
||||
s.Start()
|
||||
return s
|
||||
}
|
||||
|
||||
func NewServer(opt Options) (s *Server) {
|
||||
var httpSer, httpsSer *http.Server
|
||||
var httpHandler http.Handler
|
||||
|
||||
logger := logging.With().Str("module", "server").Str("name", opt.Name).Logger()
|
||||
|
||||
certAvailable := false
|
||||
if opt.CertProvider != nil {
|
||||
_, err := opt.CertProvider.GetCert(nil)
|
||||
certAvailable = err == nil
|
||||
}
|
||||
|
||||
if certAvailable && opt.RedirectToHTTPS && opt.HTTPSAddr != "" {
|
||||
httpHandler = redirectToTLSHandler(opt.HTTPSAddr)
|
||||
} else {
|
||||
httpHandler = opt.Handler
|
||||
}
|
||||
|
||||
if opt.HTTPAddr != "" {
|
||||
httpSer = &http.Server{
|
||||
Addr: opt.HTTPAddr,
|
||||
Handler: httpHandler,
|
||||
ErrorLog: log.New(io.Discard, "", 0), // most are tls related
|
||||
}
|
||||
}
|
||||
if certAvailable && opt.HTTPSAddr != "" {
|
||||
httpsSer = &http.Server{
|
||||
Addr: opt.HTTPSAddr,
|
||||
Handler: opt.Handler,
|
||||
ErrorLog: log.New(io.Discard, "", 0), // most are tls related
|
||||
TLSConfig: &tls.Config{
|
||||
GetCertificate: opt.CertProvider.GetCert,
|
||||
},
|
||||
}
|
||||
}
|
||||
return &Server{
|
||||
Name: opt.Name,
|
||||
CertProvider: opt.CertProvider,
|
||||
http: httpSer,
|
||||
https: httpsSer,
|
||||
task: task.GlobalTask(opt.Name + " server"),
|
||||
l: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Start will start the http and https servers.
|
||||
//
|
||||
// If both are not set, this does nothing.
|
||||
//
|
||||
// Start() is non-blocking.
|
||||
func (s *Server) Start() {
|
||||
if s.http == nil && s.https == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.startTime = time.Now()
|
||||
if s.http != nil {
|
||||
go func() {
|
||||
s.handleErr("http", s.http.ListenAndServe())
|
||||
}()
|
||||
s.httpStarted = true
|
||||
s.l.Info().Str("addr", s.http.Addr).Msg("server started")
|
||||
}
|
||||
|
||||
if s.https != nil {
|
||||
go func() {
|
||||
s.handleErr("https", s.https.ListenAndServeTLS(s.CertProvider.GetCertPath(), s.CertProvider.GetKeyPath()))
|
||||
}()
|
||||
s.httpsStarted = true
|
||||
s.l.Info().Str("addr", s.https.Addr).Msgf("server started")
|
||||
}
|
||||
|
||||
s.task.OnFinished("stop server", s.stop)
|
||||
}
|
||||
|
||||
func (s *Server) stop() {
|
||||
if s.http == nil && s.https == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.http != nil && s.httpStarted {
|
||||
s.handleErr("http", s.http.Shutdown(s.task.Context()))
|
||||
s.httpStarted = false
|
||||
}
|
||||
|
||||
if s.https != nil && s.httpsStarted {
|
||||
s.handleErr("https", s.https.Shutdown(s.task.Context()))
|
||||
s.httpsStarted = false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Uptime() time.Duration {
|
||||
return time.Since(s.startTime)
|
||||
}
|
||||
|
||||
func (s *Server) handleErr(scheme string, err error) {
|
||||
switch {
|
||||
case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled):
|
||||
return
|
||||
default:
|
||||
s.l.Fatal().Err(err).Str("scheme", scheme).Msg("server error")
|
||||
}
|
||||
}
|
||||
|
||||
func redirectToTLSHandler(port string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
r.URL.Scheme = "https"
|
||||
r.URL.Host = r.URL.Hostname() + port
|
||||
|
||||
var redirectCode int
|
||||
if r.Method == http.MethodGet {
|
||||
redirectCode = http.StatusMovedPermanently
|
||||
} else {
|
||||
redirectCode = http.StatusPermanentRedirect
|
||||
}
|
||||
http.Redirect(w, r, r.URL.String(), redirectCode)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user