diff --git a/go.mod b/go.mod index 1f7d20e2..dd6ab467 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,9 @@ require ( github.com/bytedance/sonic v1.13.2 github.com/docker/cli v28.1.1+incompatible github.com/luthermonson/go-proxmox v0.2.2 + github.com/spf13/afero v1.14.0 github.com/stretchr/testify v1.10.0 + go.uber.org/atomic v1.11.0 ) replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250418000134-7af8fd7b079e @@ -49,7 +51,7 @@ require ( github.com/cloudflare/cloudflare-go v0.115.0 // indirect github.com/cloudwego/base64x v0.1.5 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/diskfs/go-diskfs v1.5.0 // indirect + github.com/diskfs/go-diskfs v1.6.0 // indirect github.com/distribution/reference v0.6.0 // indirect github.com/djherbis/times v1.6.0 // indirect github.com/docker/go-connections v0.5.0 // indirect @@ -64,11 +66,11 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/gorilla/websocket v1.5.3 // indirect - github.com/jinzhu/copier v0.3.4 // indirect - github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/jinzhu/copier v0.4.0 // indirect + github.com/klauspost/cpuid/v2 v2.2.10 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35 // indirect - github.com/magefile/mage v1.14.0 // indirect + github.com/magefile/mage v1.15.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/miekg/dns v1.1.65 // indirect @@ -93,7 +95,7 @@ require ( go.opentelemetry.io/otel v1.35.0 // indirect go.opentelemetry.io/otel/sdk v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect - golang.org/x/arch v0.8.0 // indirect + golang.org/x/arch v0.16.0 // indirect golang.org/x/mod v0.24.0 // indirect golang.org/x/sync v0.13.0 // indirect golang.org/x/sys v0.32.0 // indirect diff --git a/go.sum b/go.sum index de4e8928..3eab3a7f 100644 --- a/go.sum +++ b/go.sum @@ -33,8 +33,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/diskfs/go-diskfs v1.5.0 h1:0SANkrab4ifiZBytk380gIesYh5Gc+3i40l7qsrYP4s= -github.com/diskfs/go-diskfs v1.5.0/go.mod h1:bRFumZeGFCO8C2KNswrQeuj2m1WCVr4Ms5IjWMczMDk= +github.com/diskfs/go-diskfs v1.6.0 h1:YmK5+vLSfkwC6kKKRTRPGaDGNF+Xh8FXeiNHwryDfu4= +github.com/diskfs/go-diskfs v1.6.0/go.mod h1:bRFumZeGFCO8C2KNswrQeuj2m1WCVr4Ms5IjWMczMDk= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c= @@ -107,15 +107,15 @@ github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslC github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI= github.com/jarcoal/httpmock v1.3.0 h1:2RJ8GP0IIaWwcC9Fp2BmVi8Kog3v2Hn7VXM3fTd+nuc= github.com/jarcoal/httpmock v1.3.0/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= -github.com/jinzhu/copier v0.3.4 h1:mfU6jI9PtCeUjkjQ322dlff9ELjGDu975C2p/nrubVI= -github.com/jinzhu/copier v0.3.4/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= +github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= +github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= -github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= +github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= @@ -131,8 +131,8 @@ github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35 h1:PpXWgLPs+Fqr32 github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= github.com/luthermonson/go-proxmox v0.2.2 h1:BZ7VEj302wxw2i/EwTcyEiBzQib8teocB2SSkLHyySY= github.com/luthermonson/go-proxmox v0.2.2/go.mod h1:oyFgg2WwTEIF0rP6ppjiixOHa5ebK1p8OaRiFhvICBQ= -github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= -github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= +github.com/magefile/mage v1.15.0 h1:BvGheCMAsG3bWUDbZ8AyXXpCNwU9u5CB6sM+HNb9HYg= +github.com/magefile/mage v1.15.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= @@ -194,6 +194,8 @@ github.com/shirou/gopsutil/v4 v4.25.3 h1:SeA68lsu8gLggyMbmCn8cmp97V1TI9ld9sVzAUc github.com/shirou/gopsutil/v4 v4.25.3/go.mod h1:xbuxyoZj+UsgnZrENu3lQivsngRR5BdjbJwf2fv4szA= github.com/sirupsen/logrus v1.9.4-0.20230606125235-dd1b4c2e81af h1:Sp5TG9f7K39yfB+If0vjp97vuT74F72r8hfRpP8jLU0= github.com/sirupsen/logrus v1.9.4-0.20230606125235-dd1b4c2e81af/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/spf13/afero v1.14.0 h1:9tH6MapGnn/j0eb0yIXiLjERO8RB6xIVZRDCX7PtqWA= +github.com/spf13/afero v1.14.0/go.mod h1:acJQ8t0ohCGuMN3O+Pv0V0hgMxNYDlvdk+VTfyZmbYo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -234,8 +236,10 @@ go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8= -golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= -golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +golang.org/x/arch v0.16.0 h1:foMtLTdyOmIniqWCHjY6+JxuC54XP1fDwx4N0ASyW+U= +golang.org/x/arch v0.16.0/go.mod h1:JmwW7aLIoRUKgaTzhkiEFxvcEiQGyOg9BMonBJUS7EE= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index 17dcf271..df02b4df 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -60,7 +60,7 @@ func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Config) return } - ep.accessLogger, err = accesslog.NewFileAccessLogger(parent, cfg) + ep.accessLogger, err = accesslog.NewAccessLogger(parent, cfg) if err != nil { return } diff --git a/internal/net/gphttp/accesslog/access_logger.go b/internal/net/gphttp/accesslog/access_logger.go index 9a6fc8a1..bf45847b 100644 --- a/internal/net/gphttp/accesslog/access_logger.go +++ b/internal/net/gphttp/accesslog/access_logger.go @@ -2,59 +2,99 @@ package accesslog import ( "bufio" - "bytes" "io" "net/http" "sync" "time" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" + "github.com/yusing/go-proxy/internal/utils/synk" + "golang.org/x/time/rate" ) type ( AccessLogger struct { - task *task.Task - cfg *Config - io AccessLogIO - buffered *bufio.Writer + task *task.Task + cfg *Config + io AccessLogIO + buffered *bufio.Writer + supportRotate bool + + lineBufPool *synk.BytesPool // buffer pool for formatting a single log line + + errRateLimiter *rate.Limiter + + logger zerolog.Logger - lineBufPool sync.Pool // buffer pool for formatting a single log line Formatter } AccessLogIO interface { - io.ReadWriteCloser - io.ReadWriteSeeker - io.ReaderAt + io.Writer sync.Locker Name() string // file name or path - Truncate(size int64) error } Formatter interface { - // Format writes a log line to line without a trailing newline - Format(line *bytes.Buffer, req *http.Request, res *http.Response) - SetGetTimeNow(getTimeNow func() time.Time) + // AppendLog appends a log line to line with or without a trailing newline + AppendLog(line []byte, req *http.Request, res *http.Response) []byte } ) -func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger { +const MinBufferSize = 4 * kilobyte + +const ( + flushInterval = 30 * time.Second + rotateInterval = time.Hour +) + +func NewAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) { + var ios []AccessLogIO + + if cfg.Stdout { + ios = append(ios, stdoutIO) + } + + if cfg.Path != "" { + io, err := newFileIO(cfg.Path) + if err != nil { + return nil, err + } + ios = append(ios, io) + } + + if len(ios) == 0 { + return nil, nil + } + + return NewAccessLoggerWithIO(parent, NewMultiWriter(ios...), cfg), nil +} + +func NewMockAccessLogger(parent task.Parent, cfg *Config) *AccessLogger { + return NewAccessLoggerWithIO(parent, NewMockFile(), cfg) +} + +func NewAccessLoggerWithIO(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger { if cfg.BufferSize == 0 { cfg.BufferSize = DefaultBufferSize } - if cfg.BufferSize < 4096 { - cfg.BufferSize = 4096 + if cfg.BufferSize < MinBufferSize { + cfg.BufferSize = MinBufferSize } l := &AccessLogger{ - task: parent.Subtask("accesslog"), - cfg: cfg, - io: io, - buffered: bufio.NewWriterSize(io, cfg.BufferSize), + task: parent.Subtask("accesslog."+io.Name(), true), + cfg: cfg, + io: io, + buffered: bufio.NewWriterSize(io, cfg.BufferSize), + lineBufPool: synk.NewBytesPool(1024, synk.DefaultMaxBytes), + errRateLimiter: rate.NewLimiter(rate.Every(time.Second), 1), + logger: logging.With().Str("file", io.Name()).Logger(), } - fmt := CommonFormatter{cfg: &l.cfg.Fields, GetTimeNow: time.Now} + fmt := CommonFormatter{cfg: &l.cfg.Fields} switch l.cfg.Format { case FormatCommon: l.Formatter = &fmt @@ -66,14 +106,19 @@ func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLog panic("invalid access log format") } - l.lineBufPool.New = func() any { - return bytes.NewBuffer(make([]byte, 0, 1024)) + if _, ok := l.io.(supportRotate); ok { + l.supportRotate = true } + go l.start() return l } -func (l *AccessLogger) checkKeep(req *http.Request, res *http.Response) bool { +func (l *AccessLogger) Config() *Config { + return l.cfg +} + +func (l *AccessLogger) shouldLog(req *http.Request, res *http.Response) bool { if !l.cfg.Filters.StatusCodes.CheckKeep(req, res) || !l.cfg.Filters.Method.CheckKeep(req, res) || !l.cfg.Filters.Headers.CheckKeep(req, res) || @@ -84,53 +129,63 @@ func (l *AccessLogger) checkKeep(req *http.Request, res *http.Response) bool { } func (l *AccessLogger) Log(req *http.Request, res *http.Response) { - if !l.checkKeep(req, res) { + if !l.shouldLog(req, res) { return } - line := l.lineBufPool.Get().(*bytes.Buffer) - line.Reset() + line := l.lineBufPool.Get() defer l.lineBufPool.Put(line) - l.Formatter.Format(line, req, res) - line.WriteRune('\n') - l.write(line.Bytes()) + line = l.Formatter.AppendLog(line, req, res) + if line[len(line)-1] != '\n' { + line = append(line, '\n') + } + l.lockWrite(line) } func (l *AccessLogger) LogError(req *http.Request, err error) { l.Log(req, &http.Response{StatusCode: http.StatusInternalServerError, Status: err.Error()}) } -func (l *AccessLogger) Config() *Config { - return l.cfg +func (l *AccessLogger) ShouldRotate() bool { + return l.cfg.Retention.IsValid() && l.supportRotate } -func (l *AccessLogger) Rotate() error { - if l.cfg.Retention == nil { - return nil +func (l *AccessLogger) Rotate() (result *RotateResult, err error) { + if !l.ShouldRotate() { + return nil, nil } + l.io.Lock() defer l.io.Unlock() - return l.rotate() + return rotateLogFile(l.io.(supportRotate), l.cfg.Retention) } func (l *AccessLogger) handleErr(err error) { - gperr.LogError("failed to write access log", err) + if l.errRateLimiter.Allow() { + gperr.LogError("failed to write access log", err) + } else { + gperr.LogError("too many errors, stopping access log", err) + l.task.Finish(err) + } } func (l *AccessLogger) start() { defer func() { + defer l.task.Finish(nil) + defer l.close() if err := l.Flush(); err != nil { l.handleErr(err) } - l.close() - l.task.Finish(nil) }() // flushes the buffer every 30 seconds flushTicker := time.NewTicker(30 * time.Second) defer flushTicker.Stop() + rotateTicker := time.NewTicker(rotateInterval) + defer rotateTicker.Stop() + for { select { case <-l.task.Context().Done(): @@ -139,6 +194,18 @@ func (l *AccessLogger) start() { if err := l.Flush(); err != nil { l.handleErr(err) } + case <-rotateTicker.C: + if !l.ShouldRotate() { + continue + } + l.logger.Info().Msg("rotating access log file") + if res, err := l.Rotate(); err != nil { + l.handleErr(err) + } else if res != nil { + res.Print(&l.logger) + } else { + l.logger.Info().Msg("no rotation needed") + } } } } @@ -150,18 +217,20 @@ func (l *AccessLogger) Flush() error { } func (l *AccessLogger) close() { - l.io.Lock() - defer l.io.Unlock() - l.io.Close() + if r, ok := l.io.(io.Closer); ok { + l.io.Lock() + defer l.io.Unlock() + r.Close() + } } -func (l *AccessLogger) write(data []byte) { +func (l *AccessLogger) lockWrite(data []byte) { l.io.Lock() // prevent concurrent write, i.e. log rotation, other access loggers _, err := l.buffered.Write(data) l.io.Unlock() if err != nil { l.handleErr(err) } else { - logging.Debug().Msg("access log flushed to " + l.io.Name()) + logging.Trace().Msg("access log flushed to " + l.io.Name()) } } diff --git a/internal/net/gphttp/accesslog/access_logger_test.go b/internal/net/gphttp/accesslog/access_logger_test.go index 012d8ebc..cb8f468a 100644 --- a/internal/net/gphttp/accesslog/access_logger_test.go +++ b/internal/net/gphttp/accesslog/access_logger_test.go @@ -1,7 +1,6 @@ package accesslog_test import ( - "bytes" "encoding/json" "fmt" "net/http" @@ -11,7 +10,7 @@ import ( . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" "github.com/yusing/go-proxy/internal/task" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) const ( @@ -22,14 +21,14 @@ const ( referer = "https://www.google.com/" proto = "HTTP/1.1" ua = "Go-http-client/1.1" - status = http.StatusOK + status = http.StatusNotFound contentLength = 100 method = http.MethodGet ) var ( testTask = task.RootTask("test", false) - testURL = Must(url.Parse("http://" + host + uri)) + testURL = expect.Must(url.Parse("http://" + host + uri)) req = &http.Request{ RemoteAddr: remote, Method: method, @@ -53,22 +52,20 @@ var ( ) func fmtLog(cfg *Config) (ts string, line string) { - var buf bytes.Buffer + buf := make([]byte, 0, 1024) t := time.Now() - logger := NewAccessLogger(testTask, nil, cfg) - logger.Formatter.SetGetTimeNow(func() time.Time { - return t - }) - logger.Format(&buf, req, resp) - return t.Format(LogTimeFormat), buf.String() + logger := NewMockAccessLogger(testTask, cfg) + MockTimeNow(t) + buf = logger.AppendLog(buf, req, resp) + return t.Format(LogTimeFormat), string(buf) } func TestAccessLoggerCommon(t *testing.T) { config := DefaultConfig() config.Format = FormatCommon ts, log := fmtLog(config) - ExpectEqual(t, log, + expect.Equal(t, log, fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d", host, remote, ts, method, uri, proto, status, contentLength, ), @@ -79,7 +76,7 @@ func TestAccessLoggerCombined(t *testing.T) { config := DefaultConfig() config.Format = FormatCombined ts, log := fmtLog(config) - ExpectEqual(t, log, + expect.Equal(t, log, fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d \"%s\" \"%s\"", host, remote, ts, method, uri, proto, status, contentLength, referer, ua, ), @@ -91,37 +88,79 @@ func TestAccessLoggerRedactQuery(t *testing.T) { config.Format = FormatCommon config.Fields.Query.Default = FieldModeRedact ts, log := fmtLog(config) - ExpectEqual(t, log, + expect.Equal(t, log, fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d", host, remote, ts, method, uriRedacted, proto, status, contentLength, ), ) } +type JSONLogEntry struct { + Time string `json:"time"` + IP string `json:"ip"` + Method string `json:"method"` + Scheme string `json:"scheme"` + Host string `json:"host"` + Path string `json:"path"` + Protocol string `json:"protocol"` + Status int `json:"status"` + Error string `json:"error,omitempty"` + ContentType string `json:"type"` + Size int64 `json:"size"` + Referer string `json:"referer"` + UserAgent string `json:"useragent"` + Query map[string][]string `json:"query,omitempty"` + Headers map[string][]string `json:"headers,omitempty"` + Cookies map[string]string `json:"cookies,omitempty"` +} + func getJSONEntry(t *testing.T, config *Config) JSONLogEntry { t.Helper() config.Format = FormatJSON var entry JSONLogEntry _, log := fmtLog(config) err := json.Unmarshal([]byte(log), &entry) - ExpectNoError(t, err) + expect.NoError(t, err) return entry } func TestAccessLoggerJSON(t *testing.T) { config := DefaultConfig() entry := getJSONEntry(t, config) - ExpectEqual(t, entry.IP, remote) - ExpectEqual(t, entry.Method, method) - ExpectEqual(t, entry.Scheme, "http") - ExpectEqual(t, entry.Host, testURL.Host) - ExpectEqual(t, entry.URI, testURL.RequestURI()) - ExpectEqual(t, entry.Protocol, proto) - ExpectEqual(t, entry.Status, status) - ExpectEqual(t, entry.ContentType, "text/plain") - ExpectEqual(t, entry.Size, contentLength) - ExpectEqual(t, entry.Referer, referer) - ExpectEqual(t, entry.UserAgent, ua) - ExpectEqual(t, len(entry.Headers), 0) - ExpectEqual(t, len(entry.Cookies), 0) + expect.Equal(t, entry.IP, remote) + expect.Equal(t, entry.Method, method) + expect.Equal(t, entry.Scheme, "http") + expect.Equal(t, entry.Host, testURL.Host) + expect.Equal(t, entry.Path, testURL.Path) + expect.Equal(t, entry.Protocol, proto) + expect.Equal(t, entry.Status, status) + expect.Equal(t, entry.ContentType, "text/plain") + expect.Equal(t, entry.Size, contentLength) + expect.Equal(t, entry.Referer, referer) + expect.Equal(t, entry.UserAgent, ua) + expect.Equal(t, len(entry.Headers), 0) + expect.Equal(t, len(entry.Cookies), 0) + if status >= 400 { + expect.Equal(t, entry.Error, http.StatusText(status)) + } +} + +func BenchmarkAccessLoggerJSON(b *testing.B) { + config := DefaultConfig() + config.Format = FormatJSON + logger := NewMockAccessLogger(testTask, config) + b.ResetTimer() + for b.Loop() { + logger.Log(req, resp) + } +} + +func BenchmarkAccessLoggerCombined(b *testing.B) { + config := DefaultConfig() + config.Format = FormatCombined + logger := NewMockAccessLogger(testTask, config) + b.ResetTimer() + for b.Loop() { + logger.Log(req, resp) + } } diff --git a/internal/net/gphttp/accesslog/back_scanner.go b/internal/net/gphttp/accesslog/back_scanner.go index 2e550059..bf17a1c7 100644 --- a/internal/net/gphttp/accesslog/back_scanner.go +++ b/internal/net/gphttp/accesslog/back_scanner.go @@ -2,32 +2,40 @@ package accesslog import ( "bytes" + "errors" "io" ) // BackScanner provides an interface to read a file backward line by line. type BackScanner struct { - file AccessLogIO - chunkSize int - offset int64 - buffer []byte - line []byte - err error + file supportRotate size int64 + chunkSize int + chunkBuf []byte + + offset int64 + chunk []byte + line []byte + err error } // NewBackScanner creates a new Scanner to read the file backward. // chunkSize determines the size of each read chunk from the end of the file. -func NewBackScanner(file AccessLogIO, chunkSize int) *BackScanner { +func NewBackScanner(file supportRotate, chunkSize int) *BackScanner { size, err := file.Seek(0, io.SeekEnd) if err != nil { return &BackScanner{err: err} } + return newBackScanner(file, size, make([]byte, chunkSize)) +} + +func newBackScanner(file supportRotate, fileSize int64, buf []byte) *BackScanner { return &BackScanner{ file: file, - chunkSize: chunkSize, - offset: size, - size: size, + size: fileSize, + offset: fileSize, + chunkSize: len(buf), + chunkBuf: buf, } } @@ -41,9 +49,9 @@ func (s *BackScanner) Scan() bool { // Read chunks until a newline is found or the file is fully read for { // Check if there's a line in the buffer - if idx := bytes.LastIndexByte(s.buffer, '\n'); idx >= 0 { - s.line = s.buffer[idx+1:] - s.buffer = s.buffer[:idx] + if idx := bytes.LastIndexByte(s.chunk, '\n'); idx >= 0 { + s.line = s.chunk[idx+1:] + s.chunk = s.chunk[:idx] if len(s.line) > 0 { return true } @@ -53,9 +61,9 @@ func (s *BackScanner) Scan() bool { for { if s.offset <= 0 { // No more data to read; check remaining buffer - if len(s.buffer) > 0 { - s.line = s.buffer - s.buffer = nil + if len(s.chunk) > 0 { + s.line = s.chunk + s.chunk = nil return true } return false @@ -63,22 +71,27 @@ func (s *BackScanner) Scan() bool { newOffset := max(0, s.offset-int64(s.chunkSize)) chunkSize := s.offset - newOffset - chunk := make([]byte, chunkSize) + chunk := s.chunkBuf[:chunkSize] n, err := s.file.ReadAt(chunk, newOffset) - if err != nil && err != io.EOF { - s.err = err + if err != nil { + if !errors.Is(err, io.EOF) { + s.err = err + } + return false + } else if n == 0 { return false } // Prepend the chunk to the buffer - s.buffer = append(chunk[:n], s.buffer...) + clone := append([]byte{}, chunk[:n]...) + s.chunk = append(clone, s.chunk...) s.offset = newOffset // Check for newline in the updated buffer - if idx := bytes.LastIndexByte(s.buffer, '\n'); idx >= 0 { - s.line = s.buffer[idx+1:] - s.buffer = s.buffer[:idx] + if idx := bytes.LastIndexByte(s.chunk, '\n'); idx >= 0 { + s.line = s.chunk[idx+1:] + s.chunk = s.chunk[:idx] if len(s.line) > 0 { return true } @@ -102,3 +115,12 @@ func (s *BackScanner) FileSize() int64 { func (s *BackScanner) Err() error { return s.err } + +func (s *BackScanner) Reset() error { + _, err := s.file.Seek(0, io.SeekStart) + if err != nil { + return err + } + *s = *newBackScanner(s.file, s.size, s.chunkBuf) + return nil +} diff --git a/internal/net/gphttp/accesslog/back_scanner_test.go b/internal/net/gphttp/accesslog/back_scanner_test.go index 939b4118..59e4ca67 100644 --- a/internal/net/gphttp/accesslog/back_scanner_test.go +++ b/internal/net/gphttp/accesslog/back_scanner_test.go @@ -2,8 +2,16 @@ package accesslog import ( "fmt" + "net/http" + "net/http/httptest" + "os" "strings" "testing" + + "github.com/spf13/afero" + "github.com/yusing/go-proxy/internal/task" + "github.com/yusing/go-proxy/internal/utils/strutils" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestBackScanner(t *testing.T) { @@ -52,7 +60,7 @@ func TestBackScanner(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Setup mock file - mockFile := &MockFile{} + mockFile := NewMockFile() _, err := mockFile.Write([]byte(tt.input)) if err != nil { t.Fatalf("failed to write to mock file: %v", err) @@ -94,7 +102,7 @@ func TestBackScannerWithVaryingChunkSizes(t *testing.T) { for _, chunkSize := range chunkSizes { t.Run(fmt.Sprintf("chunk_size_%d", chunkSize), func(t *testing.T) { - mockFile := &MockFile{} + mockFile := NewMockFile() _, err := mockFile.Write([]byte(input)) if err != nil { t.Fatalf("failed to write to mock file: %v", err) @@ -125,3 +133,136 @@ func TestBackScannerWithVaryingChunkSizes(t *testing.T) { }) } } + +func logEntry() []byte { + accesslog := NewMockAccessLogger(task.RootTask("test", false), &Config{ + Format: FormatJSON, + }) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello")) + })) + srv.URL = "http://localhost:8080" + defer srv.Close() + // make a request to the server + req, _ := http.NewRequest("GET", srv.URL, nil) + res := httptest.NewRecorder() + // server the request + srv.Config.Handler.ServeHTTP(res, req) + b := accesslog.AppendLog(nil, req, res.Result()) + if b[len(b)-1] != '\n' { + b = append(b, '\n') + } + return b +} + +func TestReset(t *testing.T) { + file, err := afero.TempFile(afero.NewOsFs(), "", "accesslog") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + defer os.Remove(file.Name()) + line := logEntry() + nLines := 1000 + for range nLines { + _, err := file.Write(line) + if err != nil { + t.Fatalf("failed to write to temp file: %v", err) + } + } + linesRead := 0 + s := NewBackScanner(file, defaultChunkSize) + for s.Scan() { + linesRead++ + } + if err := s.Err(); err != nil { + t.Errorf("scanner error: %v", err) + } + expect.Equal(t, linesRead, nLines) + s.Reset() + + linesRead = 0 + for s.Scan() { + linesRead++ + } + if err := s.Err(); err != nil { + t.Errorf("scanner error: %v", err) + } + expect.Equal(t, linesRead, nLines) +} + +// 100000 log entries +func BenchmarkBackScanner(b *testing.B) { + mockFile := NewMockFile() + line := logEntry() + for range 100000 { + _, _ = mockFile.Write(line) + } + for i := range 14 { + chunkSize := (2 << i) * kilobyte + scanner := NewBackScanner(mockFile, chunkSize) + name := strutils.FormatByteSize(chunkSize) + b.ResetTimer() + b.Run(name, func(b *testing.B) { + for b.Loop() { + _ = scanner.Reset() + for scanner.Scan() { + } + } + }) + } +} + +func BenchmarkBackScannerRealFile(b *testing.B) { + file, err := afero.TempFile(afero.NewOsFs(), "", "accesslog") + if err != nil { + b.Fatalf("failed to create temp file: %v", err) + } + defer os.Remove(file.Name()) + + for range 10000 { + _, err = file.Write(logEntry()) + if err != nil { + b.Fatalf("failed to write to temp file: %v", err) + } + } + + scanner := NewBackScanner(file, 256*kilobyte) + b.ResetTimer() + for scanner.Scan() { + } + if err := scanner.Err(); err != nil { + b.Errorf("scanner error: %v", err) + } +} + +/* +BenchmarkBackScanner +BenchmarkBackScanner/2_KiB +BenchmarkBackScanner/2_KiB-20 52 23254071 ns/op 67596663 B/op 26420 allocs/op +BenchmarkBackScanner/4_KiB +BenchmarkBackScanner/4_KiB-20 55 20961059 ns/op 62529378 B/op 13211 allocs/op +BenchmarkBackScanner/8_KiB +BenchmarkBackScanner/8_KiB-20 64 18242460 ns/op 62951141 B/op 6608 allocs/op +BenchmarkBackScanner/16_KiB +BenchmarkBackScanner/16_KiB-20 52 20162076 ns/op 62940256 B/op 3306 allocs/op +BenchmarkBackScanner/32_KiB +BenchmarkBackScanner/32_KiB-20 54 19247968 ns/op 67553645 B/op 1656 allocs/op +BenchmarkBackScanner/64_KiB +BenchmarkBackScanner/64_KiB-20 60 20909046 ns/op 64053342 B/op 827 allocs/op +BenchmarkBackScanner/128_KiB +BenchmarkBackScanner/128_KiB-20 68 17759890 ns/op 62201945 B/op 414 allocs/op +BenchmarkBackScanner/256_KiB +BenchmarkBackScanner/256_KiB-20 52 19531877 ns/op 61030487 B/op 208 allocs/op +BenchmarkBackScanner/512_KiB +BenchmarkBackScanner/512_KiB-20 54 19124656 ns/op 61030485 B/op 208 allocs/op +BenchmarkBackScanner/1_MiB +BenchmarkBackScanner/1_MiB-20 67 17078936 ns/op 61030495 B/op 208 allocs/op +BenchmarkBackScanner/2_MiB +BenchmarkBackScanner/2_MiB-20 66 18467421 ns/op 61030492 B/op 208 allocs/op +BenchmarkBackScanner/4_MiB +BenchmarkBackScanner/4_MiB-20 68 17214573 ns/op 61030486 B/op 208 allocs/op +BenchmarkBackScanner/8_MiB +BenchmarkBackScanner/8_MiB-20 57 18235229 ns/op 61030492 B/op 208 allocs/op +BenchmarkBackScanner/16_MiB +BenchmarkBackScanner/16_MiB-20 57 19343441 ns/op 61030499 B/op 208 allocs/op +*/ diff --git a/internal/net/gphttp/accesslog/config.go b/internal/net/gphttp/accesslog/config.go index a1dbe2f1..91455c0d 100644 --- a/internal/net/gphttp/accesslog/config.go +++ b/internal/net/gphttp/accesslog/config.go @@ -1,6 +1,9 @@ package accesslog -import "github.com/yusing/go-proxy/internal/utils" +import ( + "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/utils" +) type ( Format string @@ -19,7 +22,8 @@ type ( Config struct { BufferSize int `json:"buffer_size"` Format Format `json:"format" validate:"oneof=common combined json"` - Path string `json:"path" validate:"required"` + Path string `json:"path"` + Stdout bool `json:"stdout"` Filters Filters `json:"filters"` Fields Fields `json:"fields"` Retention *Retention `json:"retention"` @@ -30,14 +34,24 @@ var ( FormatCommon Format = "common" FormatCombined Format = "combined" FormatJSON Format = "json" + + AvailableFormats = []Format{FormatCommon, FormatCombined, FormatJSON} ) -const DefaultBufferSize = 64 * 1024 // 64KB +const DefaultBufferSize = 64 * kilobyte // 64KB + +func (cfg *Config) Validate() gperr.Error { + if cfg.Path == "" && !cfg.Stdout { + return gperr.New("path or stdout is required") + } + return nil +} func DefaultConfig() *Config { return &Config{ BufferSize: DefaultBufferSize, Format: FormatCombined, + Retention: &Retention{Days: 30}, Fields: Fields{ Headers: FieldConfig{ Default: FieldModeDrop, diff --git a/internal/net/gphttp/accesslog/config_test.go b/internal/net/gphttp/accesslog/config_test.go index 8483638e..1a5597bd 100644 --- a/internal/net/gphttp/accesslog/config_test.go +++ b/internal/net/gphttp/accesslog/config_test.go @@ -6,7 +6,7 @@ import ( "github.com/yusing/go-proxy/internal/docker" . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" "github.com/yusing/go-proxy/internal/utils" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestNewConfig(t *testing.T) { @@ -27,27 +27,27 @@ func TestNewConfig(t *testing.T) { "proxy.fields.cookies.config.foo": "keep", } parsed, err := docker.ParseLabels(labels) - ExpectNoError(t, err) + expect.NoError(t, err) var config Config err = utils.Deserialize(parsed, &config) - ExpectNoError(t, err) + expect.NoError(t, err) - ExpectEqual(t, config.BufferSize, 10) - ExpectEqual(t, config.Format, FormatCombined) - ExpectEqual(t, config.Path, "/tmp/access.log") - ExpectEqual(t, config.Filters.StatusCodes.Values, []*StatusCodeRange{{Start: 200, End: 299}}) - ExpectEqual(t, len(config.Filters.Method.Values), 2) - ExpectEqual(t, config.Filters.Method.Values, []HTTPMethod{"GET", "POST"}) - ExpectEqual(t, len(config.Filters.Headers.Values), 2) - ExpectEqual(t, config.Filters.Headers.Values, []*HTTPHeader{{Key: "foo", Value: "bar"}, {Key: "baz", Value: ""}}) - ExpectTrue(t, config.Filters.Headers.Negative) - ExpectEqual(t, len(config.Filters.CIDR.Values), 1) - ExpectEqual(t, config.Filters.CIDR.Values[0].String(), "192.168.10.0/24") - ExpectEqual(t, config.Fields.Headers.Default, FieldModeKeep) - ExpectEqual(t, config.Fields.Headers.Config["foo"], FieldModeRedact) - ExpectEqual(t, config.Fields.Query.Default, FieldModeDrop) - ExpectEqual(t, config.Fields.Query.Config["foo"], FieldModeKeep) - ExpectEqual(t, config.Fields.Cookies.Default, FieldModeRedact) - ExpectEqual(t, config.Fields.Cookies.Config["foo"], FieldModeKeep) + expect.Equal(t, config.BufferSize, 10) + expect.Equal(t, config.Format, FormatCombined) + expect.Equal(t, config.Path, "/tmp/access.log") + expect.Equal(t, config.Filters.StatusCodes.Values, []*StatusCodeRange{{Start: 200, End: 299}}) + expect.Equal(t, len(config.Filters.Method.Values), 2) + expect.Equal(t, config.Filters.Method.Values, []HTTPMethod{"GET", "POST"}) + expect.Equal(t, len(config.Filters.Headers.Values), 2) + expect.Equal(t, config.Filters.Headers.Values, []*HTTPHeader{{Key: "foo", Value: "bar"}, {Key: "baz", Value: ""}}) + expect.True(t, config.Filters.Headers.Negative) + expect.Equal(t, len(config.Filters.CIDR.Values), 1) + expect.Equal(t, config.Filters.CIDR.Values[0].String(), "192.168.10.0/24") + expect.Equal(t, config.Fields.Headers.Default, FieldModeKeep) + expect.Equal(t, config.Fields.Headers.Config["foo"], FieldModeRedact) + expect.Equal(t, config.Fields.Query.Default, FieldModeDrop) + expect.Equal(t, config.Fields.Query.Config["foo"], FieldModeKeep) + expect.Equal(t, config.Fields.Cookies.Default, FieldModeRedact) + expect.Equal(t, config.Fields.Cookies.Config["foo"], FieldModeKeep) } diff --git a/internal/net/gphttp/accesslog/fields.go b/internal/net/gphttp/accesslog/fields.go index f785b94d..7b5e013c 100644 --- a/internal/net/gphttp/accesslog/fields.go +++ b/internal/net/gphttp/accesslog/fields.go @@ -1,8 +1,11 @@ package accesslog import ( + "iter" "net/http" "net/url" + + "github.com/rs/zerolog" ) type ( @@ -21,83 +24,181 @@ const ( RedactedValue = "REDACTED" ) -func processMap[V any](cfg *FieldConfig, m map[string]V, redactedV V) map[string]V { +type mapStringStringIter interface { + Iter(yield func(k string, v []string) bool) + MarshalZerologObject(e *zerolog.Event) +} + +type mapStringStringSlice struct { + m map[string][]string +} + +func (m mapStringStringSlice) Iter(yield func(k string, v []string) bool) { + for k, v := range m.m { + if !yield(k, v) { + return + } + } +} + +func (m mapStringStringSlice) MarshalZerologObject(e *zerolog.Event) { + for k, v := range m.m { + e.Strs(k, v) + } +} + +type mapStringStringRedacted struct { + m map[string][]string +} + +func (m mapStringStringRedacted) Iter(yield func(k string, v []string) bool) { + for k := range m.m { + if !yield(k, []string{RedactedValue}) { + return + } + } +} + +func (m mapStringStringRedacted) MarshalZerologObject(e *zerolog.Event) { + for k, v := range m.Iter { + e.Strs(k, v) + } +} + +type mapStringStringSliceWithConfig struct { + m map[string][]string + cfg *FieldConfig +} + +func (m mapStringStringSliceWithConfig) Iter(yield func(k string, v []string) bool) { + var mode FieldMode + var ok bool + for k, v := range m.m { + if mode, ok = m.cfg.Config[k]; !ok { + mode = m.cfg.Default + } + switch mode { + case FieldModeKeep: + if !yield(k, v) { + return + } + case FieldModeRedact: + if !yield(k, []string{RedactedValue}) { + return + } + } + } +} + +func (m mapStringStringSliceWithConfig) MarshalZerologObject(e *zerolog.Event) { + for k, v := range m.Iter { + e.Strs(k, v) + } +} + +type mapStringStringDrop struct{} + +func (m mapStringStringDrop) Iter(yield func(k string, v []string) bool) {} +func (m mapStringStringDrop) MarshalZerologObject(e *zerolog.Event) {} + +var mapStringStringDropIter mapStringStringIter = mapStringStringDrop{} + +func mapIter[Map http.Header | url.Values](cfg *FieldConfig, m Map) mapStringStringIter { if len(cfg.Config) == 0 { switch cfg.Default { case FieldModeKeep: - return m + return mapStringStringSlice{m: m} case FieldModeDrop: - return nil + return mapStringStringDropIter case FieldModeRedact: - redacted := make(map[string]V) - for k := range m { - redacted[k] = redactedV - } - return redacted + return mapStringStringRedacted{m: m} } } + return mapStringStringSliceWithConfig{m: m, cfg: cfg} +} - if len(m) == 0 { - return m - } +type slice[V any] struct { + s []V + getKey func(V) string + getVal func(V) string + cfg *FieldConfig +} - newMap := make(map[string]V, len(m)) - for k := range m { +type sliceIter interface { + Iter(yield func(k string, v string) bool) + MarshalZerologObject(e *zerolog.Event) +} + +func (s *slice[V]) Iter(yield func(k string, v string) bool) { + for _, v := range s.s { + k := s.getKey(v) var mode FieldMode var ok bool - if mode, ok = cfg.Config[k]; !ok { - mode = cfg.Default + if mode, ok = s.cfg.Config[k]; !ok { + mode = s.cfg.Default } switch mode { case FieldModeKeep: - newMap[k] = m[k] + if !yield(k, s.getVal(v)) { + return + } case FieldModeRedact: - newMap[k] = redactedV + if !yield(k, RedactedValue) { + return + } } } - return newMap } -func processSlice[V any, VReturn any](cfg *FieldConfig, s []V, getKey func(V) string, convert func(V) VReturn, redact func(V) VReturn) map[string]VReturn { +type sliceDrop struct{} + +func (s sliceDrop) Iter(yield func(k string, v string) bool) {} +func (s sliceDrop) MarshalZerologObject(e *zerolog.Event) {} + +var sliceDropIter sliceIter = sliceDrop{} + +func (s *slice[V]) MarshalZerologObject(e *zerolog.Event) { + for k, v := range s.Iter { + e.Str(k, v) + } +} + +func iterSlice[V any](cfg *FieldConfig, s []V, getKey func(V) string, getVal func(V) string) sliceIter { if len(s) == 0 || len(cfg.Config) == 0 && cfg.Default == FieldModeDrop { - return nil + return sliceDropIter } - newMap := make(map[string]VReturn, len(s)) - for _, v := range s { - var mode FieldMode - var ok bool - k := getKey(v) - if mode, ok = cfg.Config[k]; !ok { - mode = cfg.Default - } - switch mode { - case FieldModeKeep: - newMap[k] = convert(v) - case FieldModeRedact: - newMap[k] = redact(v) - } - } - return newMap + return &slice[V]{s: s, getKey: getKey, getVal: getVal, cfg: cfg} } -func (cfg *FieldConfig) ProcessHeaders(headers http.Header) http.Header { - return processMap(cfg, headers, []string{RedactedValue}) +func (cfg *FieldConfig) IterHeaders(headers http.Header) iter.Seq2[string, []string] { + return mapIter(cfg, headers).Iter } -func (cfg *FieldConfig) ProcessQuery(q url.Values) url.Values { - return processMap(cfg, q, []string{RedactedValue}) +func (cfg *FieldConfig) ZerologHeaders(headers http.Header) zerolog.LogObjectMarshaler { + return mapIter(cfg, headers) } -func (cfg *FieldConfig) ProcessCookies(cookies []*http.Cookie) map[string]string { - return processSlice(cfg, cookies, - func(c *http.Cookie) string { - return c.Name - }, - func(c *http.Cookie) string { - return c.Value - }, - func(c *http.Cookie) string { - return RedactedValue - }) +func (cfg *FieldConfig) IterQuery(q url.Values) iter.Seq2[string, []string] { + return mapIter(cfg, q).Iter +} + +func (cfg *FieldConfig) ZerologQuery(q url.Values) zerolog.LogObjectMarshaler { + return mapIter(cfg, q) +} + +func cookieGetKey(c *http.Cookie) string { + return c.Name +} + +func cookieGetValue(c *http.Cookie) string { + return c.Value +} + +func (cfg *FieldConfig) IterCookies(cookies []*http.Cookie) iter.Seq2[string, string] { + return iterSlice(cfg, cookies, cookieGetKey, cookieGetValue).Iter +} + +func (cfg *FieldConfig) ZerologCookies(cookies []*http.Cookie) zerolog.LogObjectMarshaler { + return iterSlice(cfg, cookies, cookieGetKey, cookieGetValue) } diff --git a/internal/net/gphttp/accesslog/fields_test.go b/internal/net/gphttp/accesslog/fields_test.go index f4827348..8e0fd8e1 100644 --- a/internal/net/gphttp/accesslog/fields_test.go +++ b/internal/net/gphttp/accesslog/fields_test.go @@ -4,7 +4,7 @@ import ( "testing" . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) // Cookie header should be removed, @@ -15,7 +15,7 @@ func TestAccessLoggerJSONKeepHeaders(t *testing.T) { entry := getJSONEntry(t, config) for k, v := range req.Header { if k != "Cookie" { - ExpectEqual(t, entry.Headers[k], v) + expect.Equal(t, entry.Headers[k], v) } } @@ -24,8 +24,8 @@ func TestAccessLoggerJSONKeepHeaders(t *testing.T) { "User-Agent": FieldModeDrop, } entry = getJSONEntry(t, config) - ExpectEqual(t, entry.Headers["Referer"], []string{RedactedValue}) - ExpectEqual(t, entry.Headers["User-Agent"], nil) + expect.Equal(t, entry.Headers["Referer"], []string{RedactedValue}) + expect.Equal(t, entry.Headers["User-Agent"], nil) } func TestAccessLoggerJSONDropHeaders(t *testing.T) { @@ -33,7 +33,7 @@ func TestAccessLoggerJSONDropHeaders(t *testing.T) { config.Fields.Headers.Default = FieldModeDrop entry := getJSONEntry(t, config) for k := range req.Header { - ExpectEqual(t, entry.Headers[k], nil) + expect.Equal(t, entry.Headers[k], nil) } config.Fields.Headers.Config = map[string]FieldMode{ @@ -41,18 +41,17 @@ func TestAccessLoggerJSONDropHeaders(t *testing.T) { "User-Agent": FieldModeRedact, } entry = getJSONEntry(t, config) - ExpectEqual(t, entry.Headers["Referer"], []string{req.Header.Get("Referer")}) - ExpectEqual(t, entry.Headers["User-Agent"], []string{RedactedValue}) + expect.Equal(t, entry.Headers["Referer"], []string{req.Header.Get("Referer")}) + expect.Equal(t, entry.Headers["User-Agent"], []string{RedactedValue}) } func TestAccessLoggerJSONRedactHeaders(t *testing.T) { config := DefaultConfig() config.Fields.Headers.Default = FieldModeRedact entry := getJSONEntry(t, config) - ExpectEqual(t, len(entry.Headers["Cookie"]), 0) for k := range req.Header { if k != "Cookie" { - ExpectEqual(t, entry.Headers[k], []string{RedactedValue}) + expect.Equal(t, entry.Headers[k], []string{RedactedValue}) } } } @@ -62,9 +61,8 @@ func TestAccessLoggerJSONKeepCookies(t *testing.T) { config.Fields.Headers.Default = FieldModeKeep config.Fields.Cookies.Default = FieldModeKeep entry := getJSONEntry(t, config) - ExpectEqual(t, len(entry.Headers["Cookie"]), 0) for _, cookie := range req.Cookies() { - ExpectEqual(t, entry.Cookies[cookie.Name], cookie.Value) + expect.Equal(t, entry.Cookies[cookie.Name], cookie.Value) } } @@ -73,9 +71,8 @@ func TestAccessLoggerJSONRedactCookies(t *testing.T) { config.Fields.Headers.Default = FieldModeKeep config.Fields.Cookies.Default = FieldModeRedact entry := getJSONEntry(t, config) - ExpectEqual(t, len(entry.Headers["Cookie"]), 0) for _, cookie := range req.Cookies() { - ExpectEqual(t, entry.Cookies[cookie.Name], RedactedValue) + expect.Equal(t, entry.Cookies[cookie.Name], RedactedValue) } } @@ -83,14 +80,14 @@ func TestAccessLoggerJSONDropQuery(t *testing.T) { config := DefaultConfig() config.Fields.Query.Default = FieldModeDrop entry := getJSONEntry(t, config) - ExpectEqual(t, entry.Query["foo"], nil) - ExpectEqual(t, entry.Query["bar"], nil) + expect.Equal(t, entry.Query["foo"], nil) + expect.Equal(t, entry.Query["bar"], nil) } func TestAccessLoggerJSONRedactQuery(t *testing.T) { config := DefaultConfig() config.Fields.Query.Default = FieldModeRedact entry := getJSONEntry(t, config) - ExpectEqual(t, entry.Query["foo"], []string{RedactedValue}) - ExpectEqual(t, entry.Query["bar"], []string{RedactedValue}) + expect.Equal(t, entry.Query["foo"], []string{RedactedValue}) + expect.Equal(t, entry.Query["bar"], []string{RedactedValue}) } diff --git a/internal/net/gphttp/accesslog/file_logger.go b/internal/net/gphttp/accesslog/file_logger.go index 1b3ace15..a3679ac5 100644 --- a/internal/net/gphttp/accesslog/file_logger.go +++ b/internal/net/gphttp/accesslog/file_logger.go @@ -3,11 +3,10 @@ package accesslog import ( "fmt" "os" - "path" + pathPkg "path" "sync" "github.com/yusing/go-proxy/internal/logging" - "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/utils" ) @@ -27,16 +26,16 @@ var ( openedFilesMu sync.Mutex ) -func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) { +func newFileIO(path string) (AccessLogIO, error) { openedFilesMu.Lock() var file *File - path := path.Clean(cfg.Path) + path = pathPkg.Clean(path) if opened, ok := openedFiles[path]; ok { opened.refCount.Add() file = opened } else { - f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o644) + f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o644) if err != nil { openedFilesMu.Unlock() return nil, fmt.Errorf("access log open error: %w", err) @@ -47,7 +46,7 @@ func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) } openedFilesMu.Unlock() - return NewAccessLogger(parent, file, cfg), nil + return file, nil } func (f *File) Close() error { diff --git a/internal/net/gphttp/accesslog/file_logger_test.go b/internal/net/gphttp/accesslog/file_logger_test.go index 0321a853..b9961f63 100644 --- a/internal/net/gphttp/accesslog/file_logger_test.go +++ b/internal/net/gphttp/accesslog/file_logger_test.go @@ -6,7 +6,7 @@ import ( "sync" "testing" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" "github.com/yusing/go-proxy/internal/task" ) @@ -16,26 +16,25 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) { cfg := DefaultConfig() cfg.Path = "test.log" - parent := task.RootTask("test", false) loggerCount := 10 accessLogIOs := make([]AccessLogIO, loggerCount) // make test log file file, err := os.Create(cfg.Path) - ExpectNoError(t, err) + expect.NoError(t, err) file.Close() t.Cleanup(func() { - ExpectNoError(t, os.Remove(cfg.Path)) + expect.NoError(t, os.Remove(cfg.Path)) }) for i := range loggerCount { wg.Add(1) go func(index int) { defer wg.Done() - logger, err := NewFileAccessLogger(parent, cfg) - ExpectNoError(t, err) - accessLogIOs[index] = logger.io + file, err := newFileIO(cfg.Path) + expect.NoError(t, err) + accessLogIOs[index] = file }(i) } @@ -43,12 +42,12 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) { firstIO := accessLogIOs[0] for _, io := range accessLogIOs { - ExpectEqual(t, io, firstIO) + expect.Equal(t, io, firstIO) } } func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) { - var file MockFile + file := NewMockFile() cfg := DefaultConfig() cfg.BufferSize = 1024 @@ -59,15 +58,15 @@ func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) { loggers := make([]*AccessLogger, loggerCount) for i := range loggerCount { - loggers[i] = NewAccessLogger(parent, &file, cfg) + loggers[i] = NewAccessLoggerWithIO(parent, file, cfg) } var wg sync.WaitGroup req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) resp := &http.Response{StatusCode: http.StatusOK} + wg.Add(len(loggers)) for _, logger := range loggers { - wg.Add(1) go func(l *AccessLogger) { defer wg.Done() parallelLog(l, req, resp, logCountPerLogger) @@ -78,8 +77,8 @@ func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) { wg.Wait() expected := loggerCount * logCountPerLogger - actual := file.LineCount() - ExpectEqual(t, actual, expected) + actual := file.NumLines() + expect.Equal(t, actual, expected) } func parallelLog(logger *AccessLogger, req *http.Request, resp *http.Response, n int) { diff --git a/internal/net/gphttp/accesslog/filter.go b/internal/net/gphttp/accesslog/filter.go index c0c3e29d..605ec543 100644 --- a/internal/net/gphttp/accesslog/filter.go +++ b/internal/net/gphttp/accesslog/filter.go @@ -6,7 +6,7 @@ import ( "strings" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/net/types" + gpnet "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/utils/strutils" ) @@ -24,7 +24,9 @@ type ( Key, Value string } Host string - CIDR struct{ types.CIDR } + CIDR struct { + gpnet.CIDR + } ) var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter") @@ -86,7 +88,7 @@ func (h Host) Fulfill(req *http.Request, res *http.Response) bool { return req.Host == string(h) } -func (cidr CIDR) Fulfill(req *http.Request, res *http.Response) bool { +func (cidr *CIDR) Fulfill(req *http.Request, res *http.Response) bool { ip, _, err := net.SplitHostPort(req.RemoteAddr) if err != nil { ip = req.RemoteAddr diff --git a/internal/net/gphttp/accesslog/filter_test.go b/internal/net/gphttp/accesslog/filter_test.go index a934a7b6..22e83604 100644 --- a/internal/net/gphttp/accesslog/filter_test.go +++ b/internal/net/gphttp/accesslog/filter_test.go @@ -1,12 +1,14 @@ package accesslog_test import ( + "net" "net/http" "testing" . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" + gpnet "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/utils/strutils" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestStatusCodeFilter(t *testing.T) { @@ -15,20 +17,20 @@ func TestStatusCodeFilter(t *testing.T) { } t.Run("positive", func(t *testing.T) { filter := &LogFilter[*StatusCodeRange]{} - ExpectTrue(t, filter.CheckKeep(nil, nil)) + expect.True(t, filter.CheckKeep(nil, nil)) // keep any 2xx 3xx (inclusive) filter.Values = values - ExpectFalse(t, filter.CheckKeep(nil, &http.Response{ + expect.False(t, filter.CheckKeep(nil, &http.Response{ StatusCode: http.StatusForbidden, })) - ExpectTrue(t, filter.CheckKeep(nil, &http.Response{ + expect.True(t, filter.CheckKeep(nil, &http.Response{ StatusCode: http.StatusOK, })) - ExpectTrue(t, filter.CheckKeep(nil, &http.Response{ + expect.True(t, filter.CheckKeep(nil, &http.Response{ StatusCode: http.StatusMultipleChoices, })) - ExpectTrue(t, filter.CheckKeep(nil, &http.Response{ + expect.True(t, filter.CheckKeep(nil, &http.Response{ StatusCode: http.StatusPermanentRedirect, })) }) @@ -37,20 +39,20 @@ func TestStatusCodeFilter(t *testing.T) { filter := &LogFilter[*StatusCodeRange]{ Negative: true, } - ExpectFalse(t, filter.CheckKeep(nil, nil)) + expect.False(t, filter.CheckKeep(nil, nil)) // drop any 2xx 3xx (inclusive) filter.Values = values - ExpectTrue(t, filter.CheckKeep(nil, &http.Response{ + expect.True(t, filter.CheckKeep(nil, &http.Response{ StatusCode: http.StatusForbidden, })) - ExpectFalse(t, filter.CheckKeep(nil, &http.Response{ + expect.False(t, filter.CheckKeep(nil, &http.Response{ StatusCode: http.StatusOK, })) - ExpectFalse(t, filter.CheckKeep(nil, &http.Response{ + expect.False(t, filter.CheckKeep(nil, &http.Response{ StatusCode: http.StatusMultipleChoices, })) - ExpectFalse(t, filter.CheckKeep(nil, &http.Response{ + expect.False(t, filter.CheckKeep(nil, &http.Response{ StatusCode: http.StatusPermanentRedirect, })) }) @@ -59,19 +61,19 @@ func TestStatusCodeFilter(t *testing.T) { func TestMethodFilter(t *testing.T) { t.Run("positive", func(t *testing.T) { filter := &LogFilter[HTTPMethod]{} - ExpectTrue(t, filter.CheckKeep(&http.Request{ + expect.True(t, filter.CheckKeep(&http.Request{ Method: http.MethodGet, }, nil)) - ExpectTrue(t, filter.CheckKeep(&http.Request{ + expect.True(t, filter.CheckKeep(&http.Request{ Method: http.MethodPost, }, nil)) // keep get only filter.Values = []HTTPMethod{http.MethodGet} - ExpectTrue(t, filter.CheckKeep(&http.Request{ + expect.True(t, filter.CheckKeep(&http.Request{ Method: http.MethodGet, }, nil)) - ExpectFalse(t, filter.CheckKeep(&http.Request{ + expect.False(t, filter.CheckKeep(&http.Request{ Method: http.MethodPost, }, nil)) }) @@ -80,19 +82,19 @@ func TestMethodFilter(t *testing.T) { filter := &LogFilter[HTTPMethod]{ Negative: true, } - ExpectFalse(t, filter.CheckKeep(&http.Request{ + expect.False(t, filter.CheckKeep(&http.Request{ Method: http.MethodGet, }, nil)) - ExpectFalse(t, filter.CheckKeep(&http.Request{ + expect.False(t, filter.CheckKeep(&http.Request{ Method: http.MethodPost, }, nil)) // drop post only filter.Values = []HTTPMethod{http.MethodPost} - ExpectFalse(t, filter.CheckKeep(&http.Request{ + expect.False(t, filter.CheckKeep(&http.Request{ Method: http.MethodPost, }, nil)) - ExpectTrue(t, filter.CheckKeep(&http.Request{ + expect.True(t, filter.CheckKeep(&http.Request{ Method: http.MethodGet, }, nil)) }) @@ -112,53 +114,54 @@ func TestHeaderFilter(t *testing.T) { headerFoo := []*HTTPHeader{ strutils.MustParse[*HTTPHeader]("Foo"), } - ExpectEqual(t, headerFoo[0].Key, "Foo") - ExpectEqual(t, headerFoo[0].Value, "") + expect.Equal(t, headerFoo[0].Key, "Foo") + expect.Equal(t, headerFoo[0].Value, "") headerFooBar := []*HTTPHeader{ strutils.MustParse[*HTTPHeader]("Foo=bar"), } - ExpectEqual(t, headerFooBar[0].Key, "Foo") - ExpectEqual(t, headerFooBar[0].Value, "bar") + expect.Equal(t, headerFooBar[0].Key, "Foo") + expect.Equal(t, headerFooBar[0].Value, "bar") t.Run("positive", func(t *testing.T) { filter := &LogFilter[*HTTPHeader]{} - ExpectTrue(t, filter.CheckKeep(fooBar, nil)) - ExpectTrue(t, filter.CheckKeep(fooBaz, nil)) + expect.True(t, filter.CheckKeep(fooBar, nil)) + expect.True(t, filter.CheckKeep(fooBaz, nil)) // keep any foo filter.Values = headerFoo - ExpectTrue(t, filter.CheckKeep(fooBar, nil)) - ExpectTrue(t, filter.CheckKeep(fooBaz, nil)) + expect.True(t, filter.CheckKeep(fooBar, nil)) + expect.True(t, filter.CheckKeep(fooBaz, nil)) // keep foo == bar filter.Values = headerFooBar - ExpectTrue(t, filter.CheckKeep(fooBar, nil)) - ExpectFalse(t, filter.CheckKeep(fooBaz, nil)) + expect.True(t, filter.CheckKeep(fooBar, nil)) + expect.False(t, filter.CheckKeep(fooBaz, nil)) }) t.Run("negative", func(t *testing.T) { filter := &LogFilter[*HTTPHeader]{ Negative: true, } - ExpectFalse(t, filter.CheckKeep(fooBar, nil)) - ExpectFalse(t, filter.CheckKeep(fooBaz, nil)) + expect.False(t, filter.CheckKeep(fooBar, nil)) + expect.False(t, filter.CheckKeep(fooBaz, nil)) // drop any foo filter.Values = headerFoo - ExpectFalse(t, filter.CheckKeep(fooBar, nil)) - ExpectFalse(t, filter.CheckKeep(fooBaz, nil)) + expect.False(t, filter.CheckKeep(fooBar, nil)) + expect.False(t, filter.CheckKeep(fooBaz, nil)) // drop foo == bar filter.Values = headerFooBar - ExpectFalse(t, filter.CheckKeep(fooBar, nil)) - ExpectTrue(t, filter.CheckKeep(fooBaz, nil)) + expect.False(t, filter.CheckKeep(fooBar, nil)) + expect.True(t, filter.CheckKeep(fooBaz, nil)) }) } func TestCIDRFilter(t *testing.T) { - cidr := []*CIDR{ - strutils.MustParse[*CIDR]("192.168.10.0/24"), - } - ExpectEqual(t, cidr[0].String(), "192.168.10.0/24") + cidr := []*CIDR{{gpnet.CIDR{ + IP: net.ParseIP("192.168.10.0"), + Mask: net.CIDRMask(24, 32), + }}} + expect.Equal(t, cidr[0].String(), "192.168.10.0/24") inCIDR := &http.Request{ RemoteAddr: "192.168.10.1", } @@ -168,21 +171,21 @@ func TestCIDRFilter(t *testing.T) { t.Run("positive", func(t *testing.T) { filter := &LogFilter[*CIDR]{} - ExpectTrue(t, filter.CheckKeep(inCIDR, nil)) - ExpectTrue(t, filter.CheckKeep(notInCIDR, nil)) + expect.True(t, filter.CheckKeep(inCIDR, nil)) + expect.True(t, filter.CheckKeep(notInCIDR, nil)) filter.Values = cidr - ExpectTrue(t, filter.CheckKeep(inCIDR, nil)) - ExpectFalse(t, filter.CheckKeep(notInCIDR, nil)) + expect.True(t, filter.CheckKeep(inCIDR, nil)) + expect.False(t, filter.CheckKeep(notInCIDR, nil)) }) t.Run("negative", func(t *testing.T) { filter := &LogFilter[*CIDR]{Negative: true} - ExpectFalse(t, filter.CheckKeep(inCIDR, nil)) - ExpectFalse(t, filter.CheckKeep(notInCIDR, nil)) + expect.False(t, filter.CheckKeep(inCIDR, nil)) + expect.False(t, filter.CheckKeep(notInCIDR, nil)) filter.Values = cidr - ExpectFalse(t, filter.CheckKeep(inCIDR, nil)) - ExpectTrue(t, filter.CheckKeep(notInCIDR, nil)) + expect.False(t, filter.CheckKeep(inCIDR, nil)) + expect.True(t, filter.CheckKeep(notInCIDR, nil)) }) } diff --git a/internal/net/gphttp/accesslog/formatter.go b/internal/net/gphttp/accesslog/formatter.go index d0fcd681..33e5c5bf 100644 --- a/internal/net/gphttp/accesslog/formatter.go +++ b/internal/net/gphttp/accesslog/formatter.go @@ -2,42 +2,20 @@ package accesslog import ( "bytes" - "encoding/json" + "iter" "net" "net/http" - "net/url" "strconv" - "time" - "github.com/yusing/go-proxy/internal/logging" + "github.com/rs/zerolog" ) type ( CommonFormatter struct { - cfg *Fields - GetTimeNow func() time.Time // for testing purposes only + cfg *Fields } CombinedFormatter struct{ CommonFormatter } JSONFormatter struct{ CommonFormatter } - - JSONLogEntry struct { - Time string `json:"time"` - IP string `json:"ip"` - Method string `json:"method"` - Scheme string `json:"scheme"` - Host string `json:"host"` - URI string `json:"uri"` - Protocol string `json:"protocol"` - Status int `json:"status"` - Error string `json:"error,omitempty"` - ContentType string `json:"type"` - Size int64 `json:"size"` - Referer string `json:"referer"` - UserAgent string `json:"useragent"` - Query map[string][]string `json:"query,omitempty"` - Headers map[string][]string `json:"headers,omitempty"` - Cookies map[string]string `json:"cookies,omitempty"` - } ) const LogTimeFormat = "02/Jan/2006:15:04:05 -0700" @@ -49,12 +27,24 @@ func scheme(req *http.Request) string { return "http" } -func requestURI(u *url.URL, query url.Values) string { - uri := u.EscapedPath() - if len(query) > 0 { - uri += "?" + query.Encode() +func appendRequestURI(line []byte, req *http.Request, query iter.Seq2[string, []string]) []byte { + uri := req.URL.EscapedPath() + line = append(line, uri...) + isFirst := true + for k, v := range query { + if isFirst { + line = append(line, '?') + isFirst = false + } else { + line = append(line, '&') + } + line = append(line, k...) + line = append(line, '=') + for _, v := range v { + line = append(line, v...) + } } - return uri + return line } func clientIP(req *http.Request) string { @@ -65,80 +55,102 @@ func clientIP(req *http.Request) string { return req.RemoteAddr } -// debug only. -func (f *CommonFormatter) SetGetTimeNow(getTimeNow func() time.Time) { - f.GetTimeNow = getTimeNow +func (f *CommonFormatter) AppendLog(line []byte, req *http.Request, res *http.Response) []byte { + query := f.cfg.Query.IterQuery(req.URL.Query()) + + line = append(line, req.Host...) + line = append(line, ' ') + + line = append(line, clientIP(req)...) + line = append(line, " - - ["...) + + line = TimeNow().AppendFormat(line, LogTimeFormat) + line = append(line, `] "`...) + + line = append(line, req.Method...) + line = append(line, ' ') + line = appendRequestURI(line, req, query) + line = append(line, ' ') + line = append(line, req.Proto...) + line = append(line, '"') + line = append(line, ' ') + + line = strconv.AppendInt(line, int64(res.StatusCode), 10) + line = append(line, ' ') + line = strconv.AppendInt(line, res.ContentLength, 10) + return line } -func (f *CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) { - query := f.cfg.Query.ProcessQuery(req.URL.Query()) - - line.WriteString(req.Host) - line.WriteRune(' ') - - line.WriteString(clientIP(req)) - line.WriteString(" - - [") - - line.WriteString(f.GetTimeNow().Format(LogTimeFormat)) - line.WriteString("] \"") - - line.WriteString(req.Method) - line.WriteRune(' ') - line.WriteString(requestURI(req.URL, query)) - line.WriteRune(' ') - line.WriteString(req.Proto) - line.WriteString("\" ") - - line.WriteString(strconv.Itoa(res.StatusCode)) - line.WriteRune(' ') - line.WriteString(strconv.FormatInt(res.ContentLength, 10)) +func (f *CombinedFormatter) AppendLog(line []byte, req *http.Request, res *http.Response) []byte { + line = f.CommonFormatter.AppendLog(line, req, res) + line = append(line, " \""...) + line = append(line, req.Referer()...) + line = append(line, "\" \""...) + line = append(line, req.UserAgent()...) + line = append(line, '"') + return line } -func (f *CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) { - f.CommonFormatter.Format(line, req, res) - line.WriteString(" \"") - line.WriteString(req.Referer()) - line.WriteString("\" \"") - line.WriteString(req.UserAgent()) - line.WriteRune('"') +type zeroLogStringStringMapMarshaler struct { + values map[string]string } -func (f *JSONFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) { - query := f.cfg.Query.ProcessQuery(req.URL.Query()) - headers := f.cfg.Headers.ProcessHeaders(req.Header) - headers.Del("Cookie") - cookies := f.cfg.Cookies.ProcessCookies(req.Cookies()) - - entry := JSONLogEntry{ - Time: f.GetTimeNow().Format(LogTimeFormat), - IP: clientIP(req), - Method: req.Method, - Scheme: scheme(req), - Host: req.Host, - URI: requestURI(req.URL, query), - Protocol: req.Proto, - Status: res.StatusCode, - ContentType: res.Header.Get("Content-Type"), - Size: res.ContentLength, - Referer: req.Referer(), - UserAgent: req.UserAgent(), - Query: query, - Headers: headers, - Cookies: cookies, +func (z *zeroLogStringStringMapMarshaler) MarshalZerologObject(e *zerolog.Event) { + if len(z.values) == 0 { + return } + for k, v := range z.values { + e.Str(k, v) + } +} + +type zeroLogStringStringSliceMapMarshaler struct { + values map[string][]string +} + +func (z *zeroLogStringStringSliceMapMarshaler) MarshalZerologObject(e *zerolog.Event) { + if len(z.values) == 0 { + return + } + for k, v := range z.values { + e.Strs(k, v) + } +} + +func (f *JSONFormatter) AppendLog(line []byte, req *http.Request, res *http.Response) []byte { + query := f.cfg.Query.ZerologQuery(req.URL.Query()) + headers := f.cfg.Headers.ZerologHeaders(req.Header) + cookies := f.cfg.Cookies.ZerologCookies(req.Cookies()) + contentType := res.Header.Get("Content-Type") + + writer := bytes.NewBuffer(line) + logger := zerolog.New(writer).With().Logger() + event := logger.Info(). + Str("time", TimeNow().Format(LogTimeFormat)). + Str("ip", clientIP(req)). + Str("method", req.Method). + Str("scheme", scheme(req)). + Str("host", req.Host). + Str("path", req.URL.Path). + Str("protocol", req.Proto). + Int("status", res.StatusCode). + Str("type", contentType). + Int64("size", res.ContentLength). + Str("referer", req.Referer()). + Str("useragent", req.UserAgent()). + Object("query", query). + Object("headers", headers). + Object("cookies", cookies) if res.StatusCode >= 400 { - entry.Error = res.Status + if res.Status != "" { + event.Str("error", res.Status) + } else { + event.Str("error", http.StatusText(res.StatusCode)) + } } - if entry.ContentType == "" { - // try to get content type from request - entry.ContentType = req.Header.Get("Content-Type") - } - - marshaller := json.NewEncoder(line) - err := marshaller.Encode(entry) - if err != nil { - logging.Err(err).Msg("failed to marshal json log") - } + // NOTE: zerolog will append a newline to the buffer + event.Send() + return writer.Bytes() } diff --git a/internal/net/gphttp/accesslog/mock_file.go b/internal/net/gphttp/accesslog/mock_file.go index 54c30c63..acb2d100 100644 --- a/internal/net/gphttp/accesslog/mock_file.go +++ b/internal/net/gphttp/accesslog/mock_file.go @@ -3,75 +3,47 @@ package accesslog import ( "bytes" "io" - "sync" + + "github.com/spf13/afero" ) +type noLock struct{} + +func (noLock) Lock() {} +func (noLock) Unlock() {} + type MockFile struct { - data []byte - position int64 - sync.Mutex + afero.File + noLock } -func (m *MockFile) Seek(offset int64, whence int) (int64, error) { - switch whence { - case io.SeekStart: - m.position = offset - case io.SeekCurrent: - m.position += offset - case io.SeekEnd: - m.position = int64(len(m.data)) + offset +func NewMockFile() *MockFile { + f, _ := afero.TempFile(afero.NewMemMapFs(), "", "") + return &MockFile{ + File: f, } - return m.position, nil -} - -func (m *MockFile) Write(p []byte) (n int, err error) { - m.data = append(m.data, p...) - n = len(p) - m.position += int64(n) - return -} - -func (m *MockFile) Name() string { - return "mock" -} - -func (m *MockFile) Read(p []byte) (n int, err error) { - if m.position >= int64(len(m.data)) { - return 0, io.EOF - } - n = copy(p, m.data[m.position:]) - m.position += int64(n) - return n, nil -} - -func (m *MockFile) ReadAt(p []byte, off int64) (n int, err error) { - if off >= int64(len(m.data)) { - return 0, io.EOF - } - n = copy(p, m.data[off:]) - return n, nil -} - -func (m *MockFile) Close() error { - return nil -} - -func (m *MockFile) Truncate(size int64) error { - m.data = m.data[:size] - m.position = size - return nil -} - -func (m *MockFile) LineCount() int { - m.Lock() - defer m.Unlock() - return bytes.Count(m.data[:m.position], []byte("\n")) } func (m *MockFile) Len() int64 { - return m.position + filesize, _ := m.Seek(0, io.SeekEnd) + _, _ = m.Seek(0, io.SeekStart) + return filesize } func (m *MockFile) Content() []byte { - return m.data[:m.position] + buf := bytes.NewBuffer(nil) + m.Seek(0, io.SeekStart) + _, _ = buf.ReadFrom(m.File) + m.Seek(0, io.SeekStart) + return buf.Bytes() +} + +func (m *MockFile) NumLines() int { + content := m.Content() + count := bytes.Count(content, []byte("\n")) + // account for last line if it does not end with a newline + if len(content) > 0 && content[len(content)-1] != '\n' { + count++ + } + return count } diff --git a/internal/net/gphttp/accesslog/multi_writer.go b/internal/net/gphttp/accesslog/multi_writer.go new file mode 100644 index 00000000..3577bc48 --- /dev/null +++ b/internal/net/gphttp/accesslog/multi_writer.go @@ -0,0 +1,46 @@ +package accesslog + +import "strings" + +type MultiWriter struct { + writers []AccessLogIO +} + +func NewMultiWriter(writers ...AccessLogIO) AccessLogIO { + if len(writers) == 0 { + return nil + } + if len(writers) == 1 { + return writers[0] + } + return &MultiWriter{ + writers: writers, + } +} + +func (w *MultiWriter) Write(p []byte) (n int, err error) { + for _, writer := range w.writers { + writer.Write(p) + } + return len(p), nil +} + +func (w *MultiWriter) Lock() { + for _, writer := range w.writers { + writer.Lock() + } +} + +func (w *MultiWriter) Unlock() { + for _, writer := range w.writers { + writer.Unlock() + } +} + +func (w *MultiWriter) Name() string { + names := make([]string, len(w.writers)) + for i, writer := range w.writers { + names[i] = writer.Name() + } + return strings.Join(names, ", ") +} diff --git a/internal/net/gphttp/accesslog/retention.go b/internal/net/gphttp/accesslog/retention.go index f0b5e2a7..3a130cc2 100644 --- a/internal/net/gphttp/accesslog/retention.go +++ b/internal/net/gphttp/accesslog/retention.go @@ -1,6 +1,7 @@ package accesslog import ( + "fmt" "strconv" "github.com/yusing/go-proxy/internal/gperr" @@ -8,8 +9,9 @@ import ( ) type Retention struct { - Days uint64 `json:"days"` - Last uint64 `json:"last"` + Days uint64 `json:"days"` + Last uint64 `json:"last"` + KeepSize uint64 `json:"keep_size"` } var ( @@ -17,7 +19,8 @@ var ( ErrZeroValue = gperr.New("zero value") ) -var defaultChunkSize = 64 * 1024 // 64KB +// see back_scanner_test.go#L210 for benchmarks +var defaultChunkSize = 256 * kilobyte // Syntax: // @@ -25,6 +28,8 @@ var defaultChunkSize = 64 * 1024 // 64KB // // last // +// KB|MB|GB|kb|mb|gb +// // Parse implements strutils.Parser. func (r *Retention) Parse(v string) (err error) { split := strutils.SplitSpace(v) @@ -35,22 +40,55 @@ func (r *Retention) Parse(v string) (err error) { case "last": r.Last, err = strconv.ParseUint(split[1], 10, 64) default: // days|weeks|months - r.Days, err = strconv.ParseUint(split[0], 10, 64) + n, err := strconv.ParseUint(split[0], 10, 64) if err != nil { - return + return err } switch split[1] { - case "days": - case "weeks": - r.Days *= 7 - case "months": - r.Days *= 30 + case "day", "days": + r.Days = n + case "week", "weeks": + r.Days = n * 7 + case "month", "months": + r.Days = n * 30 + case "kb", "Kb": + r.KeepSize = n * kilobits + case "KB": + r.KeepSize = n * kilobyte + case "mb", "Mb": + r.KeepSize = n * megabits + case "MB": + r.KeepSize = n * megabyte + case "gb", "Gb": + r.KeepSize = n * gigabits + case "GB": + r.KeepSize = n * gigabyte default: return ErrInvalidSyntax.Subject("unit " + split[1]) } } - if r.Days == 0 && r.Last == 0 { + if !r.IsValid() { return ErrZeroValue } return } + +func (r *Retention) String() string { + if r.Days > 0 { + return fmt.Sprintf("%d days", r.Days) + } + if r.Last > 0 { + return fmt.Sprintf("last %d", r.Last) + } + if r.KeepSize > 0 { + return strutils.FormatByteSize(r.KeepSize) + } + return "" +} + +func (r *Retention) IsValid() bool { + if r == nil { + return false + } + return r.Days > 0 || r.Last > 0 || r.KeepSize > 0 +} diff --git a/internal/net/gphttp/accesslog/retention_test.go b/internal/net/gphttp/accesslog/retention_test.go index 2a2eb988..125039e7 100644 --- a/internal/net/gphttp/accesslog/retention_test.go +++ b/internal/net/gphttp/accesslog/retention_test.go @@ -4,7 +4,7 @@ import ( "testing" . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestParseRetention(t *testing.T) { @@ -24,9 +24,9 @@ func TestParseRetention(t *testing.T) { r := &Retention{} err := r.Parse(test.input) if !test.shouldErr { - ExpectNoError(t, err) + expect.NoError(t, err) } else { - ExpectEqual(t, r, test.expected) + expect.Equal(t, r, test.expected) } }) } diff --git a/internal/net/gphttp/accesslog/rotate.go b/internal/net/gphttp/accesslog/rotate.go index e93c22d0..4578aaf4 100644 --- a/internal/net/gphttp/accesslog/rotate.go +++ b/internal/net/gphttp/accesslog/rotate.go @@ -4,116 +4,252 @@ import ( "bytes" "io" "time" + + "github.com/rs/zerolog" + "github.com/yusing/go-proxy/internal/utils/strutils" + "github.com/yusing/go-proxy/internal/utils/synk" ) -func (l *AccessLogger) rotate() (err error) { - // Get retention configuration - config := l.Config().Retention - var shouldKeep func(t time.Time, lineCount int) bool +type supportRotate interface { + io.Reader + io.Writer + io.Seeker + io.ReaderAt + io.WriterAt + Truncate(size int64) error +} + +type RotateResult struct { + Filename string + OriginalSize int64 // original size of the file + NumBytesRead int64 // number of bytes read from the file + NumBytesKeep int64 // number of bytes to keep + NumLinesRead int // number of lines read from the file + NumLinesKeep int // number of lines to keep + NumLinesInvalid int // number of invalid lines +} + +func (r *RotateResult) Print(logger *zerolog.Logger) { + logger.Info(). + Str("original_size", strutils.FormatByteSize(r.OriginalSize)). + Str("bytes_read", strutils.FormatByteSize(r.NumBytesRead)). + Str("bytes_keep", strutils.FormatByteSize(r.NumBytesKeep)). + Int("lines_read", r.NumLinesRead). + Int("lines_keep", r.NumLinesKeep). + Int("lines_invalid", r.NumLinesInvalid). + Msg("log rotate result") +} + +type lineInfo struct { + Pos int64 // Position from the start of the file + Size int64 // Size of this line +} + +// do not allocate initial size +var rotateBytePool = synk.NewBytesPool(0, 16*1024*1024) + +// rotateLogFile rotates the log file based on the retention policy. +// It returns the result of the rotation and an error if any. +// +// The file is rotated by reading the file backward line-by-line +// and stop once error occurs or found a line that should not be kept. +// +// Any invalid lines will be skipped and not included in the result. +// +// If the file does not need to be rotated, it returns nil, nil. +func rotateLogFile(file supportRotate, config *Retention) (result *RotateResult, err error) { + if config.KeepSize > 0 { + return rotateLogFileBySize(file, config) + } + + var shouldStop func() bool + t := TimeNow() if config.Last > 0 { - shouldKeep = func(_ time.Time, lineCount int) bool { - return lineCount < int(config.Last) - } + shouldStop = func() bool { return result.NumLinesKeep-result.NumLinesInvalid == int(config.Last) } + // not needed to parse time for last N lines } else if config.Days > 0 { - cutoff := time.Now().AddDate(0, 0, -int(config.Days)) - shouldKeep = func(t time.Time, _ int) bool { - return !t.IsZero() && !t.Before(cutoff) - } + cutoff := TimeNow().AddDate(0, 0, -int(config.Days)+1) + shouldStop = func() bool { return t.Before(cutoff) } } else { - return nil // No retention policy set + return nil, nil // should not happen } - s := NewBackScanner(l.io, defaultChunkSize) - nRead := 0 - nLines := 0 + s := NewBackScanner(file, defaultChunkSize) + result = &RotateResult{ + OriginalSize: s.FileSize(), + } + + // nothing to rotate, return the nothing + if result.OriginalSize == 0 { + return nil, nil + } + + // Store the line positions and sizes we want to keep + linesToKeep := make([]lineInfo, 0) + lastLineValid := false + for s.Scan() { - nRead += len(s.Bytes()) + 1 - nLines++ - t := ParseLogTime(s.Bytes()) - if !shouldKeep(t, nLines) { + result.NumLinesRead++ + lineSize := int64(len(s.Bytes()) + 1) // +1 for newline + linePos := result.OriginalSize - result.NumBytesRead - lineSize + result.NumBytesRead += lineSize + + // Check if line has valid time + t = ParseLogTime(s.Bytes()) + if t.IsZero() { + result.NumLinesInvalid++ + lastLineValid = false + continue + } + + // Check if we should stop based on retention policy + if shouldStop() { break } + + // Add line to those we want to keep + if lastLineValid { + last := linesToKeep[len(linesToKeep)-1] + linesToKeep[len(linesToKeep)-1] = lineInfo{ + Pos: last.Pos - lineSize, + Size: last.Size + lineSize, + } + } else { + linesToKeep = append(linesToKeep, lineInfo{ + Pos: linePos, + Size: lineSize, + }) + } + result.NumBytesKeep += lineSize + result.NumLinesKeep++ + lastLineValid = true } + if s.Err() != nil { - return s.Err() + return nil, s.Err() } - beg := int64(nRead) - if _, err := l.io.Seek(-beg, io.SeekEnd); err != nil { - return err - } - buf := make([]byte, nRead) - if _, err := l.io.Read(buf); err != nil { - return err + // nothing to keep, truncate to empty + if len(linesToKeep) == 0 { + return nil, file.Truncate(0) } - if err := l.writeTruncate(buf); err != nil { - return err + // nothing to rotate, return nothing + if result.NumBytesKeep == result.OriginalSize { + return nil, nil } - return nil + + // Read each line and write it to the beginning of the file + writePos := int64(0) + buf := rotateBytePool.Get() + defer rotateBytePool.Put(buf) + + // in reverse order to keep the order of the lines (from old to new) + for i := len(linesToKeep) - 1; i >= 0; i-- { + line := linesToKeep[i] + n := line.Size + if cap(buf) < int(n) { + buf = make([]byte, n) + } + buf = buf[:n] + + // Read the line from its original position + if _, err := file.ReadAt(buf, line.Pos); err != nil { + return nil, err + } + + // Write it to the new position + if _, err := file.WriteAt(buf, writePos); err != nil { + return nil, err + } + writePos += n + } + + if err := file.Truncate(writePos); err != nil { + return nil, err + } + + return result, nil } -func (l *AccessLogger) writeTruncate(buf []byte) (err error) { - // Seek to beginning and truncate - if _, err := l.io.Seek(0, 0); err != nil { - return err - } - - // Write buffer back to file - nWritten, err := l.buffered.Write(buf) +// rotateLogFileBySize rotates the log file by size. +// It returns the result of the rotation and an error if any. +// +// The file is not being read, it just truncate the file to the new size. +// +// Invalid lines will not be detected and included in the result. +func rotateLogFileBySize(file supportRotate, config *Retention) (result *RotateResult, err error) { + filesize, err := file.Seek(0, io.SeekEnd) if err != nil { - return err - } - if err = l.buffered.Flush(); err != nil { - return err + return nil, err } - // Truncate file - if err = l.io.Truncate(int64(nWritten)); err != nil { - return err + result = &RotateResult{ + OriginalSize: filesize, } - // check bytes written == buffer size - if nWritten != len(buf) { - return io.ErrShortWrite + keepSize := int64(config.KeepSize) + if keepSize >= filesize { + result.NumBytesKeep = filesize + return result, nil } - return + result.NumBytesKeep = keepSize + + err = file.Truncate(keepSize) + if err != nil { + return nil, err + } + + return result, nil } -const timeLen = len(`"time":"`) - -var timeJSON = []byte(`"time":"`) - +// ParseLogTime parses the time from the log line. +// It returns the time if the time is found and valid in the log line, +// otherwise it returns zero time. func ParseLogTime(line []byte) (t time.Time) { if len(line) == 0 { return } - if i := bytes.Index(line, timeJSON); i != -1 { // JSON format - var jsonStart = i + timeLen - var jsonEnd = i + timeLen + len(LogTimeFormat) - if len(line) < jsonEnd { - return - } - timeStr := line[jsonStart:jsonEnd] - t, _ = time.Parse(LogTimeFormat, string(timeStr)) + if timeStr := ExtractTime(line); timeStr != nil { + t, _ = time.Parse(LogTimeFormat, string(timeStr)) // ignore error return } - - // Common/Combined format - // Format: - - [02/Jan/2006:15:04:05 -0700] ... - start := bytes.IndexByte(line, '[') - if start == -1 { - return - } - end := bytes.IndexByte(line[start:], ']') - if end == -1 { - return - } - end += start // adjust end position relative to full line - - timeStr := line[start+1 : end] - t, _ = time.Parse(LogTimeFormat, string(timeStr)) // ignore error return } + +var timeJSON = []byte(`"time":"`) + +// ExtractTime extracts the time from the log line. +// It returns the time if the time is found, +// otherwise it returns nil. +// +// The returned time is not validated. +func ExtractTime(line []byte) []byte { + //TODO: optimize this + switch line[0] { + case '{': // JSON format + if i := bytes.Index(line, timeJSON); i != -1 { + var jsonStart = i + len(`"time":"`) + var jsonEnd = i + len(`"time":"`) + len(LogTimeFormat) + if len(line) < jsonEnd { + return nil + } + return line[jsonStart:jsonEnd] + } + return nil // invalid JSON line + default: + // Common/Combined format + // Format: - - [02/Jan/2006:15:04:05 -0700] ... + start := bytes.IndexByte(line, '[') + if start == -1 { + return nil + } + end := start + 1 + len(LogTimeFormat) + if len(line) < end { + return nil + } + return line[start+1 : end] + } +} diff --git a/internal/net/gphttp/accesslog/rotate_test.go b/internal/net/gphttp/accesslog/rotate_test.go index 8b81792f..5b963106 100644 --- a/internal/net/gphttp/accesslog/rotate_test.go +++ b/internal/net/gphttp/accesslog/rotate_test.go @@ -1,6 +1,7 @@ package accesslog_test import ( + "bytes" "fmt" "testing" "time" @@ -8,79 +9,280 @@ import ( . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/utils/strutils" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" +) + +var ( + testTime = expect.Must(time.Parse(time.RFC3339, "2024-01-31T03:04:05Z")) + testTimeStr = testTime.Format(LogTimeFormat) ) func TestParseLogTime(t *testing.T) { - tests := []string{ - `{"foo":"bar","time":"%s","bar":"baz"}`, - `example.com 192.168.1.1 - - [%s] "GET / HTTP/1.1" 200 1234`, - } - testTime := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC) - testTimeStr := testTime.Format(LogTimeFormat) + t.Run("valid time", func(t *testing.T) { + tests := []string{ + `{"foo":"bar","time":"%s","bar":"baz"}`, + `example.com 192.168.1.1 - - [%s] "GET / HTTP/1.1" 200 1234`, + } - for i, test := range tests { - tests[i] = fmt.Sprintf(test, testTimeStr) - } + for i, test := range tests { + tests[i] = fmt.Sprintf(test, testTimeStr) + } - for _, test := range tests { - t.Run(test, func(t *testing.T) { - actual := ParseLogTime([]byte(test)) - ExpectTrue(t, actual.Equal(testTime)) + for _, test := range tests { + t.Run(test, func(t *testing.T) { + extracted := ExtractTime([]byte(test)) + expect.Equal(t, string(extracted), testTimeStr) + got := ParseLogTime([]byte(test)) + expect.True(t, got.Equal(testTime), "expected %s, got %s", testTime, got) + }) + } + }) + + t.Run("invalid time", func(t *testing.T) { + tests := []string{ + `{"foo":"bar","time":"invalid","bar":"baz"}`, + `example.com 192.168.1.1 - - [invalid] "GET / HTTP/1.1" 200 1234`, + } + for _, test := range tests { + t.Run(test, func(t *testing.T) { + expect.True(t, ParseLogTime([]byte(test)).IsZero(), "expected zero time, got %s", ParseLogTime([]byte(test))) + }) + } + }) +} + +func TestRotateKeepLast(t *testing.T) { + for _, format := range AvailableFormats { + t.Run(string(format)+" keep last", func(t *testing.T) { + file := NewMockFile() + MockTimeNow(testTime) + logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{ + Format: format, + }) + expect.Nil(t, logger.Config().Retention) + + for range 10 { + logger.Log(req, resp) + } + expect.NoError(t, logger.Flush()) + + expect.Greater(t, file.Len(), int64(0)) + expect.Equal(t, file.NumLines(), 10) + + retention := strutils.MustParse[*Retention]("last 5") + expect.Equal(t, retention.Days, 0) + expect.Equal(t, retention.Last, 5) + expect.Equal(t, retention.KeepSize, 0) + logger.Config().Retention = retention + + result, err := logger.Rotate() + expect.NoError(t, err) + expect.Equal(t, file.NumLines(), int(retention.Last)) + expect.Equal(t, result.NumLinesKeep, int(retention.Last)) + expect.Equal(t, result.NumLinesInvalid, 0) + }) + + t.Run(string(format)+" keep days", func(t *testing.T) { + file := NewMockFile() + logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{ + Format: format, + }) + expect.Nil(t, logger.Config().Retention) + nLines := 10 + for i := range nLines { + MockTimeNow(testTime.AddDate(0, 0, -nLines+i+1)) + logger.Log(req, resp) + } + logger.Flush() + expect.Equal(t, file.NumLines(), nLines) + + retention := strutils.MustParse[*Retention]("3 days") + expect.Equal(t, retention.Days, 3) + expect.Equal(t, retention.Last, 0) + expect.Equal(t, retention.KeepSize, 0) + logger.Config().Retention = retention + + MockTimeNow(testTime) + result, err := logger.Rotate() + expect.NoError(t, err) + expect.Equal(t, file.NumLines(), int(retention.Days)) + expect.Equal(t, result.NumLinesKeep, int(retention.Days)) + expect.Equal(t, result.NumLinesInvalid, 0) + + rotated := file.Content() + rotatedLines := bytes.Split(rotated, []byte("\n")) + for i, line := range rotatedLines { + if i >= int(retention.Days) { // may ends with a newline + break + } + timeBytes := ExtractTime(line) + got, err := time.Parse(LogTimeFormat, string(timeBytes)) + expect.NoError(t, err) + want := testTime.AddDate(0, 0, -int(retention.Days)+i+1) + expect.True(t, got.Equal(want), "expected %s, got %s", want, got) + } }) } } -func TestRetentionCommonFormat(t *testing.T) { - var file MockFile - logger := NewAccessLogger(task.RootTask("test", false), &file, &Config{ - Format: FormatCommon, - BufferSize: 1024, - }) - for range 10 { - logger.Log(req, resp) - } - logger.Flush() - // test.Finish(nil) - - ExpectEqual(t, logger.Config().Retention, nil) - ExpectTrue(t, file.Len() > 0) - ExpectEqual(t, file.LineCount(), 10) - - t.Run("keep last", func(t *testing.T) { - logger.Config().Retention = strutils.MustParse[*Retention]("last 5") - ExpectEqual(t, logger.Config().Retention.Days, 0) - ExpectEqual(t, logger.Config().Retention.Last, 5) - ExpectNoError(t, logger.Rotate()) - ExpectEqual(t, file.LineCount(), 5) - }) - - _ = file.Truncate(0) - - timeNow := time.Now() - for i := range 10 { - logger.Formatter.(*CommonFormatter).GetTimeNow = func() time.Time { - return timeNow.AddDate(0, 0, -10+i) - } - logger.Log(req, resp) - } - logger.Flush() - ExpectEqual(t, file.LineCount(), 10) - - t.Run("keep days", func(t *testing.T) { - logger.Config().Retention = strutils.MustParse[*Retention]("3 days") - ExpectEqual(t, logger.Config().Retention.Days, 3) - ExpectEqual(t, logger.Config().Retention.Last, 0) - ExpectNoError(t, logger.Rotate()) - ExpectEqual(t, file.LineCount(), 3) - rotated := string(file.Content()) - _ = file.Truncate(0) - for i := range 3 { - logger.Formatter.(*CommonFormatter).GetTimeNow = func() time.Time { - return timeNow.AddDate(0, 0, -3+i) +func TestRotateKeepFileSize(t *testing.T) { + for _, format := range AvailableFormats { + t.Run(string(format)+" keep size no rotation", func(t *testing.T) { + file := NewMockFile() + logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{ + Format: format, + }) + expect.Nil(t, logger.Config().Retention) + nLines := 10 + for i := range nLines { + MockTimeNow(testTime.AddDate(0, 0, -nLines+i+1)) + logger.Log(req, resp) } + logger.Flush() + expect.Equal(t, file.NumLines(), nLines) + + retention := strutils.MustParse[*Retention]("100 KB") + expect.Equal(t, retention.KeepSize, 100*1024) + expect.Equal(t, retention.Days, 0) + expect.Equal(t, retention.Last, 0) + logger.Config().Retention = retention + + MockTimeNow(testTime) + result, err := logger.Rotate() + expect.NoError(t, err) + + // file should be untouched as 100KB > 10 lines * bytes per line + expect.Equal(t, result.NumBytesKeep, file.Len()) + expect.Equal(t, result.NumBytesRead, 0, "should not read any bytes") + }) + } + + t.Run("keep size with rotation", func(t *testing.T) { + file := NewMockFile() + logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{ + Format: FormatJSON, + }) + expect.Nil(t, logger.Config().Retention) + nLines := 100 + for i := range nLines { + MockTimeNow(testTime.AddDate(0, 0, -nLines+i+1)) logger.Log(req, resp) } - ExpectEqual(t, rotated, string(file.Content())) + logger.Flush() + expect.Equal(t, file.NumLines(), nLines) + + retention := strutils.MustParse[*Retention]("10 KB") + expect.Equal(t, retention.KeepSize, 10*1024) + expect.Equal(t, retention.Days, 0) + expect.Equal(t, retention.Last, 0) + logger.Config().Retention = retention + + MockTimeNow(testTime) + result, err := logger.Rotate() + expect.NoError(t, err) + expect.Equal(t, result.NumBytesKeep, int64(retention.KeepSize)) + expect.Equal(t, file.Len(), int64(retention.KeepSize)) + expect.Equal(t, result.NumBytesRead, 0, "should not read any bytes") }) } + +// skipping invalid lines is not supported for keep file_size +func TestRotateSkipInvalidTime(t *testing.T) { + for _, format := range AvailableFormats { + t.Run(string(format), func(t *testing.T) { + file := NewMockFile() + logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{ + Format: format, + }) + expect.Nil(t, logger.Config().Retention) + nLines := 10 + for i := range nLines { + MockTimeNow(testTime.AddDate(0, 0, -nLines+i+1)) + logger.Log(req, resp) + logger.Flush() + + n, err := file.Write([]byte("invalid time\n")) + expect.NoError(t, err) + expect.Equal(t, n, len("invalid time\n")) + } + expect.Equal(t, file.NumLines(), 2*nLines) + + retention := strutils.MustParse[*Retention]("3 days") + expect.Equal(t, retention.Days, 3) + expect.Equal(t, retention.Last, 0) + logger.Config().Retention = retention + + result, err := logger.Rotate() + expect.NoError(t, err) + // should read one invalid line after every valid line + expect.Equal(t, result.NumLinesKeep, int(retention.Days)) + expect.Equal(t, result.NumLinesInvalid, nLines-int(retention.Days)*2) + expect.Equal(t, file.NumLines(), int(retention.Days)) + }) + } +} + +func BenchmarkRotate(b *testing.B) { + tests := []*Retention{ + {Days: 30}, + {Last: 100}, + {KeepSize: 24 * 1024}, + } + for _, retention := range tests { + b.Run(fmt.Sprintf("retention_%s", retention), func(b *testing.B) { + file := NewMockFile() + logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{ + Format: FormatJSON, + Retention: retention, + }) + for i := range 100 { + MockTimeNow(testTime.AddDate(0, 0, -100+i+1)) + logger.Log(req, resp) + } + logger.Flush() + content := file.Content() + b.ResetTimer() + for b.Loop() { + b.StopTimer() + file = NewMockFile() + _, _ = file.Write(content) + b.StartTimer() + _, _ = logger.Rotate() + } + }) + } +} + +func BenchmarkRotateWithInvalidTime(b *testing.B) { + tests := []*Retention{ + {Days: 30}, + {Last: 100}, + {KeepSize: 24 * 1024}, + } + for _, retention := range tests { + b.Run(fmt.Sprintf("retention_%s", retention), func(b *testing.B) { + file := NewMockFile() + logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{ + Format: FormatJSON, + Retention: retention, + }) + for i := range 10000 { + MockTimeNow(testTime.AddDate(0, 0, -10000+i+1)) + logger.Log(req, resp) + if i%10 == 0 { + _, _ = file.Write([]byte("invalid time\n")) + } + } + logger.Flush() + content := file.Content() + b.ResetTimer() + for b.Loop() { + b.StopTimer() + file = NewMockFile() + _, _ = file.Write(content) + b.StartTimer() + _, _ = logger.Rotate() + } + }) + } +} diff --git a/internal/net/gphttp/accesslog/stdout_logger.go b/internal/net/gphttp/accesslog/stdout_logger.go new file mode 100644 index 00000000..2e1f2456 --- /dev/null +++ b/internal/net/gphttp/accesslog/stdout_logger.go @@ -0,0 +1,18 @@ +package accesslog + +import ( + "io" + "os" +) + +type StdoutLogger struct { + io.Writer +} + +var stdoutIO = &StdoutLogger{os.Stdout} + +func (l *StdoutLogger) Lock() {} +func (l *StdoutLogger) Unlock() {} +func (l *StdoutLogger) Name() string { + return "stdout" +} diff --git a/internal/net/gphttp/accesslog/time_now.go b/internal/net/gphttp/accesslog/time_now.go new file mode 100644 index 00000000..554c4c8f --- /dev/null +++ b/internal/net/gphttp/accesslog/time_now.go @@ -0,0 +1,48 @@ +package accesslog + +import ( + "time" + + "github.com/yusing/go-proxy/internal/task" + "go.uber.org/atomic" +) + +var ( + TimeNow = DefaultTimeNow + shouldCallTimeNow atomic.Bool + timeNowTicker = time.NewTicker(shouldCallTimeNowInterval) + lastTimeNow = time.Now() +) + +const shouldCallTimeNowInterval = 100 * time.Millisecond + +func MockTimeNow(t time.Time) { + TimeNow = func() time.Time { + return t + } +} + +// DefaultTimeNow is a time.Now wrapper that reduces the number of calls to time.Now +// by caching the result and only allow calling time.Now when the ticker fires. +// +// Returned value may have +-100ms error. +func DefaultTimeNow() time.Time { + if shouldCallTimeNow.Load() { + lastTimeNow = time.Now() + shouldCallTimeNow.Store(false) + } + return lastTimeNow +} + +func init() { + go func() { + for { + select { + case <-task.RootContext().Done(): + return + case <-timeNowTicker.C: + shouldCallTimeNow.Store(true) + } + } + }() +} diff --git a/internal/net/gphttp/accesslog/time_now_test.go b/internal/net/gphttp/accesslog/time_now_test.go new file mode 100644 index 00000000..ccc7eb8c --- /dev/null +++ b/internal/net/gphttp/accesslog/time_now_test.go @@ -0,0 +1,102 @@ +package accesslog + +import ( + "testing" + "time" +) + +func BenchmarkTimeNow(b *testing.B) { + b.Run("default", func(b *testing.B) { + for b.Loop() { + time.Now() + } + }) + + b.Run("reduced_call", func(b *testing.B) { + for b.Loop() { + DefaultTimeNow() + } + }) +} + +func TestDefaultTimeNow(t *testing.T) { + // Get initial time + t1 := DefaultTimeNow() + + // Second call should return the same time without calling time.Now + t2 := DefaultTimeNow() + + if !t1.Equal(t2) { + t.Errorf("Expected t1 == t2, got t1 = %v, t2 = %v", t1, t2) + } + + // Set shouldCallTimeNow to true + shouldCallTimeNow.Store(true) + + // This should update the lastTimeNow + t3 := DefaultTimeNow() + + // The time should have changed + if t2.Equal(t3) { + t.Errorf("Expected t2 != t3, got t2 = %v, t3 = %v", t2, t3) + } + + // Fourth call should return the same time as third call + t4 := DefaultTimeNow() + + if !t3.Equal(t4) { + t.Errorf("Expected t3 == t4, got t3 = %v, t4 = %v", t3, t4) + } +} + +func TestMockTimeNow(t *testing.T) { + // Save the original TimeNow function to restore later + originalTimeNow := TimeNow + defer func() { + TimeNow = originalTimeNow + }() + + // Create a fixed time + fixedTime := time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC) + + // Mock the time + MockTimeNow(fixedTime) + + // TimeNow should return the fixed time + result := TimeNow() + + if !result.Equal(fixedTime) { + t.Errorf("Expected %v, got %v", fixedTime, result) + } +} + +func TestTimeNowTicker(t *testing.T) { + // This test verifies that the ticker properly updates shouldCallTimeNow + + // Reset the flag + shouldCallTimeNow.Store(false) + + // Wait for the ticker to tick (slightly more than the interval) + time.Sleep(shouldCallTimeNowInterval + 10*time.Millisecond) + + // The ticker should have set shouldCallTimeNow to true + if !shouldCallTimeNow.Load() { + t.Error("Expected shouldCallTimeNow to be true after ticker interval") + } + + // Call DefaultTimeNow which should reset the flag + DefaultTimeNow() + + // Check that the flag is reset + if shouldCallTimeNow.Load() { + t.Error("Expected shouldCallTimeNow to be false after calling DefaultTimeNow") + } +} + +/* +BenchmarkTimeNow +BenchmarkTimeNow/default +BenchmarkTimeNow/default-20 48158628 24.86 ns/op 0 B/op 0 allocs/op +BenchmarkTimeNow/reduced_call +BenchmarkTimeNow/reduced_call-20 1000000000 1.000 ns/op 0 B/op 0 allocs/op +*/ diff --git a/internal/net/gphttp/accesslog/units.go b/internal/net/gphttp/accesslog/units.go new file mode 100644 index 00000000..809f7ad3 --- /dev/null +++ b/internal/net/gphttp/accesslog/units.go @@ -0,0 +1,11 @@ +package accesslog + +const ( + kilobyte = 1024 + megabyte = 1024 * kilobyte + gigabyte = 1024 * megabyte + + kilobits = 1000 + megabits = 1000 * kilobits + gigabits = 1000 * megabits +) diff --git a/internal/route/fileserver.go b/internal/route/fileserver.go index e2cf0644..f7adde1a 100644 --- a/internal/route/fileserver.go +++ b/internal/route/fileserver.go @@ -84,7 +84,7 @@ func (s *FileServer) Start(parent task.Parent) gperr.Error { if s.UseAccessLog() { var err error - s.accessLogger, err = accesslog.NewFileAccessLogger(s.task, s.AccessLog) + s.accessLogger, err = accesslog.NewAccessLogger(s.task, s.AccessLog) if err != nil { s.task.Finish(err) return gperr.Wrap(err) diff --git a/internal/route/reverse_proxy.go b/internal/route/reverse_proxy.go index 69257190..0456848f 100755 --- a/internal/route/reverse_proxy.go +++ b/internal/route/reverse_proxy.go @@ -111,7 +111,7 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error { if r.UseAccessLog() { var err error - r.rp.AccessLogger, err = accesslog.NewFileAccessLogger(r.task, r.AccessLog) + r.rp.AccessLogger, err = accesslog.NewAccessLogger(r.task, r.AccessLog) if err != nil { r.task.Finish(err) return gperr.Wrap(err) diff --git a/internal/utils/synk/pool.go b/internal/utils/synk/pool.go new file mode 100644 index 00000000..2f1164be --- /dev/null +++ b/internal/utils/synk/pool.go @@ -0,0 +1,42 @@ +package synk + +import "sync" + +type ( + // Pool is a wrapper of sync.Pool that limits the size of the object. + Pool[T any] struct { + pool sync.Pool + maxSize int + } + BytesPool = Pool[byte] +) + +const ( + DefaultInitBytes = 1024 + DefaultMaxBytes = 1024 * 1024 +) + +func NewPool[T any](initSize int, maxSize int) *Pool[T] { + return &Pool[T]{ + pool: sync.Pool{ + New: func() any { + return make([]T, 0, initSize) + }, + }, + maxSize: maxSize, + } +} + +func NewBytesPool(initSize int, maxSize int) *BytesPool { + return NewPool[byte](initSize, maxSize) +} + +func (p *Pool[T]) Get() []T { + return p.pool.Get().([]T) +} + +func (p *Pool[T]) Put(b []T) { + if cap(b) <= p.maxSize { + p.pool.Put(b[:0]) + } +}