diff --git a/internal/route/rules/do.go b/internal/route/rules/do.go index 4dbf7d70..5e16e971 100644 --- a/internal/route/rules/do.go +++ b/internal/route/rules/do.go @@ -470,17 +470,29 @@ var commands = map[string]struct { build: func(args any) HandlerFunc { level, f, tmpl := args.(*onLogArgs).Unpack() var logger io.Writer - if f == stdout || f == stderr { + isStdLogger := f == stdout || f == stderr + if isStdLogger { logger = logging.NewLoggerWithFixedLevel(level, f) } else { logger = f } return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { - _, err := tmpl.ExpandVars(w, r, logger) - if err != nil { + if isStdLogger { + bufPool := w.BufPool() + buf := bufPool.GetBuffer() + defer bufPool.PutBuffer(buf) + + if _, err := tmpl.ExpandVars(w, r, buf); err != nil { + return err + } + if buf.Len() == 0 { + return nil + } + _, err := logger.Write(buf.Bytes()) return err } - return nil + _, err := tmpl.ExpandVars(w, r, logger) + return err } }, }, diff --git a/internal/route/rules/do_log_test.go b/internal/route/rules/do_log_test.go index 191a0732..dcd958be 100644 --- a/internal/route/rules/do_log_test.go +++ b/internal/route/rules/do_log_test.go @@ -1,6 +1,7 @@ package rules import ( + "bytes" "fmt" "maps" "net/http" @@ -8,6 +9,7 @@ import ( "reflect" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -69,6 +71,17 @@ default { } func TestLogCommand_StdoutAndStderr(t *testing.T) { + originalStdout := stdout + originalStderr := stderr + var stdoutBuf bytes.Buffer + var stderrBuf bytes.Buffer + stdout = noopWriteCloser{&stdoutBuf} + stderr = noopWriteCloser{&stderrBuf} + defer func() { + stdout = originalStdout + stderr = originalStderr + }() + upstream := mockUpstream(http.StatusOK, "success") var rules Rules @@ -88,8 +101,12 @@ default { handler.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) - // Note: We can't easily capture stdout/stderr in unit tests, - // but we can verify no errors occurred and the handler completed + require.Eventually(t, func() bool { + return strings.Contains(stdoutBuf.String(), "stdout: GET 200") && + strings.Contains(stderrBuf.String(), "stderr: /test 200") + }, time.Second, 10*time.Millisecond) + assert.Equal(t, 1, strings.Count(stdoutBuf.String(), "stdout: GET 200")) + assert.Equal(t, 1, strings.Count(stderrBuf.String(), "stderr: /test 200")) } func TestLogCommand_DifferentLogLevels(t *testing.T) {