From cb1b830b7ead016c9486ba813f795d61fde9b8ae Mon Sep 17 00:00:00 2001 From: yusing Date: Sat, 10 Jan 2026 18:40:06 +0800 Subject: [PATCH] fix(rules): simplify and correct tests --- internal/route/rules/do_log_test.go | 142 +++++-------------------- internal/route/rules/http_flow_test.go | 44 +++----- internal/route/rules/io.go | 37 +++++++ internal/route/rules/vars_test.go | 6 +- 4 files changed, 80 insertions(+), 149 deletions(-) diff --git a/internal/route/rules/do_log_test.go b/internal/route/rules/do_log_test.go index b21b1956..e7741ab1 100644 --- a/internal/route/rules/do_log_test.go +++ b/internal/route/rules/do_log_test.go @@ -5,8 +5,6 @@ import ( "maps" "net/http" "net/http/httptest" - "os" - "path/filepath" "reflect" "strings" "testing" @@ -44,18 +42,14 @@ func TestLogCommand_TemporaryFile(t *testing.T) { "Content-Type": []string{"application/json"}, }) - // Create a temporary file for logging - tempFile, err := os.CreateTemp("", "test-log-*.log") - require.NoError(t, err) - tempFile.Close() - defer os.Remove(tempFile.Name()) + logFile := TestRandomFileName() var rules Rules - err = parseRules(fmt.Sprintf(` + err := parseRules(fmt.Sprintf(` - name: log-request-response do: | log info %q '$req_method $req_url $status_code $resp_header(Content-Type)' -`, tempFile.Name()), &rules) +`, logFile), &rules) require.NoError(t, err) handler := rules.BuildHandler(upstream) @@ -70,8 +64,7 @@ func TestLogCommand_TemporaryFile(t *testing.T) { assert.Equal(t, "success response", w.Body.String()) // Read and verify log content - content, err := os.ReadFile(tempFile.Name()) - require.NoError(t, err) + content := TestFileContent(logFile) logContent := string(content) assert.Equal(t, "POST /api/users 200 application/json\n", logContent) @@ -106,24 +99,12 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) { func TestLogCommand_DifferentLogLevels(t *testing.T) { upstream := mockUpstream(404, "not found") - // Create temporary files for different log levels - infoFile, err := os.CreateTemp("", "test-info-*.log") - require.NoError(t, err) - infoFile.Close() - defer os.Remove(infoFile.Name()) - - warnFile, err := os.CreateTemp("", "test-warn-*.log") - require.NoError(t, err) - warnFile.Close() - defer os.Remove(warnFile.Name()) - - errorFile, err := os.CreateTemp("", "test-error-*.log") - require.NoError(t, err) - errorFile.Close() - defer os.Remove(errorFile.Name()) + infoFile := TestRandomFileName() + warnFile := TestRandomFileName() + errorFile := TestRandomFileName() var rules Rules - err = parseRules(fmt.Sprintf(` + err := parseRules(fmt.Sprintf(` - name: log-info do: | log info %s "INFO: $req_method $status_code" @@ -133,7 +114,7 @@ func TestLogCommand_DifferentLogLevels(t *testing.T) { - name: log-error do: | log error %s "ERROR: $req_method $req_path $status_code" -`, infoFile.Name(), warnFile.Name(), errorFile.Name()), &rules) +`, infoFile, warnFile, errorFile), &rules) require.NoError(t, err) handler := rules.BuildHandler(upstream) @@ -146,16 +127,13 @@ func TestLogCommand_DifferentLogLevels(t *testing.T) { assert.Equal(t, 404, w.Code) // Verify each log file - infoContent, err := os.ReadFile(infoFile.Name()) - require.NoError(t, err) + infoContent := TestFileContent(infoFile) assert.Equal(t, "INFO: DELETE 404", strings.TrimSpace(string(infoContent))) - warnContent, err := os.ReadFile(warnFile.Name()) - require.NoError(t, err) + warnContent := TestFileContent(warnFile) assert.Equal(t, "WARN: /api/resource/123 404", strings.TrimSpace(string(warnContent))) - errorContent, err := os.ReadFile(errorFile.Name()) - require.NoError(t, err) + errorContent := TestFileContent(errorFile) assert.Equal(t, "ERROR: DELETE /api/resource/123 404", strings.TrimSpace(string(errorContent))) } @@ -167,18 +145,14 @@ func TestLogCommand_TemplateVariables(t *testing.T) { w.Write([]byte("created")) }) - // Create temporary file - tempFile, err := os.CreateTemp("", "test-template-*.log") - require.NoError(t, err) - tempFile.Close() - defer os.Remove(tempFile.Name()) + tempFile := TestRandomFileName() var rules Rules - err = parseRules(fmt.Sprintf(` + err := parseRules(fmt.Sprintf(` - name: log-with-templates do: | log info %s 'Request: $req_method $req_url Host: $req_host User-Agent: $header(User-Agent) Response: $status_code Custom-Header: $resp_header(X-Custom-Header) Content-Length: $resp_header(Content-Length)' -`, tempFile.Name()), &rules) +`, tempFile), &rules) require.NoError(t, err) handler := rules.BuildHandler(upstream) @@ -193,8 +167,7 @@ func TestLogCommand_TemplateVariables(t *testing.T) { assert.Equal(t, 201, w.Code) // Verify log content - content, err := os.ReadFile(tempFile.Name()) - require.NoError(t, err) + content := TestFileContent(tempFile) logContent := strings.TrimSpace(string(content)) assert.Equal(t, "Request: PUT /api/resource Host: example.com User-Agent: test-client/1.0 Response: 201 Custom-Header: custom-value Content-Length: 42", logContent) @@ -215,19 +188,11 @@ func TestLogCommand_ConditionalLogging(t *testing.T) { } }) - // Create temporary files - successFile, err := os.CreateTemp("", "test-success-*.log") - require.NoError(t, err) - successFile.Close() - defer os.Remove(successFile.Name()) - - errorFile, err := os.CreateTemp("", "test-error-*.log") - require.NoError(t, err) - errorFile.Close() - defer os.Remove(errorFile.Name()) + successFile := TestRandomFileName() + errorFile := TestRandomFileName() var rules Rules - err = parseRules(fmt.Sprintf(` + err := parseRules(fmt.Sprintf(` - name: log-success on: status 2xx do: | @@ -236,7 +201,7 @@ func TestLogCommand_ConditionalLogging(t *testing.T) { on: status 4xx | status 5xx do: | log error %q "ERROR: $req_method $req_path $status_code" -`, successFile.Name(), errorFile.Name()), &rules) +`, successFile, errorFile), &rules) require.NoError(t, err) handler := rules.BuildHandler(upstream) @@ -260,15 +225,13 @@ func TestLogCommand_ConditionalLogging(t *testing.T) { assert.Equal(t, 500, w3.Code) // Verify success log - successContent, err := os.ReadFile(successFile.Name()) - require.NoError(t, err) + successContent := TestFileContent(successFile) successLines := strings.Split(strings.TrimSpace(string(successContent)), "\n") assert.Len(t, successLines, 1) assert.Equal(t, "SUCCESS: GET /success 200", successLines[0]) // Verify error log - errorContent, err := os.ReadFile(errorFile.Name()) - require.NoError(t, err) + errorContent := TestFileContent(errorFile) errorLines := strings.Split(strings.TrimSpace(string(errorContent)), "\n") require.Len(t, errorLines, 2) assert.Equal(t, "ERROR: GET /notfound 404", errorLines[0]) @@ -278,17 +241,13 @@ func TestLogCommand_ConditionalLogging(t *testing.T) { func TestLogCommand_MultipleLogEntries(t *testing.T) { upstream := mockUpstream(200, "response") - // Create temporary file - tempFile, err := os.CreateTemp("", "test-multiple-*.log") - require.NoError(t, err) - tempFile.Close() - defer os.Remove(tempFile.Name()) + tempFile := TestRandomFileName() var rules Rules - err = parseRules(fmt.Sprintf(` + err := parseRules(fmt.Sprintf(` - name: log-multiple do: | - log info %q "$req_method $req_path $status_code"`, tempFile.Name()), &rules) + log info %q "$req_method $req_path $status_code"`, tempFile), &rules) require.NoError(t, err) handler := rules.BuildHandler(upstream) @@ -312,8 +271,7 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) { } // Verify all requests were logged - content, err := os.ReadFile(tempFile.Name()) - require.NoError(t, err) + content := TestFileContent(tempFile) logContent := strings.TrimSpace(string(content)) lines := strings.Split(logContent, "\n") @@ -325,54 +283,6 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) { } } -func TestLogCommand_FilePermissions(t *testing.T) { - upstream := mockUpstream(200, "success") - - // Create a temporary directory - tempDir, err := os.MkdirTemp("", "test-log-dir") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - // Create a log file path within the temp directory - logFilePath := filepath.Join(tempDir, "test.log") - - var rules Rules - err = parseRules(fmt.Sprintf(` -- on: status 2xx - do: log info %q "$req_method $status_code"`, logFilePath), &rules) - require.NoError(t, err) - - handler := rules.BuildHandler(upstream) - - req := httptest.NewRequest("GET", "/test", nil) - w := httptest.NewRecorder() - - handler.ServeHTTP(w, req) - - assert.Equal(t, 200, w.Code) - - // Verify file was created and is writable - _, err = os.Stat(logFilePath) - require.NoError(t, err) - - // Test writing to the file again to ensure it's not closed - req2 := httptest.NewRequest("POST", "/test2", nil) - w2 := httptest.NewRecorder() - handler.ServeHTTP(w2, req2) - - assert.Equal(t, 200, w2.Code) - - // Verify both entries are in the file - content, err := os.ReadFile(logFilePath) - require.NoError(t, err) - logContent := strings.TrimSpace(string(content)) - lines := strings.Split(logContent, "\n") - - require.Len(t, lines, 2) - assert.Equal(t, "GET 200", lines[0]) - assert.Equal(t, "POST 200", lines[1]) -} - func TestLogCommand_InvalidTemplate(t *testing.T) { var rules Rules diff --git a/internal/route/rules/http_flow_test.go b/internal/route/rules/http_flow_test.go index f0bf741a..b8321e85 100644 --- a/internal/route/rules/http_flow_test.go +++ b/internal/route/rules/http_flow_test.go @@ -208,18 +208,14 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) { "X-Upstream": []string{"upstream-value"}, }) - tempFile, err := os.CreateTemp("", "test-log-*.txt") - // Create a temporary file for logging - require.NoError(t, err) - defer os.Remove(tempFile.Name()) - tempFile.Close() + tempFile := TestRandomFileName() var rules Rules - err = parseRules(fmt.Sprintf(` + err := parseRules(fmt.Sprintf(` - name: log-response on: path /test do: log info %s "$req_method $status_code" -`, tempFile.Name()), &rules) +`, tempFile), &rules) require.NoError(t, err) handler := rules.BuildHandler(upstream) @@ -234,7 +230,7 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) { assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream")) // Check log file - content, err := os.ReadFile(tempFile.Name()) + content := TestFileContent(tempFile) require.NoError(t, err) assert.Equal(t, "GET 200\n", string(content)) } @@ -253,16 +249,13 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) { var rules Rules // Create a temporary file for logging - tempFile, err := os.CreateTemp("", "test-error-log-*.txt") - require.NoError(t, err) - defer os.Remove(tempFile.Name()) - tempFile.Close() + tempFile := TestRandomFileName() - err = parseRules(fmt.Sprintf(` + err := parseRules(fmt.Sprintf(` - name: log-errors on: status 4xx do: log error %s "$req_url returned $status_code" -`, tempFile.Name()), &rules) +`, tempFile), &rules) require.NoError(t, err) handler := rules.BuildHandler(upstream) @@ -282,7 +275,7 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) { assert.Equal(t, 404, w2.Code) // Check log file - content, err := os.ReadFile(tempFile.Name()) + content := TestFileContent(tempFile) require.NoError(t, err) lines := strings.Split(strings.TrimSpace(string(content)), "\n") require.Len(t, lines, 1, "only 4xx requests should be logged") @@ -345,18 +338,11 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) { }) // Create temporary files for logging - logFile, err := os.CreateTemp("", "test-access-log-*.txt") - require.NoError(t, err) - defer os.Remove(logFile.Name()) - logFile.Close() - - errorLogFile, err := os.CreateTemp("", "test-error-log-*.txt") - require.NoError(t, err) - defer os.Remove(errorLogFile.Name()) - errorLogFile.Close() + logFile := TestRandomFileName() + errorLogFile := TestRandomFileName() var rules Rules - err = parseRules(fmt.Sprintf(` + err := parseRules(fmt.Sprintf(` - name: add-correlation-id do: set resp_header X-Correlation-Id random_uuid - name: validate-auth @@ -369,7 +355,7 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) { on: status 4xx do: | log error %q "ERROR: $req_method $req_url $status_code" -`, logFile.Name(), errorLogFile.Name()), &rules) +`, logFile, errorLogFile), &rules) require.NoError(t, err) handler := rules.BuildHandler(upstream) @@ -403,16 +389,14 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) { assert.Equal(t, 401, w3.Code) // Check log files - logContent, err := os.ReadFile(logFile.Name()) - require.NoError(t, err) + logContent := TestFileContent(logFile) lines := strings.Split(strings.TrimSpace(string(logContent)), "\n") require.Len(t, lines, 3, "all requests should be logged") assert.Equal(t, "GET /public -> 200", lines[0]) assert.Equal(t, "GET /protected -> 401", lines[1]) assert.Equal(t, "GET /protected -> 401", lines[2]) - errorLogContent, err := os.ReadFile(errorLogFile.Name()) - require.NoError(t, err) + errorLogContent := TestFileContent(errorLogFile) // Should have at least one 401 error logged lines = strings.Split(strings.TrimSpace(string(errorLogContent)), "\n") require.Len(t, lines, 2, "all errors should be logged") diff --git a/internal/route/rules/io.go b/internal/route/rules/io.go index 8cfab7b2..de2e7088 100644 --- a/internal/route/rules/io.go +++ b/internal/route/rules/io.go @@ -1,9 +1,14 @@ package rules import ( + "bytes" + "fmt" "io" + "math/rand" "os" + "sync" + "github.com/yusing/godoxy/internal/common" "github.com/yusing/godoxy/internal/logging/accesslog" gperr "github.com/yusing/goutils/errs" ) @@ -21,6 +26,11 @@ var ( stderr io.WriteCloser = noopWriteCloser{os.Stderr} ) +var ( + testFiles = make(map[string]*bytes.Buffer) + testFilesLock sync.Mutex +) + func openFile(path string) (io.WriteCloser, gperr.Error) { switch path { case "/dev/stdout": @@ -28,9 +38,36 @@ func openFile(path string) (io.WriteCloser, gperr.Error) { case "/dev/stderr": return stderr, nil } + + if common.IsTest { + testFilesLock.Lock() + defer testFilesLock.Unlock() + if buf, ok := testFiles[path]; ok { + return noopWriteCloser{buf}, nil + } + buf := bytes.NewBuffer(nil) + testFiles[path] = buf + return noopWriteCloser{buf}, nil + } + f, err := accesslog.NewFileIO(path) if err != nil { return nil, ErrInvalidArguments.With(err) } return f, nil } + +func TestRandomFileName() string { + return fmt.Sprintf("test-file-%d.txt", rand.Intn(1000000)) +} + +func TestFileContent(path string) []byte { + testFilesLock.Lock() + defer testFilesLock.Unlock() + + buf, ok := testFiles[path] + if !ok { + return nil + } + return buf.Bytes() +} diff --git a/internal/route/rules/vars_test.go b/internal/route/rules/vars_test.go index 3a090751..c719a84a 100644 --- a/internal/route/rules/vars_test.go +++ b/internal/route/rules/vars_test.go @@ -484,7 +484,7 @@ func TestExpandVars(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var out strings.Builder - err := ExpandVars(httputils.NewResponseModifier(httptest.NewRecorder()), testRequest, tt.input, &out) + err := ExpandVars(testResponseModifier, testRequest, tt.input, &out) if tt.wantErr { require.Error(t, err) @@ -506,7 +506,7 @@ func TestExpandVars_Integration(t *testing.T) { testResponseModifier.WriteHeader(200) var out strings.Builder - err := ExpandVars(httputils.NewResponseModifier(httptest.NewRecorder()), testRequest, + err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_url $status_code User-Agent=$header(User-Agent)", &out) @@ -537,7 +537,7 @@ func TestExpandVars_Integration(t *testing.T) { testResponseModifier.WriteHeader(200) var out strings.Builder - err := ExpandVars(httputils.NewResponseModifier(httptest.NewRecorder()), testRequest, + err := ExpandVars(testResponseModifier, testRequest, "Status: $status_code, Cache: $resp_header(Cache-Control), Limit: $resp_header(X-Rate-Limit)", &out)