tweak: replace coder/websocket with gorilla/websocket

This commit is contained in:
yusing
2025-05-19 23:15:11 +08:00
parent cee6eaecff
commit 1f50ee7f2f
16 changed files with 96 additions and 83 deletions

View File

@@ -4,8 +4,7 @@ import (
"net/http"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/gorilla/websocket"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
@@ -15,8 +14,7 @@ import (
func ListAgents(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
if httpheaders.IsWebsocket(r.Header) {
gpwebsocket.Periodic(w, r, 10*time.Second, func(conn *websocket.Conn) error {
wsjson.Write(r.Context(), conn, cfg.ListAgents())
return nil
return conn.WriteJSON(cfg.ListAgents())
})
} else {
gphttp.RespondJSON(w, r, cfg.ListAgents())

View File

@@ -22,8 +22,7 @@ func RenewCert(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
//nolint:errcheck
defer conn.CloseNow()
defer conn.Close()
logs, cancel := memlogger.Events()
defer cancel()
@@ -35,7 +34,7 @@ func RenewCert(w http.ResponseWriter, r *http.Request) {
err = autocert.ObtainCert()
if err != nil {
gperr.LogError("failed to obtain cert", err)
gpwebsocket.WriteText(r, conn, err.Error())
_ = gpwebsocket.WriteText(conn, err.Error())
} else {
logging.Info().Msg("cert obtained successfully")
}
@@ -46,7 +45,7 @@ func RenewCert(w http.ResponseWriter, r *http.Request) {
if err != nil {
return
}
if !gpwebsocket.WriteText(r, conn, string(l)) {
if err := gpwebsocket.WriteText(conn, string(l)); err != nil {
return
}
case <-done:

View File

@@ -1,15 +1,18 @@
package dockerapi
import (
"context"
"errors"
"net/http"
"strconv"
"github.com/coder/websocket"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/pkg/stdcopy"
"github.com/gorilla/websocket"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
"github.com/yusing/go-proxy/internal/task"
)
func Logs(w http.ResponseWriter, r *http.Request) {
@@ -31,6 +34,7 @@ func Logs(w http.ResponseWriter, r *http.Request) {
gphttp.NotFound(w, "server not found")
return
}
defer dockerClient.Close()
opts := container.LogsOptions{
ShowStdout: stdout,
@@ -56,11 +60,14 @@ func Logs(w http.ResponseWriter, r *http.Request) {
if err != nil {
return
}
defer conn.CloseNow()
defer conn.Close()
writer := gpwebsocket.NewWriter(r.Context(), conn, websocket.MessageText)
writer := gpwebsocket.NewWriter(r.Context(), conn, websocket.TextMessage)
_, err = stdcopy.StdCopy(writer, writer, logs) // de-multiplex logs
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, task.ErrProgramExiting) {
return
}
logging.Err(err).
Str("server", server).
Str("container", containerID).

View File

@@ -6,8 +6,7 @@ import (
"net/http"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/gorilla/websocket"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/gperr"
@@ -65,10 +64,12 @@ func getDockerClient(server string) (*docker.SharedClient, bool, error) {
break
}
}
for _, agent := range cfg.ListAgents() {
if agent.Name() == server {
host = agent.FakeDockerHost()
break
if host == "" {
for _, agent := range cfg.ListAgents() {
if agent.Name() == server {
host = agent.FakeDockerHost()
break
}
}
}
if host == "" {
@@ -115,7 +116,7 @@ func serveHTTP[V any, T ResultType[V]](w http.ResponseWriter, r *http.Request, g
if err != nil {
return err
}
return wsjson.Write(r.Context(), conn, result)
return conn.WriteJSON(result)
})
} else {
result, err := getResult(r.Context(), dockerClients)

View File

@@ -4,8 +4,7 @@ import (
"net/http"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/gorilla/websocket"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
@@ -15,7 +14,7 @@ import (
func Health(w http.ResponseWriter, r *http.Request) {
if httpheaders.IsWebsocket(r.Header) {
gpwebsocket.Periodic(w, r, 1*time.Second, func(conn *websocket.Conn) error {
return wsjson.Write(r.Context(), conn, routes.HealthMap())
return conn.WriteJSON(routes.HealthMap())
})
} else {
gphttp.RespondJSON(w, r, routes.HealthMap())

View File

@@ -4,8 +4,7 @@ import (
"net/http"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/gorilla/websocket"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
@@ -15,7 +14,7 @@ import (
func ListRouteProvidersHandler(cfgInstance config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
if httpheaders.IsWebsocket(r.Header) {
gpwebsocket.Periodic(w, r, 5*time.Second, func(conn *websocket.Conn) error {
return wsjson.Write(r.Context(), conn, cfgInstance.RouteProviderList())
return conn.WriteJSON(cfgInstance.RouteProviderList())
})
} else {
gphttp.RespondJSON(w, r, cfgInstance.RouteProviderList())

View File

@@ -4,8 +4,7 @@ import (
"net/http"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/gorilla/websocket"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
@@ -16,7 +15,7 @@ import (
func Stats(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
if httpheaders.IsWebsocket(r.Header) {
gpwebsocket.Periodic(w, r, 1*time.Second, func(conn *websocket.Conn) error {
return wsjson.Write(r.Context(), conn, getStats(cfg))
return conn.WriteJSON(getStats(cfg))
})
} else {
gphttp.RespondJSON(w, r, getStats(cfg))

View File

@@ -8,7 +8,7 @@ import (
"sync"
"time"
"github.com/coder/websocket"
"github.com/gorilla/websocket"
"github.com/puzpuzpuz/xsync/v4"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
)
@@ -81,7 +81,7 @@ func (m *memLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.connChans.Store(logCh, struct{}{})
defer func() {
_ = conn.CloseNow()
_ = conn.Close()
m.notifyLock.Lock()
m.connChans.Delete(logCh)
@@ -89,7 +89,7 @@ func (m *memLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.notifyLock.Unlock()
}()
if err := m.wsInitial(r.Context(), conn); err != nil {
if err := m.wsInitial(conn); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -169,15 +169,16 @@ func (m *memLogger) events() (logs <-chan []byte, cancel func()) {
}
}
func (m *memLogger) writeBytes(ctx context.Context, conn *websocket.Conn, b []byte) error {
return conn.Write(ctx, websocket.MessageText, b)
func (m *memLogger) writeBytes(conn *websocket.Conn, b []byte) error {
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
return conn.WriteMessage(websocket.TextMessage, b)
}
func (m *memLogger) wsInitial(ctx context.Context, conn *websocket.Conn) error {
func (m *memLogger) wsInitial(conn *websocket.Conn) error {
m.Lock()
defer m.Unlock()
return m.writeBytes(ctx, conn, m.Bytes())
return m.writeBytes(conn, m.Bytes())
}
func (m *memLogger) wsStreamLog(ctx context.Context, conn *websocket.Conn, ch <-chan *logEntryRange) {
@@ -188,7 +189,7 @@ func (m *memLogger) wsStreamLog(ctx context.Context, conn *websocket.Conn, ch <-
case logRange := <-ch:
m.RLock()
msg := m.Bytes()[logRange.Start:logRange.End]
err := m.writeBytes(ctx, conn, msg)
err := m.writeBytes(conn, msg)
m.RUnlock()
if err != nil {
return

View File

@@ -5,8 +5,7 @@ import (
"net/http"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/gorilla/websocket"
metricsutils "github.com/yusing/go-proxy/internal/metrics/utils"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
@@ -45,7 +44,7 @@ func (p *Poller[T, AggregateT]) ServeHTTP(w http.ResponseWriter, r *http.Request
if data == nil {
return nil
}
return wsjson.Write(r.Context(), conn, data)
return conn.WriteJSON(data)
})
} else {
data, err := p.getRespData(r)

View File

@@ -1,11 +1,14 @@
package gpwebsocket
import (
"net"
"net/http"
"slices"
"strings"
"sync"
"time"
"github.com/coder/websocket"
"github.com/gorilla/websocket"
"github.com/yusing/go-proxy/internal/logging"
)
@@ -27,29 +30,41 @@ func SetWebsocketAllowedDomains(h http.Header, domains []string) {
h[HeaderXGoDoxyWebsocketAllowedDomains] = domains
}
func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
var originPats []string
var localAddresses = []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
const writeTimeout = time.Second * 10
func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
upgrader := websocket.Upgrader{}
allowedDomains := WebsocketAllowedDomains(r.Header)
if len(allowedDomains) == 0 {
warnNoMatchDomainOnce.Do(warnNoMatchDomains)
originPats = []string{"*"}
} else {
originPats = make([]string, len(allowedDomains))
for i, domain := range allowedDomains {
if domain[0] != '.' {
originPats[i] = "*." + domain
} else {
originPats[i] = "*" + domain
}
upgrader.CheckOrigin = func(r *http.Request) bool {
return true
}
} else {
upgrader.CheckOrigin = func(r *http.Request) bool {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
if slices.Contains(localAddresses, host) {
return true
}
for _, domain := range allowedDomains {
if domain[0] == '.' {
if host == domain[1:] || strings.HasSuffix(host, domain) {
return true
}
} else if host == domain || strings.HasSuffix(host, "."+domain) {
return true
}
}
return false
}
originPats = append(originPats, localAddresses...)
}
return websocket.Accept(w, r, &websocket.AcceptOptions{
OriginPatterns: originPats,
})
return upgrader.Upgrade(w, r, nil)
}
func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) {
@@ -58,8 +73,7 @@ func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
//nolint:errcheck
defer conn.CloseNow()
defer conn.Close()
if err := do(conn); err != nil {
return
@@ -73,6 +87,7 @@ func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do
case <-r.Context().Done():
return
case <-ticker.C:
_ = conn.SetWriteDeadline(time.Now().Add(writeTimeout))
if err := do(conn); err != nil {
return
}
@@ -83,10 +98,7 @@ func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do
// WriteText writes a text message to the websocket connection.
// It returns true if the message was written successfully, false otherwise.
// It logs an error if the message is not written successfully.
func WriteText(r *http.Request, conn *websocket.Conn, msg string) bool {
if err := conn.Write(r.Context(), websocket.MessageText, []byte(msg)); err != nil {
logging.Err(err).Msg("failed to write text message")
return false
}
return true
func WriteText(conn *websocket.Conn, msg string) error {
_ = conn.SetWriteDeadline(time.Now().Add(writeTimeout))
return conn.WriteMessage(websocket.TextMessage, []byte(msg))
}

View File

@@ -3,16 +3,16 @@ package gpwebsocket
import (
"context"
"github.com/coder/websocket"
"github.com/gorilla/websocket"
)
type Writer struct {
conn *websocket.Conn
msgType websocket.MessageType
msgType int
ctx context.Context
}
func NewWriter(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) *Writer {
func NewWriter(ctx context.Context, conn *websocket.Conn, msgType int) *Writer {
return &Writer{
ctx: ctx,
conn: conn,
@@ -21,9 +21,10 @@ func NewWriter(ctx context.Context, conn *websocket.Conn, msgType websocket.Mess
}
func (w *Writer) Write(p []byte) (int, error) {
return len(p), w.conn.Write(w.ctx, w.msgType, p)
}
func (w *Writer) Close() error {
return w.conn.CloseNow()
select {
case <-w.ctx.Done():
return 0, w.ctx.Err()
default:
return len(p), w.conn.WriteMessage(w.msgType, p)
}
}