fix(lint): improve styling and fix lint errors

This commit is contained in:
yusing
2026-02-10 16:57:41 +08:00
parent 978dd886c0
commit a0d0ad0958
42 changed files with 317 additions and 341 deletions

View File

@@ -23,19 +23,21 @@ import (
apitypes "github.com/yusing/goutils/apitypes" apitypes "github.com/yusing/goutils/apitypes"
) )
// NewHandler creates a new Gin engine for the API.
//
// @title GoDoxy API // @title GoDoxy API
// @version 1.0 // @version 1.0
// @description GoDoxy API // @description GoDoxy API
// @termsOfService https://github.com/yusing/godoxy/blob/main/LICENSE // @termsOfService https://github.com/yusing/godoxy/blob/main/LICENSE
//
// @contact.name Yusing // @contact.name Yusing
// @contact.url https://github.com/yusing/godoxy/issues // @contact.url https://github.com/yusing/godoxy/issues
//
// @license.name MIT // @license.name MIT
// @license.url https://github.com/yusing/godoxy/blob/main/LICENSE // @license.url https://github.com/yusing/godoxy/blob/main/LICENSE
//
// @BasePath /api/v1 // @BasePath /api/v1
//
// @externalDocs.description GoDoxy Docs // @externalDocs.description GoDoxy Docs
// @externalDocs.url https://docs.godoxy.dev // @externalDocs.url https://docs.godoxy.dev
func NewHandler(requireAuth bool) *gin.Engine { func NewHandler(requireAuth bool) *gin.Engine {

View File

@@ -7,7 +7,6 @@ import (
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types" entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
"github.com/yusing/godoxy/internal/homepage" "github.com/yusing/godoxy/internal/homepage"
_ "github.com/yusing/goutils/apitypes"
apitypes "github.com/yusing/goutils/apitypes" apitypes "github.com/yusing/goutils/apitypes"
) )

View File

@@ -257,8 +257,8 @@ func handlerWithRecover(w http.ResponseWriter, r *http.Request, h http.HandlerFu
} }
func parseRules(rawRules []RawRule) ([]ParsedRule, rules.Rules, error) { func parseRules(rawRules []RawRule) ([]ParsedRule, rules.Rules, error) {
var parsedRules []ParsedRule parsedRules := make([]ParsedRule, 0, len(rawRules))
var rulesList rules.Rules rulesList := make(rules.Rules, 0, len(rawRules))
var valErrs gperr.Builder var valErrs gperr.Builder

View File

@@ -79,7 +79,7 @@ func TestPlayground(t *testing.T) {
if len(resp.MatchedRules) != 1 { if len(resp.MatchedRules) != 1 {
t.Errorf("expected 1 matched rule, got %d", len(resp.MatchedRules)) t.Errorf("expected 1 matched rule, got %d", len(resp.MatchedRules))
} }
if resp.FinalResponse.StatusCode != 403 { if resp.FinalResponse.StatusCode != http.StatusForbidden {
t.Errorf("expected status 403, got %d", resp.FinalResponse.StatusCode) t.Errorf("expected status 403, got %d", resp.FinalResponse.StatusCode)
} }
if resp.UpstreamCalled { if resp.UpstreamCalled {
@@ -168,7 +168,7 @@ func TestPlayground(t *testing.T) {
if len(resp.MatchedRules) != 1 { if len(resp.MatchedRules) != 1 {
t.Errorf("expected 1 matched rule, got %d", len(resp.MatchedRules)) t.Errorf("expected 1 matched rule, got %d", len(resp.MatchedRules))
} }
if resp.FinalResponse.StatusCode != 405 { if resp.FinalResponse.StatusCode != http.StatusMethodNotAllowed {
t.Errorf("expected status 405, got %d", resp.FinalResponse.StatusCode) t.Errorf("expected status 405, got %d", resp.FinalResponse.StatusCode)
} }
}, },
@@ -179,7 +179,7 @@ func TestPlayground(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// Create request // Create request
body, _ := json.Marshal(tt.request) body, _ := json.Marshal(tt.request)
req := httptest.NewRequest("POST", "/api/v1/route/playground", bytes.NewReader(body)) req := httptest.NewRequest(http.MethodPost, "/api/v1/route/playground", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
// Create response recorder // Create response recorder
@@ -214,7 +214,7 @@ func TestPlayground(t *testing.T) {
func TestPlaygroundInvalidRequest(t *testing.T) { func TestPlaygroundInvalidRequest(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
req := httptest.NewRequest("POST", "/api/v1/route/playground", bytes.NewReader([]byte(`{}`))) req := httptest.NewRequest(http.MethodPost, "/api/v1/route/playground", bytes.NewReader([]byte(`{}`)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()

View File

@@ -135,7 +135,7 @@ func (auth *OIDCProvider) setSessionTokenCookie(w http.ResponseWriter, r *http.R
func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionClaims, valid bool, err error) { func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionClaims, valid bool, err error) {
claims = &sessionClaims{} claims = &sessionClaims{}
sessionToken, err := jwt.ParseWithClaims(sessionJWT, claims, func(t *jwt.Token) (interface{}, error) { sessionToken, err := jwt.ParseWithClaims(sessionJWT, claims, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
} }

View File

@@ -24,8 +24,9 @@ type (
tokenTTL time.Duration tokenTTL time.Duration
} }
UserPassClaims struct { UserPassClaims struct {
Username string `json:"username"`
jwt.RegisteredClaims jwt.RegisteredClaims
Username string `json:"username"`
} }
) )
@@ -78,7 +79,7 @@ func (auth *UserPassAuth) CheckToken(r *http.Request) error {
return ErrMissingSessionToken return ErrMissingSessionToken
} }
var claims UserPassClaims var claims UserPassClaims
token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
} }

View File

@@ -467,7 +467,7 @@ func (p *Provider) scheduleRenewal(parent task.Parent) {
log.Warn().Err(p.fmtError(err)).Msg("autocert: cert renew failed") log.Warn().Err(p.fmtError(err)).Msg("autocert: cert renew failed")
notif.Notify(&notif.LogMessage{ notif.Notify(&notif.LogMessage{
Level: zerolog.ErrorLevel, Level: zerolog.ErrorLevel,
Title: fmt.Sprintf("SSL certificate renewal failed for %s", p.GetName()), Title: "SSL certificate renewal failed for " + p.GetName(),
Body: notif.MessageBody(err.Error()), Body: notif.MessageBody(err.Error()),
}) })
return return
@@ -477,7 +477,7 @@ func (p *Provider) scheduleRenewal(parent task.Parent) {
notif.Notify(&notif.LogMessage{ notif.Notify(&notif.LogMessage{
Level: zerolog.InfoLevel, Level: zerolog.InfoLevel,
Title: fmt.Sprintf("SSL certificate renewed for %s", p.GetName()), Title: "SSL certificate renewed for " + p.GetName(),
Body: notif.ListBody(p.cfg.Domains), Body: notif.ListBody(p.cfg.Domains),
}) })

View File

@@ -4,7 +4,7 @@ import "context"
type ContextKey struct{} type ContextKey struct{}
func SetCtx(ctx interface{ SetValue(any, any) }, p Provider) { func SetCtx(ctx interface{ SetValue(key, value any) }, p Provider) {
ctx.SetValue(ContextKey{}, p) ctx.SetValue(ContextKey{}, p)
} }

View File

@@ -26,11 +26,9 @@ var (
const configEventFlushInterval = 500 * time.Millisecond const configEventFlushInterval = 500 * time.Millisecond
const ( var (
cfgRenameWarn = `Config file renamed, not reloading. errCfgRenameWarn = errors.New("config file renamed, not reloading; Make sure you rename it back before next time you start")
Make sure you rename it back before next time you start.` errCfgDeleteWarn = errors.New(`config file deleted, not reloading; You may run "ls-config" to show or dump the current config`)
cfgDeleteWarn = `Config file deleted, not reloading.
You may run "ls-config" to show or dump the current config.`
) )
func logNotifyError(action string, err error) { func logNotifyError(action string, err error) {
@@ -142,11 +140,11 @@ func OnConfigChange(ev []watcherEvents.Event) {
// no matter how many events during the interval // no matter how many events during the interval
// just reload once and check the last event // just reload once and check the last event
switch ev[len(ev)-1].Action { switch ev[len(ev)-1].Action {
case events.ActionFileRenamed: case watcherEvents.ActionFileRenamed:
logNotifyWarn("rename", errors.New(cfgRenameWarn)) logNotifyWarn("rename", errCfgRenameWarn)
return return
case events.ActionFileDeleted: case watcherEvents.ActionFileDeleted:
logNotifyWarn("delete", errors.New(cfgDeleteWarn)) logNotifyWarn("delete", errCfgDeleteWarn)
return return
} }

View File

@@ -17,6 +17,7 @@ import (
"github.com/goccy/go-yaml" "github.com/goccy/go-yaml"
"github.com/puzpuzpuz/xsync/v4" "github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log"
acl "github.com/yusing/godoxy/internal/acl/types" acl "github.com/yusing/godoxy/internal/acl/types"
"github.com/yusing/godoxy/internal/agentpool" "github.com/yusing/godoxy/internal/agentpool"
"github.com/yusing/godoxy/internal/api" "github.com/yusing/godoxy/internal/api"
@@ -90,11 +91,6 @@ func SetState(state config.State) {
cfg := state.Value() cfg := state.Value()
config.ActiveState.Store(state) config.ActiveState.Store(state)
homepage.ActiveConfig.Store(&cfg.Homepage) homepage.ActiveConfig.Store(&cfg.Homepage)
if autocertProvider := state.AutoCertProvider(); autocertProvider != nil {
autocert.ActiveProvider.Store(autocertProvider.(*autocert.Provider))
} else {
autocert.ActiveProvider.Store(nil)
}
} }
func HasState() bool { func HasState() bool {
@@ -203,25 +199,31 @@ func (state *state) NumProviders() int {
} }
func (state *state) FlushTmpLog() { func (state *state) FlushTmpLog() {
state.tmpLogBuf.WriteTo(os.Stdout) _, _ = state.tmpLogBuf.WriteTo(os.Stdout)
state.tmpLogBuf.Reset() state.tmpLogBuf.Reset()
} }
func (state *state) StartAPIServers() { func (state *state) StartAPIServers() {
// API Handler needs to start after auth is initialized. // API Handler needs to start after auth is initialized.
server.StartServer(state.task.Subtask("api_server", false), server.Options{ _, err := server.StartServer(state.task.Subtask("api_server", false), server.Options{
Name: "api", Name: "api",
HTTPAddr: common.APIHTTPAddr, HTTPAddr: common.APIHTTPAddr,
Handler: api.NewHandler(true), Handler: api.NewHandler(true),
}) })
if err != nil {
log.Err(err).Msg("failed to start API server")
}
// Local API Handler is used for unauthenticated access. // Local API Handler is used for unauthenticated access.
if common.LocalAPIHTTPAddr != "" { if common.LocalAPIHTTPAddr != "" {
server.StartServer(state.task.Subtask("local_api_server", false), server.Options{ _, err := server.StartServer(state.task.Subtask("local_api_server", false), server.Options{
Name: "local_api", Name: "local_api",
HTTPAddr: common.LocalAPIHTTPAddr, HTTPAddr: common.LocalAPIHTTPAddr,
Handler: api.NewHandler(false), Handler: api.NewHandler(false),
}) })
if err != nil {
log.Err(err).Msg("failed to start local API server")
}
} }
} }

View File

@@ -47,10 +47,10 @@ var _ entrypoint.Entrypoint = &Entrypoint{}
var emptyCfg Config var emptyCfg Config
func NewTestEntrypoint(t testing.TB, cfg *Config) *Entrypoint { func NewTestEntrypoint(tb testing.TB, cfg *Config) *Entrypoint {
t.Helper() tb.Helper()
testTask := task.GetTestTask(t) testTask := task.GetTestTask(tb)
ep := NewEntrypoint(testTask, cfg) ep := NewEntrypoint(testTask, cfg)
entrypoint.SetCtx(testTask, ep) entrypoint.SetCtx(testTask, ep)
return ep return ep
@@ -160,6 +160,7 @@ func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Request
} }
func findRouteAnyDomain(routes HTTPRoutes, host string) types.HTTPRoute { func findRouteAnyDomain(routes HTTPRoutes, host string) types.HTTPRoute {
//nolint:modernize
idx := strings.IndexByte(host, '.') idx := strings.IndexByte(host, '.')
if idx != -1 { if idx != -1 {
target := host[:idx] target := host[:idx]

View File

@@ -19,7 +19,7 @@ import (
"github.com/yusing/goutils/server" "github.com/yusing/goutils/server"
) )
// httpServer is a server that listens on a given address and serves HTTP routes. // HTTPServer is a server that listens on a given address and serves HTTP routes.
type HTTPServer interface { type HTTPServer interface {
Listen(addr string, proto HTTPProto) error Listen(addr string, proto HTTPProto) error
AddRoute(route types.HTTPRoute) AddRoute(route types.HTTPRoute)
@@ -109,6 +109,8 @@ func (srv *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
rec := accesslog.GetResponseRecorder(w) rec := accesslog.GetResponseRecorder(w)
w = rec w = rec
defer func() { defer func() {
// there is no body to close
//nolint:bodyclose
srv.ep.accessLogger.LogRequest(r, rec.Response()) srv.ep.accessLogger.LogRequest(r, rec.Response())
accesslog.PutResponseRecorder(rec) accesslog.PutResponseRecorder(rec)
}() }()

View File

@@ -19,7 +19,7 @@ func TestShortLinkMatcher_FQDNAlias(t *testing.T) {
matcher.AddRoute("app.domain.com") matcher.AddRoute("app.domain.com")
t.Run("exact path", func(t *testing.T) { t.Run("exact path", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app", nil) req := httptest.NewRequest(http.MethodGet, "/app", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
matcher.ServeHTTP(w, req) matcher.ServeHTTP(w, req)
@@ -28,7 +28,7 @@ func TestShortLinkMatcher_FQDNAlias(t *testing.T) {
}) })
t.Run("with path remainder", func(t *testing.T) { t.Run("with path remainder", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app/foo/bar", nil) req := httptest.NewRequest(http.MethodGet, "/app/foo/bar", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
matcher.ServeHTTP(w, req) matcher.ServeHTTP(w, req)
@@ -37,7 +37,7 @@ func TestShortLinkMatcher_FQDNAlias(t *testing.T) {
}) })
t.Run("with query", func(t *testing.T) { t.Run("with query", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app/foo?x=y&z=1", nil) req := httptest.NewRequest(http.MethodGet, "/app/foo?x=y&z=1", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
matcher.ServeHTTP(w, req) matcher.ServeHTTP(w, req)
@@ -53,7 +53,7 @@ func TestShortLinkMatcher_SubdomainAlias(t *testing.T) {
matcher.AddRoute("app") matcher.AddRoute("app")
t.Run("exact path", func(t *testing.T) { t.Run("exact path", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app", nil) req := httptest.NewRequest(http.MethodGet, "/app", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
matcher.ServeHTTP(w, req) matcher.ServeHTTP(w, req)
@@ -62,7 +62,7 @@ func TestShortLinkMatcher_SubdomainAlias(t *testing.T) {
}) })
t.Run("with path remainder", func(t *testing.T) { t.Run("with path remainder", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app/foo/bar", nil) req := httptest.NewRequest(http.MethodGet, "/app/foo/bar", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
matcher.ServeHTTP(w, req) matcher.ServeHTTP(w, req)
@@ -78,7 +78,7 @@ func TestShortLinkMatcher_NotFound(t *testing.T) {
matcher.AddRoute("app") matcher.AddRoute("app")
t.Run("missing key", func(t *testing.T) { t.Run("missing key", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
matcher.ServeHTTP(w, req) matcher.ServeHTTP(w, req)
@@ -86,7 +86,7 @@ func TestShortLinkMatcher_NotFound(t *testing.T) {
}) })
t.Run("unknown key", func(t *testing.T) { t.Run("unknown key", func(t *testing.T) {
req := httptest.NewRequest("GET", "/unknown", nil) req := httptest.NewRequest(http.MethodGet, "/unknown", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
matcher.ServeHTTP(w, req) matcher.ServeHTTP(w, req)
@@ -103,13 +103,13 @@ func TestShortLinkMatcher_AddDelRoute(t *testing.T) {
matcher.AddRoute("app2.domain.com") matcher.AddRoute("app2.domain.com")
t.Run("both routes work", func(t *testing.T) { t.Run("both routes work", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app1", nil) req := httptest.NewRequest(http.MethodGet, "/app1", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
matcher.ServeHTTP(w, req) matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code) assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app1.example.com/", w.Header().Get("Location")) assert.Equal(t, "https://app1.example.com/", w.Header().Get("Location"))
req = httptest.NewRequest("GET", "/app2.domain.com", nil) req = httptest.NewRequest(http.MethodGet, "/app2.domain.com", nil)
w = httptest.NewRecorder() w = httptest.NewRecorder()
matcher.ServeHTTP(w, req) matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code) assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
@@ -119,12 +119,12 @@ func TestShortLinkMatcher_AddDelRoute(t *testing.T) {
t.Run("delete route", func(t *testing.T) { t.Run("delete route", func(t *testing.T) {
matcher.DelRoute("app1") matcher.DelRoute("app1")
req := httptest.NewRequest("GET", "/app1", nil) req := httptest.NewRequest(http.MethodGet, "/app1", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
matcher.ServeHTTP(w, req) matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusNotFound, w.Code) assert.Equal(t, http.StatusNotFound, w.Code)
req = httptest.NewRequest("GET", "/app2.domain.com", nil) req = httptest.NewRequest(http.MethodGet, "/app2.domain.com", nil)
w = httptest.NewRecorder() w = httptest.NewRecorder()
matcher.ServeHTTP(w, req) matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code) assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
@@ -140,7 +140,7 @@ func TestShortLinkMatcher_NoDefaultDomainSuffix(t *testing.T) {
t.Run("subdomain alias ignored", func(t *testing.T) { t.Run("subdomain alias ignored", func(t *testing.T) {
matcher.AddRoute("app") matcher.AddRoute("app")
req := httptest.NewRequest("GET", "/app", nil) req := httptest.NewRequest(http.MethodGet, "/app", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
matcher.ServeHTTP(w, req) matcher.ServeHTTP(w, req)
@@ -150,7 +150,7 @@ func TestShortLinkMatcher_NoDefaultDomainSuffix(t *testing.T) {
t.Run("FQDN alias still works", func(t *testing.T) { t.Run("FQDN alias still works", func(t *testing.T) {
matcher.AddRoute("app.domain.com") matcher.AddRoute("app.domain.com")
req := httptest.NewRequest("GET", "/app.domain.com", nil) req := httptest.NewRequest(http.MethodGet, "/app.domain.com", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
matcher.ServeHTTP(w, req) matcher.ServeHTTP(w, req)
@@ -169,7 +169,7 @@ func TestEntrypoint_ShortLinkDispatch(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
t.Run("shortlink host", func(t *testing.T) { t.Run("shortlink host", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app", nil) req := httptest.NewRequest(http.MethodGet, "/app", nil)
req.Host = common.ShortLinkPrefix req.Host = common.ShortLinkPrefix
w := httptest.NewRecorder() w := httptest.NewRecorder()
server.ServeHTTP(w, req) server.ServeHTTP(w, req)
@@ -179,7 +179,7 @@ func TestEntrypoint_ShortLinkDispatch(t *testing.T) {
}) })
t.Run("shortlink host with port", func(t *testing.T) { t.Run("shortlink host with port", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app", nil) req := httptest.NewRequest(http.MethodGet, "/app", nil)
req.Host = common.ShortLinkPrefix + ":8080" req.Host = common.ShortLinkPrefix + ":8080"
w := httptest.NewRecorder() w := httptest.NewRecorder()
server.ServeHTTP(w, req) server.ServeHTTP(w, req)
@@ -189,7 +189,7 @@ func TestEntrypoint_ShortLinkDispatch(t *testing.T) {
}) })
t.Run("normal host", func(t *testing.T) { t.Run("normal host", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app", nil) req := httptest.NewRequest(http.MethodGet, "/app", nil)
req.Host = "app.example.com" req.Host = "app.example.com"
w := httptest.NewRecorder() w := httptest.NewRecorder()
server.ServeHTTP(w, req) server.ServeHTTP(w, req)

View File

@@ -16,21 +16,23 @@ import (
type DockerHealthcheckState struct { type DockerHealthcheckState struct {
client *docker.SharedClient client *docker.SharedClient
containerId string containerID string
numDockerFailures int numDockerFailures int
} }
const dockerFailuresThreshold = 3 const dockerFailuresThreshold = 3
var ErrDockerHealthCheckFailedTooManyTimes = errors.New("docker health check failed too many times") var (
var ErrDockerHealthCheckNotAvailable = errors.New("docker health check not available") ErrDockerHealthCheckFailedTooManyTimes = errors.New("docker health check failed too many times")
ErrDockerHealthCheckNotAvailable = errors.New("docker health check not available")
)
func NewDockerHealthcheckState(client *docker.SharedClient, containerId string) *DockerHealthcheckState { func NewDockerHealthcheckState(client *docker.SharedClient, containerID string) *DockerHealthcheckState {
client.InterceptHTTPClient(interceptDockerInspectResponse) client.InterceptHTTPClient(interceptDockerInspectResponse)
return &DockerHealthcheckState{ return &DockerHealthcheckState{
client: client, client: client,
containerId: containerId, containerID: containerID,
numDockerFailures: 0, numDockerFailures: 0,
} }
} }
@@ -44,7 +46,7 @@ func Docker(ctx context.Context, state *DockerHealthcheckState, timeout time.Dur
defer cancel() defer cancel()
// the actual inspect response is intercepted and returned as RequestInterceptedError // the actual inspect response is intercepted and returned as RequestInterceptedError
_, err := state.client.ContainerInspect(ctx, state.containerId, client.ContainerInspectOptions{}) _, err := state.client.ContainerInspect(ctx, state.containerID, client.ContainerInspectOptions{})
var interceptedErr *httputils.RequestInterceptedError var interceptedErr *httputils.RequestInterceptedError
if !httputils.AsRequestInterceptedError(err, &interceptedErr) { if !httputils.AsRequestInterceptedError(err, &interceptedErr) {

View File

@@ -14,6 +14,7 @@ import (
config "github.com/yusing/godoxy/internal/config/types" config "github.com/yusing/godoxy/internal/config/types"
"github.com/yusing/godoxy/internal/notif" "github.com/yusing/godoxy/internal/notif"
"github.com/yusing/godoxy/internal/types" "github.com/yusing/godoxy/internal/types"
"github.com/yusing/goutils/events"
strutils "github.com/yusing/goutils/strings" strutils "github.com/yusing/goutils/strings"
"github.com/yusing/goutils/synk" "github.com/yusing/goutils/synk"
"github.com/yusing/goutils/task" "github.com/yusing/goutils/task"
@@ -269,6 +270,7 @@ func (mon *monitor) notifyServiceUp(logger *zerolog.Logger, result *types.Health
Body: extras, Body: extras,
Color: notif.ColorSuccess, Color: notif.ColorSuccess,
}) })
events.Global.Add(events.NewEvent(events.LevelInfo, "health", "service_up", mon))
} }
func (mon *monitor) notifyServiceDown(logger *zerolog.Logger, result *types.HealthCheckResult) { func (mon *monitor) notifyServiceDown(logger *zerolog.Logger, result *types.HealthCheckResult) {
@@ -281,6 +283,7 @@ func (mon *monitor) notifyServiceDown(logger *zerolog.Logger, result *types.Heal
Body: extras, Body: extras,
Color: notif.ColorError, Color: notif.ColorError,
}) })
events.Global.Add(events.NewEvent(events.LevelWarn, "health", "service_down", mon))
} }
func (mon *monitor) buildNotificationExtras(result *types.HealthCheckResult) notif.FieldsBody { func (mon *monitor) buildNotificationExtras(result *types.HealthCheckResult) notif.FieldsBody {

View File

@@ -2,9 +2,9 @@ package monitor
import ( import (
"errors" "errors"
"fmt"
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"time" "time"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@@ -14,8 +14,10 @@ import (
"github.com/yusing/godoxy/internal/types" "github.com/yusing/godoxy/internal/types"
) )
type Result = types.HealthCheckResult type (
type Monitor = types.HealthMonCheck Result = types.HealthCheckResult
Monitor = types.HealthMonCheck
)
// NewMonitor creates a health monitor based on the route type and configuration. // NewMonitor creates a health monitor based on the route type and configuration.
// //
@@ -78,22 +80,22 @@ func NewFileServerHealthMonitor(config types.HealthCheckConfig, path string) Mon
return &mon return &mon
} }
func NewStreamHealthMonitor(config types.HealthCheckConfig, targetUrl *url.URL) Monitor { func NewStreamHealthMonitor(config types.HealthCheckConfig, targetURL *url.URL) Monitor {
var mon monitor var mon monitor
mon.init(targetUrl, config, func(u *url.URL) (result Result, err error) { mon.init(targetURL, config, func(u *url.URL) (result Result, err error) {
return healthcheck.Stream(mon.Context(), u, config.Timeout) return healthcheck.Stream(mon.Context(), u, config.Timeout)
}) })
return &mon return &mon
} }
func NewDockerHealthMonitor(config types.HealthCheckConfig, client *docker.SharedClient, containerId string, fallback Monitor) Monitor { func NewDockerHealthMonitor(config types.HealthCheckConfig, client *docker.SharedClient, containerID string, fallback Monitor) Monitor {
state := healthcheck.NewDockerHealthcheckState(client, containerId) state := healthcheck.NewDockerHealthcheckState(client, containerID)
displayURL := &url.URL{ // only for display purposes, no actual request is made displayURL := &url.URL{ // only for display purposes, no actual request is made
Scheme: "docker", Scheme: "docker",
Host: client.DaemonHost(), Host: client.DaemonHost(),
Path: "/containers/" + containerId + "/json", Path: "/containers/" + containerID + "/json",
} }
logger := log.With().Str("host", client.DaemonHost()).Str("container_id", containerId).Logger() logger := log.With().Str("host", client.DaemonHost()).Str("container_id", containerID).Logger()
isFirstFailure := true isFirstFailure := true
var mon monitor var mon monitor
@@ -114,20 +116,20 @@ func NewDockerHealthMonitor(config types.HealthCheckConfig, client *docker.Share
return &mon return &mon
} }
func NewAgentProxiedMonitor(config types.HealthCheckConfig, agent *agentpool.Agent, targetUrl *url.URL) Monitor { func NewAgentProxiedMonitor(config types.HealthCheckConfig, agent *agentpool.Agent, targetURL *url.URL) Monitor {
var mon monitor var mon monitor
mon.init(targetUrl, config, func(u *url.URL) (result Result, err error) { mon.init(targetURL, config, func(u *url.URL) (result Result, err error) {
return CheckHealthAgentProxied(agent, config.Timeout, u) return CheckHealthAgentProxied(agent, config.Timeout, u)
}) })
return &mon return &mon
} }
func CheckHealthAgentProxied(agent *agentpool.Agent, timeout time.Duration, targetUrl *url.URL) (Result, error) { func CheckHealthAgentProxied(agent *agentpool.Agent, timeout time.Duration, targetURL *url.URL) (Result, error) {
query := url.Values{ query := url.Values{
"scheme": {targetUrl.Scheme}, "scheme": {targetURL.Scheme},
"host": {targetUrl.Host}, "host": {targetURL.Host},
"path": {targetUrl.Path}, "path": {targetURL.Path},
"timeout": {fmt.Sprintf("%d", timeout.Milliseconds())}, "timeout": {strconv.FormatInt(timeout.Milliseconds(), 10)},
} }
resp, err := agent.DoHealthCheck(timeout, query.Encode()) resp, err := agent.DoHealthCheck(timeout, query.Encode())
result := Result{ result := Result{

View File

@@ -137,11 +137,11 @@ func fetchIcon(ctx context.Context, filename string) (Result, error) {
for _, fileType := range []string{"svg", "webp", "png"} { for _, fileType := range []string{"svg", "webp", "png"} {
result, err := fetchKnownIcon(ctx, icons.NewURL(icons.SourceSelfhSt, filename, fileType)) result, err := fetchKnownIcon(ctx, icons.NewURL(icons.SourceSelfhSt, filename, fileType))
if err == nil { if err == nil {
return result, err return result, nil
} }
result, err = fetchKnownIcon(ctx, icons.NewURL(icons.SourceWalkXCode, filename, fileType)) result, err = fetchKnownIcon(ctx, icons.NewURL(icons.SourceWalkXCode, filename, fileType))
if err == nil { if err == nil {
return result, err return result, nil
} }
} }
return FetchResultWithErrorf(http.StatusNotFound, "no icon found") return FetchResultWithErrorf(http.StatusNotFound, "no icon found")
@@ -152,6 +152,8 @@ type contextValue struct {
uri string uri string
} }
type contextKey struct{}
func FindIcon(ctx context.Context, r route, uri string, variant icons.Variant) (Result, error) { func FindIcon(ctx context.Context, r route, uri string, variant icons.Variant) (Result, error) {
for _, ref := range r.References() { for _, ref := range r.References() {
ref = sanitizeName(ref) ref = sanitizeName(ref)
@@ -160,7 +162,7 @@ func FindIcon(ctx context.Context, r route, uri string, variant icons.Variant) (
} }
result, err := fetchIcon(ctx, ref) result, err := fetchIcon(ctx, ref)
if err == nil { if err == nil {
return result, err return result, nil
} }
} }
if r, ok := r.(httpRoute); ok { if r, ok := r.(httpRoute); ok {
@@ -168,13 +170,13 @@ func FindIcon(ctx context.Context, r route, uri string, variant icons.Variant) (
return FetchResultWithErrorf(http.StatusServiceUnavailable, "service unavailable") return FetchResultWithErrorf(http.StatusServiceUnavailable, "service unavailable")
} }
// fallback to parse html // fallback to parse html
return findIconSlowCached(context.WithValue(ctx, "route", contextValue{r: r, uri: uri}), r.Key()) return findIconSlowCached(context.WithValue(ctx, contextKey{}, contextValue{r: r, uri: uri}), r.Key())
} }
return FetchResultWithErrorf(http.StatusNotFound, "no icon found") return FetchResultWithErrorf(http.StatusNotFound, "no icon found")
} }
var findIconSlowCached = cache.NewKeyFunc(func(ctx context.Context, key string) (Result, error) { var findIconSlowCached = cache.NewKeyFunc(func(ctx context.Context, key string) (Result, error) {
v := ctx.Value("route").(contextValue) v := ctx.Value(contextKey{}).(contextValue)
return findIconSlow(ctx, v.r, v.uri, nil) return findIconSlow(ctx, v.r, v.uri, nil)
}).WithMaxEntries(200).WithRetriesConstantBackoff(math.MaxInt, 15*time.Second).Build() // infinite retries, 15 seconds interval }).WithMaxEntries(200).WithRetriesConstantBackoff(math.MaxInt, 15*time.Second).Build() // infinite retries, 15 seconds interval

View File

@@ -11,7 +11,7 @@ var _ nettypes.Stream = (*Watcher)(nil)
// ListenAndServe implements nettypes.Stream. // ListenAndServe implements nettypes.Stream.
func (w *Watcher) ListenAndServe(ctx context.Context, predial, onRead nettypes.HookFunc) error { func (w *Watcher) ListenAndServe(ctx context.Context, predial, onRead nettypes.HookFunc) error {
return w.stream.ListenAndServe(ctx, func(ctx context.Context) error { //nolint:contextcheck return w.stream.ListenAndServe(ctx, func(ctx context.Context) error {
return w.preDial(ctx, predial) return w.preDial(ctx, predial)
}, func(ctx context.Context) error { }, func(ctx context.Context) error {
return w.onRead(ctx, onRead) return w.onRead(ctx, onRead)

View File

@@ -10,7 +10,6 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/yusing/godoxy/internal/common"
"github.com/yusing/godoxy/internal/entrypoint" "github.com/yusing/godoxy/internal/entrypoint"
. "github.com/yusing/godoxy/internal/net/gphttp/middleware" . "github.com/yusing/godoxy/internal/net/gphttp/middleware"
"github.com/yusing/godoxy/internal/route" "github.com/yusing/godoxy/internal/route"
@@ -40,7 +39,7 @@ func TestBypassCIDR(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com", nil) req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
req.RemoteAddr = test.remoteAddr req.RemoteAddr = test.remoteAddr
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
mr.ModifyRequest(noOpHandler, recorder, req) mr.ModifyRequest(noOpHandler, recorder, req)
@@ -76,7 +75,7 @@ func TestBypassPath(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com"+test.path, nil) req := httptest.NewRequest(http.MethodGet, "http://example.com"+test.path, nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
mr.ModifyRequest(noOpHandler, recorder, req) mr.ModifyRequest(noOpHandler, recorder, req)
expect.NoError(t, err) expect.NoError(t, err)
@@ -126,7 +125,7 @@ func TestReverseProxyBypass(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com"+test.path, nil) req := httptest.NewRequest(http.MethodGet, "http://example.com"+test.path, nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
rp.ServeHTTP(recorder, req) rp.ServeHTTP(recorder, req)
if test.expectBypass { if test.expectBypass {
@@ -160,7 +159,7 @@ func TestBypassResponse(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com"+test.path, nil) req := httptest.NewRequest(http.MethodGet, "http://example.com"+test.path, nil)
resp := &http.Response{ resp := &http.Response{
StatusCode: http.StatusOK, StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("test")), Body: io.NopCloser(strings.NewReader("test")),
@@ -201,7 +200,7 @@ func TestBypassResponse(t *testing.T) {
StatusCode: test.statusCode, StatusCode: test.statusCode,
Body: io.NopCloser(strings.NewReader("test")), Body: io.NopCloser(strings.NewReader("test")),
Header: make(http.Header), Header: make(http.Header),
Request: httptest.NewRequest("GET", "http://example.com", nil), Request: httptest.NewRequest(http.MethodGet, "http://example.com", nil),
} }
mErr := mr.ModifyResponse(resp) mErr := mr.ModifyResponse(resp)
expect.NoError(t, mErr) expect.NoError(t, mErr)
@@ -232,10 +231,12 @@ func TestEntrypointBypassRoute(t *testing.T) {
entry := entrypoint.NewTestEntrypoint(t, nil) entry := entrypoint.NewTestEntrypoint(t, nil)
_, err = route.NewStartedTestRoute(t, &route.Route{ _, err = route.NewStartedTestRoute(t, &route.Route{
Alias: "test-route", Alias: "test-route",
Host: host, Scheme: routeTypes.SchemeHTTP,
Host: host,
Port: routeTypes.Port{ Port: routeTypes.Port{
Proxy: portInt, Listening: 1000,
Proxy: portInt,
}, },
}) })
expect.NoError(t, err) expect.NoError(t, err)
@@ -255,8 +256,8 @@ func TestEntrypointBypassRoute(t *testing.T) {
expect.NoError(t, err) expect.NoError(t, err)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
req := httptest.NewRequest("GET", "http://test-route.example.com", nil) req := httptest.NewRequest(http.MethodGet, "http://test-route.example.com", nil)
server, ok := entry.GetServer(common.ProxyHTTPAddr) server, ok := entry.GetServer(":1000")
if !ok { if !ok {
t.Fatal("server not found") t.Fatal("server not found")
} }

View File

@@ -105,7 +105,7 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*nettypes.CIDR) error {
return err return err
} }
resp, err := http.DefaultClient.Do(req) //nolint:gosec resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -3,6 +3,7 @@ package middleware
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@@ -48,7 +49,7 @@ func (m *crowdsecMiddleware) setup() {
func (m *crowdsecMiddleware) finalize() error { func (m *crowdsecMiddleware) finalize() error {
if !strings.HasPrefix(m.Endpoint, "/") { if !strings.HasPrefix(m.Endpoint, "/") {
return fmt.Errorf("endpoint must start with /") return errors.New("endpoint must start with /")
} }
if m.Timeout == 0 { if m.Timeout == 0 {
m.Timeout = 5 * time.Second m.Timeout = 5 * time.Second
@@ -179,12 +180,12 @@ func (m *crowdsecMiddleware) buildCrowdSecURL(ctx context.Context) (string, erro
// If not found in routes, assume it's an IP address // If not found in routes, assume it's an IP address
if m.Port == 0 { if m.Port == 0 {
return "", fmt.Errorf("port must be specified when using IP address") return "", errors.New("port must be specified when using IP address")
} }
return fmt.Sprintf("http://%s%s", net.JoinHostPort(m.Route, strconv.Itoa(m.Port)), m.Endpoint), nil return fmt.Sprintf("http://%s%s", net.JoinHostPort(m.Route, strconv.Itoa(m.Port)), m.Endpoint), nil
} }
return "", fmt.Errorf("route or IP address must be specified") return "", errors.New("route or IP address must be specified")
} }
func (m *crowdsecMiddleware) getHTTPVersion(r *http.Request) string { func (m *crowdsecMiddleware) getHTTPVersion(r *http.Request) string {

View File

@@ -59,7 +59,7 @@ func (client *GotifyClient) MarshalMessage(logMsg *LogMessage) ([]byte, error) {
} }
if client.Format == LogFormatMarkdown { if client.Format == LogFormatMarkdown {
msg.Extras = map[string]interface{}{ msg.Extras = map[string]any{
"client::display": map[string]string{ "client::display": map[string]string{
"contentType": "text/markdown", "contentType": "text/markdown",
}, },

View File

@@ -20,6 +20,7 @@ import (
type Client struct { type Client struct {
*proxmox.Client *proxmox.Client
*proxmox.Cluster *proxmox.Cluster
Version *proxmox.Version Version *proxmox.Version
BaseURL *url.URL BaseURL *url.URL
// id -> resource; id: lxc/<vmid> or qemu/<vmid> // id -> resource; id: lxc/<vmid> or qemu/<vmid>
@@ -29,6 +30,7 @@ type Client struct {
type VMResource struct { type VMResource struct {
*proxmox.ClusterResource *proxmox.ClusterResource
IPs []net.IP IPs []net.IP
} }
@@ -37,9 +39,9 @@ var (
ErrNoResources = errors.New("no resources") ErrNoResources = errors.New("no resources")
) )
func NewClient(baseUrl string, opts ...proxmox.Option) *Client { func NewClient(baseURL string, opts ...proxmox.Option) *Client {
return &Client{ return &Client{
Client: proxmox.NewClient(baseUrl, opts...), Client: proxmox.NewClient(baseURL, opts...),
resources: make(map[string]*VMResource), resources: make(map[string]*VMResource),
} }
} }

View File

@@ -109,7 +109,7 @@ func (n *Node) LXCIsStopped(ctx context.Context, vmid uint64) (bool, error) {
} }
func (n *Node) LXCSetShutdownTimeout(ctx context.Context, vmid uint64, timeout time.Duration) error { func (n *Node) LXCSetShutdownTimeout(ctx context.Context, vmid uint64, timeout time.Duration) error {
return n.client.Put(ctx, fmt.Sprintf("/nodes/%s/lxc/%d/config", n.name, vmid), map[string]interface{}{ return n.client.Put(ctx, fmt.Sprintf("/nodes/%s/lxc/%d/config", n.name, vmid), map[string]any{
"startup": fmt.Sprintf("down=%.0f", timeout.Seconds()), "startup": fmt.Sprintf("down=%.0f", timeout.Seconds()),
}, nil) }, nil)
} }

View File

@@ -91,7 +91,7 @@ func (p *Provider) GetType() provider.Type {
return p.t return p.t
} }
// to work with json marshaller. // MarshalText implements encoding.TextMarshaler.
func (p *Provider) MarshalText() ([]byte, error) { func (p *Provider) MarshalText() ([]byte, error) {
return []byte(p.String()), nil return []byte(p.String()), nil
} }

View File

@@ -57,7 +57,7 @@ type (
PathPatterns []string `json:"path_patterns,omitempty" extensions:"x-nullable"` PathPatterns []string `json:"path_patterns,omitempty" extensions:"x-nullable"`
Rules rules.Rules `json:"rules,omitempty" extensions:"x-nullable"` Rules rules.Rules `json:"rules,omitempty" extensions:"x-nullable"`
RuleFile string `json:"rule_file,omitempty" extensions:"x-nullable"` RuleFile string `json:"rule_file,omitempty" extensions:"x-nullable"`
HealthCheck types.HealthCheckConfig `json:"healthcheck,omitempty" extensions:"x-nullable"` // null on load-balancer routes HealthCheck types.HealthCheckConfig `json:"healthcheck,omitzero" extensions:"x-nullable"` // null on load-balancer routes
LoadBalance *types.LoadBalancerConfig `json:"load_balance,omitempty" extensions:"x-nullable"` LoadBalance *types.LoadBalancerConfig `json:"load_balance,omitempty" extensions:"x-nullable"`
Middlewares map[string]types.LabelMap `json:"middlewares,omitempty" extensions:"x-nullable"` Middlewares map[string]types.LabelMap `json:"middlewares,omitempty" extensions:"x-nullable"`
Homepage *homepage.ItemConfig `json:"homepage"` Homepage *homepage.ItemConfig `json:"homepage"`
@@ -276,10 +276,10 @@ func (r *Route) validate() error {
case route.SchemeFileServer: case route.SchemeFileServer:
r.Host = "" r.Host = ""
r.Port.Proxy = 0 r.Port.Proxy = 0
r.LisURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("https://%s", net.JoinHostPort(r.Bind, strconv.Itoa(r.Port.Listening)))) r.LisURL = gperr.Collect(&errs, nettypes.ParseURL, "https://"+net.JoinHostPort(r.Bind, strconv.Itoa(r.Port.Listening)))
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, "file://"+r.Root) r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, "file://"+r.Root)
case route.SchemeHTTP, route.SchemeHTTPS, route.SchemeH2C: case route.SchemeHTTP, route.SchemeHTTPS, route.SchemeH2C:
r.LisURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("https://%s", net.JoinHostPort(r.Bind, strconv.Itoa(r.Port.Listening)))) r.LisURL = gperr.Collect(&errs, nettypes.ParseURL, "https://"+net.JoinHostPort(r.Bind, strconv.Itoa(r.Port.Listening)))
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s", r.Scheme, net.JoinHostPort(r.Host, strconv.Itoa(r.Port.Proxy)))) r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s", r.Scheme, net.JoinHostPort(r.Host, strconv.Itoa(r.Port.Proxy))))
case route.SchemeTCP, route.SchemeUDP: case route.SchemeTCP, route.SchemeUDP:
bindIP := net.ParseIP(r.Bind) bindIP := net.ParseIP(r.Bind)
@@ -588,10 +588,9 @@ func (r *Route) References() []string {
return []string{r.Proxmox.VMName, aliasRef, r.Proxmox.Services[0]} return []string{r.Proxmox.VMName, aliasRef, r.Proxmox.Services[0]}
} }
return []string{r.Proxmox.Services[0], aliasRef} return []string{r.Proxmox.Services[0], aliasRef}
} else { }
if r.Proxmox.VMName != aliasRef { if r.Proxmox.VMName != aliasRef {
return []string{r.Proxmox.VMName, aliasRef} return []string{r.Proxmox.VMName, aliasRef}
}
} }
} }
return []string{aliasRef} return []string{aliasRef}

View File

@@ -76,6 +76,7 @@ var commands = map[string]struct {
if len(args) != 0 { if len(args) != 0 {
return nil, ErrExpectNoArg return nil, ErrExpectNoArg
} }
//nolint:nilnil
return nil, nil return nil, nil
}, },
build: func(args any) CommandHandler { build: func(args any) CommandHandler {
@@ -329,7 +330,7 @@ var commands = map[string]struct {
helpExample(CommandSet, "header", "User-Agent", "godoxy"), helpExample(CommandSet, "header", "User-Agent", "godoxy"),
), ),
args: map[string]string{ args: map[string]string{
"target": fmt.Sprintf("the target to set, can be %s", strings.Join(AllFields, ", ")), "target": "the target to set, can be " + strings.Join(AllFields, ", "),
"field": "the field to set", "field": "the field to set",
"value": "the value to set", "value": "the value to set",
}, },
@@ -349,7 +350,7 @@ var commands = map[string]struct {
helpExample(CommandAdd, "header", "X-Foo", "bar"), helpExample(CommandAdd, "header", "X-Foo", "bar"),
), ),
args: map[string]string{ args: map[string]string{
"target": fmt.Sprintf("the target to add, can be %s", strings.Join(AllFields, ", ")), "target": "the target to add, can be " + strings.Join(AllFields, ", "),
"field": "the field to add", "field": "the field to add",
"value": "the value to add", "value": "the value to add",
}, },
@@ -369,7 +370,7 @@ var commands = map[string]struct {
helpExample(CommandRemove, "header", "User-Agent"), helpExample(CommandRemove, "header", "User-Agent"),
), ),
args: map[string]string{ args: map[string]string{
"target": fmt.Sprintf("the target to remove, can be %s", strings.Join(AllFields, ", ")), "target": "the target to remove, can be " + strings.Join(AllFields, ", "),
"field": "the field to remove", "field": "the field to remove",
}, },
}, },
@@ -511,8 +512,10 @@ var commands = map[string]struct {
}, },
} }
type onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateString] type (
type onNotifyArgs = Tuple4[zerolog.Level, string, templateString, templateString] onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateString]
onNotifyArgs = Tuple4[zerolog.Level, string, templateString, templateString]
)
// Parse implements strutils.Parser. // Parse implements strutils.Parser.
func (cmd *Command) Parse(v string) error { func (cmd *Command) Parse(v string) error {

View File

@@ -53,7 +53,7 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("POST", "/api/users", nil) req := httptest.NewRequest(http.MethodPost, "/api/users", nil)
req.Header.Set("User-Agent", "test-agent") req.Header.Set("User-Agent", "test-agent")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -70,7 +70,7 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
} }
func TestLogCommand_StdoutAndStderr(t *testing.T) { func TestLogCommand_StdoutAndStderr(t *testing.T) {
upstream := mockUpstream(200, "success") upstream := mockUpstream(http.StatusOK, "success")
var rules Rules var rules Rules
err := parseRules(` err := parseRules(`
@@ -85,7 +85,7 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/test", nil) req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
@@ -96,7 +96,7 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) {
} }
func TestLogCommand_DifferentLogLevels(t *testing.T) { func TestLogCommand_DifferentLogLevels(t *testing.T) {
upstream := mockUpstream(404, "not found") upstream := mockUpstream(http.StatusNotFound, "not found")
infoFile := TestRandomFileName() infoFile := TestRandomFileName()
warnFile := TestRandomFileName() warnFile := TestRandomFileName()
@@ -140,7 +140,7 @@ func TestLogCommand_TemplateVariables(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Custom-Header", "custom-value") w.Header().Set("X-Custom-Header", "custom-value")
w.Header().Set("Content-Length", "42") w.Header().Set("Content-Length", "42")
w.WriteHeader(201) w.WriteHeader(http.StatusCreated)
w.Write([]byte("created")) w.Write([]byte("created"))
}) })
@@ -176,13 +176,13 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path { switch r.URL.Path {
case "/error": case "/error":
w.WriteHeader(500) w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("internal server error")) w.Write([]byte("internal server error"))
case "/notfound": case "/notfound":
w.WriteHeader(404) w.WriteHeader(http.StatusNotFound)
w.Write([]byte("not found")) w.Write([]byte("not found"))
default: default:
w.WriteHeader(200) w.WriteHeader(http.StatusOK)
w.Write([]byte("success")) w.Write([]byte("success"))
} }
}) })
@@ -206,22 +206,22 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
// Test success request // Test success request
req1 := httptest.NewRequest("GET", "/success", nil) req1 := httptest.NewRequest(http.MethodGet, "/success", nil)
w1 := httptest.NewRecorder() w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1) handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code) assert.Equal(t, http.StatusOK, w1.Code)
// Test not found request // Test not found request
req2 := httptest.NewRequest("GET", "/notfound", nil) req2 := httptest.NewRequest(http.MethodGet, "/notfound", nil)
w2 := httptest.NewRecorder() w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2) handler.ServeHTTP(w2, req2)
assert.Equal(t, 404, w2.Code) assert.Equal(t, http.StatusNotFound, w2.Code)
// Test server error request // Test server error request
req3 := httptest.NewRequest("POST", "/error", nil) req3 := httptest.NewRequest(http.MethodPost, "/error", nil)
w3 := httptest.NewRecorder() w3 := httptest.NewRecorder()
handler.ServeHTTP(w3, req3) handler.ServeHTTP(w3, req3)
assert.Equal(t, 500, w3.Code) assert.Equal(t, http.StatusInternalServerError, w3.Code)
// Verify success log // Verify success log
successContent := TestFileContent(successFile) successContent := TestFileContent(successFile)
@@ -238,7 +238,7 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
} }
func TestLogCommand_MultipleLogEntries(t *testing.T) { func TestLogCommand_MultipleLogEntries(t *testing.T) {
upstream := mockUpstream(200, "response") upstream := mockUpstream(http.StatusOK, "response")
tempFile := TestRandomFileName() tempFile := TestRandomFileName()
@@ -266,7 +266,7 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) {
req := httptest.NewRequest(reqInfo.method, reqInfo.path, nil) req := httptest.NewRequest(reqInfo.method, reqInfo.path, nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, http.StatusOK, w.Code)
} }
// Verify all requests were logged // Verify all requests were logged

View File

@@ -67,7 +67,7 @@ func TestFieldHandler_Header(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
tt.setup(req) tt.setup(req)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -126,8 +126,8 @@ func TestFieldHandler_ResponseHeader(t *testing.T) {
verify: func(w *httptest.ResponseRecorder) { verify: func(w *httptest.ResponseRecorder) {
values := w.Header()["X-Response-Test"] values := w.Header()["X-Response-Test"]
require.Len(t, values, 2) require.Len(t, values, 2)
assert.Equal(t, values[0], "existing-value") assert.Equal(t, "existing-value", values[0])
assert.Equal(t, values[1], "additional-value") assert.Equal(t, "additional-value", values[1])
}, },
}, },
{ {
@@ -143,7 +143,7 @@ func TestFieldHandler_ResponseHeader(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
if tt.setup != nil { if tt.setup != nil {
tt.setup(w) tt.setup(w)
@@ -232,7 +232,7 @@ func TestFieldHandler_Query(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
tt.setup(req) tt.setup(req)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -330,7 +330,7 @@ func TestFieldHandler_Cookie(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
tt.setup(req) tt.setup(req)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -396,7 +396,7 @@ func TestFieldHandler_Body(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
tt.setup(req) tt.setup(req)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -440,7 +440,7 @@ func TestFieldHandler_ResponseBody(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
tt.setup(req) tt.setup(req)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -494,7 +494,7 @@ func TestFieldHandler_StatusCode(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
rm := httputils.NewResponseModifier(w) rm := httputils.NewResponseModifier(w)
var cmd Command var cmd Command

View File

@@ -23,9 +23,8 @@ import (
) )
// mockUpstream creates a simple upstream handler for testing // mockUpstream creates a simple upstream handler for testing
func mockUpstream(status int, body string) http.HandlerFunc { func mockUpstream(body string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(status)
w.Write([]byte(body)) w.Write([]byte(body))
} }
} }
@@ -51,7 +50,7 @@ func parseRules(data string, target *Rules) error {
func TestHTTPFlow_BasicPreRules(t *testing.T) { func TestHTTPFlow_BasicPreRules(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Custom-Header", r.Header.Get("X-Custom-Header")) w.Header().Set("X-Custom-Header", r.Header.Get("X-Custom-Header"))
w.WriteHeader(200) w.WriteHeader(http.StatusOK)
w.Write([]byte("upstream response")) w.Write([]byte("upstream response"))
}) })
@@ -65,18 +64,18 @@ func TestHTTPFlow_BasicPreRules(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "upstream response", w.Body.String()) assert.Equal(t, "upstream response", w.Body.String())
assert.Equal(t, "test-value", w.Header().Get("X-Custom-Header")) assert.Equal(t, "test-value", w.Header().Get("X-Custom-Header"))
} }
func TestHTTPFlow_BypassRule(t *testing.T) { func TestHTTPFlow_BypassRule(t *testing.T) {
upstream := mockUpstream(200, "upstream response") upstream := mockUpstream("upstream response")
var rules Rules var rules Rules
err := parseRules(` err := parseRules(`
@@ -91,17 +90,17 @@ func TestHTTPFlow_BypassRule(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/bypass", nil) req := httptest.NewRequest(http.MethodGet, "/bypass", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "upstream response", w.Body.String()) assert.Equal(t, "upstream response", w.Body.String())
} }
func TestHTTPFlow_TerminatingCommand(t *testing.T) { func TestHTTPFlow_TerminatingCommand(t *testing.T) {
upstream := mockUpstream(200, "should not be called") upstream := mockUpstream("should not be called")
var rules Rules var rules Rules
err := parseRules(` err := parseRules(`
@@ -116,18 +115,18 @@ func TestHTTPFlow_TerminatingCommand(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/error", nil) req := httptest.NewRequest(http.MethodGet, "/error", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.Equal(t, 403, w.Code) assert.Equal(t, http.StatusForbidden, w.Code)
assert.Equal(t, "Forbidden\n", w.Body.String()) assert.Equal(t, "Forbidden\n", w.Body.String())
assert.Empty(t, w.Header().Get("X-Header")) assert.Empty(t, w.Header().Get("X-Header"))
} }
func TestHTTPFlow_RedirectFlow(t *testing.T) { func TestHTTPFlow_RedirectFlow(t *testing.T) {
upstream := mockUpstream(200, "should not be called") upstream := mockUpstream("should not be called")
var rules Rules var rules Rules
err := parseRules(` err := parseRules(`
@@ -139,18 +138,18 @@ func TestHTTPFlow_RedirectFlow(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/old-path", nil) req := httptest.NewRequest(http.MethodGet, "/old-path", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.Equal(t, 307, w.Code) // TemporaryRedirect assert.Equal(t, http.StatusTemporaryRedirect, w.Code) // TemporaryRedirect
assert.Equal(t, "/new-path", w.Header().Get("Location")) assert.Equal(t, "/new-path", w.Header().Get("Location"))
} }
func TestHTTPFlow_RewriteFlow(t *testing.T) { func TestHTTPFlow_RewriteFlow(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200) w.WriteHeader(http.StatusOK)
w.Write([]byte("path: " + r.URL.Path)) w.Write([]byte("path: " + r.URL.Path))
}) })
@@ -164,18 +163,18 @@ func TestHTTPFlow_RewriteFlow(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/api/users", nil) req := httptest.NewRequest(http.MethodGet, "/api/users", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "path: /v1/users", w.Body.String()) assert.Equal(t, "path: /v1/users", w.Body.String())
} }
func TestHTTPFlow_MultiplePreRules(t *testing.T) { func TestHTTPFlow_MultiplePreRules(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200) w.WriteHeader(http.StatusOK)
w.Write([]byte("upstream: " + r.Header.Get("X-Request-Id"))) w.Write([]byte("upstream: " + r.Header.Get("X-Request-Id")))
}) })
@@ -192,18 +191,18 @@ func TestHTTPFlow_MultiplePreRules(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "upstream: req-123", w.Body.String()) assert.Equal(t, "upstream: req-123", w.Body.String())
assert.Equal(t, "token-456", req.Header.Get("X-Auth-Token")) assert.Equal(t, "token-456", req.Header.Get("X-Auth-Token"))
} }
func TestHTTPFlow_PostResponseRule(t *testing.T) { func TestHTTPFlow_PostResponseRule(t *testing.T) {
upstream := mockUpstreamWithHeaders(200, "success", http.Header{ upstream := mockUpstreamWithHeaders(http.StatusOK, "success", http.Header{
"X-Upstream": []string{"upstream-value"}, "X-Upstream": []string{"upstream-value"},
}) })
@@ -219,12 +218,12 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/test", nil) req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "success", w.Body.String()) assert.Equal(t, "success", w.Body.String())
assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream")) assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream"))
@@ -237,10 +236,10 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) {
func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) { func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/success" { if r.URL.Path == "/success" {
w.WriteHeader(200) w.WriteHeader(http.StatusOK)
w.Write([]byte("success")) w.Write([]byte("success"))
} else { } else {
w.WriteHeader(404) w.WriteHeader(http.StatusNotFound)
w.Write([]byte("not found")) w.Write([]byte("not found"))
} }
}) })
@@ -260,18 +259,18 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
// Test successful request (should not log) // Test successful request (should not log)
req1 := httptest.NewRequest("GET", "/success", nil) req1 := httptest.NewRequest(http.MethodGet, "/success", nil)
w1 := httptest.NewRecorder() w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1) handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code) assert.Equal(t, http.StatusOK, w1.Code)
// Test error request (should log) // Test error request (should log)
req2 := httptest.NewRequest("GET", "/notfound", nil) req2 := httptest.NewRequest(http.MethodGet, "/notfound", nil)
w2 := httptest.NewRecorder() w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2) handler.ServeHTTP(w2, req2)
assert.Equal(t, 404, w2.Code) assert.Equal(t, http.StatusNotFound, w2.Code)
// Check log file // Check log file
content := TestFileContent(tempFile) content := TestFileContent(tempFile)
@@ -283,7 +282,7 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
func TestHTTPFlow_ConditionalRules(t *testing.T) { func TestHTTPFlow_ConditionalRules(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200) w.WriteHeader(http.StatusOK)
w.Write([]byte("hello " + r.Header.Get("X-Username"))) w.Write([]byte("hello " + r.Header.Get("X-Username")))
}) })
@@ -304,19 +303,19 @@ func TestHTTPFlow_ConditionalRules(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
// Test with Authorization header // Test with Authorization header
req1 := httptest.NewRequest("GET", "/", nil) req1 := httptest.NewRequest(http.MethodGet, "/", nil)
req1.Header.Set("Authorization", "Bearer token") req1.Header.Set("Authorization", "Bearer token")
w1 := httptest.NewRecorder() w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1) handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code) assert.Equal(t, http.StatusOK, w1.Code)
assert.Equal(t, "hello authenticated-user", w1.Body.String()) assert.Equal(t, "hello authenticated-user", w1.Body.String())
assert.Equal(t, "authenticated-user", w1.Header().Get("X-Username")) assert.Equal(t, "authenticated-user", w1.Header().Get("X-Username"))
// Test without Authorization header // Test without Authorization header
req2 := httptest.NewRequest("GET", "/", nil) req2 := httptest.NewRequest(http.MethodGet, "/", nil)
w2 := httptest.NewRecorder() w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2) handler.ServeHTTP(w2, req2)
assert.Equal(t, 200, w2.Code) assert.Equal(t, http.StatusOK, w2.Code)
assert.Equal(t, "hello anonymous", w2.Body.String()) assert.Equal(t, "hello anonymous", w2.Body.String())
assert.Equal(t, "anonymous", w2.Header().Get("X-Username")) assert.Equal(t, "anonymous", w2.Header().Get("X-Username"))
} }
@@ -326,13 +325,13 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
// Simulate different responses based on path // Simulate different responses based on path
if r.URL.Path == "/protected" { if r.URL.Path == "/protected" {
if r.Header.Get("X-Auth") != "valid" { if r.Header.Get("X-Auth") != "valid" {
w.WriteHeader(401) w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("unauthorized")) w.Write([]byte("unauthorized"))
return return
} }
} }
w.Header().Set("X-Response-Time", "100ms") w.Header().Set("X-Response-Time", "100ms")
w.WriteHeader(200) w.WriteHeader(http.StatusOK)
w.Write([]byte("success")) w.Write([]byte("success"))
}) })
@@ -360,32 +359,32 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
// Test successful request // Test successful request
req1 := httptest.NewRequest("GET", "/public", nil) req1 := httptest.NewRequest(http.MethodGet, "/public", nil)
w1 := httptest.NewRecorder() w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1) handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code) assert.Equal(t, http.StatusOK, w1.Code)
assert.Equal(t, "success", w1.Body.String()) assert.Equal(t, "success", w1.Body.String())
assert.Equal(t, "random_uuid", w1.Header().Get("X-Correlation-Id")) assert.Equal(t, "random_uuid", w1.Header().Get("X-Correlation-Id"))
assert.Equal(t, "100ms", w1.Header().Get("X-Response-Time")) assert.Equal(t, "100ms", w1.Header().Get("X-Response-Time"))
// Test unauthorized protected request // Test unauthorized protected request
req2 := httptest.NewRequest("GET", "/protected", nil) req2 := httptest.NewRequest(http.MethodGet, "/protected", nil)
w2 := httptest.NewRecorder() w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2) handler.ServeHTTP(w2, req2)
assert.Equal(t, 401, w2.Code) assert.Equal(t, http.StatusUnauthorized, w2.Code)
assert.Equal(t, w2.Body.String(), "Unauthorized\n") assert.Equal(t, "Unauthorized\n", w2.Body.String())
// Test authorized protected request // Test authorized protected request
req3 := httptest.NewRequest("GET", "/protected", nil) req3 := httptest.NewRequest(http.MethodGet, "/protected", nil)
req3.SetBasicAuth("user", "pass") req3.SetBasicAuth("user", "pass")
w3 := httptest.NewRecorder() w3 := httptest.NewRecorder()
handler.ServeHTTP(w3, req3) handler.ServeHTTP(w3, req3)
// This should fail because our simple upstream expects X-Auth: valid header // This should fail because our simple upstream expects X-Auth: valid header
// but the basic auth requirement should add the appropriate header // but the basic auth requirement should add the appropriate header
assert.Equal(t, 401, w3.Code) assert.Equal(t, http.StatusUnauthorized, w3.Code)
// Check log files // Check log files
logContent := TestFileContent(logFile) logContent := TestFileContent(logFile)
@@ -404,7 +403,7 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
} }
func TestHTTPFlow_DefaultRule(t *testing.T) { func TestHTTPFlow_DefaultRule(t *testing.T) {
upstream := mockUpstream(200, "upstream response") upstream := mockUpstream("upstream response")
var rules Rules var rules Rules
err := parseRules(` err := parseRules(`
@@ -419,20 +418,20 @@ func TestHTTPFlow_DefaultRule(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
// Test default rule // Test default rule
req1 := httptest.NewRequest("GET", "/regular", nil) req1 := httptest.NewRequest(http.MethodGet, "/regular", nil)
w1 := httptest.NewRecorder() w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1) handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code) assert.Equal(t, http.StatusOK, w1.Code)
assert.Equal(t, "true", w1.Header().Get("X-Default-Applied")) assert.Equal(t, "true", w1.Header().Get("X-Default-Applied"))
assert.Empty(t, w1.Header().Get("X-Special-Handled")) assert.Empty(t, w1.Header().Get("X-Special-Handled"))
// Test special rule + default rule // Test special rule + default rule
req2 := httptest.NewRequest("GET", "/special", nil) req2 := httptest.NewRequest(http.MethodGet, "/special", nil)
w2 := httptest.NewRecorder() w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2) handler.ServeHTTP(w2, req2)
assert.Equal(t, 200, w2.Code) assert.Equal(t, http.StatusOK, w2.Code)
assert.Equal(t, "true", w2.Header().Get("X-Default-Applied")) assert.Equal(t, "true", w2.Header().Get("X-Default-Applied"))
assert.Equal(t, "true", w2.Header().Get("X-Special-Handled")) assert.Equal(t, "true", w2.Header().Get("X-Special-Handled"))
} }
@@ -442,7 +441,7 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) {
// Echo back a header // Echo back a header
headerValue := r.Header.Get("X-Test-Header") headerValue := r.Header.Get("X-Test-Header")
w.Header().Set("X-Echoed-Header", headerValue) w.Header().Set("X-Echoed-Header", headerValue)
w.WriteHeader(200) w.WriteHeader(http.StatusOK)
w.Write([]byte("header echoed")) w.Write([]byte("header echoed"))
}) })
@@ -460,14 +459,14 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Secret", "secret-value") req.Header.Set("X-Secret", "secret-value")
req.Header.Set("X-Test-Header", "original-value") req.Header.Set("X-Test-Header", "original-value")
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "modified-value", w.Header().Get("X-Echoed-Header")) assert.Equal(t, "modified-value", w.Header().Get("X-Echoed-Header"))
assert.Equal(t, "custom-value", w.Header().Get("X-Custom-Header")) assert.Equal(t, "custom-value", w.Header().Get("X-Custom-Header"))
// Ensure the secret header was removed and not passed to upstream // Ensure the secret header was removed and not passed to upstream
@@ -477,7 +476,7 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) {
func TestHTTPFlow_QueryParameterHandling(t *testing.T) { func TestHTTPFlow_QueryParameterHandling(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query() query := r.URL.Query()
w.WriteHeader(200) w.WriteHeader(http.StatusOK)
w.Write([]byte("query: " + query.Get("param"))) w.Write([]byte("query: " + query.Get("param")))
}) })
@@ -491,25 +490,23 @@ func TestHTTPFlow_QueryParameterHandling(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/path?param=original", nil) req := httptest.NewRequest(http.MethodGet, "/path?param=original", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, http.StatusOK, w.Code)
// The set command should have modified the query parameter // The set command should have modified the query parameter
assert.Equal(t, "query: added-value", w.Body.String()) assert.Equal(t, "query: added-value", w.Body.String())
} }
func TestHTTPFlow_ServeCommand(t *testing.T) { func TestHTTPFlow_ServeCommand(t *testing.T) {
// Create a temporary directory with test files // Create a temporary directory with test files
tempDir, err := os.MkdirTemp("", "test-serve-*") tempDir := t.TempDir()
require.NoError(t, err)
defer os.RemoveAll(tempDir)
// Create test files directly in the temp directory // Create test files directly in the temp directory
testFile := filepath.Join(tempDir, "index.html") testFile := filepath.Join(tempDir, "index.html")
err = os.WriteFile(testFile, []byte("<h1>Test Page</h1>"), 0644) err := os.WriteFile(testFile, []byte("<h1>Test Page</h1>"), 0o644)
require.NoError(t, err) require.NoError(t, err)
var rules Rules var rules Rules
@@ -520,7 +517,7 @@ func TestHTTPFlow_ServeCommand(t *testing.T) {
`, tempDir), &rules) `, tempDir), &rules)
require.NoError(t, err) require.NoError(t, err)
handler := rules.BuildHandler(mockUpstream(200, "should not be called")) handler := rules.BuildHandler(mockUpstream("should not be called"))
// Test serving a file - serve command serves files relative to the root directory // Test serving a file - serve command serves files relative to the root directory
// The path /files/index.html gets mapped to tempDir + "/files/index.html" // The path /files/index.html gets mapped to tempDir + "/files/index.html"
@@ -533,7 +530,7 @@ func TestHTTPFlow_ServeCommand(t *testing.T) {
err = os.WriteFile(filesIndexFile, []byte("<h1>Test Page</h1>"), 0644) err = os.WriteFile(filesIndexFile, []byte("<h1>Test Page</h1>"), 0644)
require.NoError(t, err) require.NoError(t, err)
req1 := httptest.NewRequest("GET", "/files/index.html", nil) req1 := httptest.NewRequest(http.MethodGet, "/files/index.html", nil)
w1 := httptest.NewRecorder() w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1) handler.ServeHTTP(w1, req1)
@@ -542,18 +539,18 @@ func TestHTTPFlow_ServeCommand(t *testing.T) {
assert.NotEqual(t, "should not be called", w1.Body.String()) assert.NotEqual(t, "should not be called", w1.Body.String())
// Test file not found // Test file not found
req2 := httptest.NewRequest("GET", "/files/nonexistent.html", nil) req2 := httptest.NewRequest(http.MethodGet, "/files/nonexistent.html", nil)
w2 := httptest.NewRecorder() w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2) handler.ServeHTTP(w2, req2)
assert.Equal(t, 404, w2.Code) assert.Equal(t, http.StatusNotFound, w2.Code)
} }
func TestHTTPFlow_ProxyCommand(t *testing.T) { func TestHTTPFlow_ProxyCommand(t *testing.T) {
// Create a mock upstream server // Create a mock upstream server
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Upstream-Header", "upstream-value") w.Header().Set("X-Upstream-Header", "upstream-value")
w.WriteHeader(200) w.WriteHeader(http.StatusOK)
w.Write([]byte("upstream response")) w.Write([]byte("upstream response"))
})) }))
defer upstreamServer.Close() defer upstreamServer.Close()
@@ -566,15 +563,15 @@ func TestHTTPFlow_ProxyCommand(t *testing.T) {
`, upstreamServer.URL), &rules) `, upstreamServer.URL), &rules)
require.NoError(t, err) require.NoError(t, err)
handler := rules.BuildHandler(mockUpstream(200, "should not be called")) handler := rules.BuildHandler(mockUpstream("should not be called"))
req := httptest.NewRequest("GET", "/api/test", nil) req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
// The proxy command should forward the request to the upstream server // The proxy command should forward the request to the upstream server
assert.Equal(t, 200, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "upstream response", w.Body.String()) assert.Equal(t, "upstream response", w.Body.String())
assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream-Header")) assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream-Header"))
} }
@@ -585,7 +582,7 @@ func TestHTTPFlow_NotifyCommand(t *testing.T) {
func TestHTTPFlow_FormConditions(t *testing.T) { func TestHTTPFlow_FormConditions(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200) w.WriteHeader(http.StatusOK)
w.Write([]byte("form processed")) w.Write([]byte("form processed"))
}) })
@@ -604,28 +601,28 @@ func TestHTTPFlow_FormConditions(t *testing.T) {
// Test form condition // Test form condition
formData := url.Values{"username": {"john_doe"}} formData := url.Values{"username": {"john_doe"}}
req1 := httptest.NewRequest("POST", "/", strings.NewReader(formData.Encode())) req1 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(formData.Encode()))
req1.Header.Set("Content-Type", "application/x-www-form-urlencoded") req1.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w1 := httptest.NewRecorder() w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1) handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code) assert.Equal(t, http.StatusOK, w1.Code)
assert.Equal(t, "john_doe", w1.Header().Get("X-Username")) assert.Equal(t, "john_doe", w1.Header().Get("X-Username"))
// Test postform condition // Test postform condition
postFormData := url.Values{"email": {"john@example.com"}} postFormData := url.Values{"email": {"john@example.com"}}
req2 := httptest.NewRequest("POST", "/", strings.NewReader(postFormData.Encode())) req2 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(postFormData.Encode()))
req2.Header.Set("Content-Type", "application/x-www-form-urlencoded") req2.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w2 := httptest.NewRecorder() w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2) handler.ServeHTTP(w2, req2)
assert.Equal(t, 200, w2.Code) assert.Equal(t, http.StatusOK, w2.Code)
assert.Equal(t, "john@example.com", w2.Header().Get("X-Email")) assert.Equal(t, "john@example.com", w2.Header().Get("X-Email"))
} }
func TestHTTPFlow_RemoteConditions(t *testing.T) { func TestHTTPFlow_RemoteConditions(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200) w.WriteHeader(http.StatusOK)
w.Write([]byte("remote processed")) w.Write([]byte("remote processed"))
}) })
@@ -643,27 +640,27 @@ func TestHTTPFlow_RemoteConditions(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
// Test localhost condition // Test localhost condition
req1 := httptest.NewRequest("GET", "/", nil) req1 := httptest.NewRequest(http.MethodGet, "/", nil)
req1.RemoteAddr = "127.0.0.1:12345" req1.RemoteAddr = "127.0.0.1:12345"
w1 := httptest.NewRecorder() w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1) handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code) assert.Equal(t, http.StatusOK, w1.Code)
assert.Equal(t, "local", w1.Header().Get("X-Access")) assert.Equal(t, "local", w1.Header().Get("X-Access"))
// Test private network block // Test private network block
req2 := httptest.NewRequest("GET", "/", nil) req2 := httptest.NewRequest(http.MethodGet, "/", nil)
req2.RemoteAddr = "192.168.1.100:12345" req2.RemoteAddr = "192.168.1.100:12345"
w2 := httptest.NewRecorder() w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2) handler.ServeHTTP(w2, req2)
assert.Equal(t, 403, w2.Code) assert.Equal(t, http.StatusForbidden, w2.Code)
assert.Equal(t, "Private network blocked\n", w2.Body.String()) assert.Equal(t, "Private network blocked\n", w2.Body.String())
} }
func TestHTTPFlow_BasicAuthConditions(t *testing.T) { func TestHTTPFlow_BasicAuthConditions(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200) w.WriteHeader(http.StatusOK)
w.Write([]byte("auth processed")) w.Write([]byte("auth processed"))
}) })
@@ -687,27 +684,27 @@ func TestHTTPFlow_BasicAuthConditions(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
// Test admin user // Test admin user
req1 := httptest.NewRequest("GET", "/", nil) req1 := httptest.NewRequest(http.MethodGet, "/", nil)
req1.SetBasicAuth("admin", "adminpass") req1.SetBasicAuth("admin", "adminpass")
w1 := httptest.NewRecorder() w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1) handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code) assert.Equal(t, http.StatusOK, w1.Code)
assert.Equal(t, "admin", w1.Header().Get("X-Auth-Status")) assert.Equal(t, "admin", w1.Header().Get("X-Auth-Status"))
// Test guest user // Test guest user
req2 := httptest.NewRequest("GET", "/", nil) req2 := httptest.NewRequest(http.MethodGet, "/", nil)
req2.SetBasicAuth("guest", "guestpass") req2.SetBasicAuth("guest", "guestpass")
w2 := httptest.NewRecorder() w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2) handler.ServeHTTP(w2, req2)
assert.Equal(t, 200, w2.Code) assert.Equal(t, http.StatusOK, w2.Code)
assert.Equal(t, "guest", w2.Header().Get("X-Auth-Status")) assert.Equal(t, "guest", w2.Header().Get("X-Auth-Status"))
} }
func TestHTTPFlow_RouteConditions(t *testing.T) { func TestHTTPFlow_RouteConditions(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200) w.WriteHeader(http.StatusOK)
w.Write([]byte("route processed")) w.Write([]byte("route processed"))
}) })
@@ -725,29 +722,29 @@ func TestHTTPFlow_RouteConditions(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
// Test API route // Test API route
req1 := httptest.NewRequest("GET", "/", nil) req1 := httptest.NewRequest(http.MethodGet, "/", nil)
req1 = routes.WithRouteContext(req1, mockRoute("backend")) req1 = routes.WithRouteContext(req1, mockRoute("backend"))
w1 := httptest.NewRecorder() w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1) handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code) assert.Equal(t, http.StatusOK, w1.Code)
assert.Equal(t, "backend", w1.Header().Get("X-Route")) assert.Equal(t, "backend", w1.Header().Get("X-Route"))
// Test admin route // Test admin route
req2 := httptest.NewRequest("GET", "/", nil) req2 := httptest.NewRequest(http.MethodGet, "/", nil)
req2 = routes.WithRouteContext(req2, mockRoute("frontend")) req2 = routes.WithRouteContext(req2, mockRoute("frontend"))
w2 := httptest.NewRecorder() w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2) handler.ServeHTTP(w2, req2)
assert.Equal(t, 200, w2.Code) assert.Equal(t, http.StatusOK, w2.Code)
assert.Equal(t, "frontend", w2.Header().Get("X-Route")) assert.Equal(t, "frontend", w2.Header().Get("X-Route"))
} }
func TestHTTPFlow_ResponseStatusConditions(t *testing.T) { func TestHTTPFlow_ResponseStatusConditions(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(405) w.WriteHeader(http.StatusMethodNotAllowed)
w.Write([]byte("method not allowed")) w.Write([]byte("method not allowed"))
}) })
@@ -762,18 +759,18 @@ func TestHTTPFlow_ResponseStatusConditions(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.Equal(t, 405, w.Code) assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
assert.Equal(t, "error\n", w.Body.String()) assert.Equal(t, "error\n", w.Body.String())
} }
func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) { func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Response-Header", "response header") w.Header().Set("X-Response-Header", "response header")
w.WriteHeader(200) w.WriteHeader(http.StatusOK)
w.Write([]byte("processed")) w.Write([]byte("processed"))
}) })
@@ -788,11 +785,11 @@ func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.Equal(t, 405, w.Code) assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
assert.Equal(t, "error\n", w.Body.String()) assert.Equal(t, "error\n", w.Body.String())
}) })
t.Run("with_value", func(t *testing.T) { t.Run("with_value", func(t *testing.T) {
@@ -806,11 +803,11 @@ func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.Equal(t, 405, w.Code) assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
assert.Equal(t, "error\n", w.Body.String()) assert.Equal(t, "error\n", w.Body.String())
}) })
@@ -825,18 +822,18 @@ func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "processed", w.Body.String()) assert.Equal(t, "processed", w.Body.String())
}) })
} }
func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) { func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200) w.WriteHeader(http.StatusOK)
w.Write([]byte("complex processed")) w.Write([]byte("complex processed"))
}) })
@@ -867,26 +864,26 @@ func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
// Test admin API (should match first rule) // Test admin API (should match first rule)
req1 := httptest.NewRequest("POST", "/api/admin/users", nil) req1 := httptest.NewRequest(http.MethodPost, "/api/admin/users", nil)
req1.Header.Set("Authorization", "Bearer token") req1.Header.Set("Authorization", "Bearer token")
w1 := httptest.NewRecorder() w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1) handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code) assert.Equal(t, http.StatusOK, w1.Code)
assert.Equal(t, "admin", w1.Header().Get("X-Access-Level")) assert.Equal(t, "admin", w1.Header().Get("X-Access-Level"))
assert.Equal(t, "v1", w1.Header()["X-API-Version"][0]) assert.Equal(t, "v1", w1.Header()["X-API-Version"][0])
// Test user API (should match second rule) // Test user API (should match second rule)
req2 := httptest.NewRequest("GET", "/api/users/profile", nil) req2 := httptest.NewRequest(http.MethodGet, "/api/users/profile", nil)
w2 := httptest.NewRecorder() w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2) handler.ServeHTTP(w2, req2)
assert.Equal(t, 200, w2.Code) assert.Equal(t, http.StatusOK, w2.Code)
assert.Equal(t, "user", w2.Header().Get("X-Access-Level")) assert.Equal(t, "user", w2.Header().Get("X-Access-Level"))
assert.Equal(t, "v1", w2.Header()["X-API-Version"][0]) assert.Equal(t, "v1", w2.Header()["X-API-Version"][0])
// Test public API (should match third rule) // Test public API (should match third rule)
req3 := httptest.NewRequest("GET", "/api/public/info", nil) req3 := httptest.NewRequest(http.MethodGet, "/api/public/info", nil)
w3 := httptest.NewRecorder() w3 := httptest.NewRecorder()
handler.ServeHTTP(w3, req3) handler.ServeHTTP(w3, req3)
@@ -897,7 +894,7 @@ func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
func TestHTTPFlow_ResponseModifier(t *testing.T) { func TestHTTPFlow_ResponseModifier(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200) w.WriteHeader(http.StatusOK)
w.Write([]byte("original response")) w.Write([]byte("original response"))
}) })
@@ -912,12 +909,12 @@ func TestHTTPFlow_ResponseModifier(t *testing.T) {
handler := rules.BuildHandler(upstream) handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/test", nil) req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "true", w.Header().Get("X-Modified")) assert.Equal(t, "true", w.Header().Get("X-Modified"))
assert.Equal(t, "Modified: GET /test\n", w.Body.String()) assert.Equal(t, "Modified: GET /test\n", w.Body.String())
} }

View File

@@ -41,6 +41,7 @@ const (
OnRoute = "route" OnRoute = "route"
// on response // on response
OnResponseHeader = "resp_header" OnResponseHeader = "resp_header"
OnStatus = "status" OnStatus = "status"
) )
@@ -63,6 +64,7 @@ var checkers = map[string]struct {
if len(args) != 0 { if len(args) != 0 {
return nil, ErrExpectNoArg return nil, ErrExpectNoArg
} }
//nolint:nilnil
return nil, nil return nil, nil
}, },
builder: func(args any) CheckFunc { return func(w http.ResponseWriter, r *http.Request) bool { return false } }, // this should never be called builder: func(args any) CheckFunc { return func(w http.ResponseWriter, r *http.Request) bool { return false } }, // this should never be called

View File

@@ -1,11 +1,9 @@
package rules package rules
import ( import (
"os"
"strconv" "strconv"
"testing" "testing"
gperr "github.com/yusing/goutils/errs"
expect "github.com/yusing/goutils/testing" expect "github.com/yusing/goutils/testing"
) )
@@ -15,7 +13,6 @@ func TestParser(t *testing.T) {
input string input string
subject string subject string
args []string args []string
wantErr gperr.Error
}{ }{
{ {
name: "basic", name: "basic",
@@ -93,10 +90,6 @@ func TestParser(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
subject, args, err := parse(tt.input) subject, args, err := parse(tt.input)
if tt.wantErr != nil {
expect.ErrorIs(t, tt.wantErr, err)
return
}
// t.Log(subject, args, err) // t.Log(subject, args, err)
expect.NoError(t, err) expect.NoError(t, err)
expect.Equal(t, subject, tt.subject) expect.Equal(t, subject, tt.subject)
@@ -105,12 +98,8 @@ func TestParser(t *testing.T) {
} }
t.Run("env substitution", func(t *testing.T) { t.Run("env substitution", func(t *testing.T) {
// Set up test environment variables // Set up test environment variables
os.Setenv("CLOUDFLARE_API_KEY", "test-api-key-123") t.Setenv("CLOUDFLARE_API_KEY", "test-api-key-123")
os.Setenv("DOMAIN", "example.com") t.Setenv("DOMAIN", "example.com")
defer func() {
os.Unsetenv("CLOUDFLARE_API_KEY")
os.Unsetenv("DOMAIN")
}()
tests := []struct { tests := []struct {
name string name string

View File

@@ -2,6 +2,7 @@ package rules
import ( import (
"io" "io"
"net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"testing" "testing"
@@ -11,9 +12,9 @@ import (
func BenchmarkExpandVars(b *testing.B) { func BenchmarkExpandVars(b *testing.B) {
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
testResponseModifier.WriteHeader(200) testResponseModifier.WriteHeader(http.StatusOK)
testResponseModifier.Write([]byte("Hello, world!")) testResponseModifier.Write([]byte("Hello, world!"))
testRequest := httptest.NewRequest("GET", "/", nil) testRequest := httptest.NewRequest(http.MethodGet, "/", nil)
testRequest.Header.Set("User-Agent", "test-agent/1.0") testRequest.Header.Set("User-Agent", "test-agent/1.0")
testRequest.Header.Set("X-Custom", "value1,value2") testRequest.Header.Set("X-Custom", "value1,value2")
testRequest.ContentLength = 12345 testRequest.ContentLength = 12345

View File

@@ -203,7 +203,7 @@ func TestExpandVars(t *testing.T) {
postFormData.Add("postmulti", "first") postFormData.Add("postmulti", "first")
postFormData.Add("postmulti", "second") postFormData.Add("postmulti", "second")
testRequest := httptest.NewRequest("POST", "https://example.com:8080/api/users?param1=value1&param2=value2#fragment", strings.NewReader(postFormData.Encode())) testRequest := httptest.NewRequest(http.MethodPost, "https://example.com:8080/api/users?param1=value1&param2=value2#fragment", strings.NewReader(postFormData.Encode()))
testRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded") testRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded")
testRequest.Header.Set("User-Agent", "test-agent/1.0") testRequest.Header.Set("User-Agent", "test-agent/1.0")
testRequest.Header.Add("X-Custom", "value1") testRequest.Header.Add("X-Custom", "value1")
@@ -218,7 +218,7 @@ func TestExpandVars(t *testing.T) {
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
testResponseModifier.Header().Set("Content-Type", "text/html") testResponseModifier.Header().Set("Content-Type", "text/html")
testResponseModifier.Header().Set("X-Custom-Resp", "resp-value") testResponseModifier.Header().Set("X-Custom-Resp", "resp-value")
testResponseModifier.WriteHeader(200) testResponseModifier.WriteHeader(http.StatusOK)
// set content length to 9876 by writing 9876 'a' bytes // set content length to 9876 by writing 9876 'a' bytes
testResponseModifier.Write(bytes.Repeat([]byte("a"), 9876)) testResponseModifier.Write(bytes.Repeat([]byte("a"), 9876))
@@ -498,12 +498,12 @@ func TestExpandVars(t *testing.T) {
func TestExpandVars_Integration(t *testing.T) { func TestExpandVars_Integration(t *testing.T) {
t.Run("complex log format", func(t *testing.T) { t.Run("complex log format", func(t *testing.T) {
testRequest := httptest.NewRequest("GET", "https://api.example.com/users/123?sort=asc", nil) testRequest := httptest.NewRequest(http.MethodGet, "https://api.example.com/users/123?sort=asc", nil)
testRequest.Header.Set("User-Agent", "curl/7.68.0") testRequest.Header.Set("User-Agent", "curl/7.68.0")
testRequest.RemoteAddr = "10.0.0.1:54321" testRequest.RemoteAddr = "10.0.0.1:54321"
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
testResponseModifier.WriteHeader(200) testResponseModifier.WriteHeader(http.StatusOK)
var out strings.Builder var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, err := ExpandVars(testResponseModifier, testRequest,
@@ -515,7 +515,7 @@ func TestExpandVars_Integration(t *testing.T) {
}) })
t.Run("with query parameters", func(t *testing.T) { t.Run("with query parameters", func(t *testing.T) {
testRequest := httptest.NewRequest("GET", "http://example.com/search?q=test&page=1", nil) testRequest := httptest.NewRequest(http.MethodGet, "http://example.com/search?q=test&page=1", nil)
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
@@ -529,12 +529,12 @@ func TestExpandVars_Integration(t *testing.T) {
}) })
t.Run("response headers", func(t *testing.T) { t.Run("response headers", func(t *testing.T) {
testRequest := httptest.NewRequest("GET", "/", nil) testRequest := httptest.NewRequest(http.MethodGet, "/", nil)
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
testResponseModifier.Header().Set("Cache-Control", "no-cache") testResponseModifier.Header().Set("Cache-Control", "no-cache")
testResponseModifier.Header().Set("X-Rate-Limit", "100") testResponseModifier.Header().Set("X-Rate-Limit", "100")
testResponseModifier.WriteHeader(200) testResponseModifier.WriteHeader(http.StatusOK)
var out strings.Builder var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, err := ExpandVars(testResponseModifier, testRequest,
@@ -554,7 +554,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) {
}{ }{
{ {
name: "http scheme", name: "http scheme",
request: httptest.NewRequest("GET", "http://example.com/", nil), request: httptest.NewRequest(http.MethodGet, "http://example.com/", nil),
expected: "http", expected: "http",
}, },
{ {
@@ -581,7 +581,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) {
func TestExpandVars_UpstreamVariables(t *testing.T) { func TestExpandVars_UpstreamVariables(t *testing.T) {
// Upstream variables require context from routes package // Upstream variables require context from routes package
testRequest := httptest.NewRequest("GET", "/", nil) testRequest := httptest.NewRequest(http.MethodGet, "/", nil)
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
@@ -607,7 +607,7 @@ func TestExpandVars_UpstreamVariables(t *testing.T) {
func TestExpandVars_NoHostPort(t *testing.T) { func TestExpandVars_NoHostPort(t *testing.T) {
// Test request without port in Host header // Test request without port in Host header
testRequest := httptest.NewRequest("GET", "/", nil) testRequest := httptest.NewRequest(http.MethodGet, "/", nil)
testRequest.Host = "example.com" // No port testRequest.Host = "example.com" // No port
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
@@ -623,13 +623,13 @@ func TestExpandVars_NoHostPort(t *testing.T) {
var out strings.Builder var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, "$req_port", &out) err := ExpandVars(testResponseModifier, testRequest, "$req_port", &out)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "", out.String()) require.Empty(t, out.String())
}) })
} }
func TestExpandVars_NoRemotePort(t *testing.T) { func TestExpandVars_NoRemotePort(t *testing.T) {
// Test request without port in RemoteAddr // Test request without port in RemoteAddr
testRequest := httptest.NewRequest("GET", "/", nil) testRequest := httptest.NewRequest(http.MethodGet, "/", nil)
testRequest.RemoteAddr = "192.168.1.1" // No port testRequest.RemoteAddr = "192.168.1.1" // No port
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
@@ -638,19 +638,19 @@ func TestExpandVars_NoRemotePort(t *testing.T) {
var out strings.Builder var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, "$remote_host", &out) err := ExpandVars(testResponseModifier, testRequest, "$remote_host", &out)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "", out.String()) require.Empty(t, out.String())
}) })
t.Run("remote_port without port", func(t *testing.T) { t.Run("remote_port without port", func(t *testing.T) {
var out strings.Builder var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, "$remote_port", &out) err := ExpandVars(testResponseModifier, testRequest, "$remote_port", &out)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "", out.String()) require.Empty(t, out.String())
}) })
} }
func TestExpandVars_WhitespaceHandling(t *testing.T) { func TestExpandVars_WhitespaceHandling(t *testing.T) {
testRequest := httptest.NewRequest("GET", "/test", nil) testRequest := httptest.NewRequest(http.MethodGet, "/test", nil)
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
var out strings.Builder var out strings.Builder

View File

@@ -27,6 +27,7 @@ type HTTPConfig struct {
// BuildTLSConfig creates a TLS configuration based on the HTTP config options. // BuildTLSConfig creates a TLS configuration based on the HTTP config options.
func (cfg *HTTPConfig) BuildTLSConfig(targetURL *url.URL) (*tls.Config, error) { func (cfg *HTTPConfig) BuildTLSConfig(targetURL *url.URL) (*tls.Config, error) {
//nolint:gosec
tlsConfig := &tls.Config{} tlsConfig := &tls.Config{}
// Handle InsecureSkipVerify (legacy NoTLSVerify option) // Handle InsecureSkipVerify (legacy NoTLSVerify option)

View File

@@ -15,8 +15,10 @@ func NewSubstituteEnvReader(reader io.Reader) *SubstituteEnvReader {
return &SubstituteEnvReader{reader: reader} return &SubstituteEnvReader{reader: reader}
} }
const peekSize = 4096 const (
const maxVarNameLength = 256 peekSize = 4096
maxVarNameLength = 256
)
func (r *SubstituteEnvReader) Read(p []byte) (n int, err error) { func (r *SubstituteEnvReader) Read(p []byte) (n int, err error) {
// Return buffered data first // Return buffered data first
@@ -66,6 +68,7 @@ func (r *SubstituteEnvReader) Read(p []byte) (n int, err error) {
if nMore > 0 { if nMore > 0 {
incomplete = append(incomplete, more[:nMore]...) incomplete = append(incomplete, more[:nMore]...)
// Check if pattern is now complete // Check if pattern is now complete
//nolint:modernize
if idx := bytes.IndexByte(incomplete, '}'); idx >= 0 { if idx := bytes.IndexByte(incomplete, '}'); idx >= 0 {
// Pattern complete, append the rest back to chunk // Pattern complete, append the rest back to chunk
chunk = append(chunk, incomplete...) chunk = append(chunk, incomplete...)

View File

@@ -2,8 +2,8 @@ package serialization
import ( import (
"bytes" "bytes"
"errors"
"io" "io"
"os"
"strings" "strings"
"testing" "testing"
) )
@@ -11,17 +11,9 @@ import (
// setupEnv sets up environment variables for benchmarks // setupEnv sets up environment variables for benchmarks
func setupEnv(b *testing.B) { func setupEnv(b *testing.B) {
b.Helper() b.Helper()
os.Setenv("BENCH_VAR", "benchmark_value") b.Setenv("BENCH_VAR", "benchmark_value")
os.Setenv("BENCH_VAR_2", "second_value") b.Setenv("BENCH_VAR_2", "second_value")
os.Setenv("BENCH_VAR_3", "third_value") b.Setenv("BENCH_VAR_3", "third_value")
}
// cleanupEnv cleans up environment variables after benchmarks
func cleanupEnv(b *testing.B) {
b.Helper()
os.Unsetenv("BENCH_VAR")
os.Unsetenv("BENCH_VAR_2")
os.Unsetenv("BENCH_VAR_3")
} }
// BenchmarkSubstituteEnvReader_NoSubstitution benchmarks reading without any env substitutions // BenchmarkSubstituteEnvReader_NoSubstitution benchmarks reading without any env substitutions
@@ -44,7 +36,6 @@ data: some content here
// BenchmarkSubstituteEnvReader_SingleSubstitution benchmarks reading with a single env substitution // BenchmarkSubstituteEnvReader_SingleSubstitution benchmarks reading with a single env substitution
func BenchmarkSubstituteEnvReader_SingleSubstitution(b *testing.B) { func BenchmarkSubstituteEnvReader_SingleSubstitution(b *testing.B) {
setupEnv(b) setupEnv(b)
defer cleanupEnv(b)
r := strings.NewReader(`key: ${BENCH_VAR} r := strings.NewReader(`key: ${BENCH_VAR}
`) `)
@@ -62,7 +53,6 @@ func BenchmarkSubstituteEnvReader_SingleSubstitution(b *testing.B) {
// BenchmarkSubstituteEnvReader_MultipleSubstitutions benchmarks reading with multiple env substitutions // BenchmarkSubstituteEnvReader_MultipleSubstitutions benchmarks reading with multiple env substitutions
func BenchmarkSubstituteEnvReader_MultipleSubstitutions(b *testing.B) { func BenchmarkSubstituteEnvReader_MultipleSubstitutions(b *testing.B) {
setupEnv(b) setupEnv(b)
defer cleanupEnv(b)
r := strings.NewReader(`key1: ${BENCH_VAR} r := strings.NewReader(`key1: ${BENCH_VAR}
key2: ${BENCH_VAR_2} key2: ${BENCH_VAR_2}
@@ -96,7 +86,6 @@ func BenchmarkSubstituteEnvReader_LargeInput_NoSubstitution(b *testing.B) {
// BenchmarkSubstituteEnvReader_LargeInput_WithSubstitutions benchmarks large input with scattered substitutions // BenchmarkSubstituteEnvReader_LargeInput_WithSubstitutions benchmarks large input with scattered substitutions
func BenchmarkSubstituteEnvReader_LargeInput_WithSubstitutions(b *testing.B) { func BenchmarkSubstituteEnvReader_LargeInput_WithSubstitutions(b *testing.B) {
setupEnv(b) setupEnv(b)
defer cleanupEnv(b)
var builder bytes.Buffer var builder bytes.Buffer
for range 100 { for range 100 {
@@ -118,7 +107,6 @@ func BenchmarkSubstituteEnvReader_LargeInput_WithSubstitutions(b *testing.B) {
// BenchmarkSubstituteEnvReader_SmallBuffer benchmarks reading with a small buffer size // BenchmarkSubstituteEnvReader_SmallBuffer benchmarks reading with a small buffer size
func BenchmarkSubstituteEnvReader_SmallBuffer(b *testing.B) { func BenchmarkSubstituteEnvReader_SmallBuffer(b *testing.B) {
setupEnv(b) setupEnv(b)
defer cleanupEnv(b)
r := strings.NewReader(`key: ${BENCH_VAR} and some more content here`) r := strings.NewReader(`key: ${BENCH_VAR} and some more content here`)
buf := make([]byte, 16) buf := make([]byte, 16)
@@ -127,7 +115,7 @@ func BenchmarkSubstituteEnvReader_SmallBuffer(b *testing.B) {
reader := NewSubstituteEnvReader(r) reader := NewSubstituteEnvReader(r)
for { for {
_, err := reader.Read(buf) _, err := reader.Read(buf)
if err == io.EOF { if errors.Is(err, io.EOF) {
break break
} }
if err != nil { if err != nil {
@@ -141,7 +129,6 @@ func BenchmarkSubstituteEnvReader_SmallBuffer(b *testing.B) {
// BenchmarkSubstituteEnvReader_YAMLConfig benchmarks a realistic YAML config scenario // BenchmarkSubstituteEnvReader_YAMLConfig benchmarks a realistic YAML config scenario
func BenchmarkSubstituteEnvReader_YAMLConfig(b *testing.B) { func BenchmarkSubstituteEnvReader_YAMLConfig(b *testing.B) {
setupEnv(b) setupEnv(b)
defer cleanupEnv(b)
r := strings.NewReader(`database: r := strings.NewReader(`database:
host: ${BENCH_VAR} host: ${BENCH_VAR}
@@ -170,7 +157,6 @@ server:
// BenchmarkSubstituteEnvReader_BoundaryPattern benchmarks patterns at buffer boundaries (4096 bytes) // BenchmarkSubstituteEnvReader_BoundaryPattern benchmarks patterns at buffer boundaries (4096 bytes)
func BenchmarkSubstituteEnvReader_BoundaryPattern(b *testing.B) { func BenchmarkSubstituteEnvReader_BoundaryPattern(b *testing.B) {
setupEnv(b) setupEnv(b)
defer cleanupEnv(b)
// Pattern exactly at 4090 bytes, with ${VAR} crossing the 4096 boundary // Pattern exactly at 4090 bytes, with ${VAR} crossing the 4096 boundary
prefix := strings.Repeat("x", 4090) prefix := strings.Repeat("x", 4090)
@@ -189,7 +175,6 @@ func BenchmarkSubstituteEnvReader_BoundaryPattern(b *testing.B) {
// BenchmarkSubstituteEnvReader_MultipleBoundaries benchmarks multiple patterns crossing boundaries // BenchmarkSubstituteEnvReader_MultipleBoundaries benchmarks multiple patterns crossing boundaries
func BenchmarkSubstituteEnvReader_MultipleBoundaries(b *testing.B) { func BenchmarkSubstituteEnvReader_MultipleBoundaries(b *testing.B) {
setupEnv(b) setupEnv(b)
defer cleanupEnv(b)
var builder bytes.Buffer var builder bytes.Buffer
for range 10 { for range 10 {
@@ -210,8 +195,7 @@ func BenchmarkSubstituteEnvReader_MultipleBoundaries(b *testing.B) {
// BenchmarkSubstituteEnvReader_SpecialChars benchmarks substitution with special characters // BenchmarkSubstituteEnvReader_SpecialChars benchmarks substitution with special characters
func BenchmarkSubstituteEnvReader_SpecialChars(b *testing.B) { func BenchmarkSubstituteEnvReader_SpecialChars(b *testing.B) {
os.Setenv("SPECIAL_BENCH_VAR", `value with "quotes" and \backslash\`) b.Setenv("SPECIAL_BENCH_VAR", `value with "quotes" and \backslash\`)
defer os.Unsetenv("SPECIAL_BENCH_VAR")
r := strings.NewReader(`key: ${SPECIAL_BENCH_VAR} r := strings.NewReader(`key: ${SPECIAL_BENCH_VAR}
`) `)
@@ -228,8 +212,7 @@ func BenchmarkSubstituteEnvReader_SpecialChars(b *testing.B) {
// BenchmarkSubstituteEnvReader_EmptyValue benchmarks substitution with empty value // BenchmarkSubstituteEnvReader_EmptyValue benchmarks substitution with empty value
func BenchmarkSubstituteEnvReader_EmptyValue(b *testing.B) { func BenchmarkSubstituteEnvReader_EmptyValue(b *testing.B) {
os.Setenv("EMPTY_BENCH_VAR", "") b.Setenv("EMPTY_BENCH_VAR", "")
defer os.Unsetenv("EMPTY_BENCH_VAR")
r := strings.NewReader(`key: ${EMPTY_BENCH_VAR} r := strings.NewReader(`key: ${EMPTY_BENCH_VAR}
`) `)
@@ -246,8 +229,7 @@ func BenchmarkSubstituteEnvReader_EmptyValue(b *testing.B) {
// BenchmarkSubstituteEnvReader_DollarWithoutBrace benchmarks $ without following { // BenchmarkSubstituteEnvReader_DollarWithoutBrace benchmarks $ without following {
func BenchmarkSubstituteEnvReader_DollarWithoutBrace(b *testing.B) { func BenchmarkSubstituteEnvReader_DollarWithoutBrace(b *testing.B) {
os.Setenv("BENCH_VAR", "benchmark_value") b.Setenv("BENCH_VAR", "benchmark_value")
defer os.Unsetenv("BENCH_VAR")
r := strings.NewReader(`price: $100 and $200 for ${BENCH_VAR}`) r := strings.NewReader(`price: $100 and $200 for ${BENCH_VAR}`)

View File

@@ -2,8 +2,8 @@ package serialization
import ( import (
"bytes" "bytes"
"errors"
"io" "io"
"os"
"strings" "strings"
"testing" "testing"
@@ -11,8 +11,7 @@ import (
) )
func TestSubstituteEnvReader_Basic(t *testing.T) { func TestSubstituteEnvReader_Basic(t *testing.T) {
os.Setenv("TEST_VAR", "hello") t.Setenv("TEST_VAR", "hello")
defer os.Unsetenv("TEST_VAR")
input := []byte(`key: ${TEST_VAR}`) input := []byte(`key: ${TEST_VAR}`)
reader := NewSubstituteEnvReader(bytes.NewReader(input)) reader := NewSubstituteEnvReader(bytes.NewReader(input))
@@ -23,10 +22,8 @@ func TestSubstituteEnvReader_Basic(t *testing.T) {
} }
func TestSubstituteEnvReader_Multiple(t *testing.T) { func TestSubstituteEnvReader_Multiple(t *testing.T) {
os.Setenv("VAR1", "first") t.Setenv("VAR1", "first")
os.Setenv("VAR2", "second") t.Setenv("VAR2", "second")
defer os.Unsetenv("VAR1")
defer os.Unsetenv("VAR2")
input := []byte(`a: ${VAR1}, b: ${VAR2}`) input := []byte(`a: ${VAR1}, b: ${VAR2}`)
reader := NewSubstituteEnvReader(bytes.NewReader(input)) reader := NewSubstituteEnvReader(bytes.NewReader(input))
@@ -46,8 +43,6 @@ func TestSubstituteEnvReader_NoSubstitution(t *testing.T) {
} }
func TestSubstituteEnvReader_UnsetEnvError(t *testing.T) { func TestSubstituteEnvReader_UnsetEnvError(t *testing.T) {
os.Unsetenv("UNSET_VAR_FOR_TEST")
input := []byte(`key: ${UNSET_VAR_FOR_TEST}`) input := []byte(`key: ${UNSET_VAR_FOR_TEST}`)
reader := NewSubstituteEnvReader(bytes.NewReader(input)) reader := NewSubstituteEnvReader(bytes.NewReader(input))
@@ -57,8 +52,7 @@ func TestSubstituteEnvReader_UnsetEnvError(t *testing.T) {
} }
func TestSubstituteEnvReader_SmallBuffer(t *testing.T) { func TestSubstituteEnvReader_SmallBuffer(t *testing.T) {
os.Setenv("SMALL_BUF_VAR", "value") t.Setenv("SMALL_BUF_VAR", "value")
defer os.Unsetenv("SMALL_BUF_VAR")
input := []byte(`key: ${SMALL_BUF_VAR}`) input := []byte(`key: ${SMALL_BUF_VAR}`)
reader := NewSubstituteEnvReader(bytes.NewReader(input)) reader := NewSubstituteEnvReader(bytes.NewReader(input))
@@ -70,7 +64,7 @@ func TestSubstituteEnvReader_SmallBuffer(t *testing.T) {
if n > 0 { if n > 0 {
result = append(result, buf[:n]...) result = append(result, buf[:n]...)
} }
if err == io.EOF { if errors.Is(err, io.EOF) {
break break
} }
require.NoError(t, err) require.NoError(t, err)
@@ -79,8 +73,7 @@ func TestSubstituteEnvReader_SmallBuffer(t *testing.T) {
} }
func TestSubstituteEnvReader_SpecialChars(t *testing.T) { func TestSubstituteEnvReader_SpecialChars(t *testing.T) {
os.Setenv("SPECIAL_VAR", `hello "world" \n`) t.Setenv("SPECIAL_VAR", `hello "world" \n`)
defer os.Unsetenv("SPECIAL_VAR")
input := []byte(`key: ${SPECIAL_VAR}`) input := []byte(`key: ${SPECIAL_VAR}`)
reader := NewSubstituteEnvReader(bytes.NewReader(input)) reader := NewSubstituteEnvReader(bytes.NewReader(input))
@@ -91,8 +84,7 @@ func TestSubstituteEnvReader_SpecialChars(t *testing.T) {
} }
func TestSubstituteEnvReader_EmptyValue(t *testing.T) { func TestSubstituteEnvReader_EmptyValue(t *testing.T) {
os.Setenv("EMPTY_VAR", "") t.Setenv("EMPTY_VAR", "")
defer os.Unsetenv("EMPTY_VAR")
input := []byte(`key: ${EMPTY_VAR}`) input := []byte(`key: ${EMPTY_VAR}`)
reader := NewSubstituteEnvReader(bytes.NewReader(input)) reader := NewSubstituteEnvReader(bytes.NewReader(input))
@@ -103,8 +95,7 @@ func TestSubstituteEnvReader_EmptyValue(t *testing.T) {
} }
func TestSubstituteEnvReader_LargeInput(t *testing.T) { func TestSubstituteEnvReader_LargeInput(t *testing.T) {
os.Setenv("LARGE_VAR", "replaced") t.Setenv("LARGE_VAR", "replaced")
defer os.Unsetenv("LARGE_VAR")
prefix := strings.Repeat("x", 5000) prefix := strings.Repeat("x", 5000)
suffix := strings.Repeat("y", 5000) suffix := strings.Repeat("y", 5000)
@@ -119,8 +110,7 @@ func TestSubstituteEnvReader_LargeInput(t *testing.T) {
} }
func TestSubstituteEnvReader_PatternAtBoundary(t *testing.T) { func TestSubstituteEnvReader_PatternAtBoundary(t *testing.T) {
os.Setenv("BOUNDARY_VAR", "boundary_value") t.Setenv("BOUNDARY_VAR", "boundary_value")
defer os.Unsetenv("BOUNDARY_VAR")
prefix := strings.Repeat("a", 4090) prefix := strings.Repeat("a", 4090)
input := []byte(prefix + "${BOUNDARY_VAR}") input := []byte(prefix + "${BOUNDARY_VAR}")
@@ -134,10 +124,8 @@ func TestSubstituteEnvReader_PatternAtBoundary(t *testing.T) {
} }
func TestSubstituteEnvReader_MultiplePatternsBoundary(t *testing.T) { func TestSubstituteEnvReader_MultiplePatternsBoundary(t *testing.T) {
os.Setenv("VAR_A", "aaa") t.Setenv("VAR_A", "aaa")
os.Setenv("VAR_B", "bbb") t.Setenv("VAR_B", "bbb")
defer os.Unsetenv("VAR_A")
defer os.Unsetenv("VAR_B")
prefix := strings.Repeat("x", 4090) prefix := strings.Repeat("x", 4090)
input := []byte(prefix + "${VAR_A} middle ${VAR_B}") input := []byte(prefix + "${VAR_A} middle ${VAR_B}")
@@ -151,12 +139,9 @@ func TestSubstituteEnvReader_MultiplePatternsBoundary(t *testing.T) {
} }
func TestSubstituteEnvReader_YAMLConfig(t *testing.T) { func TestSubstituteEnvReader_YAMLConfig(t *testing.T) {
os.Setenv("DB_HOST", "localhost") t.Setenv("DB_HOST", "localhost")
os.Setenv("DB_PORT", "5432") t.Setenv("DB_PORT", "5432")
os.Setenv("DB_PASSWORD", "secret123") t.Setenv("DB_PASSWORD", "secret123")
defer os.Unsetenv("DB_HOST")
defer os.Unsetenv("DB_PORT")
defer os.Unsetenv("DB_PASSWORD")
input := []byte(`database: input := []byte(`database:
host: ${DB_HOST} host: ${DB_HOST}

View File

@@ -87,7 +87,7 @@ func initPtr(dst reflect.Value) {
} }
} }
// Validate performs struct validation using go-playground/validator tags. // ValidateWithFieldTags performs struct validation using go-playground/validator tags.
// //
// It collects all validation errors and returns them as a single error. // It collects all validation errors and returns them as a single error.
// Field names in errors are prefixed with their namespace (e.g., "User.Email"). // Field names in errors are prefixed with their namespace (e.g., "User.Email").
@@ -521,7 +521,6 @@ func ConvertSlice(src reflect.Value, dst reflect.Value, checkValidateTag bool) e
// - Returns true if conversion was handled (even with error), false if // - Returns true if conversion was handled (even with error), false if
// conversion is unsupported. // conversion is unsupported.
func ConvertString(src string, dst reflect.Value) (convertible bool, convErr error) { func ConvertString(src string, dst reflect.Value) (convertible bool, convErr error) {
convertible = true
dstT := dst.Type() dstT := dst.Type()
if dst.Kind() == reflect.Pointer { if dst.Kind() == reflect.Pointer {
if dst.IsNil() { if dst.IsNil() {

View File

@@ -17,7 +17,7 @@ func initConfigDirWatcher() {
configDirWatcher = NewDirectoryWatcher(t, common.ConfigBasePath) configDirWatcher = NewDirectoryWatcher(t, common.ConfigBasePath)
} }
// create a new file watcher for file under ConfigBasePath. // NewConfigFileWatcher creates a new file watcher for file under common.ConfigBasePath.
func NewConfigFileWatcher(filename string) Watcher { func NewConfigFileWatcher(filename string) Watcher {
configDirWatcherInitOnce.Do(initConfigDirWatcher) configDirWatcherInitOnce.Do(initConfigDirWatcher)
return configDirWatcher.Add(filename) return configDirWatcher.Add(filename)

View File

@@ -124,15 +124,13 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList
retry := time.NewTicker(dockerWatcherRetryInterval) retry := time.NewTicker(dockerWatcherRetryInterval)
defer retry.Stop() defer retry.Stop()
ok := false
outer: outer:
for !ok { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-retry.C: case <-retry.C:
if checkConnection(ctx, client) { if checkConnection(ctx, client) {
ok = true
break outer break outer
} }
} }

View File

@@ -2,6 +2,7 @@ package events
import ( import (
"fmt" "fmt"
"maps"
dockerEvents "github.com/moby/moby/api/types/events" dockerEvents "github.com/moby/moby/api/types/events"
) )
@@ -69,9 +70,7 @@ var actionNameMap = func() (m map[Action]string) {
for k, v := range DockerEventMap { for k, v := range DockerEventMap {
m[v] = string(k) m[v] = string(k)
} }
for k, v := range fileActionNameMap { maps.Copy(m, fileActionNameMap)
m[k] = v
}
return m return m
}() }()