migrated from logrus to zerolog, improved error formatting, fixed concurrent map write, fixed crash on rapid page refresh for idle containers, fixed infinite recursion on gotfiy error, fixed websocket connection problem when using idlewatcher

This commit is contained in:
yusing
2024-10-29 11:34:58 +08:00
parent cfa74d69ae
commit e5bbb18414
137 changed files with 2640 additions and 2348 deletions

View File

@@ -6,7 +6,6 @@ import (
"net/http"
v1 "github.com/yusing/go-proxy/internal/api/v1"
"github.com/yusing/go-proxy/internal/api/v1/errorpage"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
@@ -38,7 +37,6 @@ func NewHandler() http.Handler {
mux.HandleFunc("PUT", "/v1/file/{filename...}", v1.SetFileContent)
mux.HandleFunc("GET", "/v1/stats", v1.Stats)
mux.HandleFunc("GET", "/v1/stats/ws", v1.StatsWS)
mux.HandleFunc("GET", "/v1/error_page", errorpage.GetHandleFunc())
return mux
}
@@ -50,7 +48,7 @@ func checkHost(f http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
host, _, _ := net.SplitHostPort(r.RemoteAddr)
if host != "127.0.0.1" && host != "localhost" && host != "[::1]" {
Logger.Warnf("blocked API request from %s", host)
LogWarn(r).Msgf("blocked API request from %s", host)
http.Error(w, "forbidden", http.StatusForbidden)
return
}

View File

@@ -1,90 +0,0 @@
package errorpage
import (
"context"
"fmt"
"os"
"path"
"sync"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
W "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events"
)
const errPagesBasePath = common.ErrorPagesBasePath
var (
dirWatcher W.Watcher
fileContentMap = F.NewMapOf[string, []byte]()
)
var setup = sync.OnceFunc(func() {
dirWatcher = W.NewDirectoryWatcher(context.Background(), errPagesBasePath)
loadContent()
go watchDir()
})
func GetStaticFile(filename string) ([]byte, bool) {
return fileContentMap.Load(filename)
}
// try <statusCode>.html -> 404.html -> not ok.
func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
content, ok = fileContentMap.Load(fmt.Sprintf("%d.html", statusCode))
if !ok && statusCode != 404 {
return fileContentMap.Load("404.html")
}
return
}
func loadContent() {
files, err := U.ListFiles(errPagesBasePath, 0)
if err != nil {
Logger.Error(err)
return
}
for _, file := range files {
if fileContentMap.Has(file) {
continue
}
content, err := os.ReadFile(file)
if err != nil {
Logger.Errorf("failed to read error page resource %s: %s", file, err)
continue
}
file = path.Base(file)
Logger.Infof("error page resource %s loaded", file)
fileContentMap.Store(file, content)
}
}
func watchDir() {
eventCh, errCh := dirWatcher.Events(context.Background())
for {
select {
case event, ok := <-eventCh:
if !ok {
return
}
filename := event.ActorName
switch event.Action {
case events.ActionFileWritten:
fileContentMap.Delete(filename)
loadContent()
case events.ActionFileDeleted:
fileContentMap.Delete(filename)
Logger.Infof("error page resource %s deleted", filename)
case events.ActionFileRenamed:
Logger.Infof("error page resource %s deleted", filename)
fileContentMap.Delete(filename)
loadContent()
}
case err := <-errCh:
Logger.Errorf("error watching error page directory: %s", err)
}
}
}

View File

@@ -1,31 +0,0 @@
package errorpage
import (
"net/http"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
)
func GetHandleFunc() http.HandlerFunc {
setup()
return serveHTTP
}
func serveHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if r.URL.Path == "/" {
http.Error(w, "invalid path", http.StatusNotFound)
return
}
content, ok := fileContentMap.Load(r.URL.Path)
if !ok {
http.Error(w, "404 not found", http.StatusNotFound)
return
}
if _, err := w.Write(content); err != nil {
HandleErr(w, r, err)
}
}

View File

@@ -39,15 +39,16 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) {
return
}
var validateErr E.Error
var valErr E.Error
if filename == common.ConfigFileName {
validateErr = config.Validate(content)
valErr = config.Validate(content)
} else if !strings.HasPrefix(filename, path.Base(common.MiddlewareComposeBasePath)) {
validateErr = provider.Validate(content)
valErr = provider.Validate(content)
}
// no validation for include files
if validateErr != nil {
U.RespondJSON(w, r, validateErr.JSONObject(), http.StatusBadRequest)
if valErr != nil {
U.RespondJSON(w, r, valErr, http.StatusBadRequest)
return
}

View File

@@ -20,16 +20,13 @@ func ReloadServer() E.Error {
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
failure := E.Failure("server reload").Extraf("status code: %v", resp.StatusCode)
b, err := io.ReadAll(resp.Body)
failure := E.Errorf("server reload status %v", resp.StatusCode)
body, err := io.ReadAll(resp.Body)
if err != nil {
return failure.Extraf("unable to read response body: %s", err)
return failure.With(err)
}
reloadErr, ok := E.FromJSON(b)
if ok {
return E.Join("reload success, but server returned error", reloadErr)
}
return failure.Extraf("unable to read response body")
reloadErr := string(body)
return failure.Withf(reloadErr)
}
return nil
}
@@ -42,7 +39,7 @@ func List[T any](what string) (_ T, outErr E.Error) {
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
outErr = E.Failure("list "+what).Extraf("status code: %v", resp.StatusCode)
outErr = E.Errorf("list %s: failed, status %v", what, resp.StatusCode)
return
}
var res T

View File

@@ -9,8 +9,8 @@ import (
func Reload(w http.ResponseWriter, r *http.Request) {
if err := config.Reload(); err != nil {
U.RespondJSON(w, r, err.JSONObject(), http.StatusInternalServerError)
} else {
w.WriteHeader(http.StatusOK)
U.HandleErr(w, r, err)
return
}
U.WriteBody(w, []byte("OK"))
}

View File

@@ -11,7 +11,7 @@ import (
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
"github.com/yusing/go-proxy/internal/server"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
func Stats(w http.ResponseWriter, r *http.Request) {
@@ -23,7 +23,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) {
originPats := make([]string, len(config.Value().MatchDomains)+len(localAddresses))
if len(originPats) == 0 {
U.Logger.Warnf("no match domains configured, accepting websocket request from all origins")
U.LogWarn(r).Msg("no match domains configured, accepting websocket API request from all origins")
originPats = []string{"*"}
} else {
for i, domain := range config.Value().MatchDomains {
@@ -38,7 +38,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) {
OriginPatterns: originPats,
})
if err != nil {
U.Logger.Errorf("/stats/ws failed to upgrade websocket: %s", err)
U.LogError(r).Err(err).Msg("failed to upgrade websocket")
return
}
/* trunk-ignore(golangci-lint/errcheck) */
@@ -53,7 +53,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) {
for range ticker.C {
stats := getStats()
if err := wsjson.Write(ctx, conn, stats); err != nil {
U.Logger.Errorf("/stats/ws failed to write JSON: %s", err)
U.LogError(r).Msg("failed to write JSON")
return
}
}
@@ -62,6 +62,6 @@ func StatsWS(w http.ResponseWriter, r *http.Request) {
func getStats() map[string]any {
return map[string]any{
"proxies": config.Statistics(),
"uptime": utils.FormatDuration(server.GetProxyServer().Uptime()),
"uptime": strutils.FormatDuration(server.GetProxyServer().Uptime()),
}
}

View File

@@ -1,37 +1,31 @@
package utils
import (
"errors"
"fmt"
"net/http"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error"
)
var Logger = logrus.WithField("module", "api")
func HandleErr(w http.ResponseWriter, r *http.Request, origErr error, code ...int) {
if origErr == nil {
return
}
err := E.From(origErr).Subjectf("%s %s", r.Method, r.URL)
Logger.Error(err)
LogError(r).Msg(origErr.Error())
if len(code) > 0 {
http.Error(w, err.String(), code[0])
http.Error(w, origErr.Error(), code[0])
return
}
http.Error(w, err.String(), http.StatusInternalServerError)
http.Error(w, origErr.Error(), http.StatusInternalServerError)
}
func ErrMissingKey(k string) error {
return errors.New("missing key '" + k + "' in query or request body")
return E.New("missing key '" + k + "' in query or request body")
}
func ErrInvalidKey(k string) error {
return errors.New("invalid key '" + k + "' in query or request body")
return E.New("invalid key '" + k + "' in query or request body")
}
func ErrNotFound(k, v string) error {
return fmt.Errorf("key %q with value %q not found", k, v)
return E.Errorf("key %q with value %q not found", k, v)
}

View File

@@ -9,12 +9,11 @@ import (
)
var (
HTTPClient = &http.Client{
httpClient = &http.Client{
Timeout: common.ConnectionTimeout,
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DisableKeepAlives: true,
ForceAttemptHTTP2: true,
ForceAttemptHTTP2: false,
DialContext: (&net.Dialer{
Timeout: common.DialTimeout,
KeepAlive: common.KeepAlive, // this is different from DisableKeepAlives
@@ -23,7 +22,7 @@ var (
},
}
Get = HTTPClient.Get
Post = HTTPClient.Post
Head = HTTPClient.Head
Get = httpClient.Get
Post = httpClient.Post
Head = httpClient.Head
)

View File

@@ -0,0 +1,18 @@
package utils
import (
"net/http"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/logging"
)
func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
return logging.WithLevel(level).Str("module", "api").
Str("method", r.Method).
Str("path", r.RequestURI)
}
func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) }
func LogWarn(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.WarnLevel) }
func LogInfo(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.InfoLevel) }

View File

@@ -3,6 +3,8 @@ package utils
import (
"encoding/json"
"net/http"
"github.com/yusing/go-proxy/internal/logging"
)
func WriteBody(w http.ResponseWriter, body []byte) {
@@ -11,15 +13,25 @@ func WriteBody(w http.ResponseWriter, body []byte) {
}
}
func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) bool {
func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) (canProceed bool) {
if len(code) > 0 {
w.WriteHeader(code[0])
}
w.Header().Set("Content-Type", "application/json")
j, err := json.MarshalIndent(data, "", " ")
if err != nil {
HandleErr(w, r, err)
return false
var j []byte
var err error
switch data := data.(type) {
case string:
j = []byte(`"` + data + `"`)
case []byte:
j = data
default:
j, err = json.MarshalIndent(data, "", " ")
if err != nil {
logging.Panic().Err(err).Msg("failed to marshal json")
return false
}
}
_, err = w.Write(j)
if err != nil {