refactor: move api/v1/utils to net/gphttp

This commit is contained in:
yusing
2025-03-28 06:48:30 +08:00
parent d315710310
commit dfd2f3962c
8 changed files with 227 additions and 147 deletions

View File

@@ -0,0 +1,46 @@
package gphttp
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"github.com/yusing/go-proxy/internal/logging"
)
func WriteBody(w http.ResponseWriter, body []byte) {
if _, err := w.Write(body); err != nil {
switch {
case errors.Is(err, http.ErrHandlerTimeout),
errors.Is(err, context.DeadlineExceeded):
logging.Err(err).Msg("timeout writing body")
default:
logging.Err(err).Msg("failed to write body")
}
}
}
func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) (canProceed bool) {
if len(code) > 0 {
w.WriteHeader(code[0])
}
w.Header().Set("Content-Type", "application/json")
var err error
switch data := data.(type) {
case string:
_, err = w.Write([]byte(fmt.Sprintf("%q", data)))
case []byte:
panic("use WriteBody instead")
default:
err = json.NewEncoder(w).Encode(data)
}
if err != nil {
LogError(r).Err(err).Msg("failed to encode json")
return false
}
return true
}

View File

@@ -0,0 +1,27 @@
package gphttp
import (
"crypto/tls"
"net"
"net/http"
"time"
)
var (
httpClient = &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
DisableKeepAlives: true,
ForceAttemptHTTP2: false,
DialContext: (&net.Dialer{
Timeout: 3 * time.Second,
KeepAlive: 60 * time.Second, // this is different from DisableKeepAlives
}).DialContext,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
Get = httpClient.Get
Post = httpClient.Post
Head = httpClient.Head
)

View File

@@ -0,0 +1,100 @@
package gphttp
import (
"context"
"encoding/json"
"errors"
"net/http"
"syscall"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
)
// ServerError is for handling server errors.
//
// It logs the error and returns http.StatusInternalServerError to the client.
// Status code can be specified as an argument.
func ServerError(w http.ResponseWriter, r *http.Request, err error, code ...int) {
switch {
case err == nil,
errors.Is(err, context.Canceled),
errors.Is(err, syscall.EPIPE),
errors.Is(err, syscall.ECONNRESET):
return
}
LogError(r).Msg(err.Error())
if httpheaders.IsWebsocket(r.Header) {
return
}
if len(code) == 0 {
code = []int{http.StatusInternalServerError}
}
http.Error(w, http.StatusText(code[0]), code[0])
}
// ClientError is for responding to client errors.
//
// It returns http.StatusBadRequest with reason to the client.
// Status code can be specified as an argument.
//
// For JSON marshallable errors (e.g. gperr.Error), it returns the error details as JSON.
// Otherwise, it returns the error details as plain text.
func ClientError(w http.ResponseWriter, err error, code ...int) {
if len(code) == 0 {
code = []int{http.StatusBadRequest}
}
if gperr.IsJSONMarshallable(err) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(err)
} else {
http.Error(w, err.Error(), code[0])
}
}
// JSONError returns a JSON response of gperr.Error with the given status code.
func JSONError(w http.ResponseWriter, err gperr.Error, code int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
json.NewEncoder(w).Encode(err)
}
// BadRequest returns a Bad Request response with the given error message.
func BadRequest(w http.ResponseWriter, err string, code ...int) {
if len(code) == 0 {
code = []int{http.StatusBadRequest}
}
w.WriteHeader(code[0])
w.Write([]byte(err))
}
// Unauthorized returns an Unauthorized response with the given error message.
func Unauthorized(w http.ResponseWriter, err string) {
BadRequest(w, err, http.StatusUnauthorized)
}
// Forbidden returns a Forbidden response with the given error message.
func Forbidden(w http.ResponseWriter, err string) {
BadRequest(w, err, http.StatusForbidden)
}
// NotFound returns a Not Found response with the given error message.
func NotFound(w http.ResponseWriter, err string) {
BadRequest(w, err, http.StatusNotFound)
}
func ErrMissingKey(k string) error {
return gperr.New(k + " is required")
}
func ErrInvalidKey(k string) error {
return gperr.New(k + " is invalid")
}
func ErrAlreadyExists(k, v string) error {
return gperr.Errorf("%s %q already exists", k, v)
}
func ErrNotFound(k, v string) error {
return gperr.Errorf("%s %q not found", k, v)
}

View File

@@ -0,0 +1,86 @@
package gpwebsocket
import (
"net/http"
"sync"
"time"
"github.com/coder/websocket"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
)
func warnNoMatchDomains() {
logging.Warn().Msg("no match domains configured, accepting websocket API request from all origins")
}
var warnNoMatchDomainOnce sync.Once
func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
var originPats []string
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
allowedDomains := httpheaders.WebsocketAllowedDomains(r.Header)
if len(allowedDomains) == 0 || common.IsDebug {
warnNoMatchDomainOnce.Do(warnNoMatchDomains)
originPats = []string{"*"}
} else {
originPats = make([]string, len(allowedDomains))
for i, domain := range allowedDomains {
if domain[0] != '.' {
originPats[i] = "*." + domain
} else {
originPats[i] = "*" + domain
}
}
originPats = append(originPats, localAddresses...)
}
return websocket.Accept(w, r, &websocket.AcceptOptions{
OriginPatterns: originPats,
})
}
func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) {
conn, err := Initiate(w, r)
if err != nil {
gphttp.ServerError(w, r, err)
return
}
//nolint:errcheck
defer conn.CloseNow()
if err := do(conn); err != nil {
gphttp.ServerError(w, r, err)
return
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-r.Context().Done():
return
case <-ticker.C:
if err := do(conn); err != nil {
gphttp.ServerError(w, r, err)
return
}
}
}
}
// WriteText writes a text message to the websocket connection.
// It returns true if the message was written successfully, false otherwise.
// It logs an error if the message is not written successfully.
func WriteText(r *http.Request, conn *websocket.Conn, msg string) bool {
if err := conn.Write(r.Context(), websocket.MessageText, []byte(msg)); err != nil {
gperr.LogError("failed to write text message", err)
return false
}
return true
}

View File

@@ -0,0 +1,21 @@
package httpheaders
import (
"net/http"
)
const (
HeaderXGoDoxyWebsocketAllowedDomains = "X-GoDoxy-Websocket-Allowed-Domains"
)
func WebsocketAllowedDomains(h http.Header) []string {
return h[HeaderXGoDoxyWebsocketAllowedDomains]
}
func SetWebsocketAllowedDomains(h http.Header, domains []string) {
h[HeaderXGoDoxyWebsocketAllowedDomains] = domains
}
func IsWebsocket(h http.Header) bool {
return UpgradeType(h) == "websocket"
}

View File

@@ -0,0 +1,20 @@
package gphttp
import (
"net/http"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/logging"
)
func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
return logging.WithLevel(level).
Str("remote", r.RemoteAddr).
Str("host", r.Host).
Str("uri", r.Method+" "+r.RequestURI)
}
func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) }
func LogWarn(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.WarnLevel) }
func LogInfo(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.InfoLevel) }
func LogDebug(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.DebugLevel) }