mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-22 08:48:43 +02:00
v0.26.0
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
|
||||
"github.com/yusing/godoxy/internal/logging"
|
||||
gphttp "github.com/yusing/godoxy/internal/net/gphttp"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
@@ -71,10 +72,11 @@ var commands = map[string]struct {
|
||||
description: makeLines("Require HTTP authentication for incoming requests"),
|
||||
args: map[string]string{},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 0 {
|
||||
return nil, ErrExpectNoArg
|
||||
}
|
||||
//nolint:nilnil
|
||||
return nil, nil
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
@@ -102,17 +104,17 @@ var commands = map[string]struct {
|
||||
"to": "the path to rewrite to, must start with /",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
}
|
||||
path1, err1 := validateURLPath(args[:1])
|
||||
path2, err2 := validateURLPath(args[1:])
|
||||
if err1 != nil {
|
||||
err1 = gperr.Errorf("from: %w", err1)
|
||||
err1 = gperr.PrependSubject(err1, "from")
|
||||
}
|
||||
if err2 != nil {
|
||||
err2 = gperr.Errorf("to: %w", err2)
|
||||
err2 = gperr.PrependSubject(err2, "to")
|
||||
}
|
||||
if err1 != nil || err2 != nil {
|
||||
return nil, gperr.Join(err1, err2)
|
||||
@@ -188,7 +190,7 @@ var commands = map[string]struct {
|
||||
"route": "the route to route to",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -197,9 +199,10 @@ var commands = map[string]struct {
|
||||
build: func(args any) CommandHandler {
|
||||
route := args.(string)
|
||||
return TerminatingCommand(func(w http.ResponseWriter, req *http.Request) error {
|
||||
r, ok := routes.HTTP.Get(route)
|
||||
ep := entrypoint.FromCtx(req.Context())
|
||||
r, ok := ep.HTTPRoutes().Get(route)
|
||||
if !ok {
|
||||
excluded, has := routes.Excluded.Get(route)
|
||||
excluded, has := ep.ExcludedRoutes().Get(route)
|
||||
if has {
|
||||
r, ok = excluded.(types.HTTPRoute)
|
||||
}
|
||||
@@ -225,7 +228,7 @@ var commands = map[string]struct {
|
||||
"text": "the error message to return",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
}
|
||||
@@ -265,7 +268,7 @@ var commands = map[string]struct {
|
||||
"realm": "the authentication realm",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) == 1 {
|
||||
return args[0], nil
|
||||
}
|
||||
@@ -327,12 +330,12 @@ var commands = map[string]struct {
|
||||
helpExample(CommandSet, "header", "User-Agent", "godoxy"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"target": fmt.Sprintf("the target to set, can be %s", strings.Join(AllFields, ", ")),
|
||||
"target": "the target to set, can be " + strings.Join(AllFields, ", "),
|
||||
"field": "the field to set",
|
||||
"value": "the value to set",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
return validateModField(ModFieldSet, args)
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
@@ -347,12 +350,12 @@ var commands = map[string]struct {
|
||||
helpExample(CommandAdd, "header", "X-Foo", "bar"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"target": fmt.Sprintf("the target to add, can be %s", strings.Join(AllFields, ", ")),
|
||||
"target": "the target to add, can be " + strings.Join(AllFields, ", "),
|
||||
"field": "the field to add",
|
||||
"value": "the value to add",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
return validateModField(ModFieldAdd, args)
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
@@ -367,11 +370,11 @@ var commands = map[string]struct {
|
||||
helpExample(CommandRemove, "header", "User-Agent"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"target": fmt.Sprintf("the target to remove, can be %s", strings.Join(AllFields, ", ")),
|
||||
"target": "the target to remove, can be " + strings.Join(AllFields, ", "),
|
||||
"field": "the field to remove",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
return validateModField(ModFieldRemove, args)
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
@@ -396,7 +399,7 @@ var commands = map[string]struct {
|
||||
"template": "the template to log",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 3 {
|
||||
return nil, ErrExpectThreeArgs
|
||||
}
|
||||
@@ -453,7 +456,7 @@ var commands = map[string]struct {
|
||||
"body": "the body of the notification",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 4 {
|
||||
return nil, ErrExpectFourArgs
|
||||
}
|
||||
@@ -509,8 +512,10 @@ var commands = map[string]struct {
|
||||
},
|
||||
}
|
||||
|
||||
type onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateString]
|
||||
type onNotifyArgs = Tuple4[zerolog.Level, string, templateString, templateString]
|
||||
type (
|
||||
onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateString]
|
||||
onNotifyArgs = Tuple4[zerolog.Level, string, templateString, templateString]
|
||||
)
|
||||
|
||||
// Parse implements strutils.Parser.
|
||||
func (cmd *Command) Parse(v string) error {
|
||||
@@ -541,7 +546,7 @@ func (cmd *Command) Parse(v string) error {
|
||||
validArgs, err := builder.validate(args)
|
||||
if err != nil {
|
||||
// Only attach help for the directive that failed, avoid bringing in unrelated KV errors
|
||||
return err.Subject(directive).With(builder.help.Error())
|
||||
return gperr.PrependSubject(err, directive).With(builder.help.Error())
|
||||
}
|
||||
|
||||
handler := builder.build(validArgs)
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
// mockUpstream creates a simple upstream handler for testing
|
||||
@@ -32,7 +31,7 @@ func mockUpstreamWithHeaders(status int, body string, headers http.Header) http.
|
||||
}
|
||||
}
|
||||
|
||||
func parseRules(data string, target *Rules) gperr.Error {
|
||||
func parseRules(data string, target *Rules) error {
|
||||
_, err := serialization.ConvertString(data, reflect.ValueOf(target))
|
||||
return err
|
||||
}
|
||||
@@ -54,7 +53,7 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/users", nil)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/users", nil)
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -71,7 +70,7 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLogCommand_StdoutAndStderr(t *testing.T) {
|
||||
upstream := mockUpstream(200, "success")
|
||||
upstream := mockUpstream(http.StatusOK, "success")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
@@ -86,7 +85,7 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
@@ -97,7 +96,7 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLogCommand_DifferentLogLevels(t *testing.T) {
|
||||
upstream := mockUpstream(404, "not found")
|
||||
upstream := mockUpstream(http.StatusNotFound, "not found")
|
||||
|
||||
infoFile := TestRandomFileName()
|
||||
warnFile := TestRandomFileName()
|
||||
@@ -141,7 +140,7 @@ func TestLogCommand_TemplateVariables(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Custom-Header", "custom-value")
|
||||
w.Header().Set("Content-Length", "42")
|
||||
w.WriteHeader(201)
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
w.Write([]byte("created"))
|
||||
})
|
||||
|
||||
@@ -177,13 +176,13 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/error":
|
||||
w.WriteHeader(500)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("internal server error"))
|
||||
case "/notfound":
|
||||
w.WriteHeader(404)
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Write([]byte("not found"))
|
||||
default:
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
}
|
||||
})
|
||||
@@ -207,22 +206,22 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test success request
|
||||
req1 := httptest.NewRequest("GET", "/success", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/success", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
|
||||
// Test not found request
|
||||
req2 := httptest.NewRequest("GET", "/notfound", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/notfound", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
assert.Equal(t, 404, w2.Code)
|
||||
assert.Equal(t, http.StatusNotFound, w2.Code)
|
||||
|
||||
// Test server error request
|
||||
req3 := httptest.NewRequest("POST", "/error", nil)
|
||||
req3 := httptest.NewRequest(http.MethodPost, "/error", nil)
|
||||
w3 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w3, req3)
|
||||
assert.Equal(t, 500, w3.Code)
|
||||
assert.Equal(t, http.StatusInternalServerError, w3.Code)
|
||||
|
||||
// Verify success log
|
||||
successContent := TestFileContent(successFile)
|
||||
@@ -239,7 +238,7 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLogCommand_MultipleLogEntries(t *testing.T) {
|
||||
upstream := mockUpstream(200, "response")
|
||||
upstream := mockUpstream(http.StatusOK, "response")
|
||||
|
||||
tempFile := TestRandomFileName()
|
||||
|
||||
@@ -267,7 +266,7 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) {
|
||||
req := httptest.NewRequest(reqInfo.method, reqInfo.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
// Verify all requests were logged
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
ioutils "github.com/yusing/goutils/io"
|
||||
)
|
||||
@@ -228,7 +227,7 @@ var modFields = map[string]struct {
|
||||
"template": "the body template",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -273,7 +272,7 @@ var modFields = map[string]struct {
|
||||
"template": "the response body template",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -301,7 +300,7 @@ var modFields = map[string]struct {
|
||||
"code": "the status code",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ func TestFieldHandler_Header(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -126,8 +126,8 @@ func TestFieldHandler_ResponseHeader(t *testing.T) {
|
||||
verify: func(w *httptest.ResponseRecorder) {
|
||||
values := w.Header()["X-Response-Test"]
|
||||
require.Len(t, values, 2)
|
||||
assert.Equal(t, values[0], "existing-value")
|
||||
assert.Equal(t, values[1], "additional-value")
|
||||
assert.Equal(t, "existing-value", values[0])
|
||||
assert.Equal(t, "additional-value", values[1])
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -143,7 +143,7 @@ func TestFieldHandler_ResponseHeader(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
if tt.setup != nil {
|
||||
tt.setup(w)
|
||||
@@ -232,7 +232,7 @@ func TestFieldHandler_Query(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -330,7 +330,7 @@ func TestFieldHandler_Cookie(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -396,7 +396,7 @@ func TestFieldHandler_Body(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -440,7 +440,7 @@ func TestFieldHandler_ResponseBody(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -494,7 +494,7 @@ func TestFieldHandler_StatusCode(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
var cmd Command
|
||||
|
||||
@@ -3,7 +3,7 @@ package rules
|
||||
import (
|
||||
"testing"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func TestErrorFormat(t *testing.T) {
|
||||
@@ -19,5 +19,5 @@ func TestErrorFormat(t *testing.T) {
|
||||
do: set invalid_command
|
||||
- do: set resp_body "{{ .Request.Method {{ .Request.URL.Path }}"
|
||||
`, &rules)
|
||||
gperr.LogError("error", err)
|
||||
log.Err(err).Msg("error")
|
||||
}
|
||||
|
||||
@@ -131,7 +131,7 @@ Generate help string as error, e.g.
|
||||
from: the path to rewrite, must start with /
|
||||
to: the path to rewrite to, must start with /
|
||||
*/
|
||||
func (h *Help) Error() gperr.Error {
|
||||
func (h *Help) Error() error {
|
||||
var lines gperr.MultilineError
|
||||
|
||||
lines.Adds(ansi.WithANSI(h.command, ansi.HighlightGreen))
|
||||
|
||||
@@ -17,16 +17,14 @@ import (
|
||||
"github.com/yusing/godoxy/internal/route"
|
||||
"github.com/yusing/godoxy/internal/route/routes"
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
. "github.com/yusing/godoxy/internal/route/rules"
|
||||
)
|
||||
|
||||
// mockUpstream creates a simple upstream handler for testing
|
||||
func mockUpstream(status int, body string) http.HandlerFunc {
|
||||
func mockUpstream(body string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(body))
|
||||
}
|
||||
}
|
||||
@@ -44,7 +42,7 @@ func mockRoute(alias string) *route.FileServer {
|
||||
return &route.FileServer{Route: &route.Route{Alias: alias}}
|
||||
}
|
||||
|
||||
func parseRules(data string, target *Rules) gperr.Error {
|
||||
func parseRules(data string, target *Rules) error {
|
||||
_, err := serialization.ConvertString(strings.TrimSpace(data), reflect.ValueOf(target))
|
||||
return err
|
||||
}
|
||||
@@ -52,7 +50,7 @@ func parseRules(data string, target *Rules) gperr.Error {
|
||||
func TestHTTPFlow_BasicPreRules(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Custom-Header", r.Header.Get("X-Custom-Header"))
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("upstream response"))
|
||||
})
|
||||
|
||||
@@ -66,18 +64,18 @@ func TestHTTPFlow_BasicPreRules(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "upstream response", w.Body.String())
|
||||
assert.Equal(t, "test-value", w.Header().Get("X-Custom-Header"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_BypassRule(t *testing.T) {
|
||||
upstream := mockUpstream(200, "upstream response")
|
||||
upstream := mockUpstream("upstream response")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
@@ -92,17 +90,17 @@ func TestHTTPFlow_BypassRule(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/bypass", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/bypass", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "upstream response", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_TerminatingCommand(t *testing.T) {
|
||||
upstream := mockUpstream(200, "should not be called")
|
||||
upstream := mockUpstream("should not be called")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
@@ -117,18 +115,18 @@ func TestHTTPFlow_TerminatingCommand(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/error", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/error", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 403, w.Code)
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
assert.Equal(t, "Forbidden\n", w.Body.String())
|
||||
assert.Empty(t, w.Header().Get("X-Header"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RedirectFlow(t *testing.T) {
|
||||
upstream := mockUpstream(200, "should not be called")
|
||||
upstream := mockUpstream("should not be called")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
@@ -140,18 +138,18 @@ func TestHTTPFlow_RedirectFlow(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/old-path", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/old-path", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 307, w.Code) // TemporaryRedirect
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, w.Code) // TemporaryRedirect
|
||||
assert.Equal(t, "/new-path", w.Header().Get("Location"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RewriteFlow(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("path: " + r.URL.Path))
|
||||
})
|
||||
|
||||
@@ -165,18 +163,18 @@ func TestHTTPFlow_RewriteFlow(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/users", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/users", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "path: /v1/users", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_MultiplePreRules(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("upstream: " + r.Header.Get("X-Request-Id")))
|
||||
})
|
||||
|
||||
@@ -193,18 +191,18 @@ func TestHTTPFlow_MultiplePreRules(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "upstream: req-123", w.Body.String())
|
||||
assert.Equal(t, "token-456", req.Header.Get("X-Auth-Token"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_PostResponseRule(t *testing.T) {
|
||||
upstream := mockUpstreamWithHeaders(200, "success", http.Header{
|
||||
upstream := mockUpstreamWithHeaders(http.StatusOK, "success", http.Header{
|
||||
"X-Upstream": []string{"upstream-value"},
|
||||
})
|
||||
|
||||
@@ -220,12 +218,12 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "success", w.Body.String())
|
||||
assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream"))
|
||||
|
||||
@@ -238,10 +236,10 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) {
|
||||
func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/success" {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
} else {
|
||||
w.WriteHeader(404)
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Write([]byte("not found"))
|
||||
}
|
||||
})
|
||||
@@ -261,18 +259,18 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test successful request (should not log)
|
||||
req1 := httptest.NewRequest("GET", "/success", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/success", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
|
||||
// Test error request (should log)
|
||||
req2 := httptest.NewRequest("GET", "/notfound", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/notfound", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 404, w2.Code)
|
||||
assert.Equal(t, http.StatusNotFound, w2.Code)
|
||||
|
||||
// Check log file
|
||||
content := TestFileContent(tempFile)
|
||||
@@ -284,7 +282,7 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
|
||||
|
||||
func TestHTTPFlow_ConditionalRules(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("hello " + r.Header.Get("X-Username")))
|
||||
})
|
||||
|
||||
@@ -305,19 +303,19 @@ func TestHTTPFlow_ConditionalRules(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test with Authorization header
|
||||
req1 := httptest.NewRequest("GET", "/", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req1.Header.Set("Authorization", "Bearer token")
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
assert.Equal(t, "hello authenticated-user", w1.Body.String())
|
||||
assert.Equal(t, "authenticated-user", w1.Header().Get("X-Username"))
|
||||
|
||||
// Test without Authorization header
|
||||
req2 := httptest.NewRequest("GET", "/", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Equal(t, "hello anonymous", w2.Body.String())
|
||||
assert.Equal(t, "anonymous", w2.Header().Get("X-Username"))
|
||||
}
|
||||
@@ -327,13 +325,13 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
|
||||
// Simulate different responses based on path
|
||||
if r.URL.Path == "/protected" {
|
||||
if r.Header.Get("X-Auth") != "valid" {
|
||||
w.WriteHeader(401)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte("unauthorized"))
|
||||
return
|
||||
}
|
||||
}
|
||||
w.Header().Set("X-Response-Time", "100ms")
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
})
|
||||
|
||||
@@ -361,32 +359,32 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test successful request
|
||||
req1 := httptest.NewRequest("GET", "/public", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/public", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
assert.Equal(t, "success", w1.Body.String())
|
||||
assert.Equal(t, "random_uuid", w1.Header().Get("X-Correlation-Id"))
|
||||
assert.Equal(t, "100ms", w1.Header().Get("X-Response-Time"))
|
||||
|
||||
// Test unauthorized protected request
|
||||
req2 := httptest.NewRequest("GET", "/protected", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 401, w2.Code)
|
||||
assert.Equal(t, w2.Body.String(), "Unauthorized\n")
|
||||
assert.Equal(t, http.StatusUnauthorized, w2.Code)
|
||||
assert.Equal(t, "Unauthorized\n", w2.Body.String())
|
||||
|
||||
// Test authorized protected request
|
||||
req3 := httptest.NewRequest("GET", "/protected", nil)
|
||||
req3 := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req3.SetBasicAuth("user", "pass")
|
||||
w3 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w3, req3)
|
||||
|
||||
// This should fail because our simple upstream expects X-Auth: valid header
|
||||
// but the basic auth requirement should add the appropriate header
|
||||
assert.Equal(t, 401, w3.Code)
|
||||
assert.Equal(t, http.StatusUnauthorized, w3.Code)
|
||||
|
||||
// Check log files
|
||||
logContent := TestFileContent(logFile)
|
||||
@@ -405,7 +403,7 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHTTPFlow_DefaultRule(t *testing.T) {
|
||||
upstream := mockUpstream(200, "upstream response")
|
||||
upstream := mockUpstream("upstream response")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
@@ -420,20 +418,20 @@ func TestHTTPFlow_DefaultRule(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test default rule
|
||||
req1 := httptest.NewRequest("GET", "/regular", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/regular", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
assert.Equal(t, "true", w1.Header().Get("X-Default-Applied"))
|
||||
assert.Empty(t, w1.Header().Get("X-Special-Handled"))
|
||||
|
||||
// Test special rule + default rule
|
||||
req2 := httptest.NewRequest("GET", "/special", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/special", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Equal(t, "true", w2.Header().Get("X-Default-Applied"))
|
||||
assert.Equal(t, "true", w2.Header().Get("X-Special-Handled"))
|
||||
}
|
||||
@@ -443,7 +441,7 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) {
|
||||
// Echo back a header
|
||||
headerValue := r.Header.Get("X-Test-Header")
|
||||
w.Header().Set("X-Echoed-Header", headerValue)
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("header echoed"))
|
||||
})
|
||||
|
||||
@@ -461,14 +459,14 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Secret", "secret-value")
|
||||
req.Header.Set("X-Test-Header", "original-value")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "modified-value", w.Header().Get("X-Echoed-Header"))
|
||||
assert.Equal(t, "custom-value", w.Header().Get("X-Custom-Header"))
|
||||
// Ensure the secret header was removed and not passed to upstream
|
||||
@@ -478,7 +476,7 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) {
|
||||
func TestHTTPFlow_QueryParameterHandling(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("query: " + query.Get("param")))
|
||||
})
|
||||
|
||||
@@ -492,25 +490,23 @@ func TestHTTPFlow_QueryParameterHandling(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/path?param=original", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/path?param=original", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
// The set command should have modified the query parameter
|
||||
assert.Equal(t, "query: added-value", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ServeCommand(t *testing.T) {
|
||||
// Create a temporary directory with test files
|
||||
tempDir, err := os.MkdirTemp("", "test-serve-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create test files directly in the temp directory
|
||||
testFile := filepath.Join(tempDir, "index.html")
|
||||
err = os.WriteFile(testFile, []byte("<h1>Test Page</h1>"), 0644)
|
||||
err := os.WriteFile(testFile, []byte("<h1>Test Page</h1>"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
var rules Rules
|
||||
@@ -521,7 +517,7 @@ func TestHTTPFlow_ServeCommand(t *testing.T) {
|
||||
`, tempDir), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(mockUpstream(200, "should not be called"))
|
||||
handler := rules.BuildHandler(mockUpstream("should not be called"))
|
||||
|
||||
// Test serving a file - serve command serves files relative to the root directory
|
||||
// The path /files/index.html gets mapped to tempDir + "/files/index.html"
|
||||
@@ -534,7 +530,7 @@ func TestHTTPFlow_ServeCommand(t *testing.T) {
|
||||
err = os.WriteFile(filesIndexFile, []byte("<h1>Test Page</h1>"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
req1 := httptest.NewRequest("GET", "/files/index.html", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/files/index.html", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
@@ -543,18 +539,18 @@ func TestHTTPFlow_ServeCommand(t *testing.T) {
|
||||
assert.NotEqual(t, "should not be called", w1.Body.String())
|
||||
|
||||
// Test file not found
|
||||
req2 := httptest.NewRequest("GET", "/files/nonexistent.html", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/files/nonexistent.html", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 404, w2.Code)
|
||||
assert.Equal(t, http.StatusNotFound, w2.Code)
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ProxyCommand(t *testing.T) {
|
||||
// Create a mock upstream server
|
||||
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Upstream-Header", "upstream-value")
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("upstream response"))
|
||||
}))
|
||||
defer upstreamServer.Close()
|
||||
@@ -567,15 +563,15 @@ func TestHTTPFlow_ProxyCommand(t *testing.T) {
|
||||
`, upstreamServer.URL), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(mockUpstream(200, "should not be called"))
|
||||
handler := rules.BuildHandler(mockUpstream("should not be called"))
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
// The proxy command should forward the request to the upstream server
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "upstream response", w.Body.String())
|
||||
assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream-Header"))
|
||||
}
|
||||
@@ -586,7 +582,7 @@ func TestHTTPFlow_NotifyCommand(t *testing.T) {
|
||||
|
||||
func TestHTTPFlow_FormConditions(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("form processed"))
|
||||
})
|
||||
|
||||
@@ -605,28 +601,28 @@ func TestHTTPFlow_FormConditions(t *testing.T) {
|
||||
|
||||
// Test form condition
|
||||
formData := url.Values{"username": {"john_doe"}}
|
||||
req1 := httptest.NewRequest("POST", "/", strings.NewReader(formData.Encode()))
|
||||
req1 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(formData.Encode()))
|
||||
req1.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
assert.Equal(t, "john_doe", w1.Header().Get("X-Username"))
|
||||
|
||||
// Test postform condition
|
||||
postFormData := url.Values{"email": {"john@example.com"}}
|
||||
req2 := httptest.NewRequest("POST", "/", strings.NewReader(postFormData.Encode()))
|
||||
req2 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(postFormData.Encode()))
|
||||
req2.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Equal(t, "john@example.com", w2.Header().Get("X-Email"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RemoteConditions(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("remote processed"))
|
||||
})
|
||||
|
||||
@@ -644,27 +640,27 @@ func TestHTTPFlow_RemoteConditions(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test localhost condition
|
||||
req1 := httptest.NewRequest("GET", "/", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req1.RemoteAddr = "127.0.0.1:12345"
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
assert.Equal(t, "local", w1.Header().Get("X-Access"))
|
||||
|
||||
// Test private network block
|
||||
req2 := httptest.NewRequest("GET", "/", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req2.RemoteAddr = "192.168.1.100:12345"
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 403, w2.Code)
|
||||
assert.Equal(t, http.StatusForbidden, w2.Code)
|
||||
assert.Equal(t, "Private network blocked\n", w2.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_BasicAuthConditions(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("auth processed"))
|
||||
})
|
||||
|
||||
@@ -688,27 +684,27 @@ func TestHTTPFlow_BasicAuthConditions(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test admin user
|
||||
req1 := httptest.NewRequest("GET", "/", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req1.SetBasicAuth("admin", "adminpass")
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
assert.Equal(t, "admin", w1.Header().Get("X-Auth-Status"))
|
||||
|
||||
// Test guest user
|
||||
req2 := httptest.NewRequest("GET", "/", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req2.SetBasicAuth("guest", "guestpass")
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Equal(t, "guest", w2.Header().Get("X-Auth-Status"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RouteConditions(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("route processed"))
|
||||
})
|
||||
|
||||
@@ -726,29 +722,29 @@ func TestHTTPFlow_RouteConditions(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test API route
|
||||
req1 := httptest.NewRequest("GET", "/", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req1 = routes.WithRouteContext(req1, mockRoute("backend"))
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
assert.Equal(t, "backend", w1.Header().Get("X-Route"))
|
||||
|
||||
// Test admin route
|
||||
req2 := httptest.NewRequest("GET", "/", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req2 = routes.WithRouteContext(req2, mockRoute("frontend"))
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Equal(t, "frontend", w2.Header().Get("X-Route"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ResponseStatusConditions(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(405)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
w.Write([]byte("method not allowed"))
|
||||
})
|
||||
|
||||
@@ -763,18 +759,18 @@ func TestHTTPFlow_ResponseStatusConditions(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 405, w.Code)
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
|
||||
assert.Equal(t, "error\n", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Response-Header", "response header")
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("processed"))
|
||||
})
|
||||
|
||||
@@ -789,11 +785,11 @@ func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 405, w.Code)
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
|
||||
assert.Equal(t, "error\n", w.Body.String())
|
||||
})
|
||||
t.Run("with_value", func(t *testing.T) {
|
||||
@@ -807,11 +803,11 @@ func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 405, w.Code)
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
|
||||
assert.Equal(t, "error\n", w.Body.String())
|
||||
})
|
||||
|
||||
@@ -826,18 +822,18 @@ func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "processed", w.Body.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("complex processed"))
|
||||
})
|
||||
|
||||
@@ -868,26 +864,26 @@ func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test admin API (should match first rule)
|
||||
req1 := httptest.NewRequest("POST", "/api/admin/users", nil)
|
||||
req1 := httptest.NewRequest(http.MethodPost, "/api/admin/users", nil)
|
||||
req1.Header.Set("Authorization", "Bearer token")
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
assert.Equal(t, "admin", w1.Header().Get("X-Access-Level"))
|
||||
assert.Equal(t, "v1", w1.Header()["X-API-Version"][0])
|
||||
|
||||
// Test user API (should match second rule)
|
||||
req2 := httptest.NewRequest("GET", "/api/users/profile", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/api/users/profile", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Equal(t, "user", w2.Header().Get("X-Access-Level"))
|
||||
assert.Equal(t, "v1", w2.Header()["X-API-Version"][0])
|
||||
|
||||
// Test public API (should match third rule)
|
||||
req3 := httptest.NewRequest("GET", "/api/public/info", nil)
|
||||
req3 := httptest.NewRequest(http.MethodGet, "/api/public/info", nil)
|
||||
w3 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w3, req3)
|
||||
|
||||
@@ -898,7 +894,7 @@ func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
|
||||
|
||||
func TestHTTPFlow_ResponseModifier(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("original response"))
|
||||
})
|
||||
|
||||
@@ -913,12 +909,12 @@ func TestHTTPFlow_ResponseModifier(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "true", w.Header().Get("X-Modified"))
|
||||
assert.Equal(t, "Modified: GET /test\n", w.Body.String())
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/yusing/godoxy/internal/common"
|
||||
"github.com/yusing/godoxy/internal/logging/accesslog"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
type noopWriteCloser struct {
|
||||
@@ -31,7 +30,7 @@ var (
|
||||
testFilesLock sync.Mutex
|
||||
)
|
||||
|
||||
func openFile(path string) (io.WriteCloser, gperr.Error) {
|
||||
func openFile(path string) (io.WriteCloser, error) {
|
||||
switch path {
|
||||
case "/dev/stdout":
|
||||
return stdout, nil
|
||||
|
||||
@@ -41,6 +41,7 @@ const (
|
||||
OnRoute = "route"
|
||||
|
||||
// on response
|
||||
|
||||
OnResponseHeader = "resp_header"
|
||||
OnStatus = "status"
|
||||
)
|
||||
@@ -59,10 +60,11 @@ var checkers = map[string]struct {
|
||||
),
|
||||
args: map[string]string{},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 0 {
|
||||
return nil, ErrExpectNoArg
|
||||
}
|
||||
//nolint:nilnil
|
||||
return nil, nil
|
||||
},
|
||||
builder: func(args any) CheckFunc { return func(w http.ResponseWriter, r *http.Request) bool { return false } }, // this should never be called
|
||||
@@ -251,7 +253,7 @@ var checkers = map[string]struct {
|
||||
"proto": "the http protocol (http, https, h3)",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -581,7 +583,7 @@ func (on *RuleOn) Parse(v string) error {
|
||||
}
|
||||
parsed, isResp, err := parseOn(rule)
|
||||
if err != nil {
|
||||
errs.Add(err.Subjectf("line %d", i+1))
|
||||
errs.AddSubjectf(err, "line %d", i+1)
|
||||
continue
|
||||
}
|
||||
if isResp {
|
||||
@@ -603,7 +605,7 @@ func (on *RuleOn) MarshalText() ([]byte, error) {
|
||||
return []byte(on.String()), nil
|
||||
}
|
||||
|
||||
func parseOn(line string) (Checker, bool, gperr.Error) {
|
||||
func parseOn(line string) (Checker, bool, error) {
|
||||
ors := splitPipe(line)
|
||||
|
||||
if len(ors) > 1 {
|
||||
@@ -645,7 +647,7 @@ func parseOn(line string) (Checker, bool, gperr.Error) {
|
||||
|
||||
validArgs, err := checker.validate(args)
|
||||
if err != nil {
|
||||
return nil, false, err.With(checker.help.Error())
|
||||
return nil, false, gperr.Wrap(err).With(checker.help.Error())
|
||||
}
|
||||
|
||||
checkFunc := checker.builder(validArgs)
|
||||
|
||||
@@ -31,7 +31,7 @@ var quoteChars = [256]bool{
|
||||
// error 403 "Forbidden 'foo' 'bar'"
|
||||
// error 403 Forbidden\ \"foo\"\ \"bar\".
|
||||
// error 403 "Message: ${CLOUDFLARE_API_KEY}"
|
||||
func parse(v string) (subject string, args []string, err gperr.Error) {
|
||||
func parse(v string) (subject string, args []string, err error) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, len(v)))
|
||||
|
||||
escaped := false
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
expect "github.com/yusing/goutils/testing"
|
||||
)
|
||||
|
||||
@@ -15,7 +13,6 @@ func TestParser(t *testing.T) {
|
||||
input string
|
||||
subject string
|
||||
args []string
|
||||
wantErr gperr.Error
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
@@ -93,10 +90,6 @@ func TestParser(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
subject, args, err := parse(tt.input)
|
||||
if tt.wantErr != nil {
|
||||
expect.ErrorIs(t, tt.wantErr, err)
|
||||
return
|
||||
}
|
||||
// t.Log(subject, args, err)
|
||||
expect.NoError(t, err)
|
||||
expect.Equal(t, subject, tt.subject)
|
||||
@@ -105,12 +98,8 @@ func TestParser(t *testing.T) {
|
||||
}
|
||||
t.Run("env substitution", func(t *testing.T) {
|
||||
// Set up test environment variables
|
||||
os.Setenv("CLOUDFLARE_API_KEY", "test-api-key-123")
|
||||
os.Setenv("DOMAIN", "example.com")
|
||||
defer func() {
|
||||
os.Unsetenv("CLOUDFLARE_API_KEY")
|
||||
os.Unsetenv("DOMAIN")
|
||||
}()
|
||||
t.Setenv("CLOUDFLARE_API_KEY", "test-api-key-123")
|
||||
t.Setenv("DOMAIN", "example.com")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/godoxy/internal/route/rules"
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
//go:embed *.yml
|
||||
@@ -35,12 +34,12 @@ func initPresets() {
|
||||
var rules rules.Rules
|
||||
content, err := fs.ReadFile(file.Name())
|
||||
if err != nil {
|
||||
gperr.LogError("failed to read rule preset", err)
|
||||
log.Err(err).Msg("failed to read rule preset")
|
||||
continue
|
||||
}
|
||||
_, err = serialization.ConvertString(string(content), reflect.ValueOf(&rules))
|
||||
if err != nil {
|
||||
gperr.LogError("failed to unmarshal rule preset", err)
|
||||
log.Err(err).Msg("failed to unmarshal rule preset")
|
||||
continue
|
||||
}
|
||||
rulePresets[file.Name()] = rules
|
||||
|
||||
@@ -3,12 +3,19 @@
|
||||
do: pass
|
||||
- name: protected
|
||||
on: |
|
||||
!path regex("(_next/static|_next/image|favicon.ico).*")
|
||||
!path glob("/api/v1/auth/*")
|
||||
!path glob("/auth/*")
|
||||
!path regex("[A-Za-z0-9_-]+\.(svg|png|jpg|jpeg|gif|ico|webp|woff2?|eot|ttf|otf|txt)(\?.+)?")
|
||||
!path /icon0.svg
|
||||
!path /favicon.ico
|
||||
!path /apple-icon.png
|
||||
!path glob("/web-app-manifest-*x*.png")
|
||||
!path regex("\/assets\/(chunks\/)?[a-zA-Z0-9\-_]+\.(css|js|woff2)")
|
||||
!path regex("\/assets\/workbox-window\.prod\.es5-[a-zA-Z0-9]+\.js")
|
||||
!path regex("/workbox-[a-zA-Z0-9]+\.js")
|
||||
!path /api/v1/version
|
||||
!path /manifest.json
|
||||
!path /sw.js
|
||||
!path /registerSW.js
|
||||
do: require_auth
|
||||
- name: proxy to backend
|
||||
on: path glob("/api/v1/*")
|
||||
|
||||
26
internal/route/rules/presets/webui_dev.yml
Normal file
26
internal/route/rules/presets/webui_dev.yml
Normal file
@@ -0,0 +1,26 @@
|
||||
- name: login page
|
||||
on: path /login
|
||||
do: pass
|
||||
- name: protected
|
||||
on: |
|
||||
!path glob("/@tanstack-start/*")
|
||||
!path glob("/@vite-plugin-pwa/*")
|
||||
!path glob("/__tsd/*")
|
||||
!path /@react-refresh
|
||||
!path /@vite/client
|
||||
!path regex("/\?token=[a-zA-Z0-9-_]+")
|
||||
!path glob("/@id/*")
|
||||
!path glob("/api/v1/auth/*")
|
||||
!path glob("/auth/*")
|
||||
!path regex("([A-Za-z0-9_\-/]+)+\.(css|ts|js|mjs|svg|png|jpg|jpeg|gif|ico|webp|woff2?|eot|ttf|otf|txt)(\?.*)?")
|
||||
!path /api/v1/version
|
||||
!path /manifest.json
|
||||
do: require_auth
|
||||
- name: proxy to backend
|
||||
on: path glob("/api/v1/*")
|
||||
do: proxy http://${API_ADDR}/
|
||||
- name: proxy to auth api
|
||||
on: path glob("/auth/*")
|
||||
do: |
|
||||
rewrite /auth /api/v1/auth
|
||||
proxy http://${API_ADDR}/
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
"github.com/rs/zerolog/log"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
@@ -58,7 +57,7 @@ func (rule *Rule) IsResponseRule() bool {
|
||||
return rule.On.IsResponseChecker() || rule.Do.IsResponseHandler()
|
||||
}
|
||||
|
||||
func (rules Rules) Validate() gperr.Error {
|
||||
func (rules Rules) Validate() error {
|
||||
var defaultRulesFound []int
|
||||
for i, rule := range rules {
|
||||
if rule.Name == "default" || rule.On.raw == OnDefault {
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
type (
|
||||
ValidateFunc func(args []string) (any, gperr.Error)
|
||||
ValidateFunc func(args []string) (any, error)
|
||||
Tuple[T1, T2 any] struct {
|
||||
First T1
|
||||
Second T2
|
||||
@@ -62,7 +62,7 @@ func (t *Tuple4[T1, T2, T3, T4]) String() string {
|
||||
}
|
||||
|
||||
// validateSingleMatcher returns Matcher with the matcher validated.
|
||||
func validateSingleMatcher(args []string) (any, gperr.Error) {
|
||||
func validateSingleMatcher(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -70,7 +70,7 @@ func validateSingleMatcher(args []string) (any, gperr.Error) {
|
||||
}
|
||||
|
||||
// toKVOptionalVMatcher returns *MapValueMatcher that value is optional.
|
||||
func toKVOptionalVMatcher(args []string) (any, gperr.Error) {
|
||||
func toKVOptionalVMatcher(args []string) (any, error) {
|
||||
switch len(args) {
|
||||
case 1:
|
||||
return &MapValueMatcher{args[0], nil}, nil
|
||||
@@ -85,7 +85,7 @@ func toKVOptionalVMatcher(args []string) (any, gperr.Error) {
|
||||
}
|
||||
}
|
||||
|
||||
func toKeyValueTemplate(args []string) (any, gperr.Error) {
|
||||
func toKeyValueTemplate(args []string) (any, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
}
|
||||
@@ -98,7 +98,7 @@ func toKeyValueTemplate(args []string) (any, gperr.Error) {
|
||||
}
|
||||
|
||||
// validateURL returns types.URL with the URL validated.
|
||||
func validateURL(args []string) (any, gperr.Error) {
|
||||
func validateURL(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -134,7 +134,7 @@ func validateAbsoluteURL(args []string) (any, gperr.Error) {
|
||||
}
|
||||
|
||||
// validateCIDR returns types.CIDR with the CIDR validated.
|
||||
func validateCIDR(args []string) (any, gperr.Error) {
|
||||
func validateCIDR(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -149,7 +149,7 @@ func validateCIDR(args []string) (any, gperr.Error) {
|
||||
}
|
||||
|
||||
// validateURLPath returns string with the path validated.
|
||||
func validateURLPath(args []string) (any, gperr.Error) {
|
||||
func validateURLPath(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -166,7 +166,7 @@ func validateURLPath(args []string) (any, gperr.Error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func validateURLPathMatcher(args []string) (any, gperr.Error) {
|
||||
func validateURLPathMatcher(args []string) (any, error) {
|
||||
path, err := validateURLPath(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -175,7 +175,7 @@ func validateURLPathMatcher(args []string) (any, gperr.Error) {
|
||||
}
|
||||
|
||||
// validateFSPath returns string with the path validated.
|
||||
func validateFSPath(args []string) (any, gperr.Error) {
|
||||
func validateFSPath(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -187,7 +187,7 @@ func validateFSPath(args []string) (any, gperr.Error) {
|
||||
}
|
||||
|
||||
// validateMethod returns string with the method validated.
|
||||
func validateMethod(args []string) (any, gperr.Error) {
|
||||
func validateMethod(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -218,7 +218,7 @@ func validateStatusCode(status string) (int, error) {
|
||||
// - 3xx
|
||||
// - 4xx
|
||||
// - 5xx
|
||||
func validateStatusRange(args []string) (any, gperr.Error) {
|
||||
func validateStatusRange(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -250,7 +250,7 @@ func validateStatusRange(args []string) (any, gperr.Error) {
|
||||
}
|
||||
|
||||
// validateUserBCryptPassword returns *HashedCrendential with the password validated.
|
||||
func validateUserBCryptPassword(args []string) (any, gperr.Error) {
|
||||
func validateUserBCryptPassword(args []string) (any, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
}
|
||||
@@ -258,7 +258,7 @@ func validateUserBCryptPassword(args []string) (any, gperr.Error) {
|
||||
}
|
||||
|
||||
// validateModField returns CommandHandler with the field validated.
|
||||
func validateModField(mod FieldModifier, args []string) (CommandHandler, gperr.Error) {
|
||||
func validateModField(mod FieldModifier, args []string) (CommandHandler, error) {
|
||||
if len(args) == 0 {
|
||||
return nil, ErrExpectTwoOrThreeArgs
|
||||
}
|
||||
@@ -275,7 +275,7 @@ func validateModField(mod FieldModifier, args []string) (CommandHandler, gperr.E
|
||||
}
|
||||
validArgs, err := setField.validate(args[1:])
|
||||
if err != nil {
|
||||
return nil, err.With(setField.help.Error())
|
||||
return nil, gperr.Wrap(err).With(setField.help.Error())
|
||||
}
|
||||
modder := setField.builder(validArgs)
|
||||
switch mod {
|
||||
@@ -299,7 +299,7 @@ func validateModField(mod FieldModifier, args []string) (CommandHandler, gperr.E
|
||||
return set, nil
|
||||
}
|
||||
|
||||
func validateTemplate(tmplStr string, newline bool) (templateString, gperr.Error) {
|
||||
func validateTemplate(tmplStr string, newline bool) (templateString, error) {
|
||||
if newline && !strings.HasSuffix(tmplStr, "\n") {
|
||||
tmplStr += "\n"
|
||||
}
|
||||
@@ -310,22 +310,15 @@ func validateTemplate(tmplStr string, newline bool) (templateString, gperr.Error
|
||||
|
||||
err := ValidateVars(tmplStr)
|
||||
if err != nil {
|
||||
return templateString{}, gperr.Wrap(err)
|
||||
return templateString{}, err
|
||||
}
|
||||
return templateString{tmplStr, true}, nil
|
||||
}
|
||||
|
||||
func validateLevel(level string) (zerolog.Level, gperr.Error) {
|
||||
func validateLevel(level string) (zerolog.Level, error) {
|
||||
l, err := zerolog.ParseLevel(level)
|
||||
if err != nil {
|
||||
return zerolog.NoLevel, ErrInvalidArguments.With(err)
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// func validateNotifProvider(provider string) gperr.Error {
|
||||
// if !notif.HasProvider(provider) {
|
||||
// return ErrInvalidArguments.Subject(provider)
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
|
||||
@@ -2,6 +2,7 @@ package rules
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
@@ -11,9 +12,9 @@ import (
|
||||
|
||||
func BenchmarkExpandVars(b *testing.B) {
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
testResponseModifier.WriteHeader(200)
|
||||
testResponseModifier.WriteHeader(http.StatusOK)
|
||||
testResponseModifier.Write([]byte("Hello, world!"))
|
||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||
testRequest := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
testRequest.Header.Set("User-Agent", "test-agent/1.0")
|
||||
testRequest.Header.Set("X-Custom", "value1,value2")
|
||||
testRequest.ContentLength = 12345
|
||||
|
||||
@@ -203,7 +203,7 @@ func TestExpandVars(t *testing.T) {
|
||||
postFormData.Add("postmulti", "first")
|
||||
postFormData.Add("postmulti", "second")
|
||||
|
||||
testRequest := httptest.NewRequest("POST", "https://example.com:8080/api/users?param1=value1¶m2=value2#fragment", strings.NewReader(postFormData.Encode()))
|
||||
testRequest := httptest.NewRequest(http.MethodPost, "https://example.com:8080/api/users?param1=value1¶m2=value2#fragment", strings.NewReader(postFormData.Encode()))
|
||||
testRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
testRequest.Header.Set("User-Agent", "test-agent/1.0")
|
||||
testRequest.Header.Add("X-Custom", "value1")
|
||||
@@ -218,7 +218,7 @@ func TestExpandVars(t *testing.T) {
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
testResponseModifier.Header().Set("Content-Type", "text/html")
|
||||
testResponseModifier.Header().Set("X-Custom-Resp", "resp-value")
|
||||
testResponseModifier.WriteHeader(200)
|
||||
testResponseModifier.WriteHeader(http.StatusOK)
|
||||
// set content length to 9876 by writing 9876 'a' bytes
|
||||
testResponseModifier.Write(bytes.Repeat([]byte("a"), 9876))
|
||||
|
||||
@@ -498,12 +498,12 @@ func TestExpandVars(t *testing.T) {
|
||||
|
||||
func TestExpandVars_Integration(t *testing.T) {
|
||||
t.Run("complex log format", func(t *testing.T) {
|
||||
testRequest := httptest.NewRequest("GET", "https://api.example.com/users/123?sort=asc", nil)
|
||||
testRequest := httptest.NewRequest(http.MethodGet, "https://api.example.com/users/123?sort=asc", nil)
|
||||
testRequest.Header.Set("User-Agent", "curl/7.68.0")
|
||||
testRequest.RemoteAddr = "10.0.0.1:54321"
|
||||
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
testResponseModifier.WriteHeader(200)
|
||||
testResponseModifier.WriteHeader(http.StatusOK)
|
||||
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest,
|
||||
@@ -515,7 +515,7 @@ func TestExpandVars_Integration(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("with query parameters", func(t *testing.T) {
|
||||
testRequest := httptest.NewRequest("GET", "http://example.com/search?q=test&page=1", nil)
|
||||
testRequest := httptest.NewRequest(http.MethodGet, "http://example.com/search?q=test&page=1", nil)
|
||||
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
@@ -529,12 +529,12 @@ func TestExpandVars_Integration(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("response headers", func(t *testing.T) {
|
||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||
testRequest := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
testResponseModifier.Header().Set("Cache-Control", "no-cache")
|
||||
testResponseModifier.Header().Set("X-Rate-Limit", "100")
|
||||
testResponseModifier.WriteHeader(200)
|
||||
testResponseModifier.WriteHeader(http.StatusOK)
|
||||
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest,
|
||||
@@ -554,7 +554,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "http scheme",
|
||||
request: httptest.NewRequest("GET", "http://example.com/", nil),
|
||||
request: httptest.NewRequest(http.MethodGet, "http://example.com/", nil),
|
||||
expected: "http",
|
||||
},
|
||||
{
|
||||
@@ -581,7 +581,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) {
|
||||
|
||||
func TestExpandVars_UpstreamVariables(t *testing.T) {
|
||||
// Upstream variables require context from routes package
|
||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||
testRequest := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
@@ -607,7 +607,7 @@ func TestExpandVars_UpstreamVariables(t *testing.T) {
|
||||
|
||||
func TestExpandVars_NoHostPort(t *testing.T) {
|
||||
// Test request without port in Host header
|
||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||
testRequest := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
testRequest.Host = "example.com" // No port
|
||||
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
@@ -623,13 +623,13 @@ func TestExpandVars_NoHostPort(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$req_port", &out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", out.String())
|
||||
require.Empty(t, out.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpandVars_NoRemotePort(t *testing.T) {
|
||||
// Test request without port in RemoteAddr
|
||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||
testRequest := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
testRequest.RemoteAddr = "192.168.1.1" // No port
|
||||
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
@@ -638,19 +638,19 @@ func TestExpandVars_NoRemotePort(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$remote_host", &out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", out.String())
|
||||
require.Empty(t, out.String())
|
||||
})
|
||||
|
||||
t.Run("remote_port without port", func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$remote_port", &out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", out.String())
|
||||
require.Empty(t, out.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpandVars_WhitespaceHandling(t *testing.T) {
|
||||
testRequest := httptest.NewRequest("GET", "/test", nil)
|
||||
testRequest := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
var out strings.Builder
|
||||
|
||||
Reference in New Issue
Block a user