diff --git a/Makefile b/Makefile index 61876796..7f3621b7 100755 --- a/Makefile +++ b/Makefile @@ -28,23 +28,21 @@ get: go get -u ./cmd && go mod tidy debug: - make build - sudo GOPROXY_DEBUG=1 bin/go-proxy + GOPROXY_DEBUG=1 make run debug-trace: - make build - sudo GOPROXY_DEBUG=1 GOPROXY_TRACE=1 bin/go-proxy + GOPROXY_DEBUG=1 GOPROXY_TRACE=1 run profile: - GODEBUG=gctrace=1 make build - sudo GOPROXY_DEBUG=1 bin/go-proxy + GODEBUG=gctrace=1 make debug + +run: build + sudo setcap CAP_NET_BIND_SERVICE=+eip bin/go-proxy + bin/go-proxy mtrace: bin/go-proxy debug-ls-mtrace > mtrace.json -run: - make build && sudo bin/go-proxy - archive: git archive HEAD -o ../go-proxy-$$(date +"%Y%m%d%H%M").zip diff --git a/go.mod b/go.mod index 306f62ac..2302e579 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/santhosh-tekuri/jsonschema v1.2.4 golang.org/x/net v0.30.0 golang.org/x/text v0.19.0 + golang.org/x/time v0.7.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -57,7 +58,6 @@ require ( golang.org/x/oauth2 v0.23.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.26.0 // indirect - golang.org/x/time v0.7.0 // indirect golang.org/x/tools v0.26.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gotest.tools/v3 v3.5.1 // indirect diff --git a/internal/net/http/middleware/cidr_whitelist.go b/internal/net/http/middleware/cidr_whitelist.go index ec198698..86664ca6 100644 --- a/internal/net/http/middleware/cidr_whitelist.go +++ b/internal/net/http/middleware/cidr_whitelist.go @@ -10,7 +10,7 @@ import ( ) type cidrWhitelist struct { - *cidrWhitelistOpts + cidrWhitelistOpts m *Middleware } @@ -22,18 +22,15 @@ type cidrWhitelistOpts struct { cachedAddr F.Map[string, bool] // cache for trusted IPs } -var CIDRWhiteList = &cidrWhitelist{ - m: &Middleware{withOptions: NewCIDRWhitelist}, -} - -var cidrWhitelistDefaults = func() *cidrWhitelistOpts { - return &cidrWhitelistOpts{ +var ( + CIDRWhiteList = &Middleware{withOptions: NewCIDRWhitelist} + cidrWhitelistDefaults = cidrWhitelistOpts{ Allow: []*types.CIDR{}, StatusCode: http.StatusForbidden, Message: "IP not allowed", cachedAddr: F.NewMapOf[string, bool](), } -} +) func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) { wl := new(cidrWhitelist) @@ -41,8 +38,8 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) { impl: wl, before: wl.checkIP, } - wl.cidrWhitelistOpts = cidrWhitelistDefaults() - err := Deserialize(opts, wl.cidrWhitelistOpts) + wl.cidrWhitelistOpts = cidrWhitelistDefaults + err := Deserialize(opts, &wl.cidrWhitelistOpts) if err != nil { return nil, err } diff --git a/internal/net/http/middleware/cidr_whitelist_test.go b/internal/net/http/middleware/cidr_whitelist_test.go index dd5fc69a..3c278d99 100644 --- a/internal/net/http/middleware/cidr_whitelist_test.go +++ b/internal/net/http/middleware/cidr_whitelist_test.go @@ -27,8 +27,8 @@ func TestCIDRWhitelist(t *testing.T) { for range 10 { result, err := newMiddlewareTest(deny, nil) ExpectNoError(t, err) - ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults().StatusCode) - ExpectEqual(t, string(result.Data), cidrWhitelistDefaults().Message) + ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults.StatusCode) + ExpectEqual(t, string(result.Data), cidrWhitelistDefaults.Message) } }) diff --git a/internal/net/http/middleware/cloudflare_real_ip.go b/internal/net/http/middleware/cloudflare_real_ip.go index aa823459..64182a9c 100644 --- a/internal/net/http/middleware/cloudflare_real_ip.go +++ b/internal/net/http/middleware/cloudflare_real_ip.go @@ -29,9 +29,7 @@ var ( cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger() ) -var CloudflareRealIP = &realIP{ - m: &Middleware{withOptions: NewCloudflareRealIP}, -} +var CloudflareRealIP = &Middleware{withOptions: NewCloudflareRealIP} func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) { cri := new(realIP) @@ -46,7 +44,7 @@ func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) { next(w, r) }, } - cri.realIPOpts = &realIPOpts{ + cri.realIPOpts = realIPOpts{ Header: "CF-Connecting-IP", Recursive: true, } diff --git a/internal/net/http/middleware/errors.go b/internal/net/http/middleware/errors.go new file mode 100644 index 00000000..faf40383 --- /dev/null +++ b/internal/net/http/middleware/errors.go @@ -0,0 +1,5 @@ +package middleware + +import E "github.com/yusing/go-proxy/internal/error" + +var ErrZeroValue = E.New("cannot be zero") diff --git a/internal/net/http/middleware/forward_auth.go b/internal/net/http/middleware/forward_auth.go index 29ce0042..dcaf1d62 100644 --- a/internal/net/http/middleware/forward_auth.go +++ b/internal/net/http/middleware/forward_auth.go @@ -19,7 +19,7 @@ import ( type ( forwardAuth struct { - *forwardAuthOpts + forwardAuthOpts m *Middleware client http.Client } @@ -33,14 +33,11 @@ type ( } ) -var ForwardAuth = &forwardAuth{ - m: &Middleware{withOptions: NewForwardAuthfunc}, -} +var ForwardAuth = &Middleware{withOptions: NewForwardAuthfunc} func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) { fa := new(forwardAuth) - fa.forwardAuthOpts = new(forwardAuthOpts) - if err := Deserialize(optsRaw, fa.forwardAuthOpts); err != nil { + if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil { return nil, err } if _, err := url.Parse(fa.Address); err != nil { diff --git a/internal/net/http/middleware/middleware.go b/internal/net/http/middleware/middleware.go index 0b2b80e6..04f9e8e7 100644 --- a/internal/net/http/middleware/middleware.go +++ b/internal/net/http/middleware/middleware.go @@ -28,7 +28,6 @@ type ( CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error) OptionsRaw = map[string]any - Options any Middleware struct { _ U.NoCopy diff --git a/internal/net/http/middleware/middlewares.go b/internal/net/http/middleware/middlewares.go index 111bd63b..1a7bff47 100644 --- a/internal/net/http/middleware/middlewares.go +++ b/internal/net/http/middleware/middlewares.go @@ -35,20 +35,23 @@ func All() map[string]*Middleware { // initialize middleware names and label parsers. func init() { + // snakes and cases will be stripped on `Get` + // so keys are lowercase without snake. allMiddlewares = map[string]*Middleware{ "setxforwarded": SetXForwarded, "hidexforwarded": HideXForwarded, "redirecthttp": RedirectHTTP, - "modifyresponse": ModifyResponse.m, - "modifyrequest": ModifyRequest.m, + "modifyresponse": ModifyResponse, + "modifyrequest": ModifyRequest, "errorpage": CustomErrorPage, "customerrorpage": CustomErrorPage, - "realip": RealIP.m, - "cloudflarerealip": CloudflareRealIP.m, - "cidrwhitelist": CIDRWhiteList.m, + "realip": RealIP, + "cloudflarerealip": CloudflareRealIP, + "cidrwhitelist": CIDRWhiteList, + "ratelimit": RateLimiter, // !experimental - "forwardauth": ForwardAuth.m, + "forwardauth": ForwardAuth, // "oauth2": OAuth2.m, } names := make(map[*Middleware][]string) diff --git a/internal/net/http/middleware/modify_request.go b/internal/net/http/middleware/modify_request.go index 0b0ce60d..02f531ca 100644 --- a/internal/net/http/middleware/modify_request.go +++ b/internal/net/http/middleware/modify_request.go @@ -7,7 +7,7 @@ import ( type ( modifyRequest struct { - *modifyRequestOpts + modifyRequestOpts m *Middleware } // order: set_headers -> add_headers -> hide_headers @@ -18,9 +18,7 @@ type ( } ) -var ModifyRequest = &modifyRequest{ - m: &Middleware{withOptions: NewModifyRequest}, -} +var ModifyRequest = &Middleware{withOptions: NewModifyRequest} func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) { mr := new(modifyRequest) @@ -34,8 +32,7 @@ func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) { impl: mr, before: Rewrite(mrFunc), } - mr.modifyRequestOpts = new(modifyRequestOpts) - err := Deserialize(optsRaw, mr.modifyRequestOpts) + err := Deserialize(optsRaw, &mr.modifyRequestOpts) if err != nil { return nil, err } diff --git a/internal/net/http/middleware/modify_request_test.go b/internal/net/http/middleware/modify_request_test.go index 6d9d9ecc..4a2bc923 100644 --- a/internal/net/http/middleware/modify_request_test.go +++ b/internal/net/http/middleware/modify_request_test.go @@ -15,7 +15,7 @@ func TestSetModifyRequest(t *testing.T) { } t.Run("set_options", func(t *testing.T) { - mr, err := ModifyRequest.m.WithOptionsClone(opts) + mr, err := ModifyRequest.WithOptionsClone(opts) ExpectNoError(t, err) ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string)) @@ -23,7 +23,7 @@ func TestSetModifyRequest(t *testing.T) { }) t.Run("request_headers", func(t *testing.T) { - result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{ + result, err := newMiddlewareTest(ModifyRequest, &testArgs{ middlewareOpt: opts, }) ExpectNoError(t, err) diff --git a/internal/net/http/middleware/modify_response.go b/internal/net/http/middleware/modify_response.go index 4edd7103..d38ab75a 100644 --- a/internal/net/http/middleware/modify_response.go +++ b/internal/net/http/middleware/modify_response.go @@ -9,16 +9,14 @@ import ( type ( modifyResponse struct { - *modifyResponseOpts + modifyResponseOpts m *Middleware } // order: set_headers -> add_headers -> hide_headers modifyResponseOpts = modifyRequestOpts ) -var ModifyResponse = &modifyResponse{ - m: &Middleware{withOptions: NewModifyResponse}, -} +var ModifyResponse = &Middleware{withOptions: NewModifyResponse} func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) { mr := new(modifyResponse) @@ -28,8 +26,7 @@ func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) { } else { mr.m.modifyResponse = mr.modifyResponse } - mr.modifyResponseOpts = new(modifyResponseOpts) - err := Deserialize(optsRaw, mr.modifyResponseOpts) + err := Deserialize(optsRaw, &mr.modifyResponseOpts) if err != nil { return nil, err } diff --git a/internal/net/http/middleware/modify_response_test.go b/internal/net/http/middleware/modify_response_test.go index 370e590c..2672b5d4 100644 --- a/internal/net/http/middleware/modify_response_test.go +++ b/internal/net/http/middleware/modify_response_test.go @@ -15,7 +15,7 @@ func TestSetModifyResponse(t *testing.T) { } t.Run("set_options", func(t *testing.T) { - mr, err := ModifyResponse.m.WithOptionsClone(opts) + mr, err := ModifyResponse.WithOptionsClone(opts) ExpectNoError(t, err) ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string)) @@ -23,7 +23,7 @@ func TestSetModifyResponse(t *testing.T) { }) t.Run("request_headers", func(t *testing.T) { - result, err := newMiddlewareTest(ModifyResponse.m, &testArgs{ + result, err := newMiddlewareTest(ModifyResponse, &testArgs{ middlewareOpt: opts, }) ExpectNoError(t, err) diff --git a/internal/net/http/middleware/rate_limit.go b/internal/net/http/middleware/rate_limit.go new file mode 100644 index 00000000..5634705e --- /dev/null +++ b/internal/net/http/middleware/rate_limit.go @@ -0,0 +1,81 @@ +package middleware + +import ( + "net/http" + "sync" + "time" + + E "github.com/yusing/go-proxy/internal/error" + "golang.org/x/time/rate" +) + +type ( + requestMap = map[string]*rate.Limiter + rateLimiter struct { + requestMap requestMap + newLimiter func() *rate.Limiter + m *Middleware + + mu sync.Mutex + } + + rateLimiterOpts struct { + Average int `json:"average"` + Burst int `json:"burst"` + Period time.Duration `json:"period"` + } +) + +var ( + RateLimiter = &Middleware{withOptions: NewRateLimiter} + rateLimiterOptsDefault = rateLimiterOpts{ + Average: 100, + Burst: 1, + Period: time.Second, + } +) + +func NewRateLimiter(optsRaw OptionsRaw) (*Middleware, E.Error) { + rl := new(rateLimiter) + opts := rateLimiterOptsDefault + err := Deserialize(optsRaw, &opts) + if err != nil { + return nil, err + } + switch { + case opts.Average == 0: + return nil, ErrZeroValue.Subject("average") + case opts.Burst == 0: + return nil, ErrZeroValue.Subject("burst") + case opts.Period == 0: + return nil, ErrZeroValue.Subject("period") + } + rl.requestMap = make(requestMap, 0) + rl.newLimiter = func() *rate.Limiter { + return rate.NewLimiter(rate.Limit(opts.Average)*rate.Every(opts.Period), opts.Burst) + } + rl.m = &Middleware{ + impl: rl, + before: rl.limit, + } + return rl.m, nil +} + +func (rl *rateLimiter) limit(next http.HandlerFunc, w ResponseWriter, r *Request) { + rl.mu.Lock() + + limiter, ok := rl.requestMap[r.RemoteAddr] + if !ok { + limiter = rl.newLimiter() + rl.requestMap[r.RemoteAddr] = limiter + } + + rl.mu.Unlock() + + if limiter.Allow() { + next(w, r) + return + } + + http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) +} diff --git a/internal/net/http/middleware/rate_limit_test.go b/internal/net/http/middleware/rate_limit_test.go new file mode 100644 index 00000000..ec21781f --- /dev/null +++ b/internal/net/http/middleware/rate_limit_test.go @@ -0,0 +1,27 @@ +package middleware + +import ( + "net/http" + "testing" + + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestRateLimit(t *testing.T) { + opts := OptionsRaw{ + "average": "10", + "burst": "10", + "period": "1s", + } + + rl, err := NewRateLimiter(opts) + ExpectNoError(t, err) + for range 10 { + result, err := newMiddlewareTest(rl, nil) + ExpectNoError(t, err) + ExpectEqual(t, result.ResponseStatus, http.StatusOK) + } + result, err := newMiddlewareTest(rl, nil) + ExpectNoError(t, err) + ExpectEqual(t, result.ResponseStatus, http.StatusTooManyRequests) +} diff --git a/internal/net/http/middleware/rate_limiter.go b/internal/net/http/middleware/rate_limiter.go deleted file mode 100644 index ea5f9588..00000000 --- a/internal/net/http/middleware/rate_limiter.go +++ /dev/null @@ -1,12 +0,0 @@ -package middleware - -type ( - rateLimiter struct { - *rateLimiterOpts - m *Middleware - } - - rateLimiterOpts struct { - Count int `json:"count"` - } -) diff --git a/internal/net/http/middleware/real_ip.go b/internal/net/http/middleware/real_ip.go index 0bbf4522..40a2bebd 100644 --- a/internal/net/http/middleware/real_ip.go +++ b/internal/net/http/middleware/real_ip.go @@ -10,7 +10,7 @@ import ( // https://nginx.org/en/docs/http/ngx_http_realip_module.html type realIP struct { - *realIPOpts + realIPOpts m *Middleware } @@ -30,16 +30,13 @@ type realIPOpts struct { Recursive bool `json:"recursive"` } -var RealIP = &realIP{ - m: &Middleware{withOptions: NewRealIP}, -} - -var realIPOptsDefault = func() *realIPOpts { - return &realIPOpts{ +var ( + RealIP = &Middleware{withOptions: NewRealIP} + realIPOptsDefault = realIPOpts{ Header: "X-Real-IP", From: []*types.CIDR{}, } -} +) func NewRealIP(opts OptionsRaw) (*Middleware, E.Error) { riWithOpts := new(realIP) @@ -47,11 +44,14 @@ func NewRealIP(opts OptionsRaw) (*Middleware, E.Error) { impl: riWithOpts, before: Rewrite(riWithOpts.setRealIP), } - riWithOpts.realIPOpts = realIPOptsDefault() - err := Deserialize(opts, riWithOpts.realIPOpts) + riWithOpts.realIPOpts = realIPOptsDefault + err := Deserialize(opts, &riWithOpts.realIPOpts) if err != nil { return nil, err } + if len(riWithOpts.From) == 0 { + return nil, E.New("no allowed CIDRs").Subject("from") + } return riWithOpts.m, nil } @@ -70,9 +70,10 @@ func (ri *realIP) setRealIP(req *Request) { if err != nil { clientIPStr = req.RemoteAddr } - clientIP := net.ParseIP(clientIPStr) - var isTrusted = false + clientIP := net.ParseIP(clientIPStr) + isTrusted := false + for _, CIDR := range ri.From { if CIDR.Contains(clientIP) { isTrusted = true