refactor: code refactor and improved context and error handling

This commit is contained in:
yusing
2025-05-24 10:02:24 +08:00
parent 1f1ae38e4d
commit 5b7c392297
31 changed files with 116 additions and 98 deletions

View File

@@ -60,7 +60,7 @@ func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
ipStr = r.RemoteAddr
}
ip := net.ParseIP(ipStr)
for _, cidr := range wl.CIDRWhitelistOpts.Allow {
for _, cidr := range wl.Allow {
if cidr.Contains(ip) {
wl.cachedAddr.Store(r.RemoteAddr, true)
allow = true
@@ -70,7 +70,7 @@ func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
}
if !allow {
wl.cachedAddr.Store(r.RemoteAddr, false)
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.CIDRWhitelistOpts.Allow)
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.Allow)
}
}
if !allow {

View File

@@ -8,7 +8,7 @@ import (
"testing"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@@ -41,7 +41,7 @@ func TestCIDRWhitelistValidation(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
"message": testMessage,
})
ExpectError(t, utils.ErrValidationError, err)
ExpectError(t, serialization.ErrValidationError, err)
})
t.Run("invalid cidr", func(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
@@ -56,7 +56,7 @@ func TestCIDRWhitelistValidation(t *testing.T) {
"status_code": 600,
"message": testMessage,
})
ExpectError(t, utils.ErrValidationError, err)
ExpectError(t, serialization.ErrValidationError, err)
})
}

View File

@@ -1,6 +1,7 @@
package middleware
import (
"context"
"errors"
"fmt"
"io"
@@ -103,7 +104,15 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
}
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
resp, err := http.Get(endpoint)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req) //nolint:gosec
if err != nil {
return err
}