diff --git a/internal/api/handler.go b/internal/api/handler.go index e36eecab..791b0987 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -23,19 +23,21 @@ import ( apitypes "github.com/yusing/goutils/apitypes" ) +// NewHandler creates a new Gin engine for the API. +// // @title GoDoxy API // @version 1.0 // @description GoDoxy API // @termsOfService https://github.com/yusing/godoxy/blob/main/LICENSE - +// // @contact.name Yusing // @contact.url https://github.com/yusing/godoxy/issues - +// // @license.name MIT // @license.url https://github.com/yusing/godoxy/blob/main/LICENSE - +// // @BasePath /api/v1 - +// // @externalDocs.description GoDoxy Docs // @externalDocs.url https://docs.godoxy.dev func NewHandler(requireAuth bool) *gin.Engine { diff --git a/internal/api/v1/homepage/categories.go b/internal/api/v1/homepage/categories.go index 7ae61c77..b5e17eea 100644 --- a/internal/api/v1/homepage/categories.go +++ b/internal/api/v1/homepage/categories.go @@ -7,7 +7,6 @@ import ( entrypoint "github.com/yusing/godoxy/internal/entrypoint/types" "github.com/yusing/godoxy/internal/homepage" - _ "github.com/yusing/goutils/apitypes" apitypes "github.com/yusing/goutils/apitypes" ) diff --git a/internal/api/v1/route/playground.go b/internal/api/v1/route/playground.go index e4f464ed..39d40549 100644 --- a/internal/api/v1/route/playground.go +++ b/internal/api/v1/route/playground.go @@ -257,8 +257,8 @@ func handlerWithRecover(w http.ResponseWriter, r *http.Request, h http.HandlerFu } func parseRules(rawRules []RawRule) ([]ParsedRule, rules.Rules, error) { - var parsedRules []ParsedRule - var rulesList rules.Rules + parsedRules := make([]ParsedRule, 0, len(rawRules)) + rulesList := make(rules.Rules, 0, len(rawRules)) var valErrs gperr.Builder diff --git a/internal/api/v1/route/playground_test.go b/internal/api/v1/route/playground_test.go index 7b357856..91c79026 100644 --- a/internal/api/v1/route/playground_test.go +++ b/internal/api/v1/route/playground_test.go @@ -79,7 +79,7 @@ func TestPlayground(t *testing.T) { if len(resp.MatchedRules) != 1 { t.Errorf("expected 1 matched rule, got %d", len(resp.MatchedRules)) } - if resp.FinalResponse.StatusCode != 403 { + if resp.FinalResponse.StatusCode != http.StatusForbidden { t.Errorf("expected status 403, got %d", resp.FinalResponse.StatusCode) } if resp.UpstreamCalled { @@ -168,7 +168,7 @@ func TestPlayground(t *testing.T) { if len(resp.MatchedRules) != 1 { t.Errorf("expected 1 matched rule, got %d", len(resp.MatchedRules)) } - if resp.FinalResponse.StatusCode != 405 { + if resp.FinalResponse.StatusCode != http.StatusMethodNotAllowed { t.Errorf("expected status 405, got %d", resp.FinalResponse.StatusCode) } }, @@ -179,7 +179,7 @@ func TestPlayground(t *testing.T) { t.Run(tt.name, func(t *testing.T) { // Create request body, _ := json.Marshal(tt.request) - req := httptest.NewRequest("POST", "/api/v1/route/playground", bytes.NewReader(body)) + req := httptest.NewRequest(http.MethodPost, "/api/v1/route/playground", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") // Create response recorder @@ -214,7 +214,7 @@ func TestPlayground(t *testing.T) { func TestPlaygroundInvalidRequest(t *testing.T) { gin.SetMode(gin.TestMode) - req := httptest.NewRequest("POST", "/api/v1/route/playground", bytes.NewReader([]byte(`{}`))) + req := httptest.NewRequest(http.MethodPost, "/api/v1/route/playground", bytes.NewReader([]byte(`{}`))) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() diff --git a/internal/auth/oauth_refresh.go b/internal/auth/oauth_refresh.go index 0fa529cd..1bd894e8 100644 --- a/internal/auth/oauth_refresh.go +++ b/internal/auth/oauth_refresh.go @@ -135,7 +135,7 @@ func (auth *OIDCProvider) setSessionTokenCookie(w http.ResponseWriter, r *http.R func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionClaims, valid bool, err error) { claims = &sessionClaims{} - sessionToken, err := jwt.ParseWithClaims(sessionJWT, claims, func(t *jwt.Token) (interface{}, error) { + sessionToken, err := jwt.ParseWithClaims(sessionJWT, claims, func(t *jwt.Token) (any, error) { if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) } diff --git a/internal/auth/userpass.go b/internal/auth/userpass.go index 73bdb199..e9fc3343 100644 --- a/internal/auth/userpass.go +++ b/internal/auth/userpass.go @@ -24,8 +24,9 @@ type ( tokenTTL time.Duration } UserPassClaims struct { - Username string `json:"username"` jwt.RegisteredClaims + + Username string `json:"username"` } ) @@ -78,7 +79,7 @@ func (auth *UserPassAuth) CheckToken(r *http.Request) error { return ErrMissingSessionToken } var claims UserPassClaims - token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(t *jwt.Token) (any, error) { if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) } diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index 336d9a62..33e88b37 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -467,7 +467,7 @@ func (p *Provider) scheduleRenewal(parent task.Parent) { log.Warn().Err(p.fmtError(err)).Msg("autocert: cert renew failed") notif.Notify(¬if.LogMessage{ Level: zerolog.ErrorLevel, - Title: fmt.Sprintf("SSL certificate renewal failed for %s", p.GetName()), + Title: "SSL certificate renewal failed for " + p.GetName(), Body: notif.MessageBody(err.Error()), }) return @@ -477,7 +477,7 @@ func (p *Provider) scheduleRenewal(parent task.Parent) { notif.Notify(¬if.LogMessage{ Level: zerolog.InfoLevel, - Title: fmt.Sprintf("SSL certificate renewed for %s", p.GetName()), + Title: "SSL certificate renewed for " + p.GetName(), Body: notif.ListBody(p.cfg.Domains), }) diff --git a/internal/autocert/types/context.go b/internal/autocert/types/context.go index 6b4167bc..37e17f5f 100644 --- a/internal/autocert/types/context.go +++ b/internal/autocert/types/context.go @@ -4,7 +4,7 @@ import "context" type ContextKey struct{} -func SetCtx(ctx interface{ SetValue(any, any) }, p Provider) { +func SetCtx(ctx interface{ SetValue(key, value any) }, p Provider) { ctx.SetValue(ContextKey{}, p) } diff --git a/internal/config/events.go b/internal/config/events.go index 4adf5f27..a98def0f 100644 --- a/internal/config/events.go +++ b/internal/config/events.go @@ -26,11 +26,9 @@ var ( const configEventFlushInterval = 500 * time.Millisecond -const ( - cfgRenameWarn = `Config file renamed, not reloading. -Make sure you rename it back before next time you start.` - cfgDeleteWarn = `Config file deleted, not reloading. -You may run "ls-config" to show or dump the current config.` +var ( + errCfgRenameWarn = errors.New("config file renamed, not reloading; Make sure you rename it back before next time you start") + errCfgDeleteWarn = errors.New(`config file deleted, not reloading; You may run "ls-config" to show or dump the current config`) ) func logNotifyError(action string, err error) { @@ -142,11 +140,11 @@ func OnConfigChange(ev []watcherEvents.Event) { // no matter how many events during the interval // just reload once and check the last event switch ev[len(ev)-1].Action { - case events.ActionFileRenamed: - logNotifyWarn("rename", errors.New(cfgRenameWarn)) + case watcherEvents.ActionFileRenamed: + logNotifyWarn("rename", errCfgRenameWarn) return - case events.ActionFileDeleted: - logNotifyWarn("delete", errors.New(cfgDeleteWarn)) + case watcherEvents.ActionFileDeleted: + logNotifyWarn("delete", errCfgDeleteWarn) return } diff --git a/internal/config/state.go b/internal/config/state.go index cc19ff1d..22c7f2d7 100644 --- a/internal/config/state.go +++ b/internal/config/state.go @@ -17,6 +17,7 @@ import ( "github.com/goccy/go-yaml" "github.com/puzpuzpuz/xsync/v4" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" acl "github.com/yusing/godoxy/internal/acl/types" "github.com/yusing/godoxy/internal/agentpool" "github.com/yusing/godoxy/internal/api" @@ -90,11 +91,6 @@ func SetState(state config.State) { cfg := state.Value() config.ActiveState.Store(state) homepage.ActiveConfig.Store(&cfg.Homepage) - if autocertProvider := state.AutoCertProvider(); autocertProvider != nil { - autocert.ActiveProvider.Store(autocertProvider.(*autocert.Provider)) - } else { - autocert.ActiveProvider.Store(nil) - } } func HasState() bool { @@ -203,25 +199,31 @@ func (state *state) NumProviders() int { } func (state *state) FlushTmpLog() { - state.tmpLogBuf.WriteTo(os.Stdout) + _, _ = state.tmpLogBuf.WriteTo(os.Stdout) state.tmpLogBuf.Reset() } func (state *state) StartAPIServers() { // API Handler needs to start after auth is initialized. - server.StartServer(state.task.Subtask("api_server", false), server.Options{ + _, err := server.StartServer(state.task.Subtask("api_server", false), server.Options{ Name: "api", HTTPAddr: common.APIHTTPAddr, Handler: api.NewHandler(true), }) + if err != nil { + log.Err(err).Msg("failed to start API server") + } // Local API Handler is used for unauthenticated access. if common.LocalAPIHTTPAddr != "" { - server.StartServer(state.task.Subtask("local_api_server", false), server.Options{ + _, err := server.StartServer(state.task.Subtask("local_api_server", false), server.Options{ Name: "local_api", HTTPAddr: common.LocalAPIHTTPAddr, Handler: api.NewHandler(false), }) + if err != nil { + log.Err(err).Msg("failed to start local API server") + } } } diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index 29be8827..15b15dbb 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -47,10 +47,10 @@ var _ entrypoint.Entrypoint = &Entrypoint{} var emptyCfg Config -func NewTestEntrypoint(t testing.TB, cfg *Config) *Entrypoint { - t.Helper() +func NewTestEntrypoint(tb testing.TB, cfg *Config) *Entrypoint { + tb.Helper() - testTask := task.GetTestTask(t) + testTask := task.GetTestTask(tb) ep := NewEntrypoint(testTask, cfg) entrypoint.SetCtx(testTask, ep) return ep @@ -160,6 +160,7 @@ func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Request } func findRouteAnyDomain(routes HTTPRoutes, host string) types.HTTPRoute { + //nolint:modernize idx := strings.IndexByte(host, '.') if idx != -1 { target := host[:idx] diff --git a/internal/entrypoint/http_server.go b/internal/entrypoint/http_server.go index 8d482f5b..01475a1f 100644 --- a/internal/entrypoint/http_server.go +++ b/internal/entrypoint/http_server.go @@ -19,7 +19,7 @@ import ( "github.com/yusing/goutils/server" ) -// httpServer is a server that listens on a given address and serves HTTP routes. +// HTTPServer is a server that listens on a given address and serves HTTP routes. type HTTPServer interface { Listen(addr string, proto HTTPProto) error AddRoute(route types.HTTPRoute) @@ -109,6 +109,8 @@ func (srv *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { rec := accesslog.GetResponseRecorder(w) w = rec defer func() { + // there is no body to close + //nolint:bodyclose srv.ep.accessLogger.LogRequest(r, rec.Response()) accesslog.PutResponseRecorder(rec) }() diff --git a/internal/entrypoint/shortlink_test.go b/internal/entrypoint/shortlink_test.go index 3c2ce3e5..99322340 100644 --- a/internal/entrypoint/shortlink_test.go +++ b/internal/entrypoint/shortlink_test.go @@ -19,7 +19,7 @@ func TestShortLinkMatcher_FQDNAlias(t *testing.T) { matcher.AddRoute("app.domain.com") t.Run("exact path", func(t *testing.T) { - req := httptest.NewRequest("GET", "/app", nil) + req := httptest.NewRequest(http.MethodGet, "/app", nil) w := httptest.NewRecorder() matcher.ServeHTTP(w, req) @@ -28,7 +28,7 @@ func TestShortLinkMatcher_FQDNAlias(t *testing.T) { }) t.Run("with path remainder", func(t *testing.T) { - req := httptest.NewRequest("GET", "/app/foo/bar", nil) + req := httptest.NewRequest(http.MethodGet, "/app/foo/bar", nil) w := httptest.NewRecorder() matcher.ServeHTTP(w, req) @@ -37,7 +37,7 @@ func TestShortLinkMatcher_FQDNAlias(t *testing.T) { }) t.Run("with query", func(t *testing.T) { - req := httptest.NewRequest("GET", "/app/foo?x=y&z=1", nil) + req := httptest.NewRequest(http.MethodGet, "/app/foo?x=y&z=1", nil) w := httptest.NewRecorder() matcher.ServeHTTP(w, req) @@ -53,7 +53,7 @@ func TestShortLinkMatcher_SubdomainAlias(t *testing.T) { matcher.AddRoute("app") t.Run("exact path", func(t *testing.T) { - req := httptest.NewRequest("GET", "/app", nil) + req := httptest.NewRequest(http.MethodGet, "/app", nil) w := httptest.NewRecorder() matcher.ServeHTTP(w, req) @@ -62,7 +62,7 @@ func TestShortLinkMatcher_SubdomainAlias(t *testing.T) { }) t.Run("with path remainder", func(t *testing.T) { - req := httptest.NewRequest("GET", "/app/foo/bar", nil) + req := httptest.NewRequest(http.MethodGet, "/app/foo/bar", nil) w := httptest.NewRecorder() matcher.ServeHTTP(w, req) @@ -78,7 +78,7 @@ func TestShortLinkMatcher_NotFound(t *testing.T) { matcher.AddRoute("app") t.Run("missing key", func(t *testing.T) { - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() matcher.ServeHTTP(w, req) @@ -86,7 +86,7 @@ func TestShortLinkMatcher_NotFound(t *testing.T) { }) t.Run("unknown key", func(t *testing.T) { - req := httptest.NewRequest("GET", "/unknown", nil) + req := httptest.NewRequest(http.MethodGet, "/unknown", nil) w := httptest.NewRecorder() matcher.ServeHTTP(w, req) @@ -103,13 +103,13 @@ func TestShortLinkMatcher_AddDelRoute(t *testing.T) { matcher.AddRoute("app2.domain.com") t.Run("both routes work", func(t *testing.T) { - req := httptest.NewRequest("GET", "/app1", nil) + req := httptest.NewRequest(http.MethodGet, "/app1", nil) w := httptest.NewRecorder() matcher.ServeHTTP(w, req) assert.Equal(t, http.StatusTemporaryRedirect, w.Code) assert.Equal(t, "https://app1.example.com/", w.Header().Get("Location")) - req = httptest.NewRequest("GET", "/app2.domain.com", nil) + req = httptest.NewRequest(http.MethodGet, "/app2.domain.com", nil) w = httptest.NewRecorder() matcher.ServeHTTP(w, req) assert.Equal(t, http.StatusTemporaryRedirect, w.Code) @@ -119,12 +119,12 @@ func TestShortLinkMatcher_AddDelRoute(t *testing.T) { t.Run("delete route", func(t *testing.T) { matcher.DelRoute("app1") - req := httptest.NewRequest("GET", "/app1", nil) + req := httptest.NewRequest(http.MethodGet, "/app1", nil) w := httptest.NewRecorder() matcher.ServeHTTP(w, req) assert.Equal(t, http.StatusNotFound, w.Code) - req = httptest.NewRequest("GET", "/app2.domain.com", nil) + req = httptest.NewRequest(http.MethodGet, "/app2.domain.com", nil) w = httptest.NewRecorder() matcher.ServeHTTP(w, req) assert.Equal(t, http.StatusTemporaryRedirect, w.Code) @@ -140,7 +140,7 @@ func TestShortLinkMatcher_NoDefaultDomainSuffix(t *testing.T) { t.Run("subdomain alias ignored", func(t *testing.T) { matcher.AddRoute("app") - req := httptest.NewRequest("GET", "/app", nil) + req := httptest.NewRequest(http.MethodGet, "/app", nil) w := httptest.NewRecorder() matcher.ServeHTTP(w, req) @@ -150,7 +150,7 @@ func TestShortLinkMatcher_NoDefaultDomainSuffix(t *testing.T) { t.Run("FQDN alias still works", func(t *testing.T) { matcher.AddRoute("app.domain.com") - req := httptest.NewRequest("GET", "/app.domain.com", nil) + req := httptest.NewRequest(http.MethodGet, "/app.domain.com", nil) w := httptest.NewRecorder() matcher.ServeHTTP(w, req) @@ -169,7 +169,7 @@ func TestEntrypoint_ShortLinkDispatch(t *testing.T) { require.NoError(t, err) t.Run("shortlink host", func(t *testing.T) { - req := httptest.NewRequest("GET", "/app", nil) + req := httptest.NewRequest(http.MethodGet, "/app", nil) req.Host = common.ShortLinkPrefix w := httptest.NewRecorder() server.ServeHTTP(w, req) @@ -179,7 +179,7 @@ func TestEntrypoint_ShortLinkDispatch(t *testing.T) { }) t.Run("shortlink host with port", func(t *testing.T) { - req := httptest.NewRequest("GET", "/app", nil) + req := httptest.NewRequest(http.MethodGet, "/app", nil) req.Host = common.ShortLinkPrefix + ":8080" w := httptest.NewRecorder() server.ServeHTTP(w, req) @@ -189,7 +189,7 @@ func TestEntrypoint_ShortLinkDispatch(t *testing.T) { }) t.Run("normal host", func(t *testing.T) { - req := httptest.NewRequest("GET", "/app", nil) + req := httptest.NewRequest(http.MethodGet, "/app", nil) req.Host = "app.example.com" w := httptest.NewRecorder() server.ServeHTTP(w, req) diff --git a/internal/health/check/docker.go b/internal/health/check/docker.go index 5d1d4cff..f595cc32 100644 --- a/internal/health/check/docker.go +++ b/internal/health/check/docker.go @@ -16,21 +16,23 @@ import ( type DockerHealthcheckState struct { client *docker.SharedClient - containerId string + containerID string numDockerFailures int } const dockerFailuresThreshold = 3 -var ErrDockerHealthCheckFailedTooManyTimes = errors.New("docker health check failed too many times") -var ErrDockerHealthCheckNotAvailable = errors.New("docker health check not available") +var ( + ErrDockerHealthCheckFailedTooManyTimes = errors.New("docker health check failed too many times") + ErrDockerHealthCheckNotAvailable = errors.New("docker health check not available") +) -func NewDockerHealthcheckState(client *docker.SharedClient, containerId string) *DockerHealthcheckState { +func NewDockerHealthcheckState(client *docker.SharedClient, containerID string) *DockerHealthcheckState { client.InterceptHTTPClient(interceptDockerInspectResponse) return &DockerHealthcheckState{ client: client, - containerId: containerId, + containerID: containerID, numDockerFailures: 0, } } @@ -44,7 +46,7 @@ func Docker(ctx context.Context, state *DockerHealthcheckState, timeout time.Dur defer cancel() // the actual inspect response is intercepted and returned as RequestInterceptedError - _, err := state.client.ContainerInspect(ctx, state.containerId, client.ContainerInspectOptions{}) + _, err := state.client.ContainerInspect(ctx, state.containerID, client.ContainerInspectOptions{}) var interceptedErr *httputils.RequestInterceptedError if !httputils.AsRequestInterceptedError(err, &interceptedErr) { diff --git a/internal/health/monitor/monitor.go b/internal/health/monitor/monitor.go index 68a6de8d..3effb0d8 100644 --- a/internal/health/monitor/monitor.go +++ b/internal/health/monitor/monitor.go @@ -14,6 +14,7 @@ import ( config "github.com/yusing/godoxy/internal/config/types" "github.com/yusing/godoxy/internal/notif" "github.com/yusing/godoxy/internal/types" + "github.com/yusing/goutils/events" strutils "github.com/yusing/goutils/strings" "github.com/yusing/goutils/synk" "github.com/yusing/goutils/task" @@ -269,6 +270,7 @@ func (mon *monitor) notifyServiceUp(logger *zerolog.Logger, result *types.Health Body: extras, Color: notif.ColorSuccess, }) + events.Global.Add(events.NewEvent(events.LevelInfo, "health", "service_up", mon)) } func (mon *monitor) notifyServiceDown(logger *zerolog.Logger, result *types.HealthCheckResult) { @@ -281,6 +283,7 @@ func (mon *monitor) notifyServiceDown(logger *zerolog.Logger, result *types.Heal Body: extras, Color: notif.ColorError, }) + events.Global.Add(events.NewEvent(events.LevelWarn, "health", "service_down", mon)) } func (mon *monitor) buildNotificationExtras(result *types.HealthCheckResult) notif.FieldsBody { diff --git a/internal/health/monitor/new.go b/internal/health/monitor/new.go index 285f081b..87121c77 100644 --- a/internal/health/monitor/new.go +++ b/internal/health/monitor/new.go @@ -2,9 +2,9 @@ package monitor import ( "errors" - "fmt" "net/http" "net/url" + "strconv" "time" "github.com/rs/zerolog/log" @@ -14,8 +14,10 @@ import ( "github.com/yusing/godoxy/internal/types" ) -type Result = types.HealthCheckResult -type Monitor = types.HealthMonCheck +type ( + Result = types.HealthCheckResult + Monitor = types.HealthMonCheck +) // NewMonitor creates a health monitor based on the route type and configuration. // @@ -78,22 +80,22 @@ func NewFileServerHealthMonitor(config types.HealthCheckConfig, path string) Mon return &mon } -func NewStreamHealthMonitor(config types.HealthCheckConfig, targetUrl *url.URL) Monitor { +func NewStreamHealthMonitor(config types.HealthCheckConfig, targetURL *url.URL) Monitor { var mon monitor - mon.init(targetUrl, config, func(u *url.URL) (result Result, err error) { + mon.init(targetURL, config, func(u *url.URL) (result Result, err error) { return healthcheck.Stream(mon.Context(), u, config.Timeout) }) return &mon } -func NewDockerHealthMonitor(config types.HealthCheckConfig, client *docker.SharedClient, containerId string, fallback Monitor) Monitor { - state := healthcheck.NewDockerHealthcheckState(client, containerId) +func NewDockerHealthMonitor(config types.HealthCheckConfig, client *docker.SharedClient, containerID string, fallback Monitor) Monitor { + state := healthcheck.NewDockerHealthcheckState(client, containerID) displayURL := &url.URL{ // only for display purposes, no actual request is made Scheme: "docker", Host: client.DaemonHost(), - Path: "/containers/" + containerId + "/json", + Path: "/containers/" + containerID + "/json", } - logger := log.With().Str("host", client.DaemonHost()).Str("container_id", containerId).Logger() + logger := log.With().Str("host", client.DaemonHost()).Str("container_id", containerID).Logger() isFirstFailure := true var mon monitor @@ -114,20 +116,20 @@ func NewDockerHealthMonitor(config types.HealthCheckConfig, client *docker.Share return &mon } -func NewAgentProxiedMonitor(config types.HealthCheckConfig, agent *agentpool.Agent, targetUrl *url.URL) Monitor { +func NewAgentProxiedMonitor(config types.HealthCheckConfig, agent *agentpool.Agent, targetURL *url.URL) Monitor { var mon monitor - mon.init(targetUrl, config, func(u *url.URL) (result Result, err error) { + mon.init(targetURL, config, func(u *url.URL) (result Result, err error) { return CheckHealthAgentProxied(agent, config.Timeout, u) }) return &mon } -func CheckHealthAgentProxied(agent *agentpool.Agent, timeout time.Duration, targetUrl *url.URL) (Result, error) { +func CheckHealthAgentProxied(agent *agentpool.Agent, timeout time.Duration, targetURL *url.URL) (Result, error) { query := url.Values{ - "scheme": {targetUrl.Scheme}, - "host": {targetUrl.Host}, - "path": {targetUrl.Path}, - "timeout": {fmt.Sprintf("%d", timeout.Milliseconds())}, + "scheme": {targetURL.Scheme}, + "host": {targetURL.Host}, + "path": {targetURL.Path}, + "timeout": {strconv.FormatInt(timeout.Milliseconds(), 10)}, } resp, err := agent.DoHealthCheck(timeout, query.Encode()) result := Result{ diff --git a/internal/homepage/icons/fetch/fetch.go b/internal/homepage/icons/fetch/fetch.go index 40061445..67b4b4ea 100644 --- a/internal/homepage/icons/fetch/fetch.go +++ b/internal/homepage/icons/fetch/fetch.go @@ -137,11 +137,11 @@ func fetchIcon(ctx context.Context, filename string) (Result, error) { for _, fileType := range []string{"svg", "webp", "png"} { result, err := fetchKnownIcon(ctx, icons.NewURL(icons.SourceSelfhSt, filename, fileType)) if err == nil { - return result, err + return result, nil } result, err = fetchKnownIcon(ctx, icons.NewURL(icons.SourceWalkXCode, filename, fileType)) if err == nil { - return result, err + return result, nil } } return FetchResultWithErrorf(http.StatusNotFound, "no icon found") @@ -152,6 +152,8 @@ type contextValue struct { uri string } +type contextKey struct{} + func FindIcon(ctx context.Context, r route, uri string, variant icons.Variant) (Result, error) { for _, ref := range r.References() { ref = sanitizeName(ref) @@ -160,7 +162,7 @@ func FindIcon(ctx context.Context, r route, uri string, variant icons.Variant) ( } result, err := fetchIcon(ctx, ref) if err == nil { - return result, err + return result, nil } } if r, ok := r.(httpRoute); ok { @@ -168,13 +170,13 @@ func FindIcon(ctx context.Context, r route, uri string, variant icons.Variant) ( return FetchResultWithErrorf(http.StatusServiceUnavailable, "service unavailable") } // fallback to parse html - return findIconSlowCached(context.WithValue(ctx, "route", contextValue{r: r, uri: uri}), r.Key()) + return findIconSlowCached(context.WithValue(ctx, contextKey{}, contextValue{r: r, uri: uri}), r.Key()) } return FetchResultWithErrorf(http.StatusNotFound, "no icon found") } var findIconSlowCached = cache.NewKeyFunc(func(ctx context.Context, key string) (Result, error) { - v := ctx.Value("route").(contextValue) + v := ctx.Value(contextKey{}).(contextValue) return findIconSlow(ctx, v.r, v.uri, nil) }).WithMaxEntries(200).WithRetriesConstantBackoff(math.MaxInt, 15*time.Second).Build() // infinite retries, 15 seconds interval diff --git a/internal/idlewatcher/handle_stream.go b/internal/idlewatcher/handle_stream.go index ec4f5df5..3dfde66c 100644 --- a/internal/idlewatcher/handle_stream.go +++ b/internal/idlewatcher/handle_stream.go @@ -11,7 +11,7 @@ var _ nettypes.Stream = (*Watcher)(nil) // ListenAndServe implements nettypes.Stream. func (w *Watcher) ListenAndServe(ctx context.Context, predial, onRead nettypes.HookFunc) error { - return w.stream.ListenAndServe(ctx, func(ctx context.Context) error { //nolint:contextcheck + return w.stream.ListenAndServe(ctx, func(ctx context.Context) error { return w.preDial(ctx, predial) }, func(ctx context.Context) error { return w.onRead(ctx, onRead) diff --git a/internal/net/gphttp/middleware/bypass_test.go b/internal/net/gphttp/middleware/bypass_test.go index 5c2afb2b..631c55ff 100644 --- a/internal/net/gphttp/middleware/bypass_test.go +++ b/internal/net/gphttp/middleware/bypass_test.go @@ -10,7 +10,6 @@ import ( "strings" "testing" - "github.com/yusing/godoxy/internal/common" "github.com/yusing/godoxy/internal/entrypoint" . "github.com/yusing/godoxy/internal/net/gphttp/middleware" "github.com/yusing/godoxy/internal/route" @@ -40,7 +39,7 @@ func TestBypassCIDR(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) req.RemoteAddr = test.remoteAddr recorder := httptest.NewRecorder() mr.ModifyRequest(noOpHandler, recorder, req) @@ -76,7 +75,7 @@ func TestBypassPath(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "http://example.com"+test.path, nil) + req := httptest.NewRequest(http.MethodGet, "http://example.com"+test.path, nil) recorder := httptest.NewRecorder() mr.ModifyRequest(noOpHandler, recorder, req) expect.NoError(t, err) @@ -126,7 +125,7 @@ func TestReverseProxyBypass(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "http://example.com"+test.path, nil) + req := httptest.NewRequest(http.MethodGet, "http://example.com"+test.path, nil) recorder := httptest.NewRecorder() rp.ServeHTTP(recorder, req) if test.expectBypass { @@ -160,7 +159,7 @@ func TestBypassResponse(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "http://example.com"+test.path, nil) + req := httptest.NewRequest(http.MethodGet, "http://example.com"+test.path, nil) resp := &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("test")), @@ -201,7 +200,7 @@ func TestBypassResponse(t *testing.T) { StatusCode: test.statusCode, Body: io.NopCloser(strings.NewReader("test")), Header: make(http.Header), - Request: httptest.NewRequest("GET", "http://example.com", nil), + Request: httptest.NewRequest(http.MethodGet, "http://example.com", nil), } mErr := mr.ModifyResponse(resp) expect.NoError(t, mErr) @@ -232,10 +231,12 @@ func TestEntrypointBypassRoute(t *testing.T) { entry := entrypoint.NewTestEntrypoint(t, nil) _, err = route.NewStartedTestRoute(t, &route.Route{ - Alias: "test-route", - Host: host, + Alias: "test-route", + Scheme: routeTypes.SchemeHTTP, + Host: host, Port: routeTypes.Port{ - Proxy: portInt, + Listening: 1000, + Proxy: portInt, }, }) expect.NoError(t, err) @@ -255,8 +256,8 @@ func TestEntrypointBypassRoute(t *testing.T) { expect.NoError(t, err) recorder := httptest.NewRecorder() - req := httptest.NewRequest("GET", "http://test-route.example.com", nil) - server, ok := entry.GetServer(common.ProxyHTTPAddr) + req := httptest.NewRequest(http.MethodGet, "http://test-route.example.com", nil) + server, ok := entry.GetServer(":1000") if !ok { t.Fatal("server not found") } diff --git a/internal/net/gphttp/middleware/cloudflare_real_ip.go b/internal/net/gphttp/middleware/cloudflare_real_ip.go index 39bacd68..8733dc64 100644 --- a/internal/net/gphttp/middleware/cloudflare_real_ip.go +++ b/internal/net/gphttp/middleware/cloudflare_real_ip.go @@ -105,7 +105,7 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*nettypes.CIDR) error { return err } - resp, err := http.DefaultClient.Do(req) //nolint:gosec + resp, err := http.DefaultClient.Do(req) if err != nil { return err } diff --git a/internal/net/gphttp/middleware/crowdsec.go b/internal/net/gphttp/middleware/crowdsec.go index 362ef6e1..9540deb6 100644 --- a/internal/net/gphttp/middleware/crowdsec.go +++ b/internal/net/gphttp/middleware/crowdsec.go @@ -3,6 +3,7 @@ package middleware import ( "bytes" "context" + "errors" "fmt" "io" "net" @@ -48,7 +49,7 @@ func (m *crowdsecMiddleware) setup() { func (m *crowdsecMiddleware) finalize() error { if !strings.HasPrefix(m.Endpoint, "/") { - return fmt.Errorf("endpoint must start with /") + return errors.New("endpoint must start with /") } if m.Timeout == 0 { m.Timeout = 5 * time.Second @@ -179,12 +180,12 @@ func (m *crowdsecMiddleware) buildCrowdSecURL(ctx context.Context) (string, erro // If not found in routes, assume it's an IP address if m.Port == 0 { - return "", fmt.Errorf("port must be specified when using IP address") + return "", errors.New("port must be specified when using IP address") } return fmt.Sprintf("http://%s%s", net.JoinHostPort(m.Route, strconv.Itoa(m.Port)), m.Endpoint), nil } - return "", fmt.Errorf("route or IP address must be specified") + return "", errors.New("route or IP address must be specified") } func (m *crowdsecMiddleware) getHTTPVersion(r *http.Request) string { diff --git a/internal/notif/gotify.go b/internal/notif/gotify.go index 15b4a008..8fe4fcce 100644 --- a/internal/notif/gotify.go +++ b/internal/notif/gotify.go @@ -59,7 +59,7 @@ func (client *GotifyClient) MarshalMessage(logMsg *LogMessage) ([]byte, error) { } if client.Format == LogFormatMarkdown { - msg.Extras = map[string]interface{}{ + msg.Extras = map[string]any{ "client::display": map[string]string{ "contentType": "text/markdown", }, diff --git a/internal/proxmox/client.go b/internal/proxmox/client.go index be4cbce4..1a101958 100644 --- a/internal/proxmox/client.go +++ b/internal/proxmox/client.go @@ -20,6 +20,7 @@ import ( type Client struct { *proxmox.Client *proxmox.Cluster + Version *proxmox.Version BaseURL *url.URL // id -> resource; id: lxc/ or qemu/ @@ -29,6 +30,7 @@ type Client struct { type VMResource struct { *proxmox.ClusterResource + IPs []net.IP } @@ -37,9 +39,9 @@ var ( ErrNoResources = errors.New("no resources") ) -func NewClient(baseUrl string, opts ...proxmox.Option) *Client { +func NewClient(baseURL string, opts ...proxmox.Option) *Client { return &Client{ - Client: proxmox.NewClient(baseUrl, opts...), + Client: proxmox.NewClient(baseURL, opts...), resources: make(map[string]*VMResource), } } diff --git a/internal/proxmox/lxc.go b/internal/proxmox/lxc.go index d5352e60..66a0c4ee 100644 --- a/internal/proxmox/lxc.go +++ b/internal/proxmox/lxc.go @@ -109,7 +109,7 @@ func (n *Node) LXCIsStopped(ctx context.Context, vmid uint64) (bool, error) { } func (n *Node) LXCSetShutdownTimeout(ctx context.Context, vmid uint64, timeout time.Duration) error { - return n.client.Put(ctx, fmt.Sprintf("/nodes/%s/lxc/%d/config", n.name, vmid), map[string]interface{}{ + return n.client.Put(ctx, fmt.Sprintf("/nodes/%s/lxc/%d/config", n.name, vmid), map[string]any{ "startup": fmt.Sprintf("down=%.0f", timeout.Seconds()), }, nil) } diff --git a/internal/route/provider/provider.go b/internal/route/provider/provider.go index 67729bd9..2389c672 100644 --- a/internal/route/provider/provider.go +++ b/internal/route/provider/provider.go @@ -91,7 +91,7 @@ func (p *Provider) GetType() provider.Type { return p.t } -// to work with json marshaller. +// MarshalText implements encoding.TextMarshaler. func (p *Provider) MarshalText() ([]byte, error) { return []byte(p.String()), nil } diff --git a/internal/route/route.go b/internal/route/route.go index 9408745b..1c233644 100644 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -57,7 +57,7 @@ type ( PathPatterns []string `json:"path_patterns,omitempty" extensions:"x-nullable"` Rules rules.Rules `json:"rules,omitempty" extensions:"x-nullable"` RuleFile string `json:"rule_file,omitempty" extensions:"x-nullable"` - HealthCheck types.HealthCheckConfig `json:"healthcheck,omitempty" extensions:"x-nullable"` // null on load-balancer routes + HealthCheck types.HealthCheckConfig `json:"healthcheck,omitzero" extensions:"x-nullable"` // null on load-balancer routes LoadBalance *types.LoadBalancerConfig `json:"load_balance,omitempty" extensions:"x-nullable"` Middlewares map[string]types.LabelMap `json:"middlewares,omitempty" extensions:"x-nullable"` Homepage *homepage.ItemConfig `json:"homepage"` @@ -276,10 +276,10 @@ func (r *Route) validate() error { case route.SchemeFileServer: r.Host = "" r.Port.Proxy = 0 - r.LisURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("https://%s", net.JoinHostPort(r.Bind, strconv.Itoa(r.Port.Listening)))) + r.LisURL = gperr.Collect(&errs, nettypes.ParseURL, "https://"+net.JoinHostPort(r.Bind, strconv.Itoa(r.Port.Listening))) r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, "file://"+r.Root) case route.SchemeHTTP, route.SchemeHTTPS, route.SchemeH2C: - r.LisURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("https://%s", net.JoinHostPort(r.Bind, strconv.Itoa(r.Port.Listening)))) + r.LisURL = gperr.Collect(&errs, nettypes.ParseURL, "https://"+net.JoinHostPort(r.Bind, strconv.Itoa(r.Port.Listening))) r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s", r.Scheme, net.JoinHostPort(r.Host, strconv.Itoa(r.Port.Proxy)))) case route.SchemeTCP, route.SchemeUDP: bindIP := net.ParseIP(r.Bind) @@ -588,10 +588,9 @@ func (r *Route) References() []string { return []string{r.Proxmox.VMName, aliasRef, r.Proxmox.Services[0]} } return []string{r.Proxmox.Services[0], aliasRef} - } else { - if r.Proxmox.VMName != aliasRef { - return []string{r.Proxmox.VMName, aliasRef} - } + } + if r.Proxmox.VMName != aliasRef { + return []string{r.Proxmox.VMName, aliasRef} } } return []string{aliasRef} diff --git a/internal/route/rules/do.go b/internal/route/rules/do.go index f1ed1cbd..11e98412 100644 --- a/internal/route/rules/do.go +++ b/internal/route/rules/do.go @@ -76,6 +76,7 @@ var commands = map[string]struct { if len(args) != 0 { return nil, ErrExpectNoArg } + //nolint:nilnil return nil, nil }, build: func(args any) CommandHandler { @@ -329,7 +330,7 @@ var commands = map[string]struct { helpExample(CommandSet, "header", "User-Agent", "godoxy"), ), args: map[string]string{ - "target": fmt.Sprintf("the target to set, can be %s", strings.Join(AllFields, ", ")), + "target": "the target to set, can be " + strings.Join(AllFields, ", "), "field": "the field to set", "value": "the value to set", }, @@ -349,7 +350,7 @@ var commands = map[string]struct { helpExample(CommandAdd, "header", "X-Foo", "bar"), ), args: map[string]string{ - "target": fmt.Sprintf("the target to add, can be %s", strings.Join(AllFields, ", ")), + "target": "the target to add, can be " + strings.Join(AllFields, ", "), "field": "the field to add", "value": "the value to add", }, @@ -369,7 +370,7 @@ var commands = map[string]struct { helpExample(CommandRemove, "header", "User-Agent"), ), args: map[string]string{ - "target": fmt.Sprintf("the target to remove, can be %s", strings.Join(AllFields, ", ")), + "target": "the target to remove, can be " + strings.Join(AllFields, ", "), "field": "the field to remove", }, }, @@ -511,8 +512,10 @@ var commands = map[string]struct { }, } -type onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateString] -type onNotifyArgs = Tuple4[zerolog.Level, string, templateString, templateString] +type ( + onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateString] + onNotifyArgs = Tuple4[zerolog.Level, string, templateString, templateString] +) // Parse implements strutils.Parser. func (cmd *Command) Parse(v string) error { diff --git a/internal/route/rules/do_log_test.go b/internal/route/rules/do_log_test.go index 352ea752..79b39f87 100644 --- a/internal/route/rules/do_log_test.go +++ b/internal/route/rules/do_log_test.go @@ -53,7 +53,7 @@ func TestLogCommand_TemporaryFile(t *testing.T) { handler := rules.BuildHandler(upstream) - req := httptest.NewRequest("POST", "/api/users", nil) + req := httptest.NewRequest(http.MethodPost, "/api/users", nil) req.Header.Set("User-Agent", "test-agent") w := httptest.NewRecorder() @@ -70,7 +70,7 @@ func TestLogCommand_TemporaryFile(t *testing.T) { } func TestLogCommand_StdoutAndStderr(t *testing.T) { - upstream := mockUpstream(200, "success") + upstream := mockUpstream(http.StatusOK, "success") var rules Rules err := parseRules(` @@ -85,7 +85,7 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) { handler := rules.BuildHandler(upstream) - req := httptest.NewRequest("GET", "/test", nil) + req := httptest.NewRequest(http.MethodGet, "/test", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) @@ -96,7 +96,7 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) { } func TestLogCommand_DifferentLogLevels(t *testing.T) { - upstream := mockUpstream(404, "not found") + upstream := mockUpstream(http.StatusNotFound, "not found") infoFile := TestRandomFileName() warnFile := TestRandomFileName() @@ -140,7 +140,7 @@ func TestLogCommand_TemplateVariables(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Custom-Header", "custom-value") w.Header().Set("Content-Length", "42") - w.WriteHeader(201) + w.WriteHeader(http.StatusCreated) w.Write([]byte("created")) }) @@ -176,13 +176,13 @@ func TestLogCommand_ConditionalLogging(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/error": - w.WriteHeader(500) + w.WriteHeader(http.StatusInternalServerError) w.Write([]byte("internal server error")) case "/notfound": - w.WriteHeader(404) + w.WriteHeader(http.StatusNotFound) w.Write([]byte("not found")) default: - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) w.Write([]byte("success")) } }) @@ -206,22 +206,22 @@ func TestLogCommand_ConditionalLogging(t *testing.T) { handler := rules.BuildHandler(upstream) // Test success request - req1 := httptest.NewRequest("GET", "/success", nil) + req1 := httptest.NewRequest(http.MethodGet, "/success", nil) w1 := httptest.NewRecorder() handler.ServeHTTP(w1, req1) - assert.Equal(t, 200, w1.Code) + assert.Equal(t, http.StatusOK, w1.Code) // Test not found request - req2 := httptest.NewRequest("GET", "/notfound", nil) + req2 := httptest.NewRequest(http.MethodGet, "/notfound", nil) w2 := httptest.NewRecorder() handler.ServeHTTP(w2, req2) - assert.Equal(t, 404, w2.Code) + assert.Equal(t, http.StatusNotFound, w2.Code) // Test server error request - req3 := httptest.NewRequest("POST", "/error", nil) + req3 := httptest.NewRequest(http.MethodPost, "/error", nil) w3 := httptest.NewRecorder() handler.ServeHTTP(w3, req3) - assert.Equal(t, 500, w3.Code) + assert.Equal(t, http.StatusInternalServerError, w3.Code) // Verify success log successContent := TestFileContent(successFile) @@ -238,7 +238,7 @@ func TestLogCommand_ConditionalLogging(t *testing.T) { } func TestLogCommand_MultipleLogEntries(t *testing.T) { - upstream := mockUpstream(200, "response") + upstream := mockUpstream(http.StatusOK, "response") tempFile := TestRandomFileName() @@ -266,7 +266,7 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) { req := httptest.NewRequest(reqInfo.method, reqInfo.path, nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) } // Verify all requests were logged diff --git a/internal/route/rules/do_set_test.go b/internal/route/rules/do_set_test.go index 4d77ccd1..21b20108 100644 --- a/internal/route/rules/do_set_test.go +++ b/internal/route/rules/do_set_test.go @@ -67,7 +67,7 @@ func TestFieldHandler_Header(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) tt.setup(req) w := httptest.NewRecorder() @@ -126,8 +126,8 @@ func TestFieldHandler_ResponseHeader(t *testing.T) { verify: func(w *httptest.ResponseRecorder) { values := w.Header()["X-Response-Test"] require.Len(t, values, 2) - assert.Equal(t, values[0], "existing-value") - assert.Equal(t, values[1], "additional-value") + assert.Equal(t, "existing-value", values[0]) + assert.Equal(t, "additional-value", values[1]) }, }, { @@ -143,7 +143,7 @@ func TestFieldHandler_ResponseHeader(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() if tt.setup != nil { tt.setup(w) @@ -232,7 +232,7 @@ func TestFieldHandler_Query(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) tt.setup(req) w := httptest.NewRecorder() @@ -330,7 +330,7 @@ func TestFieldHandler_Cookie(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) tt.setup(req) w := httptest.NewRecorder() @@ -396,7 +396,7 @@ func TestFieldHandler_Body(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) tt.setup(req) w := httptest.NewRecorder() @@ -440,7 +440,7 @@ func TestFieldHandler_ResponseBody(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) tt.setup(req) w := httptest.NewRecorder() @@ -494,7 +494,7 @@ func TestFieldHandler_StatusCode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() rm := httputils.NewResponseModifier(w) var cmd Command diff --git a/internal/route/rules/http_flow_test.go b/internal/route/rules/http_flow_test.go index cdf13261..ae40edb4 100644 --- a/internal/route/rules/http_flow_test.go +++ b/internal/route/rules/http_flow_test.go @@ -23,9 +23,8 @@ import ( ) // mockUpstream creates a simple upstream handler for testing -func mockUpstream(status int, body string) http.HandlerFunc { +func mockUpstream(body string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(status) w.Write([]byte(body)) } } @@ -51,7 +50,7 @@ func parseRules(data string, target *Rules) error { func TestHTTPFlow_BasicPreRules(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Custom-Header", r.Header.Get("X-Custom-Header")) - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) w.Write([]byte("upstream response")) }) @@ -65,18 +64,18 @@ func TestHTTPFlow_BasicPreRules(t *testing.T) { handler := rules.BuildHandler(upstream) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "upstream response", w.Body.String()) assert.Equal(t, "test-value", w.Header().Get("X-Custom-Header")) } func TestHTTPFlow_BypassRule(t *testing.T) { - upstream := mockUpstream(200, "upstream response") + upstream := mockUpstream("upstream response") var rules Rules err := parseRules(` @@ -91,17 +90,17 @@ func TestHTTPFlow_BypassRule(t *testing.T) { handler := rules.BuildHandler(upstream) - req := httptest.NewRequest("GET", "/bypass", nil) + req := httptest.NewRequest(http.MethodGet, "/bypass", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "upstream response", w.Body.String()) } func TestHTTPFlow_TerminatingCommand(t *testing.T) { - upstream := mockUpstream(200, "should not be called") + upstream := mockUpstream("should not be called") var rules Rules err := parseRules(` @@ -116,18 +115,18 @@ func TestHTTPFlow_TerminatingCommand(t *testing.T) { handler := rules.BuildHandler(upstream) - req := httptest.NewRequest("GET", "/error", nil) + req := httptest.NewRequest(http.MethodGet, "/error", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - assert.Equal(t, 403, w.Code) + assert.Equal(t, http.StatusForbidden, w.Code) assert.Equal(t, "Forbidden\n", w.Body.String()) assert.Empty(t, w.Header().Get("X-Header")) } func TestHTTPFlow_RedirectFlow(t *testing.T) { - upstream := mockUpstream(200, "should not be called") + upstream := mockUpstream("should not be called") var rules Rules err := parseRules(` @@ -139,18 +138,18 @@ func TestHTTPFlow_RedirectFlow(t *testing.T) { handler := rules.BuildHandler(upstream) - req := httptest.NewRequest("GET", "/old-path", nil) + req := httptest.NewRequest(http.MethodGet, "/old-path", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - assert.Equal(t, 307, w.Code) // TemporaryRedirect + assert.Equal(t, http.StatusTemporaryRedirect, w.Code) // TemporaryRedirect assert.Equal(t, "/new-path", w.Header().Get("Location")) } func TestHTTPFlow_RewriteFlow(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) w.Write([]byte("path: " + r.URL.Path)) }) @@ -164,18 +163,18 @@ func TestHTTPFlow_RewriteFlow(t *testing.T) { handler := rules.BuildHandler(upstream) - req := httptest.NewRequest("GET", "/api/users", nil) + req := httptest.NewRequest(http.MethodGet, "/api/users", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "path: /v1/users", w.Body.String()) } func TestHTTPFlow_MultiplePreRules(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) w.Write([]byte("upstream: " + r.Header.Get("X-Request-Id"))) }) @@ -192,18 +191,18 @@ func TestHTTPFlow_MultiplePreRules(t *testing.T) { handler := rules.BuildHandler(upstream) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "upstream: req-123", w.Body.String()) assert.Equal(t, "token-456", req.Header.Get("X-Auth-Token")) } func TestHTTPFlow_PostResponseRule(t *testing.T) { - upstream := mockUpstreamWithHeaders(200, "success", http.Header{ + upstream := mockUpstreamWithHeaders(http.StatusOK, "success", http.Header{ "X-Upstream": []string{"upstream-value"}, }) @@ -219,12 +218,12 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) { handler := rules.BuildHandler(upstream) - req := httptest.NewRequest("GET", "/test", nil) + req := httptest.NewRequest(http.MethodGet, "/test", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "success", w.Body.String()) assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream")) @@ -237,10 +236,10 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) { func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/success" { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) w.Write([]byte("success")) } else { - w.WriteHeader(404) + w.WriteHeader(http.StatusNotFound) w.Write([]byte("not found")) } }) @@ -260,18 +259,18 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) { handler := rules.BuildHandler(upstream) // Test successful request (should not log) - req1 := httptest.NewRequest("GET", "/success", nil) + req1 := httptest.NewRequest(http.MethodGet, "/success", nil) w1 := httptest.NewRecorder() handler.ServeHTTP(w1, req1) - assert.Equal(t, 200, w1.Code) + assert.Equal(t, http.StatusOK, w1.Code) // Test error request (should log) - req2 := httptest.NewRequest("GET", "/notfound", nil) + req2 := httptest.NewRequest(http.MethodGet, "/notfound", nil) w2 := httptest.NewRecorder() handler.ServeHTTP(w2, req2) - assert.Equal(t, 404, w2.Code) + assert.Equal(t, http.StatusNotFound, w2.Code) // Check log file content := TestFileContent(tempFile) @@ -283,7 +282,7 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) { func TestHTTPFlow_ConditionalRules(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) w.Write([]byte("hello " + r.Header.Get("X-Username"))) }) @@ -304,19 +303,19 @@ func TestHTTPFlow_ConditionalRules(t *testing.T) { handler := rules.BuildHandler(upstream) // Test with Authorization header - req1 := httptest.NewRequest("GET", "/", nil) + req1 := httptest.NewRequest(http.MethodGet, "/", nil) req1.Header.Set("Authorization", "Bearer token") w1 := httptest.NewRecorder() handler.ServeHTTP(w1, req1) - assert.Equal(t, 200, w1.Code) + assert.Equal(t, http.StatusOK, w1.Code) assert.Equal(t, "hello authenticated-user", w1.Body.String()) assert.Equal(t, "authenticated-user", w1.Header().Get("X-Username")) // Test without Authorization header - req2 := httptest.NewRequest("GET", "/", nil) + req2 := httptest.NewRequest(http.MethodGet, "/", nil) w2 := httptest.NewRecorder() handler.ServeHTTP(w2, req2) - assert.Equal(t, 200, w2.Code) + assert.Equal(t, http.StatusOK, w2.Code) assert.Equal(t, "hello anonymous", w2.Body.String()) assert.Equal(t, "anonymous", w2.Header().Get("X-Username")) } @@ -326,13 +325,13 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) { // Simulate different responses based on path if r.URL.Path == "/protected" { if r.Header.Get("X-Auth") != "valid" { - w.WriteHeader(401) + w.WriteHeader(http.StatusUnauthorized) w.Write([]byte("unauthorized")) return } } w.Header().Set("X-Response-Time", "100ms") - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) w.Write([]byte("success")) }) @@ -360,32 +359,32 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) { handler := rules.BuildHandler(upstream) // Test successful request - req1 := httptest.NewRequest("GET", "/public", nil) + req1 := httptest.NewRequest(http.MethodGet, "/public", nil) w1 := httptest.NewRecorder() handler.ServeHTTP(w1, req1) - assert.Equal(t, 200, w1.Code) + assert.Equal(t, http.StatusOK, w1.Code) assert.Equal(t, "success", w1.Body.String()) assert.Equal(t, "random_uuid", w1.Header().Get("X-Correlation-Id")) assert.Equal(t, "100ms", w1.Header().Get("X-Response-Time")) // Test unauthorized protected request - req2 := httptest.NewRequest("GET", "/protected", nil) + req2 := httptest.NewRequest(http.MethodGet, "/protected", nil) w2 := httptest.NewRecorder() handler.ServeHTTP(w2, req2) - assert.Equal(t, 401, w2.Code) - assert.Equal(t, w2.Body.String(), "Unauthorized\n") + assert.Equal(t, http.StatusUnauthorized, w2.Code) + assert.Equal(t, "Unauthorized\n", w2.Body.String()) // Test authorized protected request - req3 := httptest.NewRequest("GET", "/protected", nil) + req3 := httptest.NewRequest(http.MethodGet, "/protected", nil) req3.SetBasicAuth("user", "pass") w3 := httptest.NewRecorder() handler.ServeHTTP(w3, req3) // This should fail because our simple upstream expects X-Auth: valid header // but the basic auth requirement should add the appropriate header - assert.Equal(t, 401, w3.Code) + assert.Equal(t, http.StatusUnauthorized, w3.Code) // Check log files logContent := TestFileContent(logFile) @@ -404,7 +403,7 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) { } func TestHTTPFlow_DefaultRule(t *testing.T) { - upstream := mockUpstream(200, "upstream response") + upstream := mockUpstream("upstream response") var rules Rules err := parseRules(` @@ -419,20 +418,20 @@ func TestHTTPFlow_DefaultRule(t *testing.T) { handler := rules.BuildHandler(upstream) // Test default rule - req1 := httptest.NewRequest("GET", "/regular", nil) + req1 := httptest.NewRequest(http.MethodGet, "/regular", nil) w1 := httptest.NewRecorder() handler.ServeHTTP(w1, req1) - assert.Equal(t, 200, w1.Code) + assert.Equal(t, http.StatusOK, w1.Code) assert.Equal(t, "true", w1.Header().Get("X-Default-Applied")) assert.Empty(t, w1.Header().Get("X-Special-Handled")) // Test special rule + default rule - req2 := httptest.NewRequest("GET", "/special", nil) + req2 := httptest.NewRequest(http.MethodGet, "/special", nil) w2 := httptest.NewRecorder() handler.ServeHTTP(w2, req2) - assert.Equal(t, 200, w2.Code) + assert.Equal(t, http.StatusOK, w2.Code) assert.Equal(t, "true", w2.Header().Get("X-Default-Applied")) assert.Equal(t, "true", w2.Header().Get("X-Special-Handled")) } @@ -442,7 +441,7 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) { // Echo back a header headerValue := r.Header.Get("X-Test-Header") w.Header().Set("X-Echoed-Header", headerValue) - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) w.Write([]byte("header echoed")) }) @@ -460,14 +459,14 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) { handler := rules.BuildHandler(upstream) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("X-Secret", "secret-value") req.Header.Set("X-Test-Header", "original-value") w := httptest.NewRecorder() handler.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "modified-value", w.Header().Get("X-Echoed-Header")) assert.Equal(t, "custom-value", w.Header().Get("X-Custom-Header")) // Ensure the secret header was removed and not passed to upstream @@ -477,7 +476,7 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) { func TestHTTPFlow_QueryParameterHandling(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) w.Write([]byte("query: " + query.Get("param"))) }) @@ -491,25 +490,23 @@ func TestHTTPFlow_QueryParameterHandling(t *testing.T) { handler := rules.BuildHandler(upstream) - req := httptest.NewRequest("GET", "/path?param=original", nil) + req := httptest.NewRequest(http.MethodGet, "/path?param=original", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) // The set command should have modified the query parameter assert.Equal(t, "query: added-value", w.Body.String()) } func TestHTTPFlow_ServeCommand(t *testing.T) { // Create a temporary directory with test files - tempDir, err := os.MkdirTemp("", "test-serve-*") - require.NoError(t, err) - defer os.RemoveAll(tempDir) + tempDir := t.TempDir() // Create test files directly in the temp directory testFile := filepath.Join(tempDir, "index.html") - err = os.WriteFile(testFile, []byte("

Test Page

"), 0644) + err := os.WriteFile(testFile, []byte("

Test Page

"), 0o644) require.NoError(t, err) var rules Rules @@ -520,7 +517,7 @@ func TestHTTPFlow_ServeCommand(t *testing.T) { `, tempDir), &rules) require.NoError(t, err) - handler := rules.BuildHandler(mockUpstream(200, "should not be called")) + handler := rules.BuildHandler(mockUpstream("should not be called")) // Test serving a file - serve command serves files relative to the root directory // The path /files/index.html gets mapped to tempDir + "/files/index.html" @@ -533,7 +530,7 @@ func TestHTTPFlow_ServeCommand(t *testing.T) { err = os.WriteFile(filesIndexFile, []byte("

Test Page

"), 0644) require.NoError(t, err) - req1 := httptest.NewRequest("GET", "/files/index.html", nil) + req1 := httptest.NewRequest(http.MethodGet, "/files/index.html", nil) w1 := httptest.NewRecorder() handler.ServeHTTP(w1, req1) @@ -542,18 +539,18 @@ func TestHTTPFlow_ServeCommand(t *testing.T) { assert.NotEqual(t, "should not be called", w1.Body.String()) // Test file not found - req2 := httptest.NewRequest("GET", "/files/nonexistent.html", nil) + req2 := httptest.NewRequest(http.MethodGet, "/files/nonexistent.html", nil) w2 := httptest.NewRecorder() handler.ServeHTTP(w2, req2) - assert.Equal(t, 404, w2.Code) + assert.Equal(t, http.StatusNotFound, w2.Code) } func TestHTTPFlow_ProxyCommand(t *testing.T) { // Create a mock upstream server upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Upstream-Header", "upstream-value") - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) w.Write([]byte("upstream response")) })) defer upstreamServer.Close() @@ -566,15 +563,15 @@ func TestHTTPFlow_ProxyCommand(t *testing.T) { `, upstreamServer.URL), &rules) require.NoError(t, err) - handler := rules.BuildHandler(mockUpstream(200, "should not be called")) + handler := rules.BuildHandler(mockUpstream("should not be called")) - req := httptest.NewRequest("GET", "/api/test", nil) + req := httptest.NewRequest(http.MethodGet, "/api/test", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) // The proxy command should forward the request to the upstream server - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "upstream response", w.Body.String()) assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream-Header")) } @@ -585,7 +582,7 @@ func TestHTTPFlow_NotifyCommand(t *testing.T) { func TestHTTPFlow_FormConditions(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) w.Write([]byte("form processed")) }) @@ -604,28 +601,28 @@ func TestHTTPFlow_FormConditions(t *testing.T) { // Test form condition formData := url.Values{"username": {"john_doe"}} - req1 := httptest.NewRequest("POST", "/", strings.NewReader(formData.Encode())) + req1 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(formData.Encode())) req1.Header.Set("Content-Type", "application/x-www-form-urlencoded") w1 := httptest.NewRecorder() handler.ServeHTTP(w1, req1) - assert.Equal(t, 200, w1.Code) + assert.Equal(t, http.StatusOK, w1.Code) assert.Equal(t, "john_doe", w1.Header().Get("X-Username")) // Test postform condition postFormData := url.Values{"email": {"john@example.com"}} - req2 := httptest.NewRequest("POST", "/", strings.NewReader(postFormData.Encode())) + req2 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(postFormData.Encode())) req2.Header.Set("Content-Type", "application/x-www-form-urlencoded") w2 := httptest.NewRecorder() handler.ServeHTTP(w2, req2) - assert.Equal(t, 200, w2.Code) + assert.Equal(t, http.StatusOK, w2.Code) assert.Equal(t, "john@example.com", w2.Header().Get("X-Email")) } func TestHTTPFlow_RemoteConditions(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) w.Write([]byte("remote processed")) }) @@ -643,27 +640,27 @@ func TestHTTPFlow_RemoteConditions(t *testing.T) { handler := rules.BuildHandler(upstream) // Test localhost condition - req1 := httptest.NewRequest("GET", "/", nil) + req1 := httptest.NewRequest(http.MethodGet, "/", nil) req1.RemoteAddr = "127.0.0.1:12345" w1 := httptest.NewRecorder() handler.ServeHTTP(w1, req1) - assert.Equal(t, 200, w1.Code) + assert.Equal(t, http.StatusOK, w1.Code) assert.Equal(t, "local", w1.Header().Get("X-Access")) // Test private network block - req2 := httptest.NewRequest("GET", "/", nil) + req2 := httptest.NewRequest(http.MethodGet, "/", nil) req2.RemoteAddr = "192.168.1.100:12345" w2 := httptest.NewRecorder() handler.ServeHTTP(w2, req2) - assert.Equal(t, 403, w2.Code) + assert.Equal(t, http.StatusForbidden, w2.Code) assert.Equal(t, "Private network blocked\n", w2.Body.String()) } func TestHTTPFlow_BasicAuthConditions(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) w.Write([]byte("auth processed")) }) @@ -687,27 +684,27 @@ func TestHTTPFlow_BasicAuthConditions(t *testing.T) { handler := rules.BuildHandler(upstream) // Test admin user - req1 := httptest.NewRequest("GET", "/", nil) + req1 := httptest.NewRequest(http.MethodGet, "/", nil) req1.SetBasicAuth("admin", "adminpass") w1 := httptest.NewRecorder() handler.ServeHTTP(w1, req1) - assert.Equal(t, 200, w1.Code) + assert.Equal(t, http.StatusOK, w1.Code) assert.Equal(t, "admin", w1.Header().Get("X-Auth-Status")) // Test guest user - req2 := httptest.NewRequest("GET", "/", nil) + req2 := httptest.NewRequest(http.MethodGet, "/", nil) req2.SetBasicAuth("guest", "guestpass") w2 := httptest.NewRecorder() handler.ServeHTTP(w2, req2) - assert.Equal(t, 200, w2.Code) + assert.Equal(t, http.StatusOK, w2.Code) assert.Equal(t, "guest", w2.Header().Get("X-Auth-Status")) } func TestHTTPFlow_RouteConditions(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) w.Write([]byte("route processed")) }) @@ -725,29 +722,29 @@ func TestHTTPFlow_RouteConditions(t *testing.T) { handler := rules.BuildHandler(upstream) // Test API route - req1 := httptest.NewRequest("GET", "/", nil) + req1 := httptest.NewRequest(http.MethodGet, "/", nil) req1 = routes.WithRouteContext(req1, mockRoute("backend")) w1 := httptest.NewRecorder() handler.ServeHTTP(w1, req1) - assert.Equal(t, 200, w1.Code) + assert.Equal(t, http.StatusOK, w1.Code) assert.Equal(t, "backend", w1.Header().Get("X-Route")) // Test admin route - req2 := httptest.NewRequest("GET", "/", nil) + req2 := httptest.NewRequest(http.MethodGet, "/", nil) req2 = routes.WithRouteContext(req2, mockRoute("frontend")) w2 := httptest.NewRecorder() handler.ServeHTTP(w2, req2) - assert.Equal(t, 200, w2.Code) + assert.Equal(t, http.StatusOK, w2.Code) assert.Equal(t, "frontend", w2.Header().Get("X-Route")) } func TestHTTPFlow_ResponseStatusConditions(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(405) + w.WriteHeader(http.StatusMethodNotAllowed) w.Write([]byte("method not allowed")) }) @@ -762,18 +759,18 @@ func TestHTTPFlow_ResponseStatusConditions(t *testing.T) { handler := rules.BuildHandler(upstream) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - assert.Equal(t, 405, w.Code) + assert.Equal(t, http.StatusMethodNotAllowed, w.Code) assert.Equal(t, "error\n", w.Body.String()) } func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Response-Header", "response header") - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) w.Write([]byte("processed")) }) @@ -788,11 +785,11 @@ func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) { handler := rules.BuildHandler(upstream) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - assert.Equal(t, 405, w.Code) + assert.Equal(t, http.StatusMethodNotAllowed, w.Code) assert.Equal(t, "error\n", w.Body.String()) }) t.Run("with_value", func(t *testing.T) { @@ -806,11 +803,11 @@ func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) { handler := rules.BuildHandler(upstream) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - assert.Equal(t, 405, w.Code) + assert.Equal(t, http.StatusMethodNotAllowed, w.Code) assert.Equal(t, "error\n", w.Body.String()) }) @@ -825,18 +822,18 @@ func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) { handler := rules.BuildHandler(upstream) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "processed", w.Body.String()) }) } func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) w.Write([]byte("complex processed")) }) @@ -867,26 +864,26 @@ func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) { handler := rules.BuildHandler(upstream) // Test admin API (should match first rule) - req1 := httptest.NewRequest("POST", "/api/admin/users", nil) + req1 := httptest.NewRequest(http.MethodPost, "/api/admin/users", nil) req1.Header.Set("Authorization", "Bearer token") w1 := httptest.NewRecorder() handler.ServeHTTP(w1, req1) - assert.Equal(t, 200, w1.Code) + assert.Equal(t, http.StatusOK, w1.Code) assert.Equal(t, "admin", w1.Header().Get("X-Access-Level")) assert.Equal(t, "v1", w1.Header()["X-API-Version"][0]) // Test user API (should match second rule) - req2 := httptest.NewRequest("GET", "/api/users/profile", nil) + req2 := httptest.NewRequest(http.MethodGet, "/api/users/profile", nil) w2 := httptest.NewRecorder() handler.ServeHTTP(w2, req2) - assert.Equal(t, 200, w2.Code) + assert.Equal(t, http.StatusOK, w2.Code) assert.Equal(t, "user", w2.Header().Get("X-Access-Level")) assert.Equal(t, "v1", w2.Header()["X-API-Version"][0]) // Test public API (should match third rule) - req3 := httptest.NewRequest("GET", "/api/public/info", nil) + req3 := httptest.NewRequest(http.MethodGet, "/api/public/info", nil) w3 := httptest.NewRecorder() handler.ServeHTTP(w3, req3) @@ -897,7 +894,7 @@ func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) { func TestHTTPFlow_ResponseModifier(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) w.Write([]byte("original response")) }) @@ -912,12 +909,12 @@ func TestHTTPFlow_ResponseModifier(t *testing.T) { handler := rules.BuildHandler(upstream) - req := httptest.NewRequest("GET", "/test", nil) + req := httptest.NewRequest(http.MethodGet, "/test", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "true", w.Header().Get("X-Modified")) assert.Equal(t, "Modified: GET /test\n", w.Body.String()) } diff --git a/internal/route/rules/on.go b/internal/route/rules/on.go index ed0293f1..79330268 100644 --- a/internal/route/rules/on.go +++ b/internal/route/rules/on.go @@ -41,6 +41,7 @@ const ( OnRoute = "route" // on response + OnResponseHeader = "resp_header" OnStatus = "status" ) @@ -63,6 +64,7 @@ var checkers = map[string]struct { if len(args) != 0 { return nil, ErrExpectNoArg } + //nolint:nilnil return nil, nil }, builder: func(args any) CheckFunc { return func(w http.ResponseWriter, r *http.Request) bool { return false } }, // this should never be called diff --git a/internal/route/rules/parser_test.go b/internal/route/rules/parser_test.go index 1b0e8b9c..c8451c86 100644 --- a/internal/route/rules/parser_test.go +++ b/internal/route/rules/parser_test.go @@ -1,11 +1,9 @@ package rules import ( - "os" "strconv" "testing" - gperr "github.com/yusing/goutils/errs" expect "github.com/yusing/goutils/testing" ) @@ -15,7 +13,6 @@ func TestParser(t *testing.T) { input string subject string args []string - wantErr gperr.Error }{ { name: "basic", @@ -93,10 +90,6 @@ func TestParser(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { subject, args, err := parse(tt.input) - if tt.wantErr != nil { - expect.ErrorIs(t, tt.wantErr, err) - return - } // t.Log(subject, args, err) expect.NoError(t, err) expect.Equal(t, subject, tt.subject) @@ -105,12 +98,8 @@ func TestParser(t *testing.T) { } t.Run("env substitution", func(t *testing.T) { // Set up test environment variables - os.Setenv("CLOUDFLARE_API_KEY", "test-api-key-123") - os.Setenv("DOMAIN", "example.com") - defer func() { - os.Unsetenv("CLOUDFLARE_API_KEY") - os.Unsetenv("DOMAIN") - }() + t.Setenv("CLOUDFLARE_API_KEY", "test-api-key-123") + t.Setenv("DOMAIN", "example.com") tests := []struct { name string diff --git a/internal/route/rules/var_bench_test.go b/internal/route/rules/var_bench_test.go index 74cef1d3..328ad338 100644 --- a/internal/route/rules/var_bench_test.go +++ b/internal/route/rules/var_bench_test.go @@ -2,6 +2,7 @@ package rules import ( "io" + "net/http" "net/http/httptest" "net/url" "testing" @@ -11,9 +12,9 @@ import ( func BenchmarkExpandVars(b *testing.B) { testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) - testResponseModifier.WriteHeader(200) + testResponseModifier.WriteHeader(http.StatusOK) testResponseModifier.Write([]byte("Hello, world!")) - testRequest := httptest.NewRequest("GET", "/", nil) + testRequest := httptest.NewRequest(http.MethodGet, "/", nil) testRequest.Header.Set("User-Agent", "test-agent/1.0") testRequest.Header.Set("X-Custom", "value1,value2") testRequest.ContentLength = 12345 diff --git a/internal/route/rules/vars_test.go b/internal/route/rules/vars_test.go index c719a84a..ca487bad 100644 --- a/internal/route/rules/vars_test.go +++ b/internal/route/rules/vars_test.go @@ -203,7 +203,7 @@ func TestExpandVars(t *testing.T) { postFormData.Add("postmulti", "first") postFormData.Add("postmulti", "second") - testRequest := httptest.NewRequest("POST", "https://example.com:8080/api/users?param1=value1¶m2=value2#fragment", strings.NewReader(postFormData.Encode())) + testRequest := httptest.NewRequest(http.MethodPost, "https://example.com:8080/api/users?param1=value1¶m2=value2#fragment", strings.NewReader(postFormData.Encode())) testRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded") testRequest.Header.Set("User-Agent", "test-agent/1.0") testRequest.Header.Add("X-Custom", "value1") @@ -218,7 +218,7 @@ func TestExpandVars(t *testing.T) { testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) testResponseModifier.Header().Set("Content-Type", "text/html") testResponseModifier.Header().Set("X-Custom-Resp", "resp-value") - testResponseModifier.WriteHeader(200) + testResponseModifier.WriteHeader(http.StatusOK) // set content length to 9876 by writing 9876 'a' bytes testResponseModifier.Write(bytes.Repeat([]byte("a"), 9876)) @@ -498,12 +498,12 @@ func TestExpandVars(t *testing.T) { func TestExpandVars_Integration(t *testing.T) { t.Run("complex log format", func(t *testing.T) { - testRequest := httptest.NewRequest("GET", "https://api.example.com/users/123?sort=asc", nil) + testRequest := httptest.NewRequest(http.MethodGet, "https://api.example.com/users/123?sort=asc", nil) testRequest.Header.Set("User-Agent", "curl/7.68.0") testRequest.RemoteAddr = "10.0.0.1:54321" testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) - testResponseModifier.WriteHeader(200) + testResponseModifier.WriteHeader(http.StatusOK) var out strings.Builder err := ExpandVars(testResponseModifier, testRequest, @@ -515,7 +515,7 @@ func TestExpandVars_Integration(t *testing.T) { }) t.Run("with query parameters", func(t *testing.T) { - testRequest := httptest.NewRequest("GET", "http://example.com/search?q=test&page=1", nil) + testRequest := httptest.NewRequest(http.MethodGet, "http://example.com/search?q=test&page=1", nil) testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) @@ -529,12 +529,12 @@ func TestExpandVars_Integration(t *testing.T) { }) t.Run("response headers", func(t *testing.T) { - testRequest := httptest.NewRequest("GET", "/", nil) + testRequest := httptest.NewRequest(http.MethodGet, "/", nil) testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) testResponseModifier.Header().Set("Cache-Control", "no-cache") testResponseModifier.Header().Set("X-Rate-Limit", "100") - testResponseModifier.WriteHeader(200) + testResponseModifier.WriteHeader(http.StatusOK) var out strings.Builder err := ExpandVars(testResponseModifier, testRequest, @@ -554,7 +554,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) { }{ { name: "http scheme", - request: httptest.NewRequest("GET", "http://example.com/", nil), + request: httptest.NewRequest(http.MethodGet, "http://example.com/", nil), expected: "http", }, { @@ -581,7 +581,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) { func TestExpandVars_UpstreamVariables(t *testing.T) { // Upstream variables require context from routes package - testRequest := httptest.NewRequest("GET", "/", nil) + testRequest := httptest.NewRequest(http.MethodGet, "/", nil) testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) @@ -607,7 +607,7 @@ func TestExpandVars_UpstreamVariables(t *testing.T) { func TestExpandVars_NoHostPort(t *testing.T) { // Test request without port in Host header - testRequest := httptest.NewRequest("GET", "/", nil) + testRequest := httptest.NewRequest(http.MethodGet, "/", nil) testRequest.Host = "example.com" // No port testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) @@ -623,13 +623,13 @@ func TestExpandVars_NoHostPort(t *testing.T) { var out strings.Builder err := ExpandVars(testResponseModifier, testRequest, "$req_port", &out) require.NoError(t, err) - require.Equal(t, "", out.String()) + require.Empty(t, out.String()) }) } func TestExpandVars_NoRemotePort(t *testing.T) { // Test request without port in RemoteAddr - testRequest := httptest.NewRequest("GET", "/", nil) + testRequest := httptest.NewRequest(http.MethodGet, "/", nil) testRequest.RemoteAddr = "192.168.1.1" // No port testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) @@ -638,19 +638,19 @@ func TestExpandVars_NoRemotePort(t *testing.T) { var out strings.Builder err := ExpandVars(testResponseModifier, testRequest, "$remote_host", &out) require.NoError(t, err) - require.Equal(t, "", out.String()) + require.Empty(t, out.String()) }) t.Run("remote_port without port", func(t *testing.T) { var out strings.Builder err := ExpandVars(testResponseModifier, testRequest, "$remote_port", &out) require.NoError(t, err) - require.Equal(t, "", out.String()) + require.Empty(t, out.String()) }) } func TestExpandVars_WhitespaceHandling(t *testing.T) { - testRequest := httptest.NewRequest("GET", "/test", nil) + testRequest := httptest.NewRequest(http.MethodGet, "/test", nil) testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) var out strings.Builder diff --git a/internal/route/types/http_config.go b/internal/route/types/http_config.go index 3ba5f769..334897a4 100644 --- a/internal/route/types/http_config.go +++ b/internal/route/types/http_config.go @@ -27,6 +27,7 @@ type HTTPConfig struct { // BuildTLSConfig creates a TLS configuration based on the HTTP config options. func (cfg *HTTPConfig) BuildTLSConfig(targetURL *url.URL) (*tls.Config, error) { + //nolint:gosec tlsConfig := &tls.Config{} // Handle InsecureSkipVerify (legacy NoTLSVerify option) diff --git a/internal/serialization/reader.go b/internal/serialization/reader.go index b44ee3cd..7a4c4167 100644 --- a/internal/serialization/reader.go +++ b/internal/serialization/reader.go @@ -15,8 +15,10 @@ func NewSubstituteEnvReader(reader io.Reader) *SubstituteEnvReader { return &SubstituteEnvReader{reader: reader} } -const peekSize = 4096 -const maxVarNameLength = 256 +const ( + peekSize = 4096 + maxVarNameLength = 256 +) func (r *SubstituteEnvReader) Read(p []byte) (n int, err error) { // Return buffered data first @@ -66,6 +68,7 @@ func (r *SubstituteEnvReader) Read(p []byte) (n int, err error) { if nMore > 0 { incomplete = append(incomplete, more[:nMore]...) // Check if pattern is now complete + //nolint:modernize if idx := bytes.IndexByte(incomplete, '}'); idx >= 0 { // Pattern complete, append the rest back to chunk chunk = append(chunk, incomplete...) diff --git a/internal/serialization/reader_bench_test.go b/internal/serialization/reader_bench_test.go index 7a415b6d..b948c08c 100644 --- a/internal/serialization/reader_bench_test.go +++ b/internal/serialization/reader_bench_test.go @@ -2,8 +2,8 @@ package serialization import ( "bytes" + "errors" "io" - "os" "strings" "testing" ) @@ -11,17 +11,9 @@ import ( // setupEnv sets up environment variables for benchmarks func setupEnv(b *testing.B) { b.Helper() - os.Setenv("BENCH_VAR", "benchmark_value") - os.Setenv("BENCH_VAR_2", "second_value") - os.Setenv("BENCH_VAR_3", "third_value") -} - -// cleanupEnv cleans up environment variables after benchmarks -func cleanupEnv(b *testing.B) { - b.Helper() - os.Unsetenv("BENCH_VAR") - os.Unsetenv("BENCH_VAR_2") - os.Unsetenv("BENCH_VAR_3") + b.Setenv("BENCH_VAR", "benchmark_value") + b.Setenv("BENCH_VAR_2", "second_value") + b.Setenv("BENCH_VAR_3", "third_value") } // BenchmarkSubstituteEnvReader_NoSubstitution benchmarks reading without any env substitutions @@ -44,7 +36,6 @@ data: some content here // BenchmarkSubstituteEnvReader_SingleSubstitution benchmarks reading with a single env substitution func BenchmarkSubstituteEnvReader_SingleSubstitution(b *testing.B) { setupEnv(b) - defer cleanupEnv(b) r := strings.NewReader(`key: ${BENCH_VAR} `) @@ -62,7 +53,6 @@ func BenchmarkSubstituteEnvReader_SingleSubstitution(b *testing.B) { // BenchmarkSubstituteEnvReader_MultipleSubstitutions benchmarks reading with multiple env substitutions func BenchmarkSubstituteEnvReader_MultipleSubstitutions(b *testing.B) { setupEnv(b) - defer cleanupEnv(b) r := strings.NewReader(`key1: ${BENCH_VAR} key2: ${BENCH_VAR_2} @@ -96,7 +86,6 @@ func BenchmarkSubstituteEnvReader_LargeInput_NoSubstitution(b *testing.B) { // BenchmarkSubstituteEnvReader_LargeInput_WithSubstitutions benchmarks large input with scattered substitutions func BenchmarkSubstituteEnvReader_LargeInput_WithSubstitutions(b *testing.B) { setupEnv(b) - defer cleanupEnv(b) var builder bytes.Buffer for range 100 { @@ -118,7 +107,6 @@ func BenchmarkSubstituteEnvReader_LargeInput_WithSubstitutions(b *testing.B) { // BenchmarkSubstituteEnvReader_SmallBuffer benchmarks reading with a small buffer size func BenchmarkSubstituteEnvReader_SmallBuffer(b *testing.B) { setupEnv(b) - defer cleanupEnv(b) r := strings.NewReader(`key: ${BENCH_VAR} and some more content here`) buf := make([]byte, 16) @@ -127,7 +115,7 @@ func BenchmarkSubstituteEnvReader_SmallBuffer(b *testing.B) { reader := NewSubstituteEnvReader(r) for { _, err := reader.Read(buf) - if err == io.EOF { + if errors.Is(err, io.EOF) { break } if err != nil { @@ -141,7 +129,6 @@ func BenchmarkSubstituteEnvReader_SmallBuffer(b *testing.B) { // BenchmarkSubstituteEnvReader_YAMLConfig benchmarks a realistic YAML config scenario func BenchmarkSubstituteEnvReader_YAMLConfig(b *testing.B) { setupEnv(b) - defer cleanupEnv(b) r := strings.NewReader(`database: host: ${BENCH_VAR} @@ -170,7 +157,6 @@ server: // BenchmarkSubstituteEnvReader_BoundaryPattern benchmarks patterns at buffer boundaries (4096 bytes) func BenchmarkSubstituteEnvReader_BoundaryPattern(b *testing.B) { setupEnv(b) - defer cleanupEnv(b) // Pattern exactly at 4090 bytes, with ${VAR} crossing the 4096 boundary prefix := strings.Repeat("x", 4090) @@ -189,7 +175,6 @@ func BenchmarkSubstituteEnvReader_BoundaryPattern(b *testing.B) { // BenchmarkSubstituteEnvReader_MultipleBoundaries benchmarks multiple patterns crossing boundaries func BenchmarkSubstituteEnvReader_MultipleBoundaries(b *testing.B) { setupEnv(b) - defer cleanupEnv(b) var builder bytes.Buffer for range 10 { @@ -210,8 +195,7 @@ func BenchmarkSubstituteEnvReader_MultipleBoundaries(b *testing.B) { // BenchmarkSubstituteEnvReader_SpecialChars benchmarks substitution with special characters func BenchmarkSubstituteEnvReader_SpecialChars(b *testing.B) { - os.Setenv("SPECIAL_BENCH_VAR", `value with "quotes" and \backslash\`) - defer os.Unsetenv("SPECIAL_BENCH_VAR") + b.Setenv("SPECIAL_BENCH_VAR", `value with "quotes" and \backslash\`) r := strings.NewReader(`key: ${SPECIAL_BENCH_VAR} `) @@ -228,8 +212,7 @@ func BenchmarkSubstituteEnvReader_SpecialChars(b *testing.B) { // BenchmarkSubstituteEnvReader_EmptyValue benchmarks substitution with empty value func BenchmarkSubstituteEnvReader_EmptyValue(b *testing.B) { - os.Setenv("EMPTY_BENCH_VAR", "") - defer os.Unsetenv("EMPTY_BENCH_VAR") + b.Setenv("EMPTY_BENCH_VAR", "") r := strings.NewReader(`key: ${EMPTY_BENCH_VAR} `) @@ -246,8 +229,7 @@ func BenchmarkSubstituteEnvReader_EmptyValue(b *testing.B) { // BenchmarkSubstituteEnvReader_DollarWithoutBrace benchmarks $ without following { func BenchmarkSubstituteEnvReader_DollarWithoutBrace(b *testing.B) { - os.Setenv("BENCH_VAR", "benchmark_value") - defer os.Unsetenv("BENCH_VAR") + b.Setenv("BENCH_VAR", "benchmark_value") r := strings.NewReader(`price: $100 and $200 for ${BENCH_VAR}`) diff --git a/internal/serialization/reader_test.go b/internal/serialization/reader_test.go index 2d9f6961..cbc5d9b6 100644 --- a/internal/serialization/reader_test.go +++ b/internal/serialization/reader_test.go @@ -2,8 +2,8 @@ package serialization import ( "bytes" + "errors" "io" - "os" "strings" "testing" @@ -11,8 +11,7 @@ import ( ) func TestSubstituteEnvReader_Basic(t *testing.T) { - os.Setenv("TEST_VAR", "hello") - defer os.Unsetenv("TEST_VAR") + t.Setenv("TEST_VAR", "hello") input := []byte(`key: ${TEST_VAR}`) reader := NewSubstituteEnvReader(bytes.NewReader(input)) @@ -23,10 +22,8 @@ func TestSubstituteEnvReader_Basic(t *testing.T) { } func TestSubstituteEnvReader_Multiple(t *testing.T) { - os.Setenv("VAR1", "first") - os.Setenv("VAR2", "second") - defer os.Unsetenv("VAR1") - defer os.Unsetenv("VAR2") + t.Setenv("VAR1", "first") + t.Setenv("VAR2", "second") input := []byte(`a: ${VAR1}, b: ${VAR2}`) reader := NewSubstituteEnvReader(bytes.NewReader(input)) @@ -46,8 +43,6 @@ func TestSubstituteEnvReader_NoSubstitution(t *testing.T) { } func TestSubstituteEnvReader_UnsetEnvError(t *testing.T) { - os.Unsetenv("UNSET_VAR_FOR_TEST") - input := []byte(`key: ${UNSET_VAR_FOR_TEST}`) reader := NewSubstituteEnvReader(bytes.NewReader(input)) @@ -57,8 +52,7 @@ func TestSubstituteEnvReader_UnsetEnvError(t *testing.T) { } func TestSubstituteEnvReader_SmallBuffer(t *testing.T) { - os.Setenv("SMALL_BUF_VAR", "value") - defer os.Unsetenv("SMALL_BUF_VAR") + t.Setenv("SMALL_BUF_VAR", "value") input := []byte(`key: ${SMALL_BUF_VAR}`) reader := NewSubstituteEnvReader(bytes.NewReader(input)) @@ -70,7 +64,7 @@ func TestSubstituteEnvReader_SmallBuffer(t *testing.T) { if n > 0 { result = append(result, buf[:n]...) } - if err == io.EOF { + if errors.Is(err, io.EOF) { break } require.NoError(t, err) @@ -79,8 +73,7 @@ func TestSubstituteEnvReader_SmallBuffer(t *testing.T) { } func TestSubstituteEnvReader_SpecialChars(t *testing.T) { - os.Setenv("SPECIAL_VAR", `hello "world" \n`) - defer os.Unsetenv("SPECIAL_VAR") + t.Setenv("SPECIAL_VAR", `hello "world" \n`) input := []byte(`key: ${SPECIAL_VAR}`) reader := NewSubstituteEnvReader(bytes.NewReader(input)) @@ -91,8 +84,7 @@ func TestSubstituteEnvReader_SpecialChars(t *testing.T) { } func TestSubstituteEnvReader_EmptyValue(t *testing.T) { - os.Setenv("EMPTY_VAR", "") - defer os.Unsetenv("EMPTY_VAR") + t.Setenv("EMPTY_VAR", "") input := []byte(`key: ${EMPTY_VAR}`) reader := NewSubstituteEnvReader(bytes.NewReader(input)) @@ -103,8 +95,7 @@ func TestSubstituteEnvReader_EmptyValue(t *testing.T) { } func TestSubstituteEnvReader_LargeInput(t *testing.T) { - os.Setenv("LARGE_VAR", "replaced") - defer os.Unsetenv("LARGE_VAR") + t.Setenv("LARGE_VAR", "replaced") prefix := strings.Repeat("x", 5000) suffix := strings.Repeat("y", 5000) @@ -119,8 +110,7 @@ func TestSubstituteEnvReader_LargeInput(t *testing.T) { } func TestSubstituteEnvReader_PatternAtBoundary(t *testing.T) { - os.Setenv("BOUNDARY_VAR", "boundary_value") - defer os.Unsetenv("BOUNDARY_VAR") + t.Setenv("BOUNDARY_VAR", "boundary_value") prefix := strings.Repeat("a", 4090) input := []byte(prefix + "${BOUNDARY_VAR}") @@ -134,10 +124,8 @@ func TestSubstituteEnvReader_PatternAtBoundary(t *testing.T) { } func TestSubstituteEnvReader_MultiplePatternsBoundary(t *testing.T) { - os.Setenv("VAR_A", "aaa") - os.Setenv("VAR_B", "bbb") - defer os.Unsetenv("VAR_A") - defer os.Unsetenv("VAR_B") + t.Setenv("VAR_A", "aaa") + t.Setenv("VAR_B", "bbb") prefix := strings.Repeat("x", 4090) input := []byte(prefix + "${VAR_A} middle ${VAR_B}") @@ -151,12 +139,9 @@ func TestSubstituteEnvReader_MultiplePatternsBoundary(t *testing.T) { } func TestSubstituteEnvReader_YAMLConfig(t *testing.T) { - os.Setenv("DB_HOST", "localhost") - os.Setenv("DB_PORT", "5432") - os.Setenv("DB_PASSWORD", "secret123") - defer os.Unsetenv("DB_HOST") - defer os.Unsetenv("DB_PORT") - defer os.Unsetenv("DB_PASSWORD") + t.Setenv("DB_HOST", "localhost") + t.Setenv("DB_PORT", "5432") + t.Setenv("DB_PASSWORD", "secret123") input := []byte(`database: host: ${DB_HOST} diff --git a/internal/serialization/serialization.go b/internal/serialization/serialization.go index 14781855..5c13116c 100644 --- a/internal/serialization/serialization.go +++ b/internal/serialization/serialization.go @@ -87,7 +87,7 @@ func initPtr(dst reflect.Value) { } } -// Validate performs struct validation using go-playground/validator tags. +// ValidateWithFieldTags performs struct validation using go-playground/validator tags. // // It collects all validation errors and returns them as a single error. // Field names in errors are prefixed with their namespace (e.g., "User.Email"). @@ -521,7 +521,6 @@ func ConvertSlice(src reflect.Value, dst reflect.Value, checkValidateTag bool) e // - Returns true if conversion was handled (even with error), false if // conversion is unsupported. func ConvertString(src string, dst reflect.Value) (convertible bool, convErr error) { - convertible = true dstT := dst.Type() if dst.Kind() == reflect.Pointer { if dst.IsNil() { diff --git a/internal/watcher/config_file_watcher.go b/internal/watcher/config_file_watcher.go index 75a9522d..865bbdd2 100644 --- a/internal/watcher/config_file_watcher.go +++ b/internal/watcher/config_file_watcher.go @@ -17,7 +17,7 @@ func initConfigDirWatcher() { configDirWatcher = NewDirectoryWatcher(t, common.ConfigBasePath) } -// create a new file watcher for file under ConfigBasePath. +// NewConfigFileWatcher creates a new file watcher for file under common.ConfigBasePath. func NewConfigFileWatcher(filename string) Watcher { configDirWatcherInitOnce.Do(initConfigDirWatcher) return configDirWatcher.Add(filename) diff --git a/internal/watcher/docker_watcher.go b/internal/watcher/docker_watcher.go index 306ede63..f8a4c49b 100644 --- a/internal/watcher/docker_watcher.go +++ b/internal/watcher/docker_watcher.go @@ -124,15 +124,13 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList retry := time.NewTicker(dockerWatcherRetryInterval) defer retry.Stop() - ok := false outer: - for !ok { + for { select { case <-ctx.Done(): return case <-retry.C: if checkConnection(ctx, client) { - ok = true break outer } } diff --git a/internal/watcher/events/events.go b/internal/watcher/events/events.go index f088f0c0..994a8aee 100644 --- a/internal/watcher/events/events.go +++ b/internal/watcher/events/events.go @@ -2,6 +2,7 @@ package events import ( "fmt" + "maps" dockerEvents "github.com/moby/moby/api/types/events" ) @@ -69,9 +70,7 @@ var actionNameMap = func() (m map[Action]string) { for k, v := range DockerEventMap { m[v] = string(k) } - for k, v := range fileActionNameMap { - m[k] = v - } + maps.Copy(m, fileActionNameMap) return m }()