fix: middleware bypass

This commit is contained in:
yusing
2025-05-11 06:33:22 +08:00
parent f1eefde964
commit 71ca8c738e
7 changed files with 294 additions and 273 deletions

View File

@@ -8,8 +8,8 @@ import (
"github.com/gobwas/glob"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@@ -242,7 +242,7 @@ var checkers = map[string]struct {
builder: func(args any) CheckFunc {
route := args.(string)
return func(_ Cache, r *http.Request) bool {
return reverseproxy.TryGetUpstreamName(r) == route
return routes.TryGetUpstreamName(r) == route
}
},
},

View File

@@ -0,0 +1,195 @@
package rules
import (
"testing"
"github.com/yusing/go-proxy/internal/gperr"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSplitAnd(t *testing.T) {
tests := []struct {
name string
input string
want []string
}{
{
name: "empty",
input: "",
want: []string{},
},
{
name: "single",
input: "rule",
want: []string{"rule"},
},
{
name: "multiple",
input: "rule1 & rule2",
want: []string{"rule1", "rule2"},
},
{
name: "multiple_newline",
input: "rule1\n\nrule2",
want: []string{"rule1", "rule2"},
},
{
name: "multiple_newline_and",
input: "rule1\nrule2 & rule3",
want: []string{"rule1", "rule2", "rule3"},
},
{
name: "empty segment",
input: "rule1\n& &rule2& rule3",
want: []string{"rule1", "rule2", "rule3"},
},
{
name: "double_and",
input: "rule1\nrule2 && rule3",
want: []string{"rule1", "rule2", "rule3"},
},
{
name: "spaces_around",
input: " rule1\nrule2 & rule3 ",
want: []string{"rule1", "rule2", "rule3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := splitAnd(tt.input)
expect.Equal(t, got, tt.want)
})
}
}
func TestParseOn(t *testing.T) {
tests := []struct {
name string
input string
wantErr gperr.Error
}{
// header
{
name: "header_valid_kv",
input: "header Connection Upgrade",
wantErr: nil,
},
{
name: "header_valid_k",
input: "header Connection",
wantErr: nil,
},
{
name: "header_missing_arg",
input: "header",
wantErr: ErrExpectKVOptionalV,
},
// query
{
name: "query_valid_kv",
input: "query key value",
wantErr: nil,
},
{
name: "query_valid_k",
input: "query key",
wantErr: nil,
},
{
name: "query_missing_arg",
input: "query",
wantErr: ErrExpectKVOptionalV,
},
{
name: "cookie_valid_kv",
input: "cookie key value",
wantErr: nil,
},
{
name: "cookie_valid_k",
input: "cookie key",
wantErr: nil,
},
{
name: "cookie_missing_arg",
input: "cookie",
wantErr: ErrExpectKVOptionalV,
},
// method
{
name: "method_valid",
input: "method GET",
wantErr: nil,
},
{
name: "method_invalid",
input: "method invalid",
wantErr: ErrInvalidArguments,
},
{
name: "method_missing_arg",
input: "method",
wantErr: ErrExpectOneArg,
},
// path
{
name: "path_valid",
input: "path /home",
wantErr: nil,
},
{
name: "path_missing_arg",
input: "path",
wantErr: ErrExpectOneArg,
},
// remote
{
name: "remote_valid",
input: "remote 127.0.0.1",
wantErr: nil,
},
{
name: "remote_invalid",
input: "remote abcd",
wantErr: ErrInvalidArguments,
},
{
name: "remote_missing_arg",
input: "remote",
wantErr: ErrExpectOneArg,
},
{
name: "unknown_target",
input: "unknown",
wantErr: ErrInvalidOnTarget,
},
// route
{
name: "route_valid",
input: "route example",
wantErr: nil,
},
{
name: "route_missing_arg",
input: "route",
wantErr: ErrExpectOneArg,
},
{
name: "route_extra_arg",
input: "route example1 example2",
wantErr: ErrExpectOneArg,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
on := &RuleOn{}
err := on.Parse(tt.input)
if tt.wantErr != nil {
expect.HasError(t, tt.wantErr, err)
} else {
expect.NoError(t, err)
}
})
}
}

View File

@@ -1,4 +1,4 @@
package rules
package rules_test
import (
"encoding/base64"
@@ -7,199 +7,13 @@ import (
"net/url"
"testing"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
. "github.com/yusing/go-proxy/internal/utils/testing"
"github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/route/routes"
. "github.com/yusing/go-proxy/internal/route/rules"
expect "github.com/yusing/go-proxy/internal/utils/testing"
"golang.org/x/crypto/bcrypt"
)
func TestSplitAnd(t *testing.T) {
tests := []struct {
name string
input string
want []string
}{
{
name: "empty",
input: "",
want: []string{},
},
{
name: "single",
input: "rule",
want: []string{"rule"},
},
{
name: "multiple",
input: "rule1 & rule2",
want: []string{"rule1", "rule2"},
},
{
name: "multiple_newline",
input: "rule1\n\nrule2",
want: []string{"rule1", "rule2"},
},
{
name: "multiple_newline_and",
input: "rule1\nrule2 & rule3",
want: []string{"rule1", "rule2", "rule3"},
},
{
name: "empty segment",
input: "rule1\n& &rule2& rule3",
want: []string{"rule1", "rule2", "rule3"},
},
{
name: "double_and",
input: "rule1\nrule2 && rule3",
want: []string{"rule1", "rule2", "rule3"},
},
{
name: "spaces_around",
input: " rule1\nrule2 & rule3 ",
want: []string{"rule1", "rule2", "rule3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := splitAnd(tt.input)
ExpectEqual(t, got, tt.want)
})
}
}
func TestParseOn(t *testing.T) {
tests := []struct {
name string
input string
wantErr gperr.Error
}{
// header
{
name: "header_valid_kv",
input: "header Connection Upgrade",
wantErr: nil,
},
{
name: "header_valid_k",
input: "header Connection",
wantErr: nil,
},
{
name: "header_missing_arg",
input: "header",
wantErr: ErrExpectKVOptionalV,
},
// query
{
name: "query_valid_kv",
input: "query key value",
wantErr: nil,
},
{
name: "query_valid_k",
input: "query key",
wantErr: nil,
},
{
name: "query_missing_arg",
input: "query",
wantErr: ErrExpectKVOptionalV,
},
{
name: "cookie_valid_kv",
input: "cookie key value",
wantErr: nil,
},
{
name: "cookie_valid_k",
input: "cookie key",
wantErr: nil,
},
{
name: "cookie_missing_arg",
input: "cookie",
wantErr: ErrExpectKVOptionalV,
},
// method
{
name: "method_valid",
input: "method GET",
wantErr: nil,
},
{
name: "method_invalid",
input: "method invalid",
wantErr: ErrInvalidArguments,
},
{
name: "method_missing_arg",
input: "method",
wantErr: ErrExpectOneArg,
},
// path
{
name: "path_valid",
input: "path /home",
wantErr: nil,
},
{
name: "path_missing_arg",
input: "path",
wantErr: ErrExpectOneArg,
},
// remote
{
name: "remote_valid",
input: "remote 127.0.0.1",
wantErr: nil,
},
{
name: "remote_invalid",
input: "remote abcd",
wantErr: ErrInvalidArguments,
},
{
name: "remote_missing_arg",
input: "remote",
wantErr: ErrExpectOneArg,
},
{
name: "unknown_target",
input: "unknown",
wantErr: ErrInvalidOnTarget,
},
// route
{
name: "route_valid",
input: "route example",
wantErr: nil,
},
{
name: "route_missing_arg",
input: "route",
wantErr: ErrExpectOneArg,
},
{
name: "route_extra_arg",
input: "route example1 example2",
wantErr: ErrExpectOneArg,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
on := &RuleOn{}
err := on.Parse(tt.input)
if tt.wantErr != nil {
ExpectError(t, tt.wantErr, err)
} else {
ExpectNoError(t, err)
}
})
}
}
type testCorrectness struct {
name string
checker string
@@ -284,7 +98,7 @@ func TestOnCorrectness(t *testing.T) {
},
{
name: "basic_auth_correct",
checker: "basic_auth user " + string(Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
checker: "basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
input: &http.Request{
Header: http.Header{
"Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:password"))}, // "user:password"
@@ -294,7 +108,7 @@ func TestOnCorrectness(t *testing.T) {
},
{
name: "basic_auth_incorrect",
checker: "basic_auth user " + string(Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
checker: "basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
input: &http.Request{
Header: http.Header{
"Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:incorrect"))}, // "user:wrong"
@@ -305,7 +119,10 @@ func TestOnCorrectness(t *testing.T) {
{
name: "route_match",
checker: "route example",
input: reverseproxy.NewReverseProxy("example", nil, http.DefaultTransport).WithContextValue(&http.Request{}),
input: routes.WithRouteContext(&http.Request{}, expect.Must(route.NewFileServer(&route.Route{
Alias: "example",
Root: "/",
}))),
want: true,
},
{
@@ -354,12 +171,11 @@ func TestOnCorrectness(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
on, err := parseOn(tt.checker)
ExpectNoError(t, err)
var on RuleOn
err := on.Parse(tt.checker)
expect.NoError(t, err)
got := on.Check(Cache{}, tt.input)
if tt.want != got {
t.Errorf("want %v, got %v", tt.want, got)
}
expect.Equal(t, tt.want, got)
})
}
}