Compare commits

..

8 Commits

7 changed files with 68 additions and 25 deletions

View File

@@ -84,10 +84,10 @@ jobs:
outputs: type=image,name=${{ env.REGISTRY }}/${{ inputs.image_name }},push-by-digest=true,name-canonical=true,push=true
cache-from: |
type=registry,ref=${{ env.REGISTRY }}/${{ inputs.image_name }}:buildcache-${{ env.PLATFORM_PAIR }}
type=gha,scope=${{ github.workflow }}-${{ env.PLATFORM_PAIR }}
# type=gha,scope=${{ github.workflow }}-${{ env.PLATFORM_PAIR }}
cache-to: |
type=registry,ref=${{ env.REGISTRY }}/${{ inputs.image_name }}:buildcache-${{ env.PLATFORM_PAIR }},mode=max
type=gha,scope=${{ github.workflow }}-${{ env.PLATFORM_PAIR }},mode=max
# type=gha,scope=${{ github.workflow }}-${{ env.PLATFORM_PAIR }},mode=max
build-args: |
VERSION=${{ github.ref_name }}
MAKE_ARGS=${{ env.MAKE_ARGS }}

View File

@@ -19,7 +19,6 @@ func RenewCert(w http.ResponseWriter, r *http.Request) {
conn, err := gpwebsocket.Initiate(w, r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer conn.Close()

View File

@@ -73,7 +73,6 @@ func (m *memLogger) Write(p []byte) (n int, err error) {
func (m *memLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) {
conn, err := gpwebsocket.Initiate(w, r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

View File

@@ -3,13 +3,13 @@ package gpwebsocket
import (
"net"
"net/http"
"slices"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/agent/pkg/agent"
)
func warnNoMatchDomains() {
@@ -30,12 +30,17 @@ func SetWebsocketAllowedDomains(h http.Header, domains []string) {
h[HeaderXGoDoxyWebsocketAllowedDomains] = domains
}
var localAddresses = []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
const writeTimeout = time.Second * 10
// Initiate upgrades the HTTP connection to a WebSocket connection.
// It returns a WebSocket connection and an error if the upgrade fails.
// It logs and responds with an error if the upgrade fails.
//
// No further http.Error should be called after this function.
func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
upgrader := websocket.Upgrader{}
upgrader := websocket.Upgrader{
Error: errHandler,
}
allowedDomains := WebsocketAllowedDomains(r.Header)
if len(allowedDomains) == 0 {
@@ -49,9 +54,13 @@ func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
if err != nil {
host = r.Host
}
if slices.Contains(localAddresses, host) {
if host == "localhost" || host == agent.AgentHost {
return true
}
ip := net.ParseIP(host)
if ip != nil {
return ip.IsLoopback() || ip.IsPrivate()
}
for _, domain := range allowedDomains {
if domain[0] == '.' {
if host == domain[1:] || strings.HasSuffix(host, domain) {
@@ -70,7 +79,6 @@ func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) {
conn, err := Initiate(w, r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer conn.Close()
@@ -102,3 +110,15 @@ func WriteText(conn *websocket.Conn, msg string) error {
_ = conn.SetWriteDeadline(time.Now().Add(writeTimeout))
return conn.WriteMessage(websocket.TextMessage, []byte(msg))
}
func errHandler(w http.ResponseWriter, r *http.Request, status int, reason error) {
log.Error().
Str("remote", r.RemoteAddr).
Str("host", r.Host).
Str("url", r.URL.String()).
Int("status", status).
AnErr("reason", reason).
Msg("websocket error")
w.Header().Set("Sec-Websocket-Version", "13")
http.Error(w, http.StatusText(status), status)
}

View File

@@ -5,7 +5,6 @@ import (
"errors"
"os"
"reflect"
"runtime/debug"
"strconv"
"strings"
"time"
@@ -198,7 +197,7 @@ func mapUnmarshalValidate(src SerializedObject, dst any, checkValidateTag bool)
dstV.Set(reflect.Zero(dstT))
return nil
}
return gperr.Errorf("deserialize: src is %w and dst is not settable\n%s", ErrNilValue, debug.Stack())
return gperr.Errorf("deserialize: src is %w and dst is not settable", ErrNilValue)
}
if dstT.Implements(mapUnmarshalerType) {

View File

@@ -106,13 +106,38 @@ func (p BidirectionalPipe) Start() error {
return errors.Join(srcErr, dstErr)
}
type httpFlusher interface {
Flush() error
type flushErrorInterface interface {
FlushError() error
}
func getHTTPFlusher(dst io.Writer) httpFlusher {
type flusherWrapper struct {
rw http.Flusher
}
type rwUnwrapper interface {
Unwrap() http.ResponseWriter
}
func (f *flusherWrapper) FlushError() error {
f.rw.Flush()
return nil
}
func getHTTPFlusher(dst io.Writer) flushErrorInterface {
// pre-unwrap the flusher to prevent unwrap and check in every loop
if rw, ok := dst.(http.ResponseWriter); ok {
return http.NewResponseController(rw)
for {
switch t := rw.(type) {
case flushErrorInterface:
return t
case http.Flusher:
return &flusherWrapper{rw: t}
case rwUnwrapper:
rw = t.Unwrap()
default:
return nil
}
}
}
return nil
}
@@ -158,7 +183,6 @@ func CopyClose(dst *ContextWriter, src *ContextReader, sizeHint int) (err error)
}()
}
flusher := getHTTPFlusher(dst.Writer)
canFlush := flusher != nil
for {
nr, er := src.Reader.Read(buf)
if nr > 0 {
@@ -177,15 +201,10 @@ func CopyClose(dst *ContextWriter, src *ContextReader, sizeHint int) (err error)
err = io.ErrShortWrite
return
}
if canFlush {
err = flusher.Flush()
if flusher != nil {
err = flusher.FlushError()
if err != nil {
if errors.Is(err, http.ErrNotSupported) {
canFlush = false
err = nil
} else {
return err
}
return err
}
}
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"net/http"
"net/url"
"time"
agentPkg "github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/internal/watcher/health"
@@ -57,6 +58,7 @@ func NewAgentProxiedMonitor(agent *agentPkg.AgentConfig, config *health.HealthCh
}
func (mon *AgentProxiedMonitor) CheckHealth() (result *health.HealthCheckResult, err error) {
startTime := time.Now()
result = new(health.HealthCheckResult)
ctx, cancel := mon.ContextWithTimeout("timeout querying agent")
defer cancel()
@@ -64,11 +66,16 @@ func (mon *AgentProxiedMonitor) CheckHealth() (result *health.HealthCheckResult,
if err != nil {
return result, err
}
endTime := time.Now()
switch status {
case http.StatusOK:
err = json.Unmarshal(data, result)
default:
err = errors.New(string(data))
}
if err == nil && result.Latency != 0 {
// use godoxy to agent latency
result.Latency = endTime.Sub(startTime)
}
return
}