diff --git a/internal/route/route.go b/internal/route/route.go index 04a14fed..27897ed0 100644 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -55,7 +55,7 @@ type ( route.HTTPConfig PathPatterns []string `json:"path_patterns,omitempty" extensions:"x-nullable"` - Rules rules.Rules `json:"rules,omitempty" extension:"x-nullable"` + Rules rules.Rules `json:"rules,omitempty" extensions:"x-nullable"` RuleFile string `json:"rule_file,omitempty" extensions:"x-nullable"` HealthCheck types.HealthCheckConfig `json:"healthcheck,omitempty" extensions:"x-nullable"` // null on load-balancer routes LoadBalance *types.LoadBalancerConfig `json:"load_balance,omitempty" extensions:"x-nullable"` diff --git a/internal/route/rules/errors.go b/internal/route/rules/errors.go index f99a7091..4d310559 100644 --- a/internal/route/rules/errors.go +++ b/internal/route/rules/errors.go @@ -15,6 +15,7 @@ var ( ErrInvalidArguments = gperr.New("invalid arguments") ErrInvalidOnTarget = gperr.New("invalid `rule.on` target") ErrInvalidCommandSequence = gperr.New("invalid command sequence") + ErrMultipleDefaultRules = gperr.New("multiple default rules") // vars errors ErrNoArgProvided = gperr.New("no argument provided") diff --git a/internal/route/rules/on.go b/internal/route/rules/on.go index 6d017b0f..4471c656 100644 --- a/internal/route/rules/on.go +++ b/internal/route/rules/on.go @@ -26,6 +26,7 @@ func (on *RuleOn) Check(w http.ResponseWriter, r *http.Request) bool { } const ( + OnDefault = "default" OnHeader = "header" OnQuery = "query" OnCookie = "cookie" @@ -50,6 +51,22 @@ var checkers = map[string]struct { builder func(args any) CheckFunc isResponseChecker bool }{ + OnDefault: { + help: Help{ + command: OnDefault, + description: makeLines( + "The default rule is matched when no other rules are matched.", + ), + args: map[string]string{}, + }, + validate: func(args []string) (any, gperr.Error) { + if len(args) != 0 { + return nil, ErrExpectNoArg + } + return nil, nil + }, + builder: func(args any) CheckFunc { return func(w http.ResponseWriter, r *http.Request) bool { return false } }, // this should never be called + }, OnHeader: { help: Help{ command: OnHeader, diff --git a/internal/route/rules/rules.go b/internal/route/rules/rules.go index ab014492..0fbf9b88 100644 --- a/internal/route/rules/rules.go +++ b/internal/route/rules/rules.go @@ -7,6 +7,7 @@ import ( "github.com/quic-go/quic-go/http3" "github.com/rs/zerolog/log" + gperr "github.com/yusing/goutils/errs" httputils "github.com/yusing/goutils/http" "golang.org/x/net/http2" @@ -57,6 +58,19 @@ func (rule *Rule) IsResponseRule() bool { return rule.On.IsResponseChecker() || rule.Do.IsResponseHandler() } +func (rules Rules) Validate() gperr.Error { + var defaultRulesFound []int + for i, rule := range rules { + if rule.Name == "default" || rule.On.raw == OnDefault { + defaultRulesFound = append(defaultRulesFound, i) + } + } + if len(defaultRulesFound) > 1 { + return ErrMultipleDefaultRules.Withf("found %d", len(defaultRulesFound)) + } + return nil +} + // BuildHandler returns a http.HandlerFunc that implements the rules. func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { if len(rules) == 0 { @@ -74,7 +88,7 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { var nonDefaultRules Rules hasDefaultRule := false for i, rule := range rules { - if rule.Name == "default" { + if rule.Name == "default" || rule.On.raw == OnDefault { defaultRule = rule hasDefaultRule = true } else { diff --git a/internal/route/rules/rules_test.go b/internal/route/rules/rules_test.go new file mode 100644 index 00000000..7a4b2e81 --- /dev/null +++ b/internal/route/rules/rules_test.go @@ -0,0 +1,52 @@ +package rules + +import ( + "reflect" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/internal/serialization" +) + +func TestRulesValidate(t *testing.T) { + tests := []struct { + name string + rules string + want error + }{ + { + name: "no default rule", + rules: ` +- name: rule1 + on: header Host example.com + do: pass + `, + }, + { + name: "multiple default rules", + rules: ` +- name: default + do: pass +- name: rule1 + on: default + do: pass + `, + want: ErrMultipleDefaultRules, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var rules Rules + convertible, err := serialization.ConvertString(strings.TrimSpace(tt.rules), reflect.ValueOf(&rules)) + require.True(t, convertible) + + if tt.want == nil { + assert.NoError(t, err) + return + } + assert.ErrorIs(t, err, tt.want) + }) + } +}