mirror of
https://github.com/yusing/godoxy.git
synced 2026-01-18 17:37:04 +01:00
645 lines
16 KiB
Go
645 lines
16 KiB
Go
package rules
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
httputils "github.com/yusing/goutils/http"
|
|
)
|
|
|
|
func TestFieldHandler_Header(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
key string
|
|
value string
|
|
modifier FieldModifier
|
|
setup func(*http.Request)
|
|
verify func(*http.Request, *httptest.ResponseRecorder)
|
|
}{
|
|
{
|
|
name: "set header",
|
|
key: "X-Test",
|
|
value: "test-value",
|
|
modifier: ModFieldSet,
|
|
setup: func(r *http.Request) {
|
|
r.Header.Set("X-Test", "old-value")
|
|
},
|
|
verify: func(r *http.Request, w *httptest.ResponseRecorder) {
|
|
got := r.Header.Get("X-Test")
|
|
assert.Equal(t, "test-value", got, "Expected header X-Test to be 'test-value'")
|
|
},
|
|
},
|
|
{
|
|
name: "add header",
|
|
key: "X-Test",
|
|
value: "new-value",
|
|
modifier: ModFieldAdd,
|
|
setup: func(r *http.Request) {
|
|
r.Header.Set("X-Test", "existing-value")
|
|
},
|
|
verify: func(r *http.Request, w *httptest.ResponseRecorder) {
|
|
values := r.Header["X-Test"]
|
|
require.Len(t, values, 2, "Expected 2 header values")
|
|
assert.Equal(t, "existing-value", values[0], "Expected first value of X-Test header to be 'existing-value'")
|
|
assert.Equal(t, "new-value", values[1], "Expected second value of X-Test header to be 'new-value'")
|
|
},
|
|
},
|
|
{
|
|
name: "remove header",
|
|
key: "X-Test",
|
|
value: "",
|
|
modifier: ModFieldRemove,
|
|
setup: func(r *http.Request) {
|
|
r.Header.Set("X-Test", "to-be-removed")
|
|
},
|
|
verify: func(r *http.Request, w *httptest.ResponseRecorder) {
|
|
got := r.Header.Get("X-Test")
|
|
assert.Empty(t, got, "Expected header X-Test to be removed")
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
tt.setup(req)
|
|
w := httptest.NewRecorder()
|
|
|
|
tmpl, tErr := validateTemplate(tt.value, false)
|
|
if tErr != nil {
|
|
t.Fatalf("Failed to validate template: %v", tErr)
|
|
}
|
|
handler := modFields[FieldHeader].builder(&keyValueTemplate{tt.key, tmpl})
|
|
var cmd CommandHandler
|
|
switch tt.modifier {
|
|
case ModFieldSet:
|
|
cmd = handler.set
|
|
case ModFieldAdd:
|
|
cmd = handler.add
|
|
case ModFieldRemove:
|
|
cmd = handler.remove
|
|
}
|
|
|
|
err := cmd.Handle(w, req)
|
|
if err != nil {
|
|
t.Fatalf("Handler returned error: %v", err)
|
|
}
|
|
|
|
tt.verify(req, w)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFieldHandler_ResponseHeader(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
key string
|
|
value string
|
|
modifier FieldModifier
|
|
setup func(*httptest.ResponseRecorder)
|
|
verify func(*httptest.ResponseRecorder)
|
|
}{
|
|
{
|
|
name: "set response header",
|
|
key: "X-Response-Test",
|
|
value: "response-value",
|
|
modifier: ModFieldSet,
|
|
verify: func(w *httptest.ResponseRecorder) {
|
|
got := w.Header().Get("X-Response-Test")
|
|
assert.Equal(t, "response-value", got, "Expected response header X-Response-Test to be 'response-value'")
|
|
},
|
|
},
|
|
{
|
|
name: "add response header",
|
|
key: "X-Response-Test",
|
|
value: "additional-value",
|
|
modifier: ModFieldAdd,
|
|
setup: func(w *httptest.ResponseRecorder) {
|
|
w.Header().Set("X-Response-Test", "existing-value")
|
|
},
|
|
verify: func(w *httptest.ResponseRecorder) {
|
|
values := w.Header()["X-Response-Test"]
|
|
require.Len(t, values, 2)
|
|
assert.Equal(t, values[0], "existing-value")
|
|
assert.Equal(t, values[1], "additional-value")
|
|
},
|
|
},
|
|
{
|
|
name: "remove response header",
|
|
key: "X-Response-Test",
|
|
value: "",
|
|
modifier: ModFieldRemove,
|
|
verify: func(w *httptest.ResponseRecorder) {
|
|
assert.Empty(t, w.Header().Get("X-Response-Test"))
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
w := httptest.NewRecorder()
|
|
if tt.setup != nil {
|
|
tt.setup(w)
|
|
}
|
|
|
|
tmpl, tErr := validateTemplate(tt.value, false)
|
|
if tErr != nil {
|
|
t.Fatalf("Failed to validate template: %v", tErr)
|
|
}
|
|
handler := modFields[FieldResponseHeader].builder(&keyValueTemplate{tt.key, tmpl})
|
|
var cmd CommandHandler
|
|
switch tt.modifier {
|
|
case ModFieldSet:
|
|
cmd = handler.set
|
|
case ModFieldAdd:
|
|
cmd = handler.add
|
|
case ModFieldRemove:
|
|
cmd = handler.remove
|
|
}
|
|
|
|
err := cmd.Handle(w, req)
|
|
if err != nil {
|
|
t.Fatalf("Handler returned error: %v", err)
|
|
}
|
|
|
|
tt.verify(w)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFieldHandler_Query(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
key string
|
|
value string
|
|
modifier FieldModifier
|
|
setup func(*http.Request)
|
|
verify func(*http.Request)
|
|
}{
|
|
{
|
|
name: "set query",
|
|
key: "test",
|
|
value: "new-value",
|
|
modifier: ModFieldSet,
|
|
setup: func(r *http.Request) {
|
|
r.URL.RawQuery = "test=old-value&other=keep"
|
|
},
|
|
verify: func(r *http.Request) {
|
|
got := r.URL.Query().Get("test")
|
|
assert.Equal(t, "new-value", got, "Expected query 'test' to be 'new-value'")
|
|
gotOther := r.URL.Query().Get("other")
|
|
assert.Equal(t, "keep", gotOther, "Expected query 'other' to be 'keep'")
|
|
},
|
|
},
|
|
{
|
|
name: "add query",
|
|
key: "test",
|
|
value: "additional-value",
|
|
modifier: ModFieldAdd,
|
|
setup: func(r *http.Request) {
|
|
r.URL.RawQuery = "test=existing-value"
|
|
},
|
|
verify: func(r *http.Request) {
|
|
values := r.URL.Query()["test"]
|
|
require.Len(t, values, 2, "Expected 2 query values")
|
|
assert.Equal(t, "existing-value", values[0], "Expected first value of test query param to be 'existing-value'")
|
|
assert.Equal(t, "additional-value", values[1], "Expected second value of test query param to be 'additional-value'")
|
|
},
|
|
},
|
|
{
|
|
name: "remove query",
|
|
key: "test",
|
|
value: "",
|
|
modifier: ModFieldRemove,
|
|
setup: func(r *http.Request) {
|
|
r.URL.RawQuery = "test=to-be-removed&other=keep"
|
|
},
|
|
verify: func(r *http.Request) {
|
|
got := r.URL.Query().Get("test")
|
|
assert.Empty(t, got, "Expected query 'test' to be removed")
|
|
gotOther := r.URL.Query().Get("other")
|
|
assert.Equal(t, "keep", gotOther, "Expected query 'other' to be 'keep'")
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
tt.setup(req)
|
|
w := httptest.NewRecorder()
|
|
|
|
tmpl, tErr := validateTemplate(tt.value, false)
|
|
if tErr != nil {
|
|
t.Fatalf("Failed to validate template: %v", tErr)
|
|
}
|
|
handler := modFields[FieldQuery].builder(&keyValueTemplate{tt.key, tmpl})
|
|
var cmd CommandHandler
|
|
switch tt.modifier {
|
|
case ModFieldSet:
|
|
cmd = handler.set
|
|
case ModFieldAdd:
|
|
cmd = handler.add
|
|
case ModFieldRemove:
|
|
cmd = handler.remove
|
|
}
|
|
|
|
err := cmd.Handle(w, req)
|
|
if err != nil {
|
|
t.Fatalf("Handler returned error: %v", err)
|
|
}
|
|
|
|
tt.verify(req)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFieldHandler_Cookie(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
key string
|
|
value string
|
|
modifier FieldModifier
|
|
setup func(*http.Request)
|
|
verify func(*http.Request)
|
|
}{
|
|
{
|
|
name: "set cookie",
|
|
key: "test",
|
|
value: "new-value",
|
|
modifier: ModFieldSet,
|
|
setup: func(r *http.Request) {
|
|
r.AddCookie(&http.Cookie{Name: "test", Value: "old-value"})
|
|
},
|
|
verify: func(r *http.Request) {
|
|
cookie, err := r.Cookie("test")
|
|
assert.NoError(t, err, "Expected cookie 'test' to exist")
|
|
if err == nil {
|
|
assert.Equal(t, "new-value", cookie.Value, "Expected cookie 'test' to be 'new-value'")
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "add cookie",
|
|
key: "test",
|
|
value: "additional-value",
|
|
modifier: ModFieldAdd,
|
|
setup: func(r *http.Request) {
|
|
r.AddCookie(&http.Cookie{Name: "test", Value: "existing-value"})
|
|
},
|
|
verify: func(r *http.Request) {
|
|
cookies := r.Cookies()
|
|
testCookies := make([]string, 0)
|
|
for _, c := range cookies {
|
|
if c.Name == "test" {
|
|
testCookies = append(testCookies, c.Value)
|
|
}
|
|
}
|
|
require.Len(t, testCookies, 2, "Expected 2 cookies with name 'test'")
|
|
assert.Equal(t, "existing-value", testCookies[0], "Expected first value of 'test' cookie to be 'existing-value'")
|
|
assert.Equal(t, "additional-value", testCookies[1], "Expected second value of 'test' cookie to be 'additional-value'")
|
|
},
|
|
},
|
|
{
|
|
name: "remove cookie",
|
|
key: "test",
|
|
value: "",
|
|
modifier: ModFieldRemove,
|
|
setup: func(r *http.Request) {
|
|
r.AddCookie(&http.Cookie{Name: "test", Value: "to-be-removed"})
|
|
r.AddCookie(&http.Cookie{Name: "other", Value: "keep"})
|
|
},
|
|
verify: func(r *http.Request) {
|
|
_, err := r.Cookie("test")
|
|
assert.Error(t, err, "Expected cookie 'test' to be removed")
|
|
cookie, err := r.Cookie("other")
|
|
assert.NoError(t, err, "Expected cookie 'other' to exist")
|
|
if err == nil {
|
|
assert.Equal(t, "keep", cookie.Value, "Expected cookie 'other' to be 'keep'")
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
tt.setup(req)
|
|
w := httptest.NewRecorder()
|
|
|
|
tmpl, tErr := validateTemplate(tt.value, false)
|
|
if tErr != nil {
|
|
t.Fatalf("Failed to validate template: %v", tErr)
|
|
}
|
|
handler := modFields[FieldCookie].builder(&keyValueTemplate{tt.key, tmpl})
|
|
var cmd CommandHandler
|
|
switch tt.modifier {
|
|
case ModFieldSet:
|
|
cmd = handler.set
|
|
case ModFieldAdd:
|
|
cmd = handler.add
|
|
case ModFieldRemove:
|
|
cmd = handler.remove
|
|
}
|
|
|
|
err := cmd.Handle(w, req)
|
|
if err != nil {
|
|
t.Fatalf("Handler returned error: %v", err)
|
|
}
|
|
|
|
tt.verify(req)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFieldHandler_Body(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
template string
|
|
setup func(*http.Request)
|
|
verify func(*http.Request)
|
|
}{
|
|
{
|
|
name: "set body with template",
|
|
template: "Hello $req_method $req_path",
|
|
setup: func(r *http.Request) {
|
|
r.Method = "POST"
|
|
r.URL.Path = "/test"
|
|
},
|
|
verify: func(r *http.Request) {
|
|
body, err := io.ReadAll(r.Body)
|
|
assert.NoError(t, err, "Failed to read body")
|
|
expected := "Hello POST /test"
|
|
assert.Equal(t, expected, string(body), "Expected body content")
|
|
},
|
|
},
|
|
{
|
|
name: "set body with existing body",
|
|
template: "Overridden",
|
|
setup: func(r *http.Request) {
|
|
r.Body = io.NopCloser(strings.NewReader("original body"))
|
|
},
|
|
verify: func(r *http.Request) {
|
|
body, err := io.ReadAll(r.Body)
|
|
assert.NoError(t, err, "Failed to read body")
|
|
assert.Equal(t, "Overridden", string(body), "Expected body to be 'Overridden'")
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
tt.setup(req)
|
|
w := httptest.NewRecorder()
|
|
|
|
tmpl, tErr := validateTemplate(tt.template, false)
|
|
if tErr != nil {
|
|
t.Fatalf("Failed to parse template: %v", tErr)
|
|
}
|
|
|
|
handler := modFields[FieldBody].builder(tmpl)
|
|
err := handler.set.Handle(w, req)
|
|
if err != nil {
|
|
t.Fatalf("Handler returned error: %v", err)
|
|
}
|
|
|
|
tt.verify(req)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFieldHandler_ResponseBody(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
template string
|
|
setup func(*http.Request)
|
|
verify func(*httputils.ResponseModifier)
|
|
}{
|
|
{
|
|
name: "set response body with template",
|
|
template: "Response: $req_method $req_path",
|
|
setup: func(r *http.Request) {
|
|
r.Method = "GET"
|
|
r.URL.Path = "/api/test"
|
|
},
|
|
verify: func(rm *httputils.ResponseModifier) {
|
|
content := string(rm.Content())
|
|
expected := "Response: GET /api/test"
|
|
assert.Equal(t, expected, content, "Expected response body")
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
tt.setup(req)
|
|
w := httptest.NewRecorder()
|
|
|
|
// Create ResponseModifier wrapper
|
|
rm := httputils.NewResponseModifier(w)
|
|
|
|
tmpl, tErr := validateTemplate(tt.template, false)
|
|
if tErr != nil {
|
|
t.Fatalf("Failed to parse template: %v", tErr)
|
|
}
|
|
|
|
handler := modFields[FieldResponseBody].builder(tmpl)
|
|
err := handler.set.Handle(rm, req)
|
|
if err != nil {
|
|
t.Fatalf("Handler returned error: %v", err)
|
|
}
|
|
|
|
tt.verify(rm)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFieldHandler_StatusCode(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
status int
|
|
verify func(*httptest.ResponseRecorder)
|
|
}{
|
|
{
|
|
name: "set status code 200",
|
|
status: 200,
|
|
verify: func(w *httptest.ResponseRecorder) {
|
|
assert.Equal(t, 200, w.Code, "Expected status code 200")
|
|
},
|
|
},
|
|
{
|
|
name: "set status code 404",
|
|
status: 404,
|
|
verify: func(w *httptest.ResponseRecorder) {
|
|
assert.Equal(t, 404, w.Code, "Expected status code 404")
|
|
},
|
|
},
|
|
{
|
|
name: "set status code 500",
|
|
status: 500,
|
|
verify: func(w *httptest.ResponseRecorder) {
|
|
assert.Equal(t, 500, w.Code, "Expected status code 500")
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
w := httptest.NewRecorder()
|
|
rm := httputils.NewResponseModifier(w)
|
|
var cmd Command
|
|
err := cmd.Parse(fmt.Sprintf("set %s %d", FieldStatusCode, tt.status))
|
|
if err != nil {
|
|
t.Fatalf("Handler returned error: %v", err)
|
|
}
|
|
err = cmd.ServeHTTP(rm, req)
|
|
if err != nil {
|
|
t.Fatalf("Handler returned error: %v", err)
|
|
}
|
|
rm.FlushRelease()
|
|
|
|
tt.verify(w)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFieldValidation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
field string
|
|
args []string
|
|
wantError bool
|
|
}{
|
|
{
|
|
name: "header valid",
|
|
field: FieldHeader,
|
|
args: []string{"key", "value"},
|
|
wantError: false,
|
|
},
|
|
{
|
|
name: "header invalid - missing value",
|
|
field: FieldHeader,
|
|
args: []string{"key"},
|
|
wantError: true,
|
|
},
|
|
{
|
|
name: "response header valid",
|
|
field: FieldResponseHeader,
|
|
args: []string{"key", "value"},
|
|
wantError: false,
|
|
},
|
|
{
|
|
name: "query valid",
|
|
field: FieldQuery,
|
|
args: []string{"key", "value"},
|
|
wantError: false,
|
|
},
|
|
{
|
|
name: "cookie valid",
|
|
field: FieldCookie,
|
|
args: []string{"key", "value"},
|
|
wantError: false,
|
|
},
|
|
{
|
|
name: "body valid template",
|
|
field: FieldBody,
|
|
args: []string{"Hello $req_method"},
|
|
wantError: false,
|
|
},
|
|
{
|
|
name: "body invalid template syntax",
|
|
field: FieldBody,
|
|
args: []string{"Hello $invalid_field"},
|
|
wantError: true,
|
|
},
|
|
{
|
|
name: "response body valid template",
|
|
field: FieldResponseBody,
|
|
args: []string{"Response: $req_method"},
|
|
wantError: false,
|
|
},
|
|
{
|
|
name: "status code valid",
|
|
field: FieldStatusCode,
|
|
args: []string{"200"},
|
|
wantError: false,
|
|
},
|
|
{
|
|
name: "status code invalid - too low",
|
|
field: FieldStatusCode,
|
|
args: []string{"99"},
|
|
wantError: true,
|
|
},
|
|
{
|
|
name: "status code invalid - too high",
|
|
field: FieldStatusCode,
|
|
args: []string{"600"},
|
|
wantError: true,
|
|
},
|
|
{
|
|
name: "status code invalid - not a number",
|
|
field: FieldStatusCode,
|
|
args: []string{"not-a-number"},
|
|
wantError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
field, exists := modFields[tt.field]
|
|
assert.True(t, exists, "Field %s does not exist", tt.field)
|
|
|
|
_, err := field.validate(tt.args)
|
|
if tt.wantError {
|
|
assert.Error(t, err, "Expected error but got none")
|
|
} else {
|
|
assert.NoError(t, err, "Expected no error but got: %v", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAllFields(t *testing.T) {
|
|
expectedFields := []string{
|
|
FieldHeader,
|
|
FieldResponseHeader,
|
|
FieldQuery,
|
|
FieldCookie,
|
|
FieldBody,
|
|
FieldResponseBody,
|
|
FieldStatusCode,
|
|
}
|
|
|
|
require.Len(t, AllFields, len(expectedFields), "Expected %d fields", len(expectedFields))
|
|
|
|
for _, expected := range expectedFields {
|
|
found := false
|
|
for _, actual := range AllFields {
|
|
if actual == expected {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
assert.True(t, found, "Expected field %s not found in AllFields", expected)
|
|
}
|
|
}
|
|
|
|
func TestModFields(t *testing.T) {
|
|
for fieldName, field := range modFields {
|
|
// Test that each field has required components
|
|
assert.NotNil(t, field.validate, "Field %s has nil validate function", fieldName)
|
|
assert.NotNil(t, field.builder, "Field %s has nil builder function", fieldName)
|
|
assert.NotEmpty(t, field.help.command, "Field %s has empty help command", fieldName)
|
|
}
|
|
}
|