mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-21 16:01:22 +02:00
tweak: replace coder/websocket with gorilla/websocket
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
package gpwebsocket
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
)
|
||||
|
||||
@@ -27,29 +30,41 @@ func SetWebsocketAllowedDomains(h http.Header, domains []string) {
|
||||
h[HeaderXGoDoxyWebsocketAllowedDomains] = domains
|
||||
}
|
||||
|
||||
func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
|
||||
var originPats []string
|
||||
var localAddresses = []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
|
||||
|
||||
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
|
||||
const writeTimeout = time.Second * 10
|
||||
|
||||
func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
|
||||
upgrader := websocket.Upgrader{}
|
||||
|
||||
allowedDomains := WebsocketAllowedDomains(r.Header)
|
||||
if len(allowedDomains) == 0 {
|
||||
warnNoMatchDomainOnce.Do(warnNoMatchDomains)
|
||||
originPats = []string{"*"}
|
||||
} else {
|
||||
originPats = make([]string, len(allowedDomains))
|
||||
for i, domain := range allowedDomains {
|
||||
if domain[0] != '.' {
|
||||
originPats[i] = "*." + domain
|
||||
} else {
|
||||
originPats[i] = "*" + domain
|
||||
}
|
||||
upgrader.CheckOrigin = func(r *http.Request) bool {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
upgrader.CheckOrigin = func(r *http.Request) bool {
|
||||
host, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
host = r.Host
|
||||
}
|
||||
if slices.Contains(localAddresses, host) {
|
||||
return true
|
||||
}
|
||||
for _, domain := range allowedDomains {
|
||||
if domain[0] == '.' {
|
||||
if host == domain[1:] || strings.HasSuffix(host, domain) {
|
||||
return true
|
||||
}
|
||||
} else if host == domain || strings.HasSuffix(host, "."+domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
originPats = append(originPats, localAddresses...)
|
||||
}
|
||||
return websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||
OriginPatterns: originPats,
|
||||
})
|
||||
return upgrader.Upgrade(w, r, nil)
|
||||
}
|
||||
|
||||
func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) {
|
||||
@@ -58,8 +73,7 @@ func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
//nolint:errcheck
|
||||
defer conn.CloseNow()
|
||||
defer conn.Close()
|
||||
|
||||
if err := do(conn); err != nil {
|
||||
return
|
||||
@@ -73,6 +87,7 @@ func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||||
if err := do(conn); err != nil {
|
||||
return
|
||||
}
|
||||
@@ -83,10 +98,7 @@ func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do
|
||||
// WriteText writes a text message to the websocket connection.
|
||||
// It returns true if the message was written successfully, false otherwise.
|
||||
// It logs an error if the message is not written successfully.
|
||||
func WriteText(r *http.Request, conn *websocket.Conn, msg string) bool {
|
||||
if err := conn.Write(r.Context(), websocket.MessageText, []byte(msg)); err != nil {
|
||||
logging.Err(err).Msg("failed to write text message")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
func WriteText(conn *websocket.Conn, msg string) error {
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||||
return conn.WriteMessage(websocket.TextMessage, []byte(msg))
|
||||
}
|
||||
|
||||
@@ -3,16 +3,16 @@ package gpwebsocket
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type Writer struct {
|
||||
conn *websocket.Conn
|
||||
msgType websocket.MessageType
|
||||
msgType int
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewWriter(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) *Writer {
|
||||
func NewWriter(ctx context.Context, conn *websocket.Conn, msgType int) *Writer {
|
||||
return &Writer{
|
||||
ctx: ctx,
|
||||
conn: conn,
|
||||
@@ -21,9 +21,10 @@ func NewWriter(ctx context.Context, conn *websocket.Conn, msgType websocket.Mess
|
||||
}
|
||||
|
||||
func (w *Writer) Write(p []byte) (int, error) {
|
||||
return len(p), w.conn.Write(w.ctx, w.msgType, p)
|
||||
}
|
||||
|
||||
func (w *Writer) Close() error {
|
||||
return w.conn.CloseNow()
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return 0, w.ctx.Err()
|
||||
default:
|
||||
return len(p), w.conn.WriteMessage(w.msgType, p)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user