mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-17 22:19:42 +02:00
refactor: remove net.URL and net.CIDR types, improved unmarshal handling
This commit is contained in:
@@ -6,7 +6,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
@@ -24,7 +23,7 @@ type (
|
||||
Key, Value string
|
||||
}
|
||||
Host string
|
||||
CIDR struct{ types.CIDR }
|
||||
CIDR struct{ net.IPNet }
|
||||
)
|
||||
|
||||
var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package accesslog_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
@@ -155,8 +156,11 @@ func TestHeaderFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCIDRFilter(t *testing.T) {
|
||||
cidr := []*CIDR{
|
||||
strutils.MustParse[*CIDR]("192.168.10.0/24"),
|
||||
cidr := []*CIDR{{
|
||||
net.IPNet{
|
||||
IP: net.ParseIP("192.168.10.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
}},
|
||||
}
|
||||
ExpectEqual(t, cidr[0].String(), "192.168.10.0/24")
|
||||
inCIDR := &http.Request{
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/yusing/go-proxy/internal/route/routes"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
|
||||
)
|
||||
|
||||
// TODO: stats of each server.
|
||||
@@ -240,14 +239,14 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
lb.impl.ServeHTTP(srvs, rw, r)
|
||||
}
|
||||
|
||||
// MarshalJSON implements health.HealthMonitor.
|
||||
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
|
||||
// MarshalMap implements health.HealthMonitor.
|
||||
func (lb *LoadBalancer) MarshalMap() map[string]any {
|
||||
extra := make(map[string]any)
|
||||
lb.pool.RangeAll(func(k string, v Server) {
|
||||
extra[v.Key()] = v
|
||||
})
|
||||
|
||||
return (&monitor.JSONRepresentation{
|
||||
return (&health.JSONRepresentation{
|
||||
Name: lb.Name(),
|
||||
Status: lb.Status(),
|
||||
Started: lb.startTime,
|
||||
@@ -256,7 +255,7 @@ func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
|
||||
"config": lb.Config,
|
||||
"pool": extra,
|
||||
},
|
||||
}).MarshalJSON()
|
||||
}).MarshalMap()
|
||||
}
|
||||
|
||||
// Name implements health.HealthMonitor.
|
||||
|
||||
@@ -2,9 +2,9 @@ package types
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
@@ -15,7 +15,7 @@ type (
|
||||
_ U.NoCopy
|
||||
|
||||
name string
|
||||
url *net.URL
|
||||
url *url.URL
|
||||
weight Weight
|
||||
|
||||
http.Handler `json:"-"`
|
||||
@@ -27,7 +27,7 @@ type (
|
||||
health.HealthMonitor
|
||||
Name() string
|
||||
Key() string
|
||||
URL() *net.URL
|
||||
URL() *url.URL
|
||||
Weight() Weight
|
||||
SetWeight(weight Weight)
|
||||
TryWake() error
|
||||
@@ -38,7 +38,7 @@ type (
|
||||
|
||||
var NewServerPool = F.NewMap[Pool]
|
||||
|
||||
func NewServer(name string, url *net.URL, weight Weight, handler http.Handler, healthMon health.HealthMonitor) Server {
|
||||
func NewServer(name string, url *url.URL, weight Weight, handler http.Handler, healthMon health.HealthMonitor) Server {
|
||||
srv := &server{
|
||||
name: name,
|
||||
url: url,
|
||||
@@ -52,7 +52,7 @@ func NewServer(name string, url *net.URL, weight Weight, handler http.Handler, h
|
||||
func TestNewServer[T ~int | ~float32 | ~float64](weight T) Server {
|
||||
srv := &server{
|
||||
weight: Weight(weight),
|
||||
url: net.MustParseURL("http://localhost"),
|
||||
url: &url.URL{Scheme: "http", Host: "localhost"},
|
||||
}
|
||||
return srv
|
||||
}
|
||||
@@ -61,7 +61,7 @@ func (srv *server) Name() string {
|
||||
return srv.name
|
||||
}
|
||||
|
||||
func (srv *server) URL() *net.URL {
|
||||
func (srv *server) URL() *url.URL {
|
||||
return srv.url
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
@@ -18,8 +17,8 @@ type (
|
||||
cachedAddr F.Map[string, bool] // cache for trusted IPs
|
||||
}
|
||||
CIDRWhitelistOpts struct {
|
||||
Allow []*types.CIDR `validate:"min=1"`
|
||||
StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,status_code"`
|
||||
Allow []*net.IPNet `validate:"min=1"`
|
||||
StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,status_code"`
|
||||
Message string
|
||||
}
|
||||
)
|
||||
@@ -27,7 +26,7 @@ type (
|
||||
var (
|
||||
CIDRWhiteList = NewMiddleware[cidrWhitelist]()
|
||||
cidrWhitelistDefaults = CIDRWhitelistOpts{
|
||||
Allow: []*types.CIDR{},
|
||||
Allow: []*net.IPNet{},
|
||||
StatusCode: http.StatusForbidden,
|
||||
Message: "IP not allowed",
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils/atomic"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
@@ -33,7 +32,7 @@ var (
|
||||
cfCIDRsMu sync.Mutex
|
||||
|
||||
// RFC 1918.
|
||||
localCIDRs = []*types.CIDR{
|
||||
localCIDRs = []*net.IPNet{
|
||||
{IP: net.IPv4(127, 0, 0, 1), Mask: net.IPv4Mask(255, 255, 255, 255)}, // 127.0.0.1/32
|
||||
{IP: net.IPv4(10, 0, 0, 0), Mask: net.IPv4Mask(255, 0, 0, 0)}, // 10.0.0.0/8
|
||||
{IP: net.IPv4(172, 16, 0, 0), Mask: net.IPv4Mask(255, 240, 0, 0)}, // 172.16.0.0/12
|
||||
@@ -68,7 +67,7 @@ func (cri *cloudflareRealIP) getTracer() *Tracer {
|
||||
return cri.realIP.getTracer()
|
||||
}
|
||||
|
||||
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
||||
func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) {
|
||||
if time.Since(cfCIDRsLastUpdate.Load()) < cfCIDRsUpdateInterval {
|
||||
return
|
||||
}
|
||||
@@ -83,7 +82,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
||||
if common.IsTest {
|
||||
cfCIDRs = localCIDRs
|
||||
} else {
|
||||
cfCIDRs = make([]*types.CIDR, 0, 30)
|
||||
cfCIDRs = make([]*net.IPNet, 0, 30)
|
||||
err := errors.Join(
|
||||
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, &cfCIDRs),
|
||||
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs),
|
||||
@@ -103,7 +102,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
||||
return
|
||||
}
|
||||
|
||||
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
|
||||
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*net.IPNet) error {
|
||||
resp, err := http.Get(endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -124,7 +123,7 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
|
||||
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
|
||||
}
|
||||
|
||||
*cfCIDRs = append(*cfCIDRs, (*types.CIDR)(cidr))
|
||||
*cfCIDRs = append(*cfCIDRs, (*net.IPNet)(cidr))
|
||||
}
|
||||
*cfCIDRs = append(*cfCIDRs, localCIDRs...)
|
||||
return nil
|
||||
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"bytes"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
@@ -51,8 +51,8 @@ func TestModifyRequest(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("request_headers", func(t *testing.T) {
|
||||
reqURL := types.MustParseURL("https://my.app/?arg_1=b")
|
||||
upstreamURL := types.MustParseURL("http://test.example.com")
|
||||
reqURL := Must(url.Parse("https://my.app/?arg_1=b"))
|
||||
upstreamURL := Must(url.Parse("http://test.example.com"))
|
||||
result, err := newMiddlewareTest(ModifyRequest, &testArgs{
|
||||
middlewareOpt: opts,
|
||||
reqURL: reqURL,
|
||||
@@ -128,8 +128,8 @@ func TestModifyRequest(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reqURL := types.MustParseURL("https://my.app" + tt.path)
|
||||
upstreamURL := types.MustParseURL(tt.upstreamURL)
|
||||
reqURL := Must(url.Parse("https://my.app" + tt.path))
|
||||
upstreamURL := Must(url.Parse(tt.upstreamURL))
|
||||
|
||||
opts["add_prefix"] = tt.addPrefix
|
||||
result, err := newMiddlewareTest(ModifyRequest, &testArgs{
|
||||
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"bytes"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
@@ -54,8 +54,8 @@ func TestModifyResponse(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("response_headers", func(t *testing.T) {
|
||||
reqURL := types.MustParseURL("https://my.app/?arg_1=b")
|
||||
upstreamURL := types.MustParseURL("http://test.example.com")
|
||||
reqURL := Must(url.Parse("https://my.app/?arg_1=b"))
|
||||
upstreamURL := Must(url.Parse("http://test.example.com"))
|
||||
result, err := newMiddlewareTest(ModifyResponse, &testArgs{
|
||||
middlewareOpt: opts,
|
||||
reqURL: reqURL,
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
)
|
||||
|
||||
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
|
||||
@@ -19,7 +18,7 @@ type (
|
||||
// Header is the name of the header to use for the real client IP
|
||||
Header string `validate:"required"`
|
||||
// From is a list of Address / CIDRs to trust
|
||||
From []*types.CIDR `validate:"required,min=1"`
|
||||
From []*net.IPNet `validate:"required,min=1"`
|
||||
/*
|
||||
If recursive search is disabled,
|
||||
the original client address that matches one of the trusted addresses is replaced by
|
||||
@@ -36,7 +35,7 @@ var (
|
||||
RealIP = NewMiddleware[realIP]()
|
||||
realIPOptsDefault = RealIPOpts{
|
||||
Header: "X-Real-IP",
|
||||
From: []*types.CIDR{},
|
||||
From: []*net.IPNet{},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
@@ -23,7 +22,7 @@ func TestSetRealIPOpts(t *testing.T) {
|
||||
}
|
||||
optExpected := &RealIPOpts{
|
||||
Header: httpheaders.HeaderXRealIP,
|
||||
From: []*types.CIDR{
|
||||
From: []*net.IPNet{
|
||||
{
|
||||
IP: net.ParseIP("127.0.0.0"),
|
||||
Mask: net.IPv4Mask(255, 0, 0, 0),
|
||||
|
||||
@@ -2,15 +2,15 @@ package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestRedirectToHTTPs(t *testing.T) {
|
||||
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
|
||||
reqURL: types.MustParseURL("http://example.com"),
|
||||
reqURL: Must(url.Parse("http://example.com")),
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusPermanentRedirect)
|
||||
@@ -19,7 +19,7 @@ func TestRedirectToHTTPs(t *testing.T) {
|
||||
|
||||
func TestNoRedirect(t *testing.T) {
|
||||
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
|
||||
reqURL: types.MustParseURL("https://example.com"),
|
||||
reqURL: Must(url.Parse("https://example.com")),
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
|
||||
@@ -7,11 +7,11 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
@@ -80,11 +80,11 @@ type TestResult struct {
|
||||
|
||||
type testArgs struct {
|
||||
middlewareOpt OptionsRaw
|
||||
upstreamURL *types.URL
|
||||
upstreamURL *url.URL
|
||||
|
||||
realRoundTrip bool
|
||||
|
||||
reqURL *types.URL
|
||||
reqURL *url.URL
|
||||
reqMethod string
|
||||
headers http.Header
|
||||
body []byte
|
||||
@@ -96,13 +96,13 @@ type testArgs struct {
|
||||
|
||||
func (args *testArgs) setDefaults() {
|
||||
if args.reqURL == nil {
|
||||
args.reqURL = Must(types.ParseURL("https://example.com"))
|
||||
args.reqURL = Must(url.Parse("https://example.com"))
|
||||
}
|
||||
if args.reqMethod == "" {
|
||||
args.reqMethod = http.MethodGet
|
||||
}
|
||||
if args.upstreamURL == nil {
|
||||
args.upstreamURL = Must(types.ParseURL("https://10.0.0.1:8443")) // dummy url, no actual effect
|
||||
args.upstreamURL = Must(url.Parse("https://10.0.0.1:8443")) // dummy url, no actual effect
|
||||
}
|
||||
if args.respHeaders == nil {
|
||||
args.respHeaders = http.Header{}
|
||||
|
||||
@@ -28,7 +28,6 @@ import (
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
@@ -93,7 +92,7 @@ type ReverseProxy struct {
|
||||
HandlerFunc http.HandlerFunc
|
||||
|
||||
TargetName string
|
||||
TargetURL *types.URL
|
||||
TargetURL *url.URL
|
||||
}
|
||||
|
||||
func singleJoiningSlash(a, b string) string {
|
||||
@@ -133,7 +132,7 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
|
||||
// URLs to the scheme, host, and base path provided in target. If the
|
||||
// target's path is "/base" and the incoming request was for "/dir",
|
||||
// the target request will be for /base/dir.
|
||||
func NewReverseProxy(name string, target *types.URL, transport http.RoundTripper) *ReverseProxy {
|
||||
func NewReverseProxy(name string, target *url.URL, transport http.RoundTripper) *ReverseProxy {
|
||||
if transport == nil {
|
||||
panic("nil transport")
|
||||
}
|
||||
@@ -151,7 +150,7 @@ func (p *ReverseProxy) rewriteRequestURL(req *http.Request) {
|
||||
targetQuery := p.TargetURL.RawQuery
|
||||
req.URL.Scheme = p.TargetURL.Scheme
|
||||
req.URL.Host = p.TargetURL.Host
|
||||
req.URL.Path, req.URL.RawPath = joinURLPath(&p.TargetURL.URL, req.URL)
|
||||
req.URL.Path, req.URL.RawPath = joinURLPath(p.TargetURL, req.URL)
|
||||
if targetQuery == "" || req.URL.RawQuery == "" {
|
||||
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
||||
} else {
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//nolint:recvcheck
|
||||
type CIDR net.IPNet
|
||||
|
||||
func ParseCIDR(v string) (cidr CIDR, err error) {
|
||||
err = cidr.Parse(v)
|
||||
return
|
||||
}
|
||||
|
||||
func (cidr *CIDR) Parse(v string) error {
|
||||
if !strings.Contains(v, "/") {
|
||||
v += "/32" // single IP
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cidr.IP = ipnet.IP
|
||||
cidr.Mask = ipnet.Mask
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cidr CIDR) Contains(ip net.IP) bool {
|
||||
return (*net.IPNet)(&cidr).Contains(ip)
|
||||
}
|
||||
|
||||
func (cidr CIDR) String() string {
|
||||
return (*net.IPNet)(&cidr).String()
|
||||
}
|
||||
|
||||
func (cidr CIDR) MarshalText() ([]byte, error) {
|
||||
return []byte(cidr.String()), nil
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
urlPkg "net/url"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type URL struct {
|
||||
_ utils.NoCopy
|
||||
urlPkg.URL
|
||||
}
|
||||
|
||||
func MustParseURL(url string) *URL {
|
||||
u, err := ParseURL(url)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
func ParseURL(url string) (*URL, error) {
|
||||
u := &URL{}
|
||||
return u, u.Parse(url)
|
||||
}
|
||||
|
||||
func NewURL(url *urlPkg.URL) *URL {
|
||||
return &URL{URL: *url}
|
||||
}
|
||||
|
||||
func (u *URL) Parse(url string) error {
|
||||
uu, err := urlPkg.Parse(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.URL = *uu
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *URL) String() string {
|
||||
if u == nil {
|
||||
return "nil"
|
||||
}
|
||||
return u.URL.String()
|
||||
}
|
||||
|
||||
func (u *URL) MarshalJSON() (text []byte, err error) {
|
||||
if u == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return []byte("\"" + u.URL.String() + "\""), nil
|
||||
}
|
||||
|
||||
func (u *URL) Equals(other *URL) bool {
|
||||
return u.String() == other.String()
|
||||
}
|
||||
Reference in New Issue
Block a user