From 44536139c14408bd10354ad56ccd67ade37fa8d0 Mon Sep 17 00:00:00 2001 From: yusing Date: Wed, 15 Oct 2025 23:53:26 +0800 Subject: [PATCH] refactor: refine byte pools usage and fix memory leak in rules --- internal/logging/accesslog/access_logger.go | 4 +- internal/logging/accesslog/rotate.go | 7 ++- internal/route/rules/do.go | 13 ++-- internal/route/rules/do_set.go | 8 +-- internal/route/rules/response_modifier.go | 67 ++++++++++++++------- internal/route/rules/template.go | 5 ++ internal/route/rules/validate.go | 11 +++- 7 files changed, 73 insertions(+), 42 deletions(-) diff --git a/internal/logging/accesslog/access_logger.go b/internal/logging/accesslog/access_logger.go index d2627097..9cc4fcac 100644 --- a/internal/logging/accesslog/access_logger.go +++ b/internal/logging/accesslog/access_logger.go @@ -157,12 +157,12 @@ func (l *AccessLogger) Log(req *http.Request, res *http.Response) { } line := lineBufPool.Get() - defer lineBufPool.Put(line) line = l.AppendRequestLog(line, req, res) if line[len(line)-1] != '\n' { line = append(line, '\n') } l.write(line) + lineBufPool.Put(line) } func (l *AccessLogger) LogError(req *http.Request, err error) { @@ -171,12 +171,12 @@ func (l *AccessLogger) LogError(req *http.Request, err error) { func (l *AccessLogger) LogACL(info *maxmind.IPInfo, blocked bool) { line := lineBufPool.Get() - defer lineBufPool.Put(line) line = l.AppendACLLog(line, info, blocked) if line[len(line)-1] != '\n' { line = append(line, '\n') } l.write(line) + lineBufPool.Put(line) } func (l *AccessLogger) ShouldRotate() bool { diff --git a/internal/logging/accesslog/rotate.go b/internal/logging/accesslog/rotate.go index 0bf220e5..b53d33b7 100644 --- a/internal/logging/accesslog/rotate.go +++ b/internal/logging/accesslog/rotate.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "io" + "slices" "time" "github.com/rs/zerolog" @@ -167,14 +168,16 @@ func rotateLogFileByPolicy(file supportRotate, config *Retention, result *Rotate // Read each line and write it to the beginning of the file writePos := int64(0) buf := rotateBytePool.Get() - defer rotateBytePool.Put(buf) + defer func() { + rotateBytePool.Put(buf) + }() // in reverse order to keep the order of the lines (from old to new) for i := len(linesToKeep) - 1; i >= 0; i-- { line := linesToKeep[i] n := line.Size if cap(buf) < int(n) { - buf = make([]byte, n) + buf = slices.Grow(buf, int(n)-cap(buf)) } buf = buf[:n] diff --git a/internal/route/rules/do.go b/internal/route/rules/do.go index c86b07e5..4ed7f19c 100644 --- a/internal/route/rules/do.go +++ b/internal/route/rules/do.go @@ -19,7 +19,6 @@ import ( gperr "github.com/yusing/goutils/errs" httputils "github.com/yusing/goutils/http" "github.com/yusing/goutils/http/reverseproxy" - "github.com/yusing/goutils/synk" ) type ( @@ -431,10 +430,7 @@ var commands = map[string]struct { to := []string{provider} return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error { - buf := bufPool.Get() - defer bufPool.Put(buf) - - respBuf := bytes.NewBuffer(buf) + respBuf := bytes.NewBuffer(make([]byte, 0, titleTmpl.Len()+bodyTmpl.Len())) err := executeReqRespTemplateTo(titleTmpl, respBuf, w, r) if err != nil { @@ -446,10 +442,11 @@ var commands = map[string]struct { return err } + b := respBuf.Bytes() notif.Notify(¬if.LogMessage{ Level: level, - Title: string(buf[:titleLen]), - Body: notif.MessageBodyBytes(buf[titleLen:]), + Title: string(b[:titleLen]), + Body: notif.MessageBodyBytes(b[titleLen:]), To: to, }) return nil @@ -466,8 +463,6 @@ type reqResponseTemplateData struct { } } -var bufPool = synk.GetBytesPoolWithUniqueMemory() - type onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateOrStr] type onNotifyArgs = Tuple4[zerolog.Level, string, templateOrStr, templateOrStr] diff --git a/internal/route/rules/do_set.go b/internal/route/rules/do_set.go index b64e83ff..edab827f 100644 --- a/internal/route/rules/do_set.go +++ b/internal/route/rules/do_set.go @@ -1,7 +1,6 @@ package rules import ( - "bytes" "io" "net/http" "net/url" @@ -243,15 +242,14 @@ var modFields = map[string]struct { r.Body = nil } - buf := pool.Get() - b := bytes.NewBuffer(buf) - + bufPool := GetInitResponseModifier(w).BufPool() + b := bufPool.GetBuffer() err := executeRequestTemplateTo(tmpl, b, r) if err != nil { return err } r.Body = ioutils.NewHookReadCloser(io.NopCloser(b), func() { - pool.Put(buf) + bufPool.PutBuffer(b) }) return nil }), diff --git a/internal/route/rules/response_modifier.go b/internal/route/rules/response_modifier.go index 0e70d7fa..773615a7 100644 --- a/internal/route/rules/response_modifier.go +++ b/internal/route/rules/response_modifier.go @@ -13,8 +13,9 @@ import ( ) type ResponseModifier struct { + bufPool *synk.BytesPoolWithMemory + w http.ResponseWriter - b []byte // the bytes got from pool buf *bytes.Buffer statusCode int shared Cache @@ -29,8 +30,6 @@ type Response struct { Header http.Header } -var pool = synk.GetBytesPoolWithUniqueMemory() - func unwrapResponseModifier(w http.ResponseWriter) *ResponseModifier { for { switch ww := w.(type) { @@ -68,14 +67,16 @@ func GetSharedData(w http.ResponseWriter) Cache { // // It should only be called once, at the very beginning of the request. func NewResponseModifier(w http.ResponseWriter) *ResponseModifier { - b := pool.Get() return &ResponseModifier{ - w: w, - buf: bytes.NewBuffer(b), - b: b, + bufPool: synk.GetBytesPoolWithUniqueMemory(), + w: w, } } +func (rm *ResponseModifier) BufPool() *synk.BytesPoolWithMemory { + return rm.bufPool +} + // func (rm *ResponseModifier) Unwrap() http.ResponseWriter { // return rm.w // } @@ -85,13 +86,26 @@ func (rm *ResponseModifier) WriteHeader(code int) { } func (rm *ResponseModifier) ResetBody() { + if rm.buf == nil { + return + } rm.buf.Reset() } func (rm *ResponseModifier) ContentLength() int { + if rm.buf == nil { + return 0 + } return rm.buf.Len() } +func (rm *ResponseModifier) Content() []byte { + if rm.buf == nil { + return nil + } + return rm.buf.Bytes() +} + func (rm *ResponseModifier) StatusCode() int { if rm.statusCode == 0 { return http.StatusOK @@ -108,6 +122,9 @@ func (rm *ResponseModifier) Response() Response { } func (rm *ResponseModifier) Write(b []byte) (int, error) { + if rm.buf == nil { + rm.buf = rm.bufPool.GetBuffer() + } return rm.buf.Write(b) } @@ -139,29 +156,33 @@ func (rm *ResponseModifier) FlushRelease() (int, error) { // h.Del(k) // } // } - h.Set("Content-Length", strconv.Itoa(rm.buf.Len())) + contentLength := rm.ContentLength() + h.Set("Content-Length", strconv.Itoa(rm.ContentLength())) rm.w.WriteHeader(rm.StatusCode()) - nn, werr := rm.w.Write(rm.buf.Bytes()) - n += nn - if werr != nil { - rm.errs.Addf("write error: %w", werr) - } - // flush the response writer - if flusher, ok := rm.w.(http.Flusher); ok { - flusher.Flush() - } else if errFlusher, ok := rm.w.(interface{ Flush() error }); ok { - ferr := errFlusher.Flush() - if ferr != nil { - rm.errs.Addf("flush error: %w", ferr) + if contentLength > 0 { + nn, werr := rm.w.Write(rm.Content()) + n += nn + if werr != nil { + rm.errs.Addf("write error: %w", werr) + } + // flush the response writer + if flusher, ok := rm.w.(http.Flusher); ok { + flusher.Flush() + } else if errFlusher, ok := rm.w.(interface{ Flush() error }); ok { + ferr := errFlusher.Flush() + if ferr != nil { + rm.errs.Addf("flush error: %w", ferr) + } } } } // release the buffer and reset the pointers - pool.Put(rm.b) - rm.b = nil - rm.buf = nil + if rm.buf != nil { + rm.bufPool.PutBuffer(rm.buf) + rm.buf = nil + } // release the shared data if rm.shared != nil { diff --git a/internal/route/rules/template.go b/internal/route/rules/template.go index 538fb5a3..7415c863 100644 --- a/internal/route/rules/template.go +++ b/internal/route/rules/template.go @@ -8,6 +8,7 @@ import ( type templateOrStr interface { Execute(w io.Writer, data any) error + Len() int } type strTemplate string @@ -23,6 +24,10 @@ func (t strTemplate) Execute(w io.Writer, _ any) error { return nil } +func (t strTemplate) Len() int { + return len(t) +} + type keyValueTemplate = Tuple[string, templateOrStr] func executeRequestTemplateString(tmpl templateOrStr, r *http.Request) (string, error) { diff --git a/internal/route/rules/validate.go b/internal/route/rules/validate.go index 19fd3c88..bed8f935 100644 --- a/internal/route/rules/validate.go +++ b/internal/route/rules/validate.go @@ -304,6 +304,15 @@ func isTemplate(tmplStr string) bool { return strings.Contains(tmplStr, "{{") } +type templateWithLen struct { + *template.Template + len int +} + +func (t *templateWithLen) Len() int { + return t.len +} + func validateTemplate(tmplStr string, newline bool) (templateOrStr, gperr.Error) { if newline && !strings.HasSuffix(tmplStr, "\n") { tmplStr += "\n" @@ -317,7 +326,7 @@ func validateTemplate(tmplStr string, newline bool) (templateOrStr, gperr.Error) if err != nil { return nil, ErrInvalidArguments.With(err) } - return tmpl, nil + return &templateWithLen{tmpl, len(tmplStr)}, nil } func validateLevel(level string) (zerolog.Level, gperr.Error) {