From e7c87bae776a5899d7fd650d518d0bf8e70e3ffb Mon Sep 17 00:00:00 2001 From: yusing Date: Fri, 5 Dec 2025 16:06:36 +0800 Subject: [PATCH] refactor(http,rules): move SharedData and ResponseModifier to httputils - implemented dependency injection for rule auth handler --- cmd/main.go | 4 + internal/api/v1/route/playground.go | 3 +- internal/net/gphttp/middleware/bypass.go | 3 +- internal/net/gphttp/middleware/middleware.go | 4 +- internal/route/rules/cache.go | 108 -------- internal/route/rules/crypto.go | 11 +- internal/route/rules/do.go | 13 +- internal/route/rules/do_set.go | 19 +- internal/route/rules/do_set_test.go | 11 +- internal/route/rules/on.go | 23 +- internal/route/rules/response_modifier.go | 267 ------------------- internal/route/rules/rules.go | 25 +- internal/route/rules/template.go | 6 +- internal/route/rules/var_bench_test.go | 4 +- internal/route/rules/vars.go | 7 +- internal/route/rules/vars_dynamic.go | 16 +- internal/route/rules/vars_static.go | 7 +- internal/route/rules/vars_test.go | 25 +- 18 files changed, 105 insertions(+), 451 deletions(-) delete mode 100644 internal/route/rules/cache.go delete mode 100644 internal/route/rules/response_modifier.go diff --git a/cmd/main.go b/cmd/main.go index 06008f59..c56d6bc5 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -16,6 +16,7 @@ import ( "github.com/yusing/godoxy/internal/metrics/systeminfo" "github.com/yusing/godoxy/internal/metrics/uptime" "github.com/yusing/godoxy/internal/net/gphttp/middleware" + "github.com/yusing/godoxy/internal/route/rules" gperr "github.com/yusing/goutils/errs" "github.com/yusing/goutils/server" "github.com/yusing/goutils/task" @@ -58,9 +59,12 @@ func main() { } config.StartProxyServers() + if err := auth.Initialize(); err != nil { log.Fatal().Err(err).Msg("failed to initialize authentication") } + rules.InitAuthHandler(auth.AuthOrProceed) + // API Handler needs to start after auth is initialized. server.StartServer(task.RootTask("api_server", false), server.Options{ Name: "api", diff --git a/internal/api/v1/route/playground.go b/internal/api/v1/route/playground.go index 44c15740..b7321839 100644 --- a/internal/api/v1/route/playground.go +++ b/internal/api/v1/route/playground.go @@ -12,6 +12,7 @@ import ( "github.com/yusing/godoxy/internal/route/rules" apitypes "github.com/yusing/goutils/apitypes" gperr "github.com/yusing/goutils/errs" + httputils "github.com/yusing/goutils/http" ) type RawRule struct { @@ -348,7 +349,7 @@ func checkMatchedRules(rulesList rules.Rules, w http.ResponseWriter, r *http.Req var matched []string // Create a ResponseModifier to properly check rules - rm := rules.NewResponseModifier(w) + rm := httputils.NewResponseModifier(w) for _, rule := range rulesList { // Check if rule matches diff --git a/internal/net/gphttp/middleware/bypass.go b/internal/net/gphttp/middleware/bypass.go index 85f618c0..83dab822 100644 --- a/internal/net/gphttp/middleware/bypass.go +++ b/internal/net/gphttp/middleware/bypass.go @@ -7,6 +7,7 @@ import ( "github.com/rs/zerolog/log" "github.com/yusing/godoxy/internal/auth" "github.com/yusing/godoxy/internal/route/rules" + httputils "github.com/yusing/goutils/http" ) type Bypass []rules.RuleOn @@ -50,7 +51,7 @@ func (c *checkBypass) before(w http.ResponseWriter, r *http.Request) (proceedNex } func (c *checkBypass) modifyResponse(resp *http.Response) error { - if c.modRes == nil || (!c.isEnforced(resp.Request) && c.bypass.ShouldBypass(rules.ResponseAsRW(resp), resp.Request)) { + if c.modRes == nil || (!c.isEnforced(resp.Request) && c.bypass.ShouldBypass(httputils.ResponseAsRW(resp), resp.Request)) { return nil } log.Debug().Str("middleware", c.name).Str("url", resp.Request.Host+resp.Request.URL.Path).Msg("modifying response") diff --git a/internal/net/gphttp/middleware/middleware.go b/internal/net/gphttp/middleware/middleware.go index d3aa7e18..731256e8 100644 --- a/internal/net/gphttp/middleware/middleware.go +++ b/internal/net/gphttp/middleware/middleware.go @@ -9,9 +9,9 @@ import ( "github.com/bytedance/sonic" "github.com/rs/zerolog" "github.com/rs/zerolog/log" - "github.com/yusing/godoxy/internal/route/rules" "github.com/yusing/godoxy/internal/serialization" gperr "github.com/yusing/goutils/errs" + httputils "github.com/yusing/goutils/http" "github.com/yusing/goutils/http/httpheaders" "github.com/yusing/goutils/http/reverseproxy" ) @@ -197,7 +197,7 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r * } if exec, ok := m.impl.(ResponseModifier); ok { - rm := rules.NewResponseModifier(w) + rm := httputils.NewResponseModifier(w) defer rm.FlushRelease() next(rm, r) diff --git a/internal/route/rules/cache.go b/internal/route/rules/cache.go deleted file mode 100644 index afb8ac3b..00000000 --- a/internal/route/rules/cache.go +++ /dev/null @@ -1,108 +0,0 @@ -package rules - -import ( - "net" - "net/http" - "net/url" - "sync" -) - -// Cache is a map of cached values for a request. -// It prevents the same value from being parsed multiple times. -type ( - Cache map[string]any - UpdateFunc[T any] func(T) T -) - -const ( - cacheKeyQueries = "queries" - cacheKeyCookies = "cookies" - cacheKeyRemoteIP = "remote_ip" - cacheKeyBasicAuth = "basic_auth" -) - -var cachePool = sync.Pool{ - New: func() any { - return make(Cache) - }, -} - -// NewCache returns a new Cached. -func NewCache() Cache { - return cachePool.Get().(Cache) -} - -// Release clear the contents of the Cached and returns it to the pool. -func (c Cache) Release() { - clear(c) - cachePool.Put(c) -} - -// GetQueries returns the queries. -// If r does not have queries, an empty map is returned. -func (c Cache) GetQueries(r *http.Request) url.Values { - v, ok := c[cacheKeyQueries] - if !ok { - v = r.URL.Query() - c[cacheKeyQueries] = v - } - return v.(url.Values) -} - -func (c Cache) UpdateQueries(r *http.Request, update func(url.Values)) { - queries := c.GetQueries(r) - update(queries) - r.URL.RawQuery = queries.Encode() -} - -// GetCookies returns the cookies. -// If r does not have cookies, an empty slice is returned. -func (c Cache) GetCookies(r *http.Request) []*http.Cookie { - v, ok := c[cacheKeyCookies] - if !ok { - v = r.Cookies() - c[cacheKeyCookies] = v - } - return v.([]*http.Cookie) -} - -func (c Cache) UpdateCookies(r *http.Request, update UpdateFunc[[]*http.Cookie]) { - cookies := update(c.GetCookies(r)) - c[cacheKeyCookies] = cookies - r.Header.Del("Cookie") - for _, cookie := range cookies { - r.AddCookie(cookie) - } -} - -// GetRemoteIP returns the remote ip address. -// If r.RemoteAddr is not a valid ip address, nil is returned. -func (c Cache) GetRemoteIP(r *http.Request) net.IP { - v, ok := c[cacheKeyRemoteIP] - if !ok { - host, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - host = r.RemoteAddr - } - v = net.ParseIP(host) - c[cacheKeyRemoteIP] = v - } - return v.(net.IP) -} - -// GetBasicAuth returns *Credentials the basic auth username and password. -// If r does not have basic auth, nil is returned. -func (c Cache) GetBasicAuth(r *http.Request) *Credentials { - v, ok := c[cacheKeyBasicAuth] - if !ok { - u, p, ok := r.BasicAuth() - if ok { - v = &Credentials{u, []byte(p)} - c[cacheKeyBasicAuth] = v - } else { - c[cacheKeyBasicAuth] = nil - return nil - } - } - return v.(*Credentials) -} diff --git a/internal/route/rules/crypto.go b/internal/route/rules/crypto.go index 3b05a1fd..b29f6353 100644 --- a/internal/route/rules/crypto.go +++ b/internal/route/rules/crypto.go @@ -1,16 +1,15 @@ package rules -import "golang.org/x/crypto/bcrypt" +import ( + httputils "github.com/yusing/goutils/http" + "golang.org/x/crypto/bcrypt" +) type ( HashedCrendentials struct { Username string CheckMatch func(inputPwd []byte) bool } - Credentials struct { - Username string - Password []byte - } ) func BCryptCrendentials(username string, hashedPassword []byte) *HashedCrendentials { @@ -19,7 +18,7 @@ func BCryptCrendentials(username string, hashedPassword []byte) *HashedCrendenti }} } -func (hc *HashedCrendentials) Match(cred *Credentials) bool { +func (hc *HashedCrendentials) Match(cred *httputils.Credentials) bool { if cred == nil { return false } diff --git a/internal/route/rules/do.go b/internal/route/rules/do.go index e1aa3af3..3189f586 100644 --- a/internal/route/rules/do.go +++ b/internal/route/rules/do.go @@ -10,7 +10,6 @@ import ( "strings" "github.com/rs/zerolog" - "github.com/yusing/godoxy/internal/auth" "github.com/yusing/godoxy/internal/logging" gphttp "github.com/yusing/godoxy/internal/net/gphttp" nettypes "github.com/yusing/godoxy/internal/net/types" @@ -50,6 +49,14 @@ const ( CommandPassAlt = "bypass" ) +type AuthHandler func(w http.ResponseWriter, r *http.Request) (proceed bool) + +var authHandler AuthHandler + +func InitAuthHandler(handler AuthHandler) { + authHandler = handler +} + var commands = map[string]struct { help Help validate ValidateFunc @@ -70,7 +77,7 @@ var commands = map[string]struct { }, build: func(args any) CommandHandler { return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - if !auth.AuthOrProceed(w, r) { + if !authHandler(w, r) { return errTerminated } return nil @@ -198,7 +205,7 @@ var commands = map[string]struct { code, textTmpl := args.(*Tuple[int, templateString]).Unpack() return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { // error command should overwrite the response body - GetInitResponseModifier(w).ResetBody() + httputils.GetInitResponseModifier(w).ResetBody() w.WriteHeader(code) err := textTmpl.ExpandVars(w, r, w) return err diff --git a/internal/route/rules/do_set.go b/internal/route/rules/do_set.go index 6a18f60f..97e7a8ef 100644 --- a/internal/route/rules/do_set.go +++ b/internal/route/rules/do_set.go @@ -7,6 +7,7 @@ import ( "strconv" gperr "github.com/yusing/goutils/errs" + httputils "github.com/yusing/goutils/http" ioutils "github.com/yusing/goutils/io" ) @@ -128,7 +129,7 @@ var modFields = map[string]struct { if err != nil { return err } - GetSharedData(w).UpdateQueries(r, func(queries url.Values) { + httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) { queries.Set(k, v) }) return nil @@ -138,13 +139,13 @@ var modFields = map[string]struct { if err != nil { return err } - GetSharedData(w).UpdateQueries(r, func(queries url.Values) { + httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) { queries.Add(k, v) }) return nil }), remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - GetSharedData(w).UpdateQueries(r, func(queries url.Values) { + httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) { queries.Del(k) }) return nil @@ -169,7 +170,7 @@ var modFields = map[string]struct { if err != nil { return err } - GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { + httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { for i, c := range cookies { if c.Name == k { cookies[i].Value = v @@ -185,13 +186,13 @@ var modFields = map[string]struct { if err != nil { return err } - GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { + httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { return append(cookies, &http.Cookie{Name: k, Value: v}) }) return nil }), remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { + httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { index := -1 for i, c := range cookies { if c.Name == k { @@ -242,7 +243,7 @@ var modFields = map[string]struct { r.Body = nil } - bufPool := GetInitResponseModifier(w).BufPool() + bufPool := httputils.GetInitResponseModifier(w).BufPool() b := bufPool.GetBuffer() err := tmpl.ExpandVars(w, r, b) if err != nil { @@ -282,7 +283,7 @@ var modFields = map[string]struct { tmpl := args.(templateString) return &FieldHandler{ set: OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error { - rm := GetInitResponseModifier(w) + rm := httputils.GetInitResponseModifier(w) rm.ResetBody() return tmpl.ExpandVars(w, r, rm) }), @@ -317,7 +318,7 @@ var modFields = map[string]struct { status := args.(int) return &FieldHandler{ set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - GetInitResponseModifier(w).WriteHeader(status) + httputils.GetInitResponseModifier(w).WriteHeader(status) return nil }), } diff --git a/internal/route/rules/do_set_test.go b/internal/route/rules/do_set_test.go index aae2c0a6..4d77ccd1 100644 --- a/internal/route/rules/do_set_test.go +++ b/internal/route/rules/do_set_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + httputils "github.com/yusing/goutils/http" ) func TestFieldHandler_Header(t *testing.T) { @@ -420,7 +421,7 @@ func TestFieldHandler_ResponseBody(t *testing.T) { name string template string setup func(*http.Request) - verify func(*ResponseModifier) + verify func(*httputils.ResponseModifier) }{ { name: "set response body with template", @@ -429,8 +430,8 @@ func TestFieldHandler_ResponseBody(t *testing.T) { r.Method = "GET" r.URL.Path = "/api/test" }, - verify: func(rm *ResponseModifier) { - content := rm.buf.String() + verify: func(rm *httputils.ResponseModifier) { + content := string(rm.Content()) expected := "Response: GET /api/test" assert.Equal(t, expected, content, "Expected response body") }, @@ -444,7 +445,7 @@ func TestFieldHandler_ResponseBody(t *testing.T) { w := httptest.NewRecorder() // Create ResponseModifier wrapper - rm := NewResponseModifier(w) + rm := httputils.NewResponseModifier(w) tmpl, tErr := validateTemplate(tt.template, false) if tErr != nil { @@ -495,7 +496,7 @@ func TestFieldHandler_StatusCode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() - rm := NewResponseModifier(w) + rm := httputils.NewResponseModifier(w) var cmd Command err := cmd.Parse(fmt.Sprintf("set %s %d", FieldStatusCode, tt.status)) if err != nil { diff --git a/internal/route/rules/on.go b/internal/route/rules/on.go index 1d0ef54f..e25ce8cc 100644 --- a/internal/route/rules/on.go +++ b/internal/route/rules/on.go @@ -8,6 +8,7 @@ import ( "github.com/yusing/godoxy/internal/route/routes" gperr "github.com/yusing/goutils/errs" + httputils "github.com/yusing/goutils/http" ) type RuleOn struct { @@ -95,11 +96,11 @@ var checkers = map[string]struct { k, matcher := args.(*MapValueMatcher).Unpack() if matcher == nil { return func(w http.ResponseWriter, r *http.Request) bool { - return len(GetInitResponseModifier(w).Header()[k]) > 0 + return len(httputils.GetInitResponseModifier(w).Header()[k]) > 0 } } return func(w http.ResponseWriter, r *http.Request) bool { - return slices.ContainsFunc(GetInitResponseModifier(w).Header()[k], matcher) + return slices.ContainsFunc(httputils.GetInitResponseModifier(w).Header()[k], matcher) } }, }, @@ -122,11 +123,11 @@ var checkers = map[string]struct { k, matcher := args.(*MapValueMatcher).Unpack() if matcher == nil { return func(w http.ResponseWriter, r *http.Request) bool { - return len(GetSharedData(w).GetQueries(r)[k]) > 0 + return len(httputils.GetSharedData(w).GetQueries(r)[k]) > 0 } } return func(w http.ResponseWriter, r *http.Request) bool { - return slices.ContainsFunc(GetSharedData(w).GetQueries(r)[k], matcher) + return slices.ContainsFunc(httputils.GetSharedData(w).GetQueries(r)[k], matcher) } }, }, @@ -149,7 +150,7 @@ var checkers = map[string]struct { k, matcher := args.(*MapValueMatcher).Unpack() if matcher == nil { return func(w http.ResponseWriter, r *http.Request) bool { - cookies := GetSharedData(w).GetCookies(r) + cookies := httputils.GetSharedData(w).GetCookies(r) for _, cookie := range cookies { if cookie.Name == k { return true @@ -159,7 +160,7 @@ var checkers = map[string]struct { } } return func(w http.ResponseWriter, r *http.Request) bool { - cookies := GetSharedData(w).GetCookies(r) + cookies := httputils.GetSharedData(w).GetCookies(r) for _, cookie := range cookies { if cookie.Name == k { if matcher(cookie.Value) { @@ -302,7 +303,7 @@ var checkers = map[string]struct { if ones, bits := ipnet.Mask.Size(); ones == bits { wantIP := ipnet.IP return func(w http.ResponseWriter, r *http.Request) bool { - ip := GetSharedData(w).GetRemoteIP(r) + ip := httputils.GetSharedData(w).GetRemoteIP(r) if ip == nil { return false } @@ -310,7 +311,7 @@ var checkers = map[string]struct { } } return func(w http.ResponseWriter, r *http.Request) bool { - ip := GetSharedData(w).GetRemoteIP(r) + ip := httputils.GetSharedData(w).GetRemoteIP(r) if ip == nil { return false } @@ -330,7 +331,7 @@ var checkers = map[string]struct { builder: func(args any) CheckFunc { cred := args.(*HashedCrendentials) return func(w http.ResponseWriter, r *http.Request) bool { - return cred.Match(GetSharedData(w).GetBasicAuth(r)) + return cred.Match(httputils.GetSharedData(w).GetBasicAuth(r)) } }, }, @@ -378,11 +379,11 @@ var checkers = map[string]struct { beg, end := args.(*IntTuple).Unpack() if beg == end { return func(w http.ResponseWriter, _ *http.Request) bool { - return GetInitResponseModifier(w).StatusCode() == beg + return httputils.GetInitResponseModifier(w).StatusCode() == beg } } return func(w http.ResponseWriter, _ *http.Request) bool { - statusCode := GetInitResponseModifier(w).StatusCode() + statusCode := httputils.GetInitResponseModifier(w).StatusCode() return statusCode >= beg && statusCode <= end } }, diff --git a/internal/route/rules/response_modifier.go b/internal/route/rules/response_modifier.go deleted file mode 100644 index d03dda4d..00000000 --- a/internal/route/rules/response_modifier.go +++ /dev/null @@ -1,267 +0,0 @@ -package rules - -import ( - "bufio" - "bytes" - "errors" - "fmt" - "io" - "net" - "net/http" - "strconv" - - "github.com/rs/zerolog/log" - gperr "github.com/yusing/goutils/errs" - "github.com/yusing/goutils/synk" -) - -type ResponseModifier struct { - bufPool synk.UnsizedBytesPool - - w http.ResponseWriter - buf *bytes.Buffer - statusCode int - shared Cache - - origContentLength int64 // from http.Response in ResponseAsRW, -1 if not set - bodyModified bool - - hijacked bool - - errs gperr.Builder -} - -type Response struct { - StatusCode int - Header http.Header -} - -func unwrapResponseModifier(w http.ResponseWriter) *ResponseModifier { - for { - switch ww := w.(type) { - case *ResponseModifier: - return ww - case interface{ Unwrap() http.ResponseWriter }: - w = ww.Unwrap() - default: - return nil - } - } -} - -type responseAsRW struct { - resp *http.Response -} - -func (r responseAsRW) WriteHeader(code int) { - log.Error().Msg("write header after response has been created") -} - -func (r responseAsRW) Write(b []byte) (int, error) { - return 0, io.ErrClosedPipe -} - -func (r responseAsRW) Header() http.Header { - return r.resp.Header -} - -func ResponseAsRW(resp *http.Response) *ResponseModifier { - return &ResponseModifier{ - statusCode: resp.StatusCode, - w: responseAsRW{resp}, - origContentLength: resp.ContentLength, - } -} - -// GetInitResponseModifier returns the response modifier for the given response writer. -// If the response writer is already wrapped, it will return the wrapped response modifier. -// Otherwise, it will return a new response modifier. -func GetInitResponseModifier(w http.ResponseWriter) *ResponseModifier { - if rm := unwrapResponseModifier(w); rm != nil { - return rm - } - return NewResponseModifier(w) -} - -// GetSharedData returns the shared data for the given response writer. -// It will initialize the shared data if not initialized. -func GetSharedData(w http.ResponseWriter) Cache { - rm := GetInitResponseModifier(w) - if rm.shared == nil { - rm.shared = NewCache() - } - return rm.shared -} - -// NewResponseModifier returns a new response modifier for the given response writer. -// -// It should only be called once, at the very beginning of the request. -func NewResponseModifier(w http.ResponseWriter) *ResponseModifier { - return &ResponseModifier{ - bufPool: synk.GetUnsizedBytesPool(), - w: w, - origContentLength: -1, - } -} - -func (rm *ResponseModifier) BufPool() synk.UnsizedBytesPool { - return rm.bufPool -} - -// func (rm *ResponseModifier) Unwrap() http.ResponseWriter { -// return rm.w -// } - -func (rm *ResponseModifier) WriteHeader(code int) { - rm.statusCode = code -} - -// BodyReader returns a reader for the response body. -// Every call to this function will return a new reader that starts from the beginning of the buffer. -func (rm *ResponseModifier) BodyReader() io.ReadCloser { - if rm.buf == nil { - return io.NopCloser(bytes.NewReader(nil)) - } - return io.NopCloser(bytes.NewReader(rm.buf.Bytes())) -} - -func (rm *ResponseModifier) ResetBody() { - if !rm.bodyModified { - return - } - if rm.buf == nil { - return - } - rm.buf.Reset() -} - -func (rm *ResponseModifier) SetBody(r io.ReadCloser) error { - if rm.buf == nil { - rm.buf = rm.bufPool.GetBuffer() - } else { - rm.buf.Reset() - } - - rm.bodyModified = true - - _, err := io.Copy(rm.buf, r) - if err != nil { - return fmt.Errorf("failed to copy body: %w", err) - } - r.Close() - return nil -} - -func (rm *ResponseModifier) ContentLength() int { - if !rm.bodyModified { - if rm.origContentLength >= 0 { - return int(rm.origContentLength) - } - contentLength, _ := strconv.Atoi(rm.ContentLengthStr()) - return contentLength - } - return rm.buf.Len() -} - -func (rm *ResponseModifier) ContentLengthStr() string { - if !rm.bodyModified { - if rm.origContentLength >= 0 { - return strconv.FormatInt(rm.origContentLength, 10) - } - return rm.w.Header().Get("Content-Length") - } - return strconv.Itoa(rm.buf.Len()) -} - -func (rm *ResponseModifier) Content() []byte { - if rm.buf == nil { - return nil - } - return rm.buf.Bytes() -} - -func (rm *ResponseModifier) StatusCode() int { - if rm.statusCode == 0 { - return http.StatusOK - } - return rm.statusCode -} - -func (rm *ResponseModifier) Header() http.Header { - return rm.w.Header() -} - -func (rm *ResponseModifier) Response() Response { - return Response{StatusCode: rm.StatusCode(), Header: rm.Header()} -} - -func (rm *ResponseModifier) Write(b []byte) (int, error) { - if len(b) == 0 { - return 0, nil - } - - rm.bodyModified = true - if rm.buf == nil { - rm.buf = rm.bufPool.GetBuffer() - } - return rm.buf.Write(b) -} - -// AppendError appends an error to the response modifier -// the error will be formatted as "rule error: " -// -// It will be aggregated and returned in FlushRelease. -func (rm *ResponseModifier) AppendError(rule Rule, err error) { - rm.errs.Addf("rule %q error: %w", rule.Name, err) -} - -func (rm *ResponseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) { - if hijacker, ok := rm.w.(http.Hijacker); ok { - rm.hijacked = true - return hijacker.Hijack() - } - return nil, nil, errors.New("hijack not supported") -} - -// FlushRelease flushes the response modifier and releases the resources -// it returns the number of bytes written and the aggregated error -// if there is any error (rule errors or write error), it will be returned -func (rm *ResponseModifier) FlushRelease() (int, error) { - n := 0 - if !rm.hijacked { - if rm.bodyModified { - h := rm.w.Header() - h.Set("Content-Length", rm.ContentLengthStr()) - h.Del("Transfer-Encoding") - h.Del("Trailer") - } - rm.w.WriteHeader(rm.StatusCode()) - - if rm.bodyModified { - if content := rm.Content(); len(content) > 0 { - nn, werr := rm.w.Write(content) - n += nn - if werr != nil { - rm.errs.Addf("write error: %w", werr) - } - if err := http.NewResponseController(rm.w).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { - rm.errs.Addf("flush error: %w", err) - } - } - } - } - - // release the buffer and reset the pointers - if rm.buf != nil { - rm.bufPool.PutBuffer(rm.buf) - rm.buf = nil - } - - // release the shared data - if rm.shared != nil { - rm.shared.Release() - rm.shared = nil - } - - return n, rm.errs.Error() -} diff --git a/internal/route/rules/rules.go b/internal/route/rules/rules.go index 896eac65..ab014492 100644 --- a/internal/route/rules/rules.go +++ b/internal/route/rules/rules.go @@ -7,6 +7,7 @@ import ( "github.com/quic-go/quic-go/http3" "github.com/rs/zerolog/log" + httputils "github.com/yusing/goutils/http" "golang.org/x/net/http2" _ "unsafe" @@ -91,7 +92,7 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { } if defaultRule.IsResponseRule() { return func(w http.ResponseWriter, r *http.Request) { - rm := NewResponseModifier(w) + rm := httputils.NewResponseModifier(w) defer func() { if _, err := rm.FlushRelease(); err != nil { logError(err, r) @@ -101,12 +102,12 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { up(w, r) err := defaultRule.Do.exec.Handle(w, r) if err != nil && !errors.Is(err, errTerminated) { - rm.AppendError(defaultRule, err) + appendRuleError(rm, &defaultRule, err) } } } return func(w http.ResponseWriter, r *http.Request) { - rm := NewResponseModifier(w) + rm := httputils.NewResponseModifier(w) defer func() { if _, err := rm.FlushRelease(); err != nil { logError(err, r) @@ -119,7 +120,7 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { return } if !errors.Is(err, errTerminated) { - rm.AppendError(defaultRule, err) + appendRuleError(rm, &defaultRule, err) } } } @@ -138,7 +139,7 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { defaultTerminates := isTerminatingHandler(defaultRule.Do.exec) return func(w http.ResponseWriter, r *http.Request) { - rm := NewResponseModifier(w) + rm := httputils.NewResponseModifier(w) defer func() { if _, err := rm.FlushRelease(); err != nil { logError(err, r) @@ -157,7 +158,7 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { err := defaultRule.Handle(w, r) if err != nil { if !errors.Is(err, errTerminated) { - rm.AppendError(defaultRule, err) + appendRuleError(rm, &defaultRule, err) } shouldCallUpstream = false } @@ -174,7 +175,7 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { err := rule.Handle(w, r) if err != nil { if !errors.Is(err, errTerminated) { - rm.AppendError(rule, err) + appendRuleError(rm, &rule, err) } shouldCallUpstream = false break @@ -190,7 +191,7 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { err := defaultRule.Handle(w, r) if err != nil { if !errors.Is(err, errTerminated) { - rm.AppendError(defaultRule, err) + appendRuleError(rm, &defaultRule, err) return } shouldCallUpstream = false @@ -212,7 +213,7 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { err := rule.Handle(w, r) if err != nil { if !errors.Is(err, errTerminated) { - rm.AppendError(rule, err) + appendRuleError(rm, &rule, err) } return } @@ -222,12 +223,16 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { if isDefaultRulePost { err := defaultRule.Handle(w, r) if err != nil && !errors.Is(err, errTerminated) { - rm.AppendError(defaultRule, err) + appendRuleError(rm, &defaultRule, err) } } } } +func appendRuleError(rm *httputils.ResponseModifier, rule *Rule, err error) { + rm.AppendError("rule: %s, error: %w", rule.Name, err) +} + func isTerminatingHandler(handler CommandHandler) bool { switch h := handler.(type) { case TerminatingCommand: diff --git a/internal/route/rules/template.go b/internal/route/rules/template.go index 42c6f2dd..f9ee576a 100644 --- a/internal/route/rules/template.go +++ b/internal/route/rules/template.go @@ -5,6 +5,8 @@ import ( "net/http" "strings" "unsafe" + + httputils "github.com/yusing/goutils/http" ) type templateString struct { @@ -27,7 +29,7 @@ func (tmpl *templateString) ExpandVars(w http.ResponseWriter, req *http.Request, return err } - return ExpandVars(GetInitResponseModifier(w), req, tmpl.string, dstW) + return ExpandVars(httputils.GetInitResponseModifier(w), req, tmpl.string, dstW) } func (tmpl *templateString) ExpandVarsToString(w http.ResponseWriter, req *http.Request) (string, error) { @@ -36,7 +38,7 @@ func (tmpl *templateString) ExpandVarsToString(w http.ResponseWriter, req *http. } var buf strings.Builder - err := ExpandVars(GetInitResponseModifier(w), req, tmpl.string, &buf) + err := ExpandVars(httputils.GetInitResponseModifier(w), req, tmpl.string, &buf) if err != nil { return "", err } diff --git a/internal/route/rules/var_bench_test.go b/internal/route/rules/var_bench_test.go index bde8cf52..74cef1d3 100644 --- a/internal/route/rules/var_bench_test.go +++ b/internal/route/rules/var_bench_test.go @@ -5,10 +5,12 @@ import ( "net/http/httptest" "net/url" "testing" + + httputils "github.com/yusing/goutils/http" ) func BenchmarkExpandVars(b *testing.B) { - testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) testResponseModifier.WriteHeader(200) testResponseModifier.Write([]byte("Hello, world!")) testRequest := httptest.NewRequest("GET", "/", nil) diff --git a/internal/route/rules/vars.go b/internal/route/rules/vars.go index 46cd50f1..4d29eab4 100644 --- a/internal/route/rules/vars.go +++ b/internal/route/rules/vars.go @@ -8,6 +8,7 @@ import ( "regexp" "strings" + httputils "github.com/yusing/goutils/http" ioutils "github.com/yusing/goutils/io" ) @@ -15,7 +16,7 @@ import ( type ( reqVarGetter func(*http.Request) string - respVarGetter func(*ResponseModifier) string + respVarGetter func(*httputils.ResponseModifier) string ) var reVar = regexp.MustCompile(`\$[\w_]+`) @@ -36,7 +37,7 @@ func NeedExpandVars(s string) bool { } var ( - voidResponseModifier = NewResponseModifier(httptest.NewRecorder()) + voidResponseModifier = httputils.NewResponseModifier(httptest.NewRecorder()) dummyRequest = http.Request{ Method: "GET", URL: &url.URL{Path: "/"}, @@ -50,7 +51,7 @@ func ValidateVars(s string) error { return ExpandVars(voidResponseModifier, &dummyRequest, s, io.Discard) } -func ExpandVars(w *ResponseModifier, req *http.Request, src string, dstW io.Writer) error { +func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, dstW io.Writer) error { dst := ioutils.NewBufferedWriter(dstW, 1024) defer dst.Close() diff --git a/internal/route/rules/vars_dynamic.go b/internal/route/rules/vars_dynamic.go index 29208b28..717064bf 100644 --- a/internal/route/rules/vars_dynamic.go +++ b/internal/route/rules/vars_dynamic.go @@ -4,6 +4,8 @@ import ( "net/http" "net/url" "strconv" + + httputils "github.com/yusing/goutils/http" ) var ( @@ -14,31 +16,31 @@ var ( VarPostForm = "postform" ) -type dynamicVarGetter func(args []string, w *ResponseModifier, req *http.Request) (string, error) +type dynamicVarGetter func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) var dynamicVarSubsMap = map[string]dynamicVarGetter{ - VarHeader: func(args []string, w *ResponseModifier, req *http.Request) (string, error) { + VarHeader: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) { key, index, err := getKeyAndIndex(args) if err != nil { return "", err } return getValueByKeyAtIndex(req.Header, key, index) }, - VarResponseHeader: func(args []string, w *ResponseModifier, req *http.Request) (string, error) { + VarResponseHeader: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) { key, index, err := getKeyAndIndex(args) if err != nil { return "", err } return getValueByKeyAtIndex(w.Header(), key, index) }, - VarQuery: func(args []string, w *ResponseModifier, req *http.Request) (string, error) { + VarQuery: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) { key, index, err := getKeyAndIndex(args) if err != nil { return "", err } - return getValueByKeyAtIndex(GetSharedData(w).GetQueries(req), key, index) + return getValueByKeyAtIndex(httputils.GetSharedData(w).GetQueries(req), key, index) }, - VarForm: func(args []string, w *ResponseModifier, req *http.Request) (string, error) { + VarForm: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) { key, index, err := getKeyAndIndex(args) if err != nil { return "", err @@ -50,7 +52,7 @@ var dynamicVarSubsMap = map[string]dynamicVarGetter{ } return getValueByKeyAtIndex(req.Form, key, index) }, - VarPostForm: func(args []string, w *ResponseModifier, req *http.Request) (string, error) { + VarPostForm: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) { key, index, err := getKeyAndIndex(args) if err != nil { return "", err diff --git a/internal/route/rules/vars_static.go b/internal/route/rules/vars_static.go index ca52c1ee..cf27fd24 100644 --- a/internal/route/rules/vars_static.go +++ b/internal/route/rules/vars_static.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/yusing/godoxy/internal/route/routes" + httputils "github.com/yusing/goutils/http" ) const ( @@ -87,9 +88,9 @@ var staticReqVarSubsMap = map[string]reqVarGetter{ } var staticRespVarSubsMap = map[string]respVarGetter{ - VarRespContentType: func(resp *ResponseModifier) string { return resp.Header().Get("Content-Type") }, - VarRespContentLen: func(resp *ResponseModifier) string { return resp.ContentLengthStr() }, - VarRespStatusCode: func(resp *ResponseModifier) string { return strconv.Itoa(resp.StatusCode()) }, + VarRespContentType: func(resp *httputils.ResponseModifier) string { return resp.Header().Get("Content-Type") }, + VarRespContentLen: func(resp *httputils.ResponseModifier) string { return resp.ContentLengthStr() }, + VarRespStatusCode: func(resp *httputils.ResponseModifier) string { return strconv.Itoa(resp.StatusCode()) }, } func stripFragment(s string) string { diff --git a/internal/route/rules/vars_test.go b/internal/route/rules/vars_test.go index 919a691f..3a090751 100644 --- a/internal/route/rules/vars_test.go +++ b/internal/route/rules/vars_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/stretchr/testify/require" + httputils "github.com/yusing/goutils/http" ) func TestExtractArgs(t *testing.T) { @@ -214,7 +215,7 @@ func TestExpandVars(t *testing.T) { testRequest.PostForm = postFormData // Create response modifier with headers - testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) testResponseModifier.Header().Set("Content-Type", "text/html") testResponseModifier.Header().Set("X-Custom-Resp", "resp-value") testResponseModifier.WriteHeader(200) @@ -483,7 +484,7 @@ func TestExpandVars(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var out strings.Builder - err := ExpandVars(testResponseModifier, testRequest, tt.input, &out) + err := ExpandVars(httputils.NewResponseModifier(httptest.NewRecorder()), testRequest, tt.input, &out) if tt.wantErr { require.Error(t, err) @@ -501,11 +502,11 @@ func TestExpandVars_Integration(t *testing.T) { testRequest.Header.Set("User-Agent", "curl/7.68.0") testRequest.RemoteAddr = "10.0.0.1:54321" - testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) testResponseModifier.WriteHeader(200) var out strings.Builder - err := ExpandVars(testResponseModifier, testRequest, + err := ExpandVars(httputils.NewResponseModifier(httptest.NewRecorder()), testRequest, "$req_method $req_url $status_code User-Agent=$header(User-Agent)", &out) @@ -516,7 +517,7 @@ func TestExpandVars_Integration(t *testing.T) { t.Run("with query parameters", func(t *testing.T) { testRequest := httptest.NewRequest("GET", "http://example.com/search?q=test&page=1", nil) - testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) var out strings.Builder err := ExpandVars(testResponseModifier, testRequest, @@ -530,13 +531,13 @@ func TestExpandVars_Integration(t *testing.T) { t.Run("response headers", func(t *testing.T) { testRequest := httptest.NewRequest("GET", "/", nil) - testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) testResponseModifier.Header().Set("Cache-Control", "no-cache") testResponseModifier.Header().Set("X-Rate-Limit", "100") testResponseModifier.WriteHeader(200) var out strings.Builder - err := ExpandVars(testResponseModifier, testRequest, + err := ExpandVars(httputils.NewResponseModifier(httptest.NewRecorder()), testRequest, "Status: $status_code, Cache: $resp_header(Cache-Control), Limit: $resp_header(X-Rate-Limit)", &out) @@ -569,7 +570,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) var out strings.Builder err := ExpandVars(testResponseModifier, tt.request, "$req_scheme", &out) require.NoError(t, err) @@ -582,7 +583,7 @@ func TestExpandVars_UpstreamVariables(t *testing.T) { // Upstream variables require context from routes package testRequest := httptest.NewRequest("GET", "/", nil) - testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) // Test that upstream variables don't cause errors even when not set upstreamVars := []string{ @@ -609,7 +610,7 @@ func TestExpandVars_NoHostPort(t *testing.T) { testRequest := httptest.NewRequest("GET", "/", nil) testRequest.Host = "example.com" // No port - testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) t.Run("req_host without port", func(t *testing.T) { var out strings.Builder @@ -631,7 +632,7 @@ func TestExpandVars_NoRemotePort(t *testing.T) { testRequest := httptest.NewRequest("GET", "/", nil) testRequest.RemoteAddr = "192.168.1.1" // No port - testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) t.Run("remote_host without port", func(t *testing.T) { var out strings.Builder @@ -650,7 +651,7 @@ func TestExpandVars_NoRemotePort(t *testing.T) { func TestExpandVars_WhitespaceHandling(t *testing.T) { testRequest := httptest.NewRequest("GET", "/test", nil) - testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) var out strings.Builder err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path", &out)