fix(rules): simplify and correct tests

This commit is contained in:
yusing
2026-01-10 18:40:06 +08:00
committed by github-actions[bot]
parent f554f574e9
commit cb1b830b7e
4 changed files with 80 additions and 149 deletions

View File

@@ -5,8 +5,6 @@ import (
"maps" "maps"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"path/filepath"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@@ -44,18 +42,14 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
"Content-Type": []string{"application/json"}, "Content-Type": []string{"application/json"},
}) })
// Create a temporary file for logging logFile := TestRandomFileName()
tempFile, err := os.CreateTemp("", "test-log-*.log")
require.NoError(t, err)
tempFile.Close()
defer os.Remove(tempFile.Name())
var rules Rules var rules Rules
err = parseRules(fmt.Sprintf(` err := parseRules(fmt.Sprintf(`
- name: log-request-response - name: log-request-response
do: | do: |
log info %q '$req_method $req_url $status_code $resp_header(Content-Type)' log info %q '$req_method $req_url $status_code $resp_header(Content-Type)'
`, tempFile.Name()), &rules) `, logFile), &rules)
require.NoError(t, err) require.NoError(t, err)
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
@@ -70,8 +64,7 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
assert.Equal(t, "success response", w.Body.String()) assert.Equal(t, "success response", w.Body.String())
// Read and verify log content // Read and verify log content
content, err := os.ReadFile(tempFile.Name()) content := TestFileContent(logFile)
require.NoError(t, err)
logContent := string(content) logContent := string(content)
assert.Equal(t, "POST /api/users 200 application/json\n", logContent) assert.Equal(t, "POST /api/users 200 application/json\n", logContent)
@@ -106,24 +99,12 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) {
func TestLogCommand_DifferentLogLevels(t *testing.T) { func TestLogCommand_DifferentLogLevels(t *testing.T) {
upstream := mockUpstream(404, "not found") upstream := mockUpstream(404, "not found")
// Create temporary files for different log levels infoFile := TestRandomFileName()
infoFile, err := os.CreateTemp("", "test-info-*.log") warnFile := TestRandomFileName()
require.NoError(t, err) errorFile := TestRandomFileName()
infoFile.Close()
defer os.Remove(infoFile.Name())
warnFile, err := os.CreateTemp("", "test-warn-*.log")
require.NoError(t, err)
warnFile.Close()
defer os.Remove(warnFile.Name())
errorFile, err := os.CreateTemp("", "test-error-*.log")
require.NoError(t, err)
errorFile.Close()
defer os.Remove(errorFile.Name())
var rules Rules var rules Rules
err = parseRules(fmt.Sprintf(` err := parseRules(fmt.Sprintf(`
- name: log-info - name: log-info
do: | do: |
log info %s "INFO: $req_method $status_code" log info %s "INFO: $req_method $status_code"
@@ -133,7 +114,7 @@ func TestLogCommand_DifferentLogLevels(t *testing.T) {
- name: log-error - name: log-error
do: | do: |
log error %s "ERROR: $req_method $req_path $status_code" log error %s "ERROR: $req_method $req_path $status_code"
`, infoFile.Name(), warnFile.Name(), errorFile.Name()), &rules) `, infoFile, warnFile, errorFile), &rules)
require.NoError(t, err) require.NoError(t, err)
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
@@ -146,16 +127,13 @@ func TestLogCommand_DifferentLogLevels(t *testing.T) {
assert.Equal(t, 404, w.Code) assert.Equal(t, 404, w.Code)
// Verify each log file // Verify each log file
infoContent, err := os.ReadFile(infoFile.Name()) infoContent := TestFileContent(infoFile)
require.NoError(t, err)
assert.Equal(t, "INFO: DELETE 404", strings.TrimSpace(string(infoContent))) assert.Equal(t, "INFO: DELETE 404", strings.TrimSpace(string(infoContent)))
warnContent, err := os.ReadFile(warnFile.Name()) warnContent := TestFileContent(warnFile)
require.NoError(t, err)
assert.Equal(t, "WARN: /api/resource/123 404", strings.TrimSpace(string(warnContent))) assert.Equal(t, "WARN: /api/resource/123 404", strings.TrimSpace(string(warnContent)))
errorContent, err := os.ReadFile(errorFile.Name()) errorContent := TestFileContent(errorFile)
require.NoError(t, err)
assert.Equal(t, "ERROR: DELETE /api/resource/123 404", strings.TrimSpace(string(errorContent))) assert.Equal(t, "ERROR: DELETE /api/resource/123 404", strings.TrimSpace(string(errorContent)))
} }
@@ -167,18 +145,14 @@ func TestLogCommand_TemplateVariables(t *testing.T) {
w.Write([]byte("created")) w.Write([]byte("created"))
}) })
// Create temporary file tempFile := TestRandomFileName()
tempFile, err := os.CreateTemp("", "test-template-*.log")
require.NoError(t, err)
tempFile.Close()
defer os.Remove(tempFile.Name())
var rules Rules var rules Rules
err = parseRules(fmt.Sprintf(` err := parseRules(fmt.Sprintf(`
- name: log-with-templates - name: log-with-templates
do: | do: |
log info %s 'Request: $req_method $req_url Host: $req_host User-Agent: $header(User-Agent) Response: $status_code Custom-Header: $resp_header(X-Custom-Header) Content-Length: $resp_header(Content-Length)' log info %s 'Request: $req_method $req_url Host: $req_host User-Agent: $header(User-Agent) Response: $status_code Custom-Header: $resp_header(X-Custom-Header) Content-Length: $resp_header(Content-Length)'
`, tempFile.Name()), &rules) `, tempFile), &rules)
require.NoError(t, err) require.NoError(t, err)
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
@@ -193,8 +167,7 @@ func TestLogCommand_TemplateVariables(t *testing.T) {
assert.Equal(t, 201, w.Code) assert.Equal(t, 201, w.Code)
// Verify log content // Verify log content
content, err := os.ReadFile(tempFile.Name()) content := TestFileContent(tempFile)
require.NoError(t, err)
logContent := strings.TrimSpace(string(content)) logContent := strings.TrimSpace(string(content))
assert.Equal(t, "Request: PUT /api/resource Host: example.com User-Agent: test-client/1.0 Response: 201 Custom-Header: custom-value Content-Length: 42", logContent) assert.Equal(t, "Request: PUT /api/resource Host: example.com User-Agent: test-client/1.0 Response: 201 Custom-Header: custom-value Content-Length: 42", logContent)
@@ -215,19 +188,11 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
} }
}) })
// Create temporary files successFile := TestRandomFileName()
successFile, err := os.CreateTemp("", "test-success-*.log") errorFile := TestRandomFileName()
require.NoError(t, err)
successFile.Close()
defer os.Remove(successFile.Name())
errorFile, err := os.CreateTemp("", "test-error-*.log")
require.NoError(t, err)
errorFile.Close()
defer os.Remove(errorFile.Name())
var rules Rules var rules Rules
err = parseRules(fmt.Sprintf(` err := parseRules(fmt.Sprintf(`
- name: log-success - name: log-success
on: status 2xx on: status 2xx
do: | do: |
@@ -236,7 +201,7 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
on: status 4xx | status 5xx on: status 4xx | status 5xx
do: | do: |
log error %q "ERROR: $req_method $req_path $status_code" log error %q "ERROR: $req_method $req_path $status_code"
`, successFile.Name(), errorFile.Name()), &rules) `, successFile, errorFile), &rules)
require.NoError(t, err) require.NoError(t, err)
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
@@ -260,15 +225,13 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
assert.Equal(t, 500, w3.Code) assert.Equal(t, 500, w3.Code)
// Verify success log // Verify success log
successContent, err := os.ReadFile(successFile.Name()) successContent := TestFileContent(successFile)
require.NoError(t, err)
successLines := strings.Split(strings.TrimSpace(string(successContent)), "\n") successLines := strings.Split(strings.TrimSpace(string(successContent)), "\n")
assert.Len(t, successLines, 1) assert.Len(t, successLines, 1)
assert.Equal(t, "SUCCESS: GET /success 200", successLines[0]) assert.Equal(t, "SUCCESS: GET /success 200", successLines[0])
// Verify error log // Verify error log
errorContent, err := os.ReadFile(errorFile.Name()) errorContent := TestFileContent(errorFile)
require.NoError(t, err)
errorLines := strings.Split(strings.TrimSpace(string(errorContent)), "\n") errorLines := strings.Split(strings.TrimSpace(string(errorContent)), "\n")
require.Len(t, errorLines, 2) require.Len(t, errorLines, 2)
assert.Equal(t, "ERROR: GET /notfound 404", errorLines[0]) assert.Equal(t, "ERROR: GET /notfound 404", errorLines[0])
@@ -278,17 +241,13 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
func TestLogCommand_MultipleLogEntries(t *testing.T) { func TestLogCommand_MultipleLogEntries(t *testing.T) {
upstream := mockUpstream(200, "response") upstream := mockUpstream(200, "response")
// Create temporary file tempFile := TestRandomFileName()
tempFile, err := os.CreateTemp("", "test-multiple-*.log")
require.NoError(t, err)
tempFile.Close()
defer os.Remove(tempFile.Name())
var rules Rules var rules Rules
err = parseRules(fmt.Sprintf(` err := parseRules(fmt.Sprintf(`
- name: log-multiple - name: log-multiple
do: | do: |
log info %q "$req_method $req_path $status_code"`, tempFile.Name()), &rules) log info %q "$req_method $req_path $status_code"`, tempFile), &rules)
require.NoError(t, err) require.NoError(t, err)
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
@@ -312,8 +271,7 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) {
} }
// Verify all requests were logged // Verify all requests were logged
content, err := os.ReadFile(tempFile.Name()) content := TestFileContent(tempFile)
require.NoError(t, err)
logContent := strings.TrimSpace(string(content)) logContent := strings.TrimSpace(string(content))
lines := strings.Split(logContent, "\n") lines := strings.Split(logContent, "\n")
@@ -325,54 +283,6 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) {
} }
} }
func TestLogCommand_FilePermissions(t *testing.T) {
upstream := mockUpstream(200, "success")
// Create a temporary directory
tempDir, err := os.MkdirTemp("", "test-log-dir")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
// Create a log file path within the temp directory
logFilePath := filepath.Join(tempDir, "test.log")
var rules Rules
err = parseRules(fmt.Sprintf(`
- on: status 2xx
do: log info %q "$req_method $status_code"`, logFilePath), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
// Verify file was created and is writable
_, err = os.Stat(logFilePath)
require.NoError(t, err)
// Test writing to the file again to ensure it's not closed
req2 := httptest.NewRequest("POST", "/test2", nil)
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
assert.Equal(t, 200, w2.Code)
// Verify both entries are in the file
content, err := os.ReadFile(logFilePath)
require.NoError(t, err)
logContent := strings.TrimSpace(string(content))
lines := strings.Split(logContent, "\n")
require.Len(t, lines, 2)
assert.Equal(t, "GET 200", lines[0])
assert.Equal(t, "POST 200", lines[1])
}
func TestLogCommand_InvalidTemplate(t *testing.T) { func TestLogCommand_InvalidTemplate(t *testing.T) {
var rules Rules var rules Rules

View File

@@ -208,18 +208,14 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) {
"X-Upstream": []string{"upstream-value"}, "X-Upstream": []string{"upstream-value"},
}) })
tempFile, err := os.CreateTemp("", "test-log-*.txt") tempFile := TestRandomFileName()
// Create a temporary file for logging
require.NoError(t, err)
defer os.Remove(tempFile.Name())
tempFile.Close()
var rules Rules var rules Rules
err = parseRules(fmt.Sprintf(` err := parseRules(fmt.Sprintf(`
- name: log-response - name: log-response
on: path /test on: path /test
do: log info %s "$req_method $status_code" do: log info %s "$req_method $status_code"
`, tempFile.Name()), &rules) `, tempFile), &rules)
require.NoError(t, err) require.NoError(t, err)
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
@@ -234,7 +230,7 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) {
assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream")) assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream"))
// Check log file // Check log file
content, err := os.ReadFile(tempFile.Name()) content := TestFileContent(tempFile)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "GET 200\n", string(content)) assert.Equal(t, "GET 200\n", string(content))
} }
@@ -253,16 +249,13 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
var rules Rules var rules Rules
// Create a temporary file for logging // Create a temporary file for logging
tempFile, err := os.CreateTemp("", "test-error-log-*.txt") tempFile := TestRandomFileName()
require.NoError(t, err)
defer os.Remove(tempFile.Name())
tempFile.Close()
err = parseRules(fmt.Sprintf(` err := parseRules(fmt.Sprintf(`
- name: log-errors - name: log-errors
on: status 4xx on: status 4xx
do: log error %s "$req_url returned $status_code" do: log error %s "$req_url returned $status_code"
`, tempFile.Name()), &rules) `, tempFile), &rules)
require.NoError(t, err) require.NoError(t, err)
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
@@ -282,7 +275,7 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
assert.Equal(t, 404, w2.Code) assert.Equal(t, 404, w2.Code)
// Check log file // Check log file
content, err := os.ReadFile(tempFile.Name()) content := TestFileContent(tempFile)
require.NoError(t, err) require.NoError(t, err)
lines := strings.Split(strings.TrimSpace(string(content)), "\n") lines := strings.Split(strings.TrimSpace(string(content)), "\n")
require.Len(t, lines, 1, "only 4xx requests should be logged") require.Len(t, lines, 1, "only 4xx requests should be logged")
@@ -345,18 +338,11 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
}) })
// Create temporary files for logging // Create temporary files for logging
logFile, err := os.CreateTemp("", "test-access-log-*.txt") logFile := TestRandomFileName()
require.NoError(t, err) errorLogFile := TestRandomFileName()
defer os.Remove(logFile.Name())
logFile.Close()
errorLogFile, err := os.CreateTemp("", "test-error-log-*.txt")
require.NoError(t, err)
defer os.Remove(errorLogFile.Name())
errorLogFile.Close()
var rules Rules var rules Rules
err = parseRules(fmt.Sprintf(` err := parseRules(fmt.Sprintf(`
- name: add-correlation-id - name: add-correlation-id
do: set resp_header X-Correlation-Id random_uuid do: set resp_header X-Correlation-Id random_uuid
- name: validate-auth - name: validate-auth
@@ -369,7 +355,7 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
on: status 4xx on: status 4xx
do: | do: |
log error %q "ERROR: $req_method $req_url $status_code" log error %q "ERROR: $req_method $req_url $status_code"
`, logFile.Name(), errorLogFile.Name()), &rules) `, logFile, errorLogFile), &rules)
require.NoError(t, err) require.NoError(t, err)
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
@@ -403,16 +389,14 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
assert.Equal(t, 401, w3.Code) assert.Equal(t, 401, w3.Code)
// Check log files // Check log files
logContent, err := os.ReadFile(logFile.Name()) logContent := TestFileContent(logFile)
require.NoError(t, err)
lines := strings.Split(strings.TrimSpace(string(logContent)), "\n") lines := strings.Split(strings.TrimSpace(string(logContent)), "\n")
require.Len(t, lines, 3, "all requests should be logged") require.Len(t, lines, 3, "all requests should be logged")
assert.Equal(t, "GET /public -> 200", lines[0]) assert.Equal(t, "GET /public -> 200", lines[0])
assert.Equal(t, "GET /protected -> 401", lines[1]) assert.Equal(t, "GET /protected -> 401", lines[1])
assert.Equal(t, "GET /protected -> 401", lines[2]) assert.Equal(t, "GET /protected -> 401", lines[2])
errorLogContent, err := os.ReadFile(errorLogFile.Name()) errorLogContent := TestFileContent(errorLogFile)
require.NoError(t, err)
// Should have at least one 401 error logged // Should have at least one 401 error logged
lines = strings.Split(strings.TrimSpace(string(errorLogContent)), "\n") lines = strings.Split(strings.TrimSpace(string(errorLogContent)), "\n")
require.Len(t, lines, 2, "all errors should be logged") require.Len(t, lines, 2, "all errors should be logged")

View File

@@ -1,9 +1,14 @@
package rules package rules
import ( import (
"bytes"
"fmt"
"io" "io"
"math/rand"
"os" "os"
"sync"
"github.com/yusing/godoxy/internal/common"
"github.com/yusing/godoxy/internal/logging/accesslog" "github.com/yusing/godoxy/internal/logging/accesslog"
gperr "github.com/yusing/goutils/errs" gperr "github.com/yusing/goutils/errs"
) )
@@ -21,6 +26,11 @@ var (
stderr io.WriteCloser = noopWriteCloser{os.Stderr} stderr io.WriteCloser = noopWriteCloser{os.Stderr}
) )
var (
testFiles = make(map[string]*bytes.Buffer)
testFilesLock sync.Mutex
)
func openFile(path string) (io.WriteCloser, gperr.Error) { func openFile(path string) (io.WriteCloser, gperr.Error) {
switch path { switch path {
case "/dev/stdout": case "/dev/stdout":
@@ -28,9 +38,36 @@ func openFile(path string) (io.WriteCloser, gperr.Error) {
case "/dev/stderr": case "/dev/stderr":
return stderr, nil return stderr, nil
} }
if common.IsTest {
testFilesLock.Lock()
defer testFilesLock.Unlock()
if buf, ok := testFiles[path]; ok {
return noopWriteCloser{buf}, nil
}
buf := bytes.NewBuffer(nil)
testFiles[path] = buf
return noopWriteCloser{buf}, nil
}
f, err := accesslog.NewFileIO(path) f, err := accesslog.NewFileIO(path)
if err != nil { if err != nil {
return nil, ErrInvalidArguments.With(err) return nil, ErrInvalidArguments.With(err)
} }
return f, nil return f, nil
} }
func TestRandomFileName() string {
return fmt.Sprintf("test-file-%d.txt", rand.Intn(1000000))
}
func TestFileContent(path string) []byte {
testFilesLock.Lock()
defer testFilesLock.Unlock()
buf, ok := testFiles[path]
if !ok {
return nil
}
return buf.Bytes()
}

View File

@@ -484,7 +484,7 @@ func TestExpandVars(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
var out strings.Builder var out strings.Builder
err := ExpandVars(httputils.NewResponseModifier(httptest.NewRecorder()), testRequest, tt.input, &out) err := ExpandVars(testResponseModifier, testRequest, tt.input, &out)
if tt.wantErr { if tt.wantErr {
require.Error(t, err) require.Error(t, err)
@@ -506,7 +506,7 @@ func TestExpandVars_Integration(t *testing.T) {
testResponseModifier.WriteHeader(200) testResponseModifier.WriteHeader(200)
var out strings.Builder var out strings.Builder
err := ExpandVars(httputils.NewResponseModifier(httptest.NewRecorder()), testRequest, err := ExpandVars(testResponseModifier, testRequest,
"$req_method $req_url $status_code User-Agent=$header(User-Agent)", "$req_method $req_url $status_code User-Agent=$header(User-Agent)",
&out) &out)
@@ -537,7 +537,7 @@ func TestExpandVars_Integration(t *testing.T) {
testResponseModifier.WriteHeader(200) testResponseModifier.WriteHeader(200)
var out strings.Builder var out strings.Builder
err := ExpandVars(httputils.NewResponseModifier(httptest.NewRecorder()), testRequest, err := ExpandVars(testResponseModifier, testRequest,
"Status: $status_code, Cache: $resp_header(Cache-Control), Limit: $resp_header(X-Rate-Limit)", "Status: $status_code, Cache: $resp_header(Cache-Control), Limit: $resp_header(X-Rate-Limit)",
&out) &out)