diff --git a/internal/api/v1/utils/error.go b/internal/api/v1/utils/error.go deleted file mode 100644 index c7341e9c..00000000 --- a/internal/api/v1/utils/error.go +++ /dev/null @@ -1,55 +0,0 @@ -package utils - -import ( - "encoding/json" - "net/http" - - E "github.com/yusing/go-proxy/internal/error" - "github.com/yusing/go-proxy/internal/utils/strutils/ansi" -) - -// HandleErr logs the error and returns an error code to the client. -// If code is specified, it will be used as the HTTP status code; otherwise, -// http.StatusInternalServerError is used. -// -// The error is only logged but not returned to the client. -func HandleErr(w http.ResponseWriter, r *http.Request, err error, code ...int) { - if err == nil { - return - } - LogError(r).Msg(err.Error()) - if len(code) == 0 { - code = []int{http.StatusInternalServerError} - } - http.Error(w, http.StatusText(code[0]), code[0]) -} - -// RespondError returns error details to the client. -// If code is specified, it will be used as the HTTP status code; otherwise, -// http.StatusBadRequest is used. -func RespondError(w http.ResponseWriter, err error, code ...int) { - if len(code) == 0 { - code = []int{http.StatusBadRequest} - } - buf, err := json.Marshal(err) - if err != nil { // just in case - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - http.Error(w, ansi.StripANSI(err.Error()), code[0]) - return - } - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(code[0]) - _, _ = w.Write(buf) -} - -func ErrMissingKey(k string) error { - return E.New("missing key '" + k + "' in query or request body") -} - -func ErrInvalidKey(k string) error { - return E.New("invalid key '" + k + "' in query or request body") -} - -func ErrNotFound(k, v string) error { - return E.Errorf("key %q with value %q not found", k, v) -} diff --git a/internal/api/v1/utils/ws.go b/internal/api/v1/utils/ws.go deleted file mode 100644 index 127892d4..00000000 --- a/internal/api/v1/utils/ws.go +++ /dev/null @@ -1,68 +0,0 @@ -package utils - -import ( - "net/http" - "sync" - "time" - - "github.com/coder/websocket" - "github.com/yusing/go-proxy/internal/common" - config "github.com/yusing/go-proxy/internal/config/types" - "github.com/yusing/go-proxy/internal/logging" -) - -func warnNoMatchDomains() { - logging.Warn().Msg("no match domains configured, accepting websocket API request from all origins") -} - -var warnNoMatchDomainOnce sync.Once - -func InitiateWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) { - var originPats []string - - localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"} - - if len(cfg.Value().MatchDomains) == 0 { - warnNoMatchDomainOnce.Do(warnNoMatchDomains) - originPats = []string{"*"} - } else { - originPats = make([]string, len(cfg.Value().MatchDomains)) - for i, domain := range cfg.Value().MatchDomains { - originPats[i] = "*" + domain - } - originPats = append(originPats, localAddresses...) - } - if common.IsDebug { - originPats = []string{"*"} - } - return websocket.Accept(w, r, &websocket.AcceptOptions{ - OriginPatterns: originPats, - }) -} - -func PeriodicWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) { - conn, err := InitiateWS(cfg, w, r) - if err != nil { - HandleErr(w, r, err) - return - } - /* trunk-ignore(golangci-lint/errcheck) */ - defer conn.CloseNow() - - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - select { - case <-cfg.Context().Done(): - return - case <-r.Context().Done(): - return - case <-ticker.C: - if err := do(conn); err != nil { - LogError(r).Msg(err.Error()) - return - } - } - } -} diff --git a/internal/api/v1/utils/utils.go b/internal/net/gphttp/body.go similarity index 55% rename from internal/api/v1/utils/utils.go rename to internal/net/gphttp/body.go index 8cdfcaa5..3a940232 100644 --- a/internal/api/v1/utils/utils.go +++ b/internal/net/gphttp/body.go @@ -1,17 +1,24 @@ -package utils +package gphttp import ( + "context" "encoding/json" + "errors" "fmt" "net/http" "github.com/yusing/go-proxy/internal/logging" - "github.com/yusing/go-proxy/internal/utils/strutils/ansi" ) func WriteBody(w http.ResponseWriter, body []byte) { if _, err := w.Write(body); err != nil { - logging.Err(err).Msg("failed to write body") + switch { + case errors.Is(err, http.ErrHandlerTimeout), + errors.Is(err, context.DeadlineExceeded): + logging.Err(err).Msg("timeout writing body") + default: + logging.Err(err).Msg("failed to write body") + } } } @@ -20,28 +27,19 @@ func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) w.WriteHeader(code[0]) } w.Header().Set("Content-Type", "application/json") - var j []byte var err error switch data := data.(type) { case string: - j = []byte(fmt.Sprintf("%q", data)) + _, err = w.Write([]byte(fmt.Sprintf("%q", data))) case []byte: - j = data - case error: - j, err = json.Marshal(ansi.StripANSI(data.Error())) + panic("use WriteBody instead") default: - j, err = json.MarshalIndent(data, "", " ") + err = json.NewEncoder(w).Encode(data) } if err != nil { - logging.Panic().Err(err).Msg("failed to marshal json") - return false - } - - _, err = w.Write(j) - if err != nil { - HandleErr(w, r, err) + LogError(r).Err(err).Msg("failed to encode json") return false } return true diff --git a/internal/api/v1/utils/http_client.go b/internal/net/gphttp/default_client.go similarity index 65% rename from internal/api/v1/utils/http_client.go rename to internal/net/gphttp/default_client.go index 48d743be..dee455a1 100644 --- a/internal/api/v1/utils/http_client.go +++ b/internal/net/gphttp/default_client.go @@ -1,22 +1,21 @@ -package utils +package gphttp import ( "crypto/tls" "net" "net/http" - - "github.com/yusing/go-proxy/internal/common" + "time" ) var ( httpClient = &http.Client{ - Timeout: common.ConnectionTimeout, + Timeout: 5 * time.Second, Transport: &http.Transport{ DisableKeepAlives: true, ForceAttemptHTTP2: false, DialContext: (&net.Dialer{ - Timeout: common.DialTimeout, - KeepAlive: common.KeepAlive, // this is different from DisableKeepAlives + Timeout: 3 * time.Second, + KeepAlive: 60 * time.Second, // this is different from DisableKeepAlives }).DialContext, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, }, diff --git a/internal/net/gphttp/error.go b/internal/net/gphttp/error.go new file mode 100644 index 00000000..f269e3fd --- /dev/null +++ b/internal/net/gphttp/error.go @@ -0,0 +1,100 @@ +package gphttp + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "syscall" + + "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" +) + +// ServerError is for handling server errors. +// +// It logs the error and returns http.StatusInternalServerError to the client. +// Status code can be specified as an argument. +func ServerError(w http.ResponseWriter, r *http.Request, err error, code ...int) { + switch { + case err == nil, + errors.Is(err, context.Canceled), + errors.Is(err, syscall.EPIPE), + errors.Is(err, syscall.ECONNRESET): + return + } + LogError(r).Msg(err.Error()) + if httpheaders.IsWebsocket(r.Header) { + return + } + if len(code) == 0 { + code = []int{http.StatusInternalServerError} + } + http.Error(w, http.StatusText(code[0]), code[0]) +} + +// ClientError is for responding to client errors. +// +// It returns http.StatusBadRequest with reason to the client. +// Status code can be specified as an argument. +// +// For JSON marshallable errors (e.g. gperr.Error), it returns the error details as JSON. +// Otherwise, it returns the error details as plain text. +func ClientError(w http.ResponseWriter, err error, code ...int) { + if len(code) == 0 { + code = []int{http.StatusBadRequest} + } + if gperr.IsJSONMarshallable(err) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(err) + } else { + http.Error(w, err.Error(), code[0]) + } +} + +// JSONError returns a JSON response of gperr.Error with the given status code. +func JSONError(w http.ResponseWriter, err gperr.Error, code int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + json.NewEncoder(w).Encode(err) +} + +// BadRequest returns a Bad Request response with the given error message. +func BadRequest(w http.ResponseWriter, err string, code ...int) { + if len(code) == 0 { + code = []int{http.StatusBadRequest} + } + w.WriteHeader(code[0]) + w.Write([]byte(err)) +} + +// Unauthorized returns an Unauthorized response with the given error message. +func Unauthorized(w http.ResponseWriter, err string) { + BadRequest(w, err, http.StatusUnauthorized) +} + +// Forbidden returns a Forbidden response with the given error message. +func Forbidden(w http.ResponseWriter, err string) { + BadRequest(w, err, http.StatusForbidden) +} + +// NotFound returns a Not Found response with the given error message. +func NotFound(w http.ResponseWriter, err string) { + BadRequest(w, err, http.StatusNotFound) +} + +func ErrMissingKey(k string) error { + return gperr.New(k + " is required") +} + +func ErrInvalidKey(k string) error { + return gperr.New(k + " is invalid") +} + +func ErrAlreadyExists(k, v string) error { + return gperr.Errorf("%s %q already exists", k, v) +} + +func ErrNotFound(k, v string) error { + return gperr.Errorf("%s %q not found", k, v) +} diff --git a/internal/net/gphttp/gpwebsocket/utils.go b/internal/net/gphttp/gpwebsocket/utils.go new file mode 100644 index 00000000..cb67dd62 --- /dev/null +++ b/internal/net/gphttp/gpwebsocket/utils.go @@ -0,0 +1,86 @@ +package gpwebsocket + +import ( + "net/http" + "sync" + "time" + + "github.com/coder/websocket" + "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/net/gphttp" + "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" +) + +func warnNoMatchDomains() { + logging.Warn().Msg("no match domains configured, accepting websocket API request from all origins") +} + +var warnNoMatchDomainOnce sync.Once + +func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) { + var originPats []string + + localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"} + + allowedDomains := httpheaders.WebsocketAllowedDomains(r.Header) + if len(allowedDomains) == 0 || common.IsDebug { + warnNoMatchDomainOnce.Do(warnNoMatchDomains) + originPats = []string{"*"} + } else { + originPats = make([]string, len(allowedDomains)) + for i, domain := range allowedDomains { + if domain[0] != '.' { + originPats[i] = "*." + domain + } else { + originPats[i] = "*" + domain + } + } + originPats = append(originPats, localAddresses...) + } + return websocket.Accept(w, r, &websocket.AcceptOptions{ + OriginPatterns: originPats, + }) +} + +func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) { + conn, err := Initiate(w, r) + if err != nil { + gphttp.ServerError(w, r, err) + return + } + //nolint:errcheck + defer conn.CloseNow() + + if err := do(conn); err != nil { + gphttp.ServerError(w, r, err) + return + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-r.Context().Done(): + return + case <-ticker.C: + if err := do(conn); err != nil { + gphttp.ServerError(w, r, err) + return + } + } + } +} + +// WriteText writes a text message to the websocket connection. +// It returns true if the message was written successfully, false otherwise. +// It logs an error if the message is not written successfully. +func WriteText(r *http.Request, conn *websocket.Conn, msg string) bool { + if err := conn.Write(r.Context(), websocket.MessageText, []byte(msg)); err != nil { + gperr.LogError("failed to write text message", err) + return false + } + return true +} diff --git a/internal/net/gphttp/httpheaders/websocket.go b/internal/net/gphttp/httpheaders/websocket.go new file mode 100644 index 00000000..755d3248 --- /dev/null +++ b/internal/net/gphttp/httpheaders/websocket.go @@ -0,0 +1,21 @@ +package httpheaders + +import ( + "net/http" +) + +const ( + HeaderXGoDoxyWebsocketAllowedDomains = "X-GoDoxy-Websocket-Allowed-Domains" +) + +func WebsocketAllowedDomains(h http.Header) []string { + return h[HeaderXGoDoxyWebsocketAllowedDomains] +} + +func SetWebsocketAllowedDomains(h http.Header, domains []string) { + h[HeaderXGoDoxyWebsocketAllowedDomains] = domains +} + +func IsWebsocket(h http.Header) bool { + return UpgradeType(h) == "websocket" +} diff --git a/internal/api/v1/utils/logging.go b/internal/net/gphttp/logging.go similarity index 94% rename from internal/api/v1/utils/logging.go rename to internal/net/gphttp/logging.go index 194735f5..cfb67f02 100644 --- a/internal/api/v1/utils/logging.go +++ b/internal/net/gphttp/logging.go @@ -1,4 +1,4 @@ -package utils +package gphttp import ( "net/http" @@ -9,7 +9,6 @@ import ( func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event { return logging.WithLevel(level). - Str("module", "api"). Str("remote", r.RemoteAddr). Str("host", r.Host). Str("uri", r.Method+" "+r.RequestURI)