refactor: code refactor and improved context and error handling

This commit is contained in:
yusing
2025-05-24 10:02:24 +08:00
parent 1f1ae38e4d
commit 5b7c392297
31 changed files with 116 additions and 98 deletions

View File

@@ -1,3 +1,3 @@
package types
type Weight uint16
type Weight int

View File

@@ -8,10 +8,10 @@ import (
)
func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
return log.WithLevel(level).
Str("remote", r.RemoteAddr).
Str("host", r.Host).
Str("uri", r.Method+" "+r.RequestURI)
return log.WithLevel(level). //nolint:zerologlint
Str("remote", r.RemoteAddr).
Str("host", r.Host).
Str("uri", r.Method+" "+r.RequestURI)
}
func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) }

View File

@@ -60,7 +60,7 @@ func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
ipStr = r.RemoteAddr
}
ip := net.ParseIP(ipStr)
for _, cidr := range wl.CIDRWhitelistOpts.Allow {
for _, cidr := range wl.Allow {
if cidr.Contains(ip) {
wl.cachedAddr.Store(r.RemoteAddr, true)
allow = true
@@ -70,7 +70,7 @@ func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
}
if !allow {
wl.cachedAddr.Store(r.RemoteAddr, false)
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.CIDRWhitelistOpts.Allow)
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.Allow)
}
}
if !allow {

View File

@@ -8,7 +8,7 @@ import (
"testing"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@@ -41,7 +41,7 @@ func TestCIDRWhitelistValidation(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
"message": testMessage,
})
ExpectError(t, utils.ErrValidationError, err)
ExpectError(t, serialization.ErrValidationError, err)
})
t.Run("invalid cidr", func(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
@@ -56,7 +56,7 @@ func TestCIDRWhitelistValidation(t *testing.T) {
"status_code": 600,
"message": testMessage,
})
ExpectError(t, utils.ErrValidationError, err)
ExpectError(t, serialization.ErrValidationError, err)
})
}

View File

@@ -1,6 +1,7 @@
package middleware
import (
"context"
"errors"
"fmt"
"io"
@@ -103,7 +104,15 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
}
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
resp, err := http.Get(endpoint)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req) //nolint:gosec
if err != nil {
return err
}

View File

@@ -220,7 +220,6 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport
ctx := req.Context()
/* trunk-ignore(golangci-lint/revive) */
if ctx.Done() != nil {
// CloseNotifier predates context.Context, and has been
// entirely superseded by it. If the request contains
@@ -352,7 +351,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
return nil
},
}
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) //nolint:contextcheck
res, err := transport.RoundTrip(outreq)
@@ -507,18 +506,18 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
res.Header = rw.Header()
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
if err := res.Write(brw); err != nil {
/* trunk-ignore(golangci-lint/errorlint) */
//nolint:errorlint
p.errorHandler(rw, req, fmt.Errorf("response write: %s", err), true)
return
}
if err := brw.Flush(); err != nil {
/* trunk-ignore(golangci-lint/errorlint) */
//nolint:errorlint
p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true)
return
}
bdp := U.NewBidirectionalPipe(req.Context(), conn, backConn)
/* trunk-ignore(golangci-lint/errcheck) */
//nolint:errcheck
bdp.Start()
}

View File

@@ -16,7 +16,7 @@ import (
)
type CertProvider interface {
GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error)
GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error)
}
type Server struct {