fix(access_logger): nil panic when stdout only, improve concurrency safety

This commit is contained in:
yusing
2025-11-01 01:17:55 +08:00
parent 77e486f4fe
commit e670acb4b8
11 changed files with 207 additions and 117 deletions

Submodule goutils updated: 84457ea2e1...8f224d7c42

View File

@@ -7,6 +7,7 @@ import (
"sync/atomic"
"time"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
maxmind "github.com/yusing/godoxy/internal/maxmind/types"
@@ -23,11 +24,9 @@ type (
task *task.Task
cfg *Config
rawWriter io.Writer
closer io.Closer
supportRotate supportRotate
writer *ioutils.BufferedWriter
writeLock sync.Mutex
writer BufferedWriter
supportRotate SupportRotate
writeLock *sync.Mutex
closed bool
writeCount int64
@@ -41,8 +40,9 @@ type (
ACLFormatter
}
WriterWithName interface {
Writer interface {
io.WriteCloser
ShouldBeBuffered() bool
Name() string // file name or path
}
@@ -62,6 +62,8 @@ type (
}
)
var writerLocks = xsync.NewMap[string, *sync.Mutex]()
const (
InitialBufferSize = 4 * kilobyte
MaxBufferSize = 8 * megabyte
@@ -87,10 +89,10 @@ func NewAccessLogger(parent task.Parent, cfg AnyConfig) (*AccessLogger, error) {
}
func NewMockAccessLogger(parent task.Parent, cfg *RequestLoggerConfig) *AccessLogger {
return NewAccessLoggerWithIO(parent, NewMockFile(), cfg)
return NewAccessLoggerWithIO(parent, NewMockFile(true), cfg)
}
func NewAccessLoggerWithIO(parent task.Parent, writer WriterWithName, anyCfg AnyConfig) *AccessLogger {
func NewAccessLoggerWithIO(parent task.Parent, writer Writer, anyCfg AnyConfig) *AccessLogger {
cfg := anyCfg.ToConfig()
if cfg.RotateInterval == 0 {
cfg.RotateInterval = defaultRotateInterval
@@ -99,20 +101,21 @@ func NewAccessLoggerWithIO(parent task.Parent, writer WriterWithName, anyCfg Any
l := &AccessLogger{
task: parent.Subtask("accesslog."+writer.Name(), true),
cfg: cfg,
rawWriter: writer,
bufSize: InitialBufferSize,
errRateLimiter: rate.NewLimiter(rate.Every(errRateLimit), errBurst),
logger: log.With().Str("file", writer.Name()).Logger(),
}
if writer != nil {
l.writeLock, _ = writerLocks.LoadOrStore(writer.Name(), &sync.Mutex{})
if writer.ShouldBeBuffered() {
l.writer = ioutils.NewBufferedWriter(writer, InitialBufferSize)
if supportRotate, ok := writer.(SupportRotate); ok {
l.supportRotate = supportRotate
}
if closer, ok := writer.(io.Closer); ok {
l.closer = closer
}
} else {
l.writer = NewUnbufferedWriter(writer)
}
if supportRotate, ok := writer.(SupportRotate); ok {
l.supportRotate = supportRotate
}
if cfg.req != nil {
@@ -131,9 +134,7 @@ func NewAccessLoggerWithIO(parent task.Parent, writer WriterWithName, anyCfg Any
l.ACLFormatter = ACLLogFormatter{}
}
if l.writer != nil {
go l.start()
} // otherwise stdout only
go l.start()
return l
}
@@ -188,7 +189,7 @@ func (l *AccessLogger) Rotate(result *RotateResult) (rotated bool, err error) {
return false, nil
}
l.writer.Flush()
l.Flush()
l.writeLock.Lock()
defer l.writeLock.Unlock()
@@ -247,12 +248,9 @@ func (l *AccessLogger) Close() error {
if l.closed {
return nil
}
if l.closer != nil {
l.closer.Close()
}
l.writer.Release()
l.writer.Flush()
l.closed = true
return nil
return l.writer.Close()
}
func (l *AccessLogger) Flush() {
@@ -261,29 +259,22 @@ func (l *AccessLogger) Flush() {
if l.closed {
return
}
if err := l.writer.Flush(); err != nil {
l.handleErr(err)
}
l.writer.Flush()
}
func (l *AccessLogger) write(data []byte) {
if l.writer != nil {
l.writeLock.Lock()
defer l.writeLock.Unlock()
if l.closed {
return
}
n, err := l.writer.Write(data)
if err != nil {
l.handleErr(err)
} else if n < len(data) {
l.handleErr(gperr.Errorf("%w, writing %d bytes, only %d written", io.ErrShortWrite, len(data), n))
}
atomic.AddInt64(&l.writeCount, int64(n))
l.writeLock.Lock()
defer l.writeLock.Unlock()
if l.closed {
return
}
if l.cfg.Stdout {
log.Logger.Write(data) // write to stdout immediately
n, err := l.writer.Write(data)
if err != nil {
l.handleErr(err)
} else if n < len(data) {
l.handleErr(gperr.Errorf("%w, writing %d bytes, only %d written", io.ErrShortWrite, len(data), n))
}
atomic.AddInt64(&l.writeCount, int64(n))
}
func (l *AccessLogger) adjustBuffer() {

View File

@@ -61,7 +61,7 @@ func TestBackScanner(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Setup mock file
mockFile := NewMockFile()
mockFile := NewMockFile(false)
_, err := mockFile.Write([]byte(tt.input))
if err != nil {
t.Fatalf("failed to write to mock file: %v", err)
@@ -103,7 +103,7 @@ func TestBackScannerWithVaryingChunkSizes(t *testing.T) {
for _, chunkSize := range chunkSizes {
t.Run(fmt.Sprintf("chunk_size_%d", chunkSize), func(t *testing.T) {
mockFile := NewMockFile()
mockFile := NewMockFile(false)
_, err := mockFile.Write([]byte(input))
if err != nil {
t.Fatalf("failed to write to mock file: %v", err)
@@ -197,7 +197,7 @@ func TestReset(t *testing.T) {
// 100000 log entries.
func BenchmarkBackScanner(b *testing.B) {
mockFile := NewMockFile()
mockFile := NewMockFile(false)
line := logEntry()
for range 100000 {
_, _ = mockFile.Write(line)

View File

@@ -32,7 +32,7 @@ type (
}
AnyConfig interface {
ToConfig() *Config
IO() (WriterWithName, error)
IO() (Writer, error)
}
Format string
@@ -66,8 +66,7 @@ func (cfg *ConfigBase) Validate() gperr.Error {
}
// IO returns a writer for the config.
// If only stdout is enabled, it returns nil, nil.
func (cfg *ConfigBase) IO() (WriterWithName, error) {
func (cfg *ConfigBase) IO() (Writer, error) {
if cfg.Path != "" {
io, err := NewFileIO(cfg.Path)
if err != nil {
@@ -75,7 +74,7 @@ func (cfg *ConfigBase) IO() (WriterWithName, error) {
}
return io, nil
}
return nil, nil
return NewStdout(), nil
}
func (cfg *ACLLoggerConfig) ToConfig() *Config {

View File

@@ -29,12 +29,19 @@ var (
// NewFileIO creates a new file writer with cleaned path.
//
// If the file is already opened, it will be returned.
func NewFileIO(path string) (WriterWithName, error) {
func NewFileIO(path string) (Writer, error) {
openedFilesMu.Lock()
defer openedFilesMu.Unlock()
var file *File
path = filepath.Clean(path)
var err error
// make it absolute path, so that we can use it as key of `openedFiles` and shared lock
path, err = filepath.Abs(path)
if err != nil {
return nil, fmt.Errorf("access log path error: %w", err)
}
if opened, ok := openedFiles[path]; ok {
opened.refCount.Add()
return opened, nil
@@ -54,8 +61,13 @@ func NewFileIO(path string) (WriterWithName, error) {
return file, nil
}
// Name returns the absolute path of the file.
func (f *File) Name() string {
return f.f.Name()
return f.path
}
func (f *File) ShouldBeBuffered() bool {
return true
}
func (f *File) Write(p []byte) (n int, err error) {

View File

@@ -1,89 +1,96 @@
package accesslog
import (
"fmt"
"math/rand/v2"
"net/http"
"os"
"runtime"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/yusing/goutils/task"
expect "github.com/yusing/goutils/testing"
)
func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) {
var wg sync.WaitGroup
cfg := DefaultRequestLoggerConfig()
cfg.Path = "test.log"
loggerCount := 10
accessLogIOs := make([]WriterWithName, loggerCount)
loggerCount := runtime.GOMAXPROCS(0)
accessLogIOs := make([]Writer, loggerCount)
// make test log file
file, err := os.Create(cfg.Path)
expect.NoError(t, err)
assert.NoError(t, err)
file.Close()
t.Cleanup(func() {
expect.NoError(t, os.Remove(cfg.Path))
assert.NoError(t, os.Remove(cfg.Path))
})
var wg sync.WaitGroup
for i := range loggerCount {
wg.Add(1)
go func(index int) {
defer wg.Done()
wg.Go(func() {
file, err := NewFileIO(cfg.Path)
expect.NoError(t, err)
accessLogIOs[index] = file
}(i)
assert.NoError(t, err)
accessLogIOs[i] = file
})
}
wg.Wait()
firstIO := accessLogIOs[0]
for _, io := range accessLogIOs {
expect.Equal(t, io, firstIO)
assert.Equal(t, firstIO, io)
}
}
func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) {
file := NewMockFile()
for _, buffered := range []bool{false, true} {
t.Run(fmt.Sprintf("buffered=%t", buffered), func(t *testing.T) {
file := NewMockFile(buffered)
cfg := DefaultRequestLoggerConfig()
parent := task.RootTask("test", false)
cfg := DefaultRequestLoggerConfig()
parent := task.RootTask("test", false)
loggerCount := 5
logCountPerLogger := 10
loggers := make([]*AccessLogger, loggerCount)
loggerCount := runtime.GOMAXPROCS(0)
logCountPerLogger := 10
loggers := make([]*AccessLogger, loggerCount)
for i := range loggerCount {
loggers[i] = NewAccessLoggerWithIO(parent, file, cfg)
for i := range loggerCount {
loggers[i] = NewAccessLoggerWithIO(parent, file, cfg)
}
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
resp := &http.Response{StatusCode: http.StatusOK}
var wg sync.WaitGroup
for _, logger := range loggers {
wg.Go(func() {
concurrentLog(logger, req, resp, logCountPerLogger)
})
}
wg.Wait()
for _, logger := range loggers {
logger.Close()
}
expected := loggerCount * logCountPerLogger
actual := file.NumLines()
assert.Equal(t, expected, actual)
})
}
var wg sync.WaitGroup
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
resp := &http.Response{StatusCode: http.StatusOK}
wg.Add(len(loggers))
for _, logger := range loggers {
go func(l *AccessLogger) {
defer wg.Done()
parallelLog(l, req, resp, logCountPerLogger)
l.Flush()
}(logger)
}
wg.Wait()
expected := loggerCount * logCountPerLogger
actual := file.NumLines()
expect.Equal(t, actual, expected)
}
func parallelLog(logger *AccessLogger, req *http.Request, resp *http.Response, n int) {
func concurrentLog(logger *AccessLogger, req *http.Request, resp *http.Response, n int) {
var wg sync.WaitGroup
for range n {
wg.Go(func() {
logger.Log(req, resp)
if rand.IntN(2) == 0 {
logger.Flush()
}
})
}
wg.Wait()

View File

@@ -7,26 +7,27 @@ import (
"github.com/spf13/afero"
)
type noLock struct{}
func (noLock) Lock() {}
func (noLock) Unlock() {}
type MockFile struct {
afero.File
noLock
buffered bool
}
var _ SupportRotate = (*MockFile)(nil)
func NewMockFile() *MockFile {
func NewMockFile(buffered bool) *MockFile {
f, _ := afero.TempFile(afero.NewMemMapFs(), "", "")
f.Seek(0, io.SeekEnd)
return &MockFile{
File: f,
File: f,
buffered: buffered,
}
}
func (m *MockFile) ShouldBeBuffered() bool {
return m.buffered
}
func (m *MockFile) Len() int64 {
filesize, _ := m.Seek(0, io.SeekEnd)
_, _ = m.Seek(0, io.SeekStart)
@@ -60,3 +61,7 @@ func (m *MockFile) MustSize() int64 {
size, _ := m.Size()
return size
}
func (m *MockFile) Close() error {
return nil
}

View File

@@ -55,7 +55,7 @@ func TestParseLogTime(t *testing.T) {
func TestRotateKeepLast(t *testing.T) {
for _, format := range ReqLoggerFormats {
t.Run(string(format)+" keep last", func(t *testing.T) {
file := NewMockFile()
file := NewMockFile(true)
utils.MockTimeNow(testTime)
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{
Format: format,
@@ -86,7 +86,7 @@ func TestRotateKeepLast(t *testing.T) {
})
t.Run(string(format)+" keep days", func(t *testing.T) {
file := NewMockFile()
file := NewMockFile(true)
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{
Format: format,
})
@@ -132,7 +132,7 @@ func TestRotateKeepLast(t *testing.T) {
func TestRotateKeepFileSize(t *testing.T) {
for _, format := range ReqLoggerFormats {
t.Run(string(format)+" keep size no rotation", func(t *testing.T) {
file := NewMockFile()
file := NewMockFile(true)
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{
Format: format,
})
@@ -164,7 +164,7 @@ func TestRotateKeepFileSize(t *testing.T) {
}
t.Run("keep size with rotation", func(t *testing.T) {
file := NewMockFile()
file := NewMockFile(true)
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{
Format: FormatJSON,
})
@@ -198,7 +198,7 @@ func TestRotateKeepFileSize(t *testing.T) {
func TestRotateSkipInvalidTime(t *testing.T) {
for _, format := range ReqLoggerFormats {
t.Run(string(format), func(t *testing.T) {
file := NewMockFile()
file := NewMockFile(true)
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{
Format: format,
})
@@ -240,7 +240,7 @@ func BenchmarkRotate(b *testing.B) {
}
for _, retention := range tests {
b.Run(fmt.Sprintf("retention_%s", retention.String()), func(b *testing.B) {
file := NewMockFile()
file := NewMockFile(true)
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{
ConfigBase: ConfigBase{
Retention: retention,
@@ -256,7 +256,7 @@ func BenchmarkRotate(b *testing.B) {
b.ResetTimer()
for b.Loop() {
b.StopTimer()
file = NewMockFile()
file = NewMockFile(true)
_, _ = file.Write(content)
b.StartTimer()
var result RotateResult
@@ -274,7 +274,7 @@ func BenchmarkRotateWithInvalidTime(b *testing.B) {
}
for _, retention := range tests {
b.Run(fmt.Sprintf("retention_%s", retention.String()), func(b *testing.B) {
file := NewMockFile()
file := NewMockFile(true)
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{
ConfigBase: ConfigBase{
Retention: retention,
@@ -293,7 +293,7 @@ func BenchmarkRotateWithInvalidTime(b *testing.B) {
b.ResetTimer()
for b.Loop() {
b.StopTimer()
file = NewMockFile()
file = NewMockFile(true)
_, _ = file.Write(content)
b.StartTimer()
var result RotateResult

View File

@@ -0,0 +1,32 @@
package accesslog
import (
"os"
"github.com/rs/zerolog"
"github.com/yusing/godoxy/internal/logging"
)
type Stdout struct {
logger zerolog.Logger
}
func NewStdout() Writer {
return &Stdout{logger: logging.NewLoggerWithFixedLevel(zerolog.InfoLevel, os.Stdout)}
}
func (s Stdout) Name() string {
return "stdout"
}
func (s Stdout) ShouldBeBuffered() bool {
return false
}
func (s Stdout) Write(p []byte) (n int, err error) {
return s.logger.Write(p)
}
func (s Stdout) Close() error {
return nil
}

View File

@@ -0,0 +1,47 @@
package accesslog
import (
"io"
)
type BufferedWriter interface {
io.Writer
io.Closer
Flush() error
Resize(size int) error
}
type unbufferedWriter struct {
w io.Writer
}
func NewUnbufferedWriter(w io.Writer) BufferedWriter {
return unbufferedWriter{w: w}
}
func (w unbufferedWriter) Write(p []byte) (n int, err error) {
return w.w.Write(p)
}
func (w unbufferedWriter) Close() error {
if closer, ok := w.w.(io.Closer); ok {
return closer.Close()
}
return nil
}
func (w unbufferedWriter) Flush() error {
if flusher, ok := w.w.(interface{ Flush() }); ok {
flusher.Flush()
} else if errFlusher, ok := w.w.(interface{ FlushError() error }); ok {
return errFlusher.FlushError()
} else if errFlusher2, ok := w.w.(interface{ Flush() error }); ok {
return errFlusher2.Flush()
}
return nil
}
func (w unbufferedWriter) Resize(size int) error {
// No-op for unbuffered writer
return nil
}

View File

@@ -52,10 +52,7 @@ func ValidateVars(s string) error {
func ExpandVars(w *ResponseModifier, req *http.Request, src string, dstW io.Writer) error {
dst := ioutils.NewBufferedWriter(dstW, 1024)
defer func() {
dst.Flush()
dst.Release()
}()
defer dst.Close()
for i := 0; i < len(src); i++ {
ch := src[i]