mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-18 06:29:42 +02:00
refactor(api): restructured API for type safety, maintainability and docs generation
- These changes makes the API incombatible with previous versions - Added new types for error handling, success responses, and health checks. - Updated health check logic to utilize the new types for better clarity and structure. - Refactored existing handlers to improve response consistency and error handling. - Updated Makefile to include a new target for generating API types from Swagger. - Updated "new agent" API to respond an encrypted cert pair
This commit is contained in:
@@ -1,30 +0,0 @@
|
||||
package gpwebsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type Writer struct {
|
||||
conn *websocket.Conn
|
||||
msgType int
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewWriter(ctx context.Context, conn *websocket.Conn, msgType int) *Writer {
|
||||
return &Writer{
|
||||
ctx: ctx,
|
||||
conn: conn,
|
||||
msgType: msgType,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Writer) Write(p []byte) (int, error) {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return 0, w.ctx.Err()
|
||||
default:
|
||||
return len(p), w.conn.WriteMessage(w.msgType, p)
|
||||
}
|
||||
}
|
||||
@@ -8,13 +8,14 @@ import (
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/middleware"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
)
|
||||
|
||||
type ipHash struct {
|
||||
*LoadBalancer
|
||||
|
||||
realIP *middleware.Middleware
|
||||
pool Servers
|
||||
pool types.LoadBalancerServers
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
@@ -31,7 +32,7 @@ func (lb *LoadBalancer) newIPHash() impl {
|
||||
return impl
|
||||
}
|
||||
|
||||
func (impl *ipHash) OnAddServer(srv Server) {
|
||||
func (impl *ipHash) OnAddServer(srv types.LoadBalancerServer) {
|
||||
impl.mu.Lock()
|
||||
defer impl.mu.Unlock()
|
||||
|
||||
@@ -48,7 +49,7 @@ func (impl *ipHash) OnAddServer(srv Server) {
|
||||
impl.pool = append(impl.pool, srv)
|
||||
}
|
||||
|
||||
func (impl *ipHash) OnRemoveServer(srv Server) {
|
||||
func (impl *ipHash) OnRemoveServer(srv types.LoadBalancerServer) {
|
||||
impl.mu.Lock()
|
||||
defer impl.mu.Unlock()
|
||||
|
||||
@@ -60,7 +61,7 @@ func (impl *ipHash) OnRemoveServer(srv Server) {
|
||||
}
|
||||
}
|
||||
|
||||
func (impl *ipHash) ServeHTTP(_ Servers, rw http.ResponseWriter, r *http.Request) {
|
||||
func (impl *ipHash) ServeHTTP(_ types.LoadBalancerServers, rw http.ResponseWriter, r *http.Request) {
|
||||
if impl.realIP != nil {
|
||||
impl.realIP.ModifyRequest(impl.serveHTTP, rw, r)
|
||||
} else {
|
||||
|
||||
@@ -4,30 +4,31 @@ import (
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
)
|
||||
|
||||
type leastConn struct {
|
||||
*LoadBalancer
|
||||
nConn F.Map[Server, *atomic.Int64]
|
||||
nConn *xsync.Map[types.LoadBalancerServer, *atomic.Int64]
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) newLeastConn() impl {
|
||||
return &leastConn{
|
||||
LoadBalancer: lb,
|
||||
nConn: F.NewMapOf[Server, *atomic.Int64](),
|
||||
nConn: xsync.NewMap[types.LoadBalancerServer, *atomic.Int64](),
|
||||
}
|
||||
}
|
||||
|
||||
func (impl *leastConn) OnAddServer(srv Server) {
|
||||
func (impl *leastConn) OnAddServer(srv types.LoadBalancerServer) {
|
||||
impl.nConn.Store(srv, new(atomic.Int64))
|
||||
}
|
||||
|
||||
func (impl *leastConn) OnRemoveServer(srv Server) {
|
||||
func (impl *leastConn) OnRemoveServer(srv types.LoadBalancerServer) {
|
||||
impl.nConn.Delete(srv)
|
||||
}
|
||||
|
||||
func (impl *leastConn) ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request) {
|
||||
func (impl *leastConn) ServeHTTP(srvs types.LoadBalancerServers, rw http.ResponseWriter, r *http.Request) {
|
||||
srv := srvs[0]
|
||||
minConn, ok := impl.nConn.Load(srv)
|
||||
if !ok {
|
||||
|
||||
@@ -10,44 +10,43 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/utils/pool"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
// TODO: stats of each server.
|
||||
// TODO: support weighted mode.
|
||||
type (
|
||||
impl interface {
|
||||
ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request)
|
||||
OnAddServer(srv Server)
|
||||
OnRemoveServer(srv Server)
|
||||
ServeHTTP(srvs types.LoadBalancerServers, rw http.ResponseWriter, r *http.Request)
|
||||
OnAddServer(srv types.LoadBalancerServer)
|
||||
OnRemoveServer(srv types.LoadBalancerServer)
|
||||
}
|
||||
|
||||
LoadBalancer struct {
|
||||
impl
|
||||
*Config
|
||||
*types.LoadBalancerConfig
|
||||
|
||||
task *task.Task
|
||||
|
||||
pool pool.Pool[Server]
|
||||
pool pool.Pool[types.LoadBalancerServer]
|
||||
poolMu sync.Mutex
|
||||
|
||||
sumWeight Weight
|
||||
sumWeight int
|
||||
startTime time.Time
|
||||
|
||||
l zerolog.Logger
|
||||
}
|
||||
)
|
||||
|
||||
const maxWeight Weight = 100
|
||||
const maxWeight int = 100
|
||||
|
||||
func New(cfg *Config) *LoadBalancer {
|
||||
func New(cfg *types.LoadBalancerConfig) *LoadBalancer {
|
||||
lb := &LoadBalancer{
|
||||
Config: new(Config),
|
||||
pool: pool.New[Server]("loadbalancer." + cfg.Link),
|
||||
l: log.With().Str("name", cfg.Link).Logger(),
|
||||
LoadBalancerConfig: cfg,
|
||||
pool: pool.New[types.LoadBalancerServer]("loadbalancer." + cfg.Link),
|
||||
l: log.With().Str("name", cfg.Link).Logger(),
|
||||
}
|
||||
lb.UpdateConfigIfNeeded(cfg)
|
||||
return lb
|
||||
@@ -80,11 +79,11 @@ func (lb *LoadBalancer) Finish(reason any) {
|
||||
|
||||
func (lb *LoadBalancer) updateImpl() {
|
||||
switch lb.Mode {
|
||||
case types.ModeUnset, types.ModeRoundRobin:
|
||||
case types.LoadbalanceModeUnset, types.LoadbalanceModeRoundRobin:
|
||||
lb.impl = lb.newRoundRobin()
|
||||
case types.ModeLeastConn:
|
||||
case types.LoadbalanceModeLeastConn:
|
||||
lb.impl = lb.newLeastConn()
|
||||
case types.ModeIPHash:
|
||||
case types.LoadbalanceModeIPHash:
|
||||
lb.impl = lb.newIPHash()
|
||||
default: // should happen in test only
|
||||
lb.impl = lb.newRoundRobin()
|
||||
@@ -94,14 +93,14 @@ func (lb *LoadBalancer) updateImpl() {
|
||||
}
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
|
||||
func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *types.LoadBalancerConfig) {
|
||||
if cfg != nil {
|
||||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
|
||||
lb.Link = cfg.Link
|
||||
|
||||
if lb.Mode == types.ModeUnset && cfg.Mode != types.ModeUnset {
|
||||
if lb.Mode == types.LoadbalanceModeUnset && cfg.Mode != types.LoadbalanceModeUnset {
|
||||
lb.Mode = cfg.Mode
|
||||
if !lb.Mode.ValidateUpdate() {
|
||||
lb.l.Error().Msgf("invalid mode %q, fallback to %q", cfg.Mode, lb.Mode)
|
||||
@@ -119,7 +118,7 @@ func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
|
||||
}
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) AddServer(srv Server) {
|
||||
func (lb *LoadBalancer) AddServer(srv types.LoadBalancerServer) {
|
||||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
|
||||
@@ -135,7 +134,7 @@ func (lb *LoadBalancer) AddServer(srv Server) {
|
||||
lb.impl.OnAddServer(srv)
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) RemoveServer(srv Server) {
|
||||
func (lb *LoadBalancer) RemoveServer(srv types.LoadBalancerServer) {
|
||||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
|
||||
@@ -170,8 +169,8 @@ func (lb *LoadBalancer) rebalance() {
|
||||
return
|
||||
}
|
||||
if lb.sumWeight == 0 { // distribute evenly
|
||||
weightEach := maxWeight / Weight(poolSize)
|
||||
remainder := maxWeight % Weight(poolSize)
|
||||
weightEach := maxWeight / poolSize
|
||||
remainder := maxWeight % poolSize
|
||||
for _, srv := range lb.pool.Iter {
|
||||
w := weightEach
|
||||
lb.sumWeight += weightEach
|
||||
@@ -189,7 +188,7 @@ func (lb *LoadBalancer) rebalance() {
|
||||
lb.sumWeight = 0
|
||||
|
||||
for _, srv := range lb.pool.Iter {
|
||||
srv.SetWeight(Weight(float64(srv.Weight()) * scaleFactor))
|
||||
srv.SetWeight(int(float64(srv.Weight()) * scaleFactor))
|
||||
lb.sumWeight += srv.Weight()
|
||||
}
|
||||
|
||||
@@ -241,16 +240,16 @@ func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
|
||||
|
||||
status, numHealthy := lb.status()
|
||||
|
||||
return (&health.JSONRepresentation{
|
||||
return (&types.HealthJSONRepr{
|
||||
Name: lb.Name(),
|
||||
Status: status,
|
||||
Detail: fmt.Sprintf("%d/%d servers are healthy", numHealthy, lb.pool.Size()),
|
||||
Started: lb.startTime,
|
||||
Uptime: lb.Uptime(),
|
||||
Latency: lb.Latency(),
|
||||
Extra: map[string]any{
|
||||
"config": lb.Config,
|
||||
"pool": extra,
|
||||
Extra: &types.HealthExtra{
|
||||
Config: lb.LoadBalancerConfig,
|
||||
Pool: extra,
|
||||
},
|
||||
}).MarshalJSON()
|
||||
}
|
||||
@@ -261,7 +260,7 @@ func (lb *LoadBalancer) Name() string {
|
||||
}
|
||||
|
||||
// Status implements health.HealthMonitor.
|
||||
func (lb *LoadBalancer) Status() health.Status {
|
||||
func (lb *LoadBalancer) Status() types.HealthStatus {
|
||||
status, _ := lb.status()
|
||||
return status
|
||||
}
|
||||
@@ -272,9 +271,9 @@ func (lb *LoadBalancer) Detail() string {
|
||||
return fmt.Sprintf("%d/%d servers are healthy", numHealthy, lb.pool.Size())
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) status() (status health.Status, numHealthy int) {
|
||||
func (lb *LoadBalancer) status() (status types.HealthStatus, numHealthy int) {
|
||||
if lb.pool.Size() == 0 {
|
||||
return health.StatusUnknown, 0
|
||||
return types.StatusUnknown, 0
|
||||
}
|
||||
|
||||
// should be healthy if at least one server is healthy
|
||||
@@ -285,9 +284,9 @@ func (lb *LoadBalancer) status() (status health.Status, numHealthy int) {
|
||||
}
|
||||
}
|
||||
if numHealthy == 0 {
|
||||
return health.StatusUnhealthy, numHealthy
|
||||
return types.StatusUnhealthy, numHealthy
|
||||
}
|
||||
return health.StatusHealthy, numHealthy
|
||||
return types.StatusHealthy, numHealthy
|
||||
}
|
||||
|
||||
// Uptime implements health.HealthMonitor.
|
||||
@@ -309,8 +308,8 @@ func (lb *LoadBalancer) String() string {
|
||||
return lb.Name()
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) availServers() []Server {
|
||||
avail := make([]Server, 0, lb.pool.Size())
|
||||
func (lb *LoadBalancer) availServers() []types.LoadBalancerServer {
|
||||
avail := make([]types.LoadBalancerServer, 0, lb.pool.Size())
|
||||
for _, srv := range lb.pool.Iter {
|
||||
if srv.Status().Good() {
|
||||
avail = append(avail, srv)
|
||||
|
||||
@@ -3,40 +3,40 @@ package loadbalancer
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestRebalance(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("zero", func(t *testing.T) {
|
||||
lb := New(new(types.Config))
|
||||
lb := New(new(types.LoadBalancerConfig))
|
||||
for range 10 {
|
||||
lb.AddServer(types.TestNewServer(0))
|
||||
lb.AddServer(TestNewServer(0))
|
||||
}
|
||||
lb.rebalance()
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
t.Run("less", func(t *testing.T) {
|
||||
lb := New(new(types.Config))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .3))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
|
||||
lb := New(new(types.LoadBalancerConfig))
|
||||
lb.AddServer(TestNewServer(float64(maxWeight) * .1))
|
||||
lb.AddServer(TestNewServer(float64(maxWeight) * .2))
|
||||
lb.AddServer(TestNewServer(float64(maxWeight) * .3))
|
||||
lb.AddServer(TestNewServer(float64(maxWeight) * .2))
|
||||
lb.AddServer(TestNewServer(float64(maxWeight) * .1))
|
||||
lb.rebalance()
|
||||
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
t.Run("more", func(t *testing.T) {
|
||||
lb := New(new(types.Config))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .3))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .4))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .3))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
|
||||
lb := New(new(types.LoadBalancerConfig))
|
||||
lb.AddServer(TestNewServer(float64(maxWeight) * .1))
|
||||
lb.AddServer(TestNewServer(float64(maxWeight) * .2))
|
||||
lb.AddServer(TestNewServer(float64(maxWeight) * .3))
|
||||
lb.AddServer(TestNewServer(float64(maxWeight) * .4))
|
||||
lb.AddServer(TestNewServer(float64(maxWeight) * .3))
|
||||
lb.AddServer(TestNewServer(float64(maxWeight) * .2))
|
||||
lb.AddServer(TestNewServer(float64(maxWeight) * .1))
|
||||
lb.rebalance()
|
||||
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
|
||||
@@ -3,17 +3,19 @@ package loadbalancer
|
||||
import (
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
)
|
||||
|
||||
type roundRobin struct {
|
||||
index atomic.Uint32
|
||||
}
|
||||
|
||||
func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} }
|
||||
func (lb *roundRobin) OnAddServer(srv Server) {}
|
||||
func (lb *roundRobin) OnRemoveServer(srv Server) {}
|
||||
func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} }
|
||||
func (lb *roundRobin) OnAddServer(srv types.LoadBalancerServer) {}
|
||||
func (lb *roundRobin) OnRemoveServer(srv types.LoadBalancerServer) {}
|
||||
|
||||
func (lb *roundRobin) ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request) {
|
||||
func (lb *roundRobin) ServeHTTP(srvs types.LoadBalancerServers, rw http.ResponseWriter, r *http.Request) {
|
||||
index := lb.index.Add(1) % uint32(len(srvs))
|
||||
srvs[index].ServeHTTP(rw, r)
|
||||
if lb.index.Load() >= 2*uint32(len(srvs)) {
|
||||
|
||||
@@ -1,39 +1,26 @@
|
||||
package types
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
|
||||
nettypes "github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
type (
|
||||
server struct {
|
||||
_ U.NoCopy
|
||||
type server struct {
|
||||
_ U.NoCopy
|
||||
|
||||
name string
|
||||
url *nettypes.URL
|
||||
weight Weight
|
||||
name string
|
||||
url *nettypes.URL
|
||||
weight int
|
||||
|
||||
http.Handler `json:"-"`
|
||||
health.HealthMonitor
|
||||
}
|
||||
http.Handler `json:"-"`
|
||||
types.HealthMonitor
|
||||
}
|
||||
|
||||
Server interface {
|
||||
http.Handler
|
||||
health.HealthMonitor
|
||||
Name() string
|
||||
Key() string
|
||||
URL() *nettypes.URL
|
||||
Weight() Weight
|
||||
SetWeight(weight Weight)
|
||||
TryWake() error
|
||||
}
|
||||
)
|
||||
|
||||
func NewServer(name string, url *nettypes.URL, weight Weight, handler http.Handler, healthMon health.HealthMonitor) Server {
|
||||
func NewServer(name string, url *nettypes.URL, weight int, handler http.Handler, healthMon types.HealthMonitor) types.LoadBalancerServer {
|
||||
srv := &server{
|
||||
name: name,
|
||||
url: url,
|
||||
@@ -44,9 +31,9 @@ func NewServer(name string, url *nettypes.URL, weight Weight, handler http.Handl
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestNewServer[T ~int | ~float32 | ~float64](weight T) Server {
|
||||
func TestNewServer[T ~int | ~float32 | ~float64](weight T) types.LoadBalancerServer {
|
||||
srv := &server{
|
||||
weight: Weight(weight),
|
||||
weight: int(weight),
|
||||
url: nettypes.MustParseURL("http://localhost"),
|
||||
}
|
||||
return srv
|
||||
@@ -64,11 +51,11 @@ func (srv *server) Key() string {
|
||||
return srv.url.Host
|
||||
}
|
||||
|
||||
func (srv *server) Weight() Weight {
|
||||
func (srv *server) Weight() int {
|
||||
return srv.weight
|
||||
}
|
||||
|
||||
func (srv *server) SetWeight(weight Weight) {
|
||||
func (srv *server) SetWeight(weight int) {
|
||||
srv.weight = weight
|
||||
}
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
|
||||
)
|
||||
|
||||
type (
|
||||
Server = types.Server
|
||||
Servers = []types.Server
|
||||
Weight = types.Weight
|
||||
Config = types.Config
|
||||
Mode = types.Mode
|
||||
)
|
||||
@@ -1,8 +0,0 @@
|
||||
package types
|
||||
|
||||
type Config struct {
|
||||
Link string `json:"link"`
|
||||
Mode Mode `json:"mode"`
|
||||
Weight Weight `json:"weight"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type Mode string
|
||||
|
||||
const (
|
||||
ModeUnset Mode = ""
|
||||
ModeRoundRobin Mode = "roundrobin"
|
||||
ModeLeastConn Mode = "leastconn"
|
||||
ModeIPHash Mode = "iphash"
|
||||
)
|
||||
|
||||
func (mode *Mode) ValidateUpdate() bool {
|
||||
switch strutils.ToLowerNoSnake(string(*mode)) {
|
||||
case "":
|
||||
return true
|
||||
case string(ModeRoundRobin):
|
||||
*mode = ModeRoundRobin
|
||||
return true
|
||||
case string(ModeLeastConn):
|
||||
*mode = ModeLeastConn
|
||||
return true
|
||||
case string(ModeIPHash):
|
||||
*mode = ModeIPHash
|
||||
return true
|
||||
}
|
||||
*mode = ModeRoundRobin
|
||||
return false
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
package types
|
||||
|
||||
type Weight int
|
||||
252
internal/net/gphttp/websocket/manager.go
Normal file
252
internal/net/gphttp/websocket/manager.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
)
|
||||
|
||||
// Manager handles WebSocket connection state and ping-pong
|
||||
type Manager struct {
|
||||
conn *websocket.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
pongWriteTimeout time.Duration
|
||||
pingCheckTicker *time.Ticker
|
||||
lastPingTime atomic.Value
|
||||
readCh chan []byte
|
||||
err error
|
||||
}
|
||||
|
||||
var defaultUpgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
// TODO: add CORS
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
var (
|
||||
ErrReadTimeout = errors.New("read timeout")
|
||||
ErrWriteTimeout = errors.New("write timeout")
|
||||
)
|
||||
|
||||
const (
|
||||
TextMessage = websocket.TextMessage
|
||||
BinaryMessage = websocket.BinaryMessage
|
||||
)
|
||||
|
||||
// NewManagerWithUpgrade upgrades the HTTP connection to a WebSocket connection and returns a Manager.
|
||||
// If the upgrade fails, the error is returned.
|
||||
// If the upgrade succeeds, the Manager is returned.
|
||||
func NewManagerWithUpgrade(c *gin.Context, upgrader ...websocket.Upgrader) (*Manager, error) {
|
||||
var actualUpgrader websocket.Upgrader
|
||||
if len(upgrader) == 0 {
|
||||
actualUpgrader = defaultUpgrader
|
||||
} else {
|
||||
actualUpgrader = upgrader[0]
|
||||
}
|
||||
|
||||
conn, err := actualUpgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
cm := &Manager{
|
||||
conn: conn,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
pongWriteTimeout: 2 * time.Second,
|
||||
pingCheckTicker: time.NewTicker(3 * time.Second),
|
||||
readCh: make(chan []byte, 1),
|
||||
}
|
||||
cm.lastPingTime.Store(time.Now())
|
||||
|
||||
conn.SetCloseHandler(func(code int, text string) error {
|
||||
if common.IsDebug {
|
||||
cm.err = fmt.Errorf("connection closed: code=%d, text=%s", code, text)
|
||||
}
|
||||
cm.Close()
|
||||
return nil
|
||||
})
|
||||
|
||||
go cm.pingCheckRoutine()
|
||||
go cm.readRoutine()
|
||||
|
||||
return cm, nil
|
||||
}
|
||||
|
||||
// Periodic writes data to the connection periodically.
|
||||
// If the connection is closed, the error is returned.
|
||||
// If the write timeout is reached, ErrWriteTimeout is returned.
|
||||
func (cm *Manager) PeriodicWrite(interval time.Duration, getData func() (any, error)) error {
|
||||
write := func() {
|
||||
data, err := getData()
|
||||
if err != nil {
|
||||
cm.err = err
|
||||
cm.Close()
|
||||
return
|
||||
}
|
||||
if err := cm.WriteJSON(data, interval); err != nil {
|
||||
cm.err = err
|
||||
cm.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// initial write before the ticker starts
|
||||
write()
|
||||
if cm.err != nil {
|
||||
return cm.err
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-cm.ctx.Done():
|
||||
return cm.err
|
||||
case <-ticker.C:
|
||||
write()
|
||||
if cm.err != nil {
|
||||
return cm.err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WriteJSON writes a JSON message to the connection with json.
|
||||
// If the connection is closed, the error is returned.
|
||||
// If the write timeout is reached, ErrWriteTimeout is returned.
|
||||
func (cm *Manager) WriteJSON(data any, timeout time.Duration) error {
|
||||
bytes, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return cm.WriteData(websocket.TextMessage, bytes, timeout)
|
||||
}
|
||||
|
||||
// WriteData writes a message to the connection with sonic.
|
||||
// If the connection is closed, the error is returned.
|
||||
// If the write timeout is reached, ErrWriteTimeout is returned.
|
||||
func (cm *Manager) WriteData(typ int, data []byte, timeout time.Duration) error {
|
||||
select {
|
||||
case <-cm.ctx.Done():
|
||||
return cm.err
|
||||
default:
|
||||
if err := cm.conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
err := cm.conn.WriteMessage(typ, data)
|
||||
if err != nil {
|
||||
if errors.Is(err, websocket.ErrCloseSent) {
|
||||
return cm.err
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return ErrWriteTimeout
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ReadJSON reads a JSON message from the connection and unmarshals it into the provided struct with sonic
|
||||
// If the connection is closed, the error is returned.
|
||||
// If the message fails to unmarshal, the error is returned.
|
||||
// If the read timeout is reached, ErrReadTimeout is returned.
|
||||
func (cm *Manager) ReadJSON(out any, timeout time.Duration) error {
|
||||
select {
|
||||
case <-cm.ctx.Done():
|
||||
return cm.err
|
||||
case data := <-cm.readCh:
|
||||
return json.Unmarshal(data, out)
|
||||
case <-time.After(timeout):
|
||||
return ErrReadTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the connection and cancels the context
|
||||
func (cm *Manager) Close() {
|
||||
cm.cancel()
|
||||
cm.pingCheckTicker.Stop()
|
||||
cm.conn.Close()
|
||||
}
|
||||
|
||||
func (cm *Manager) GracefulClose() {
|
||||
_ = cm.conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
_ = cm.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
cm.Close()
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the context is done or the connection is closed
|
||||
func (cm *Manager) Done() <-chan struct{} {
|
||||
return cm.ctx.Done()
|
||||
}
|
||||
|
||||
func (cm *Manager) pingCheckRoutine() {
|
||||
for {
|
||||
select {
|
||||
case <-cm.ctx.Done():
|
||||
return
|
||||
case <-cm.pingCheckTicker.C:
|
||||
if time.Since(cm.lastPingTime.Load().(time.Time)) > 5*time.Second {
|
||||
if common.IsDebug {
|
||||
cm.err = errors.New("no ping received in 5 seconds, closing connection")
|
||||
}
|
||||
cm.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *Manager) readRoutine() {
|
||||
for {
|
||||
select {
|
||||
case <-cm.ctx.Done():
|
||||
return
|
||||
default:
|
||||
typ, data, err := cm.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if cm.ctx.Err() == nil { // connection is not closed
|
||||
cm.err = fmt.Errorf("failed to read message: %w", err)
|
||||
cm.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if typ == websocket.TextMessage && string(data) == "ping" {
|
||||
cm.lastPingTime.Store(time.Now())
|
||||
if err := cm.conn.SetWriteDeadline(time.Now().Add(cm.pongWriteTimeout)); err != nil {
|
||||
cm.err = fmt.Errorf("failed to set write deadline: %w", err)
|
||||
cm.Close()
|
||||
return
|
||||
}
|
||||
if err := cm.conn.WriteMessage(websocket.TextMessage, []byte("pong")); err != nil {
|
||||
cm.err = fmt.Errorf("failed to write pong message: %w", err)
|
||||
cm.Close()
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if typ == websocket.TextMessage || typ == websocket.BinaryMessage {
|
||||
select {
|
||||
case <-cm.ctx.Done():
|
||||
return
|
||||
case cm.readCh <- data:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package gpwebsocket
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"net"
|
||||
@@ -7,9 +7,11 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/agent/pkg/agent"
|
||||
apitypes "github.com/yusing/go-proxy/internal/api/types"
|
||||
)
|
||||
|
||||
func warnNoMatchDomains() {
|
||||
@@ -30,8 +32,6 @@ func SetWebsocketAllowedDomains(h http.Header, domains []string) {
|
||||
h[HeaderXGoDoxyWebsocketAllowedDomains] = domains
|
||||
}
|
||||
|
||||
const writeTimeout = time.Second * 10
|
||||
|
||||
// Initiate upgrades the HTTP connection to a WebSocket connection.
|
||||
// It returns a WebSocket connection and an error if the upgrade fails.
|
||||
// It logs and responds with an error if the upgrade fails.
|
||||
@@ -76,39 +76,16 @@ func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
|
||||
return upgrader.Upgrade(w, r, nil)
|
||||
}
|
||||
|
||||
func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) {
|
||||
conn, err := Initiate(w, r)
|
||||
func PeriodicWrite(c *gin.Context, interval time.Duration, get func() (any, error)) {
|
||||
manager, err := NewManagerWithUpgrade(c)
|
||||
if err != nil {
|
||||
c.Error(apitypes.InternalServerError(err, "failed to upgrade to websocket"))
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := do(conn); err != nil {
|
||||
return
|
||||
err = manager.PeriodicWrite(interval, get)
|
||||
if err != nil {
|
||||
c.Error(apitypes.InternalServerError(err, "failed to write to websocket"))
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||||
if err := do(conn); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WriteText writes a text message to the websocket connection.
|
||||
// It returns true if the message was written successfully, false otherwise.
|
||||
// It logs an error if the message is not written successfully.
|
||||
func WriteText(conn *websocket.Conn, msg string) error {
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||||
return conn.WriteMessage(websocket.TextMessage, []byte(msg))
|
||||
}
|
||||
|
||||
func errHandler(w http.ResponseWriter, r *http.Request, status int, reason error) {
|
||||
26
internal/net/gphttp/websocket/writer.go
Normal file
26
internal/net/gphttp/websocket/writer.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Writer struct {
|
||||
msgType int
|
||||
manager *Manager
|
||||
}
|
||||
|
||||
func (cm *Manager) NewWriter(msgType int) io.Writer {
|
||||
return &Writer{
|
||||
msgType: msgType,
|
||||
manager: cm,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Writer) Write(p []byte) (int, error) {
|
||||
return len(p), w.manager.WriteData(w.msgType, p, 10*time.Second)
|
||||
}
|
||||
|
||||
func (w *Writer) Close() error {
|
||||
return w.manager.conn.Close()
|
||||
}
|
||||
Reference in New Issue
Block a user