refactor: refine byte pools usage and fix memory leak in rules

This commit is contained in:
yusing
2025-10-15 23:53:26 +08:00
parent 2b4c39a79e
commit 44536139c1
7 changed files with 73 additions and 42 deletions

View File

@@ -157,12 +157,12 @@ func (l *AccessLogger) Log(req *http.Request, res *http.Response) {
} }
line := lineBufPool.Get() line := lineBufPool.Get()
defer lineBufPool.Put(line)
line = l.AppendRequestLog(line, req, res) line = l.AppendRequestLog(line, req, res)
if line[len(line)-1] != '\n' { if line[len(line)-1] != '\n' {
line = append(line, '\n') line = append(line, '\n')
} }
l.write(line) l.write(line)
lineBufPool.Put(line)
} }
func (l *AccessLogger) LogError(req *http.Request, err error) { func (l *AccessLogger) LogError(req *http.Request, err error) {
@@ -171,12 +171,12 @@ func (l *AccessLogger) LogError(req *http.Request, err error) {
func (l *AccessLogger) LogACL(info *maxmind.IPInfo, blocked bool) { func (l *AccessLogger) LogACL(info *maxmind.IPInfo, blocked bool) {
line := lineBufPool.Get() line := lineBufPool.Get()
defer lineBufPool.Put(line)
line = l.AppendACLLog(line, info, blocked) line = l.AppendACLLog(line, info, blocked)
if line[len(line)-1] != '\n' { if line[len(line)-1] != '\n' {
line = append(line, '\n') line = append(line, '\n')
} }
l.write(line) l.write(line)
lineBufPool.Put(line)
} }
func (l *AccessLogger) ShouldRotate() bool { func (l *AccessLogger) ShouldRotate() bool {

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"errors" "errors"
"io" "io"
"slices"
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@@ -167,14 +168,16 @@ func rotateLogFileByPolicy(file supportRotate, config *Retention, result *Rotate
// Read each line and write it to the beginning of the file // Read each line and write it to the beginning of the file
writePos := int64(0) writePos := int64(0)
buf := rotateBytePool.Get() buf := rotateBytePool.Get()
defer rotateBytePool.Put(buf) defer func() {
rotateBytePool.Put(buf)
}()
// in reverse order to keep the order of the lines (from old to new) // in reverse order to keep the order of the lines (from old to new)
for i := len(linesToKeep) - 1; i >= 0; i-- { for i := len(linesToKeep) - 1; i >= 0; i-- {
line := linesToKeep[i] line := linesToKeep[i]
n := line.Size n := line.Size
if cap(buf) < int(n) { if cap(buf) < int(n) {
buf = make([]byte, n) buf = slices.Grow(buf, int(n)-cap(buf))
} }
buf = buf[:n] buf = buf[:n]

View File

@@ -19,7 +19,6 @@ import (
gperr "github.com/yusing/goutils/errs" gperr "github.com/yusing/goutils/errs"
httputils "github.com/yusing/goutils/http" httputils "github.com/yusing/goutils/http"
"github.com/yusing/goutils/http/reverseproxy" "github.com/yusing/goutils/http/reverseproxy"
"github.com/yusing/goutils/synk"
) )
type ( type (
@@ -431,10 +430,7 @@ var commands = map[string]struct {
to := []string{provider} to := []string{provider}
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error { return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
buf := bufPool.Get() respBuf := bytes.NewBuffer(make([]byte, 0, titleTmpl.Len()+bodyTmpl.Len()))
defer bufPool.Put(buf)
respBuf := bytes.NewBuffer(buf)
err := executeReqRespTemplateTo(titleTmpl, respBuf, w, r) err := executeReqRespTemplateTo(titleTmpl, respBuf, w, r)
if err != nil { if err != nil {
@@ -446,10 +442,11 @@ var commands = map[string]struct {
return err return err
} }
b := respBuf.Bytes()
notif.Notify(&notif.LogMessage{ notif.Notify(&notif.LogMessage{
Level: level, Level: level,
Title: string(buf[:titleLen]), Title: string(b[:titleLen]),
Body: notif.MessageBodyBytes(buf[titleLen:]), Body: notif.MessageBodyBytes(b[titleLen:]),
To: to, To: to,
}) })
return nil return nil
@@ -466,8 +463,6 @@ type reqResponseTemplateData struct {
} }
} }
var bufPool = synk.GetBytesPoolWithUniqueMemory()
type onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateOrStr] type onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateOrStr]
type onNotifyArgs = Tuple4[zerolog.Level, string, templateOrStr, templateOrStr] type onNotifyArgs = Tuple4[zerolog.Level, string, templateOrStr, templateOrStr]

View File

@@ -1,7 +1,6 @@
package rules package rules
import ( import (
"bytes"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
@@ -243,15 +242,14 @@ var modFields = map[string]struct {
r.Body = nil r.Body = nil
} }
buf := pool.Get() bufPool := GetInitResponseModifier(w).BufPool()
b := bytes.NewBuffer(buf) b := bufPool.GetBuffer()
err := executeRequestTemplateTo(tmpl, b, r) err := executeRequestTemplateTo(tmpl, b, r)
if err != nil { if err != nil {
return err return err
} }
r.Body = ioutils.NewHookReadCloser(io.NopCloser(b), func() { r.Body = ioutils.NewHookReadCloser(io.NopCloser(b), func() {
pool.Put(buf) bufPool.PutBuffer(b)
}) })
return nil return nil
}), }),

View File

@@ -13,8 +13,9 @@ import (
) )
type ResponseModifier struct { type ResponseModifier struct {
bufPool *synk.BytesPoolWithMemory
w http.ResponseWriter w http.ResponseWriter
b []byte // the bytes got from pool
buf *bytes.Buffer buf *bytes.Buffer
statusCode int statusCode int
shared Cache shared Cache
@@ -29,8 +30,6 @@ type Response struct {
Header http.Header Header http.Header
} }
var pool = synk.GetBytesPoolWithUniqueMemory()
func unwrapResponseModifier(w http.ResponseWriter) *ResponseModifier { func unwrapResponseModifier(w http.ResponseWriter) *ResponseModifier {
for { for {
switch ww := w.(type) { switch ww := w.(type) {
@@ -68,14 +67,16 @@ func GetSharedData(w http.ResponseWriter) Cache {
// //
// It should only be called once, at the very beginning of the request. // It should only be called once, at the very beginning of the request.
func NewResponseModifier(w http.ResponseWriter) *ResponseModifier { func NewResponseModifier(w http.ResponseWriter) *ResponseModifier {
b := pool.Get()
return &ResponseModifier{ return &ResponseModifier{
w: w, bufPool: synk.GetBytesPoolWithUniqueMemory(),
buf: bytes.NewBuffer(b), w: w,
b: b,
} }
} }
func (rm *ResponseModifier) BufPool() *synk.BytesPoolWithMemory {
return rm.bufPool
}
// func (rm *ResponseModifier) Unwrap() http.ResponseWriter { // func (rm *ResponseModifier) Unwrap() http.ResponseWriter {
// return rm.w // return rm.w
// } // }
@@ -85,13 +86,26 @@ func (rm *ResponseModifier) WriteHeader(code int) {
} }
func (rm *ResponseModifier) ResetBody() { func (rm *ResponseModifier) ResetBody() {
if rm.buf == nil {
return
}
rm.buf.Reset() rm.buf.Reset()
} }
func (rm *ResponseModifier) ContentLength() int { func (rm *ResponseModifier) ContentLength() int {
if rm.buf == nil {
return 0
}
return rm.buf.Len() return rm.buf.Len()
} }
func (rm *ResponseModifier) Content() []byte {
if rm.buf == nil {
return nil
}
return rm.buf.Bytes()
}
func (rm *ResponseModifier) StatusCode() int { func (rm *ResponseModifier) StatusCode() int {
if rm.statusCode == 0 { if rm.statusCode == 0 {
return http.StatusOK return http.StatusOK
@@ -108,6 +122,9 @@ func (rm *ResponseModifier) Response() Response {
} }
func (rm *ResponseModifier) Write(b []byte) (int, error) { func (rm *ResponseModifier) Write(b []byte) (int, error) {
if rm.buf == nil {
rm.buf = rm.bufPool.GetBuffer()
}
return rm.buf.Write(b) return rm.buf.Write(b)
} }
@@ -139,29 +156,33 @@ func (rm *ResponseModifier) FlushRelease() (int, error) {
// h.Del(k) // h.Del(k)
// } // }
// } // }
h.Set("Content-Length", strconv.Itoa(rm.buf.Len())) contentLength := rm.ContentLength()
h.Set("Content-Length", strconv.Itoa(rm.ContentLength()))
rm.w.WriteHeader(rm.StatusCode()) rm.w.WriteHeader(rm.StatusCode())
nn, werr := rm.w.Write(rm.buf.Bytes())
n += nn
if werr != nil {
rm.errs.Addf("write error: %w", werr)
}
// flush the response writer if contentLength > 0 {
if flusher, ok := rm.w.(http.Flusher); ok { nn, werr := rm.w.Write(rm.Content())
flusher.Flush() n += nn
} else if errFlusher, ok := rm.w.(interface{ Flush() error }); ok { if werr != nil {
ferr := errFlusher.Flush() rm.errs.Addf("write error: %w", werr)
if ferr != nil { }
rm.errs.Addf("flush error: %w", ferr) // flush the response writer
if flusher, ok := rm.w.(http.Flusher); ok {
flusher.Flush()
} else if errFlusher, ok := rm.w.(interface{ Flush() error }); ok {
ferr := errFlusher.Flush()
if ferr != nil {
rm.errs.Addf("flush error: %w", ferr)
}
} }
} }
} }
// release the buffer and reset the pointers // release the buffer and reset the pointers
pool.Put(rm.b) if rm.buf != nil {
rm.b = nil rm.bufPool.PutBuffer(rm.buf)
rm.buf = nil rm.buf = nil
}
// release the shared data // release the shared data
if rm.shared != nil { if rm.shared != nil {

View File

@@ -8,6 +8,7 @@ import (
type templateOrStr interface { type templateOrStr interface {
Execute(w io.Writer, data any) error Execute(w io.Writer, data any) error
Len() int
} }
type strTemplate string type strTemplate string
@@ -23,6 +24,10 @@ func (t strTemplate) Execute(w io.Writer, _ any) error {
return nil return nil
} }
func (t strTemplate) Len() int {
return len(t)
}
type keyValueTemplate = Tuple[string, templateOrStr] type keyValueTemplate = Tuple[string, templateOrStr]
func executeRequestTemplateString(tmpl templateOrStr, r *http.Request) (string, error) { func executeRequestTemplateString(tmpl templateOrStr, r *http.Request) (string, error) {

View File

@@ -304,6 +304,15 @@ func isTemplate(tmplStr string) bool {
return strings.Contains(tmplStr, "{{") return strings.Contains(tmplStr, "{{")
} }
type templateWithLen struct {
*template.Template
len int
}
func (t *templateWithLen) Len() int {
return t.len
}
func validateTemplate(tmplStr string, newline bool) (templateOrStr, gperr.Error) { func validateTemplate(tmplStr string, newline bool) (templateOrStr, gperr.Error) {
if newline && !strings.HasSuffix(tmplStr, "\n") { if newline && !strings.HasSuffix(tmplStr, "\n") {
tmplStr += "\n" tmplStr += "\n"
@@ -317,7 +326,7 @@ func validateTemplate(tmplStr string, newline bool) (templateOrStr, gperr.Error)
if err != nil { if err != nil {
return nil, ErrInvalidArguments.With(err) return nil, ErrInvalidArguments.With(err)
} }
return tmpl, nil return &templateWithLen{tmpl, len(tmplStr)}, nil
} }
func validateLevel(level string) (zerolog.Level, gperr.Error) { func validateLevel(level string) (zerolog.Level, gperr.Error) {