refactor: refine byte pools usage and fix memory leak in rules

This commit is contained in:
yusing
2025-10-15 23:53:26 +08:00
parent 2b4c39a79e
commit 44536139c1
7 changed files with 73 additions and 42 deletions

View File

@@ -157,12 +157,12 @@ func (l *AccessLogger) Log(req *http.Request, res *http.Response) {
}
line := lineBufPool.Get()
defer lineBufPool.Put(line)
line = l.AppendRequestLog(line, req, res)
if line[len(line)-1] != '\n' {
line = append(line, '\n')
}
l.write(line)
lineBufPool.Put(line)
}
func (l *AccessLogger) LogError(req *http.Request, err error) {
@@ -171,12 +171,12 @@ func (l *AccessLogger) LogError(req *http.Request, err error) {
func (l *AccessLogger) LogACL(info *maxmind.IPInfo, blocked bool) {
line := lineBufPool.Get()
defer lineBufPool.Put(line)
line = l.AppendACLLog(line, info, blocked)
if line[len(line)-1] != '\n' {
line = append(line, '\n')
}
l.write(line)
lineBufPool.Put(line)
}
func (l *AccessLogger) ShouldRotate() bool {

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"errors"
"io"
"slices"
"time"
"github.com/rs/zerolog"
@@ -167,14 +168,16 @@ func rotateLogFileByPolicy(file supportRotate, config *Retention, result *Rotate
// Read each line and write it to the beginning of the file
writePos := int64(0)
buf := rotateBytePool.Get()
defer rotateBytePool.Put(buf)
defer func() {
rotateBytePool.Put(buf)
}()
// in reverse order to keep the order of the lines (from old to new)
for i := len(linesToKeep) - 1; i >= 0; i-- {
line := linesToKeep[i]
n := line.Size
if cap(buf) < int(n) {
buf = make([]byte, n)
buf = slices.Grow(buf, int(n)-cap(buf))
}
buf = buf[:n]

View File

@@ -19,7 +19,6 @@ import (
gperr "github.com/yusing/goutils/errs"
httputils "github.com/yusing/goutils/http"
"github.com/yusing/goutils/http/reverseproxy"
"github.com/yusing/goutils/synk"
)
type (
@@ -431,10 +430,7 @@ var commands = map[string]struct {
to := []string{provider}
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
buf := bufPool.Get()
defer bufPool.Put(buf)
respBuf := bytes.NewBuffer(buf)
respBuf := bytes.NewBuffer(make([]byte, 0, titleTmpl.Len()+bodyTmpl.Len()))
err := executeReqRespTemplateTo(titleTmpl, respBuf, w, r)
if err != nil {
@@ -446,10 +442,11 @@ var commands = map[string]struct {
return err
}
b := respBuf.Bytes()
notif.Notify(&notif.LogMessage{
Level: level,
Title: string(buf[:titleLen]),
Body: notif.MessageBodyBytes(buf[titleLen:]),
Title: string(b[:titleLen]),
Body: notif.MessageBodyBytes(b[titleLen:]),
To: to,
})
return nil
@@ -466,8 +463,6 @@ type reqResponseTemplateData struct {
}
}
var bufPool = synk.GetBytesPoolWithUniqueMemory()
type onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateOrStr]
type onNotifyArgs = Tuple4[zerolog.Level, string, templateOrStr, templateOrStr]

View File

@@ -1,7 +1,6 @@
package rules
import (
"bytes"
"io"
"net/http"
"net/url"
@@ -243,15 +242,14 @@ var modFields = map[string]struct {
r.Body = nil
}
buf := pool.Get()
b := bytes.NewBuffer(buf)
bufPool := GetInitResponseModifier(w).BufPool()
b := bufPool.GetBuffer()
err := executeRequestTemplateTo(tmpl, b, r)
if err != nil {
return err
}
r.Body = ioutils.NewHookReadCloser(io.NopCloser(b), func() {
pool.Put(buf)
bufPool.PutBuffer(b)
})
return nil
}),

View File

@@ -13,8 +13,9 @@ import (
)
type ResponseModifier struct {
bufPool *synk.BytesPoolWithMemory
w http.ResponseWriter
b []byte // the bytes got from pool
buf *bytes.Buffer
statusCode int
shared Cache
@@ -29,8 +30,6 @@ type Response struct {
Header http.Header
}
var pool = synk.GetBytesPoolWithUniqueMemory()
func unwrapResponseModifier(w http.ResponseWriter) *ResponseModifier {
for {
switch ww := w.(type) {
@@ -68,14 +67,16 @@ func GetSharedData(w http.ResponseWriter) Cache {
//
// It should only be called once, at the very beginning of the request.
func NewResponseModifier(w http.ResponseWriter) *ResponseModifier {
b := pool.Get()
return &ResponseModifier{
w: w,
buf: bytes.NewBuffer(b),
b: b,
bufPool: synk.GetBytesPoolWithUniqueMemory(),
w: w,
}
}
func (rm *ResponseModifier) BufPool() *synk.BytesPoolWithMemory {
return rm.bufPool
}
// func (rm *ResponseModifier) Unwrap() http.ResponseWriter {
// return rm.w
// }
@@ -85,13 +86,26 @@ func (rm *ResponseModifier) WriteHeader(code int) {
}
func (rm *ResponseModifier) ResetBody() {
if rm.buf == nil {
return
}
rm.buf.Reset()
}
func (rm *ResponseModifier) ContentLength() int {
if rm.buf == nil {
return 0
}
return rm.buf.Len()
}
func (rm *ResponseModifier) Content() []byte {
if rm.buf == nil {
return nil
}
return rm.buf.Bytes()
}
func (rm *ResponseModifier) StatusCode() int {
if rm.statusCode == 0 {
return http.StatusOK
@@ -108,6 +122,9 @@ func (rm *ResponseModifier) Response() Response {
}
func (rm *ResponseModifier) Write(b []byte) (int, error) {
if rm.buf == nil {
rm.buf = rm.bufPool.GetBuffer()
}
return rm.buf.Write(b)
}
@@ -139,29 +156,33 @@ func (rm *ResponseModifier) FlushRelease() (int, error) {
// h.Del(k)
// }
// }
h.Set("Content-Length", strconv.Itoa(rm.buf.Len()))
contentLength := rm.ContentLength()
h.Set("Content-Length", strconv.Itoa(rm.ContentLength()))
rm.w.WriteHeader(rm.StatusCode())
nn, werr := rm.w.Write(rm.buf.Bytes())
n += nn
if werr != nil {
rm.errs.Addf("write error: %w", werr)
}
// flush the response writer
if flusher, ok := rm.w.(http.Flusher); ok {
flusher.Flush()
} else if errFlusher, ok := rm.w.(interface{ Flush() error }); ok {
ferr := errFlusher.Flush()
if ferr != nil {
rm.errs.Addf("flush error: %w", ferr)
if contentLength > 0 {
nn, werr := rm.w.Write(rm.Content())
n += nn
if werr != nil {
rm.errs.Addf("write error: %w", werr)
}
// flush the response writer
if flusher, ok := rm.w.(http.Flusher); ok {
flusher.Flush()
} else if errFlusher, ok := rm.w.(interface{ Flush() error }); ok {
ferr := errFlusher.Flush()
if ferr != nil {
rm.errs.Addf("flush error: %w", ferr)
}
}
}
}
// release the buffer and reset the pointers
pool.Put(rm.b)
rm.b = nil
rm.buf = nil
if rm.buf != nil {
rm.bufPool.PutBuffer(rm.buf)
rm.buf = nil
}
// release the shared data
if rm.shared != nil {

View File

@@ -8,6 +8,7 @@ import (
type templateOrStr interface {
Execute(w io.Writer, data any) error
Len() int
}
type strTemplate string
@@ -23,6 +24,10 @@ func (t strTemplate) Execute(w io.Writer, _ any) error {
return nil
}
func (t strTemplate) Len() int {
return len(t)
}
type keyValueTemplate = Tuple[string, templateOrStr]
func executeRequestTemplateString(tmpl templateOrStr, r *http.Request) (string, error) {

View File

@@ -304,6 +304,15 @@ func isTemplate(tmplStr string) bool {
return strings.Contains(tmplStr, "{{")
}
type templateWithLen struct {
*template.Template
len int
}
func (t *templateWithLen) Len() int {
return t.len
}
func validateTemplate(tmplStr string, newline bool) (templateOrStr, gperr.Error) {
if newline && !strings.HasSuffix(tmplStr, "\n") {
tmplStr += "\n"
@@ -317,7 +326,7 @@ func validateTemplate(tmplStr string, newline bool) (templateOrStr, gperr.Error)
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
return tmpl, nil
return &templateWithLen{tmpl, len(tmplStr)}, nil
}
func validateLevel(level string) (zerolog.Level, gperr.Error) {