diff --git a/internal/api/csrf_test.go b/internal/api/csrf_test.go index a926eb58..53d576da 100644 --- a/internal/api/csrf_test.go +++ b/internal/api/csrf_test.go @@ -188,12 +188,14 @@ func newAuthenticatedHandler(t *testing.T) *gin.Engine { prevPassword := common.APIPassword prevDisableAuth := common.DebugDisableAuth prevIssuerURL := common.OIDCIssuerURL + prevSkipOriginCheck := common.APISkipOriginCheck common.APIJWTSecret = []byte("0123456789abcdef0123456789abcdef") common.APIUser = "username" common.APIPassword = "password" common.DebugDisableAuth = false common.OIDCIssuerURL = "" + common.APISkipOriginCheck = false t.Cleanup(func() { common.APIJWTSecret = prevSecret @@ -201,6 +203,7 @@ func newAuthenticatedHandler(t *testing.T) *gin.Engine { common.APIPassword = prevPassword common.DebugDisableAuth = prevDisableAuth common.OIDCIssuerURL = prevIssuerURL + common.APISkipOriginCheck = prevSkipOriginCheck }) require.NoError(t, auth.Initialize()) diff --git a/internal/auth/oidc_test.go b/internal/auth/oidc_test.go index 4c9dc6bb..b8959f56 100644 --- a/internal/auth/oidc_test.go +++ b/internal/auth/oidc_test.go @@ -216,7 +216,7 @@ func TestOIDCCallbackHandler(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/auth/callback?code="+tt.code+"&state="+tt.state, nil) if tt.state != "" { req.AddCookie(&http.Cookie{ - Name: CookieOauthState, + Name: defaultAuth.(*OIDCProvider).getAppScopedCookieName(CookieOauthState), Value: tt.state, }) } diff --git a/internal/jsonstore/jsonstore.go b/internal/jsonstore/jsonstore.go index 780ac32b..2a48ede3 100644 --- a/internal/jsonstore/jsonstore.go +++ b/internal/jsonstore/jsonstore.go @@ -52,10 +52,6 @@ func loadNS[T store](ns namespace) T { store := reflect.New(reflect.TypeFor[T]().Elem()).Interface().(T) store.Initialize() - if common.IsTest { - return store - } - path := filepath.Join(storesPath, string(ns)+".json") file, err := os.Open(path) if err != nil { @@ -72,7 +68,6 @@ func loadNS[T store](ns namespace) T { Msg("failed to load store") } } - stores[ns] = store log.Debug(). Str("namespace", string(ns)). Str("path", path). diff --git a/internal/jsonstore/jsonstore_test.go b/internal/jsonstore/jsonstore_test.go index 806ea5f7..e2360588 100644 --- a/internal/jsonstore/jsonstore_test.go +++ b/internal/jsonstore/jsonstore_test.go @@ -4,7 +4,18 @@ import ( "testing" ) +func setupTest(t *testing.T) { + prevStoresPath := storesPath + storesPath = t.TempDir() + t.Cleanup(func() { + storesPath = prevStoresPath + clear(stores) + }) +} + func TestNewJSON(t *testing.T) { + setupTest(t) + store := Store[string]("test") store.Store("a", "1") if v, _ := store.Load("a"); v != "1" { @@ -13,9 +24,8 @@ func TestNewJSON(t *testing.T) { } func TestSaveLoadStore(t *testing.T) { - defer clear(stores) + setupTest(t) - storesPath = t.TempDir() store := Store[string]("test") store.Store("a", "1") if err := save(); err != nil { @@ -44,9 +54,8 @@ type testObject struct { func (*testObject) Initialize() {} func TestSaveLoadObject(t *testing.T) { - defer clear(stores) + setupTest(t) - storesPath = t.TempDir() obj := Object[*testObject]("test") obj.I = 1 obj.S = "1" diff --git a/internal/maxmind/maxmind.go b/internal/maxmind/maxmind.go index 5b7ddd52..2699c5b0 100644 --- a/internal/maxmind/maxmind.go +++ b/internal/maxmind/maxmind.go @@ -106,7 +106,9 @@ func (cfg *MaxMind) LoadMaxMindDB(parent task.Parent) error { } else { cfg.Logger().Info().Msg("MaxMind DB loaded") cfg.db.Reader = reader - go cfg.scheduleUpdate(parent) + if !common.IsTest { + go cfg.scheduleUpdate(parent) + } } return nil } diff --git a/internal/net/gphttp/middleware/route_overlay_context_test.go b/internal/net/gphttp/middleware/route_overlay_context_test.go index 1b36dd34..ac27df96 100644 --- a/internal/net/gphttp/middleware/route_overlay_context_test.go +++ b/internal/net/gphttp/middleware/route_overlay_context_test.go @@ -42,7 +42,8 @@ func TestWithConsumedRouteOverlaysReturnsNewRequestWhenOverlayIsPresent(t *testi } type fakeMiddlewareHTTPRoute struct { - name string + name string + targetURL *nettypes.URL } func (r fakeMiddlewareHTTPRoute) Key() string { return r.name } @@ -54,7 +55,7 @@ func (r fakeMiddlewareHTTPRoute) MarshalZerologObject(*zerolog.Event) {} func (r fakeMiddlewareHTTPRoute) ProviderName() string { return "" } func (r fakeMiddlewareHTTPRoute) GetProvider() types.RouteProvider { return nil } func (r fakeMiddlewareHTTPRoute) ListenURL() *nettypes.URL { return nil } -func (r fakeMiddlewareHTTPRoute) TargetURL() *nettypes.URL { return nil } +func (r fakeMiddlewareHTTPRoute) TargetURL() *nettypes.URL { return r.targetURL } func (r fakeMiddlewareHTTPRoute) HealthMonitor() types.HealthMonitor { return nil } func (r fakeMiddlewareHTTPRoute) SetHealthMonitor(types.HealthMonitor) {} func (r fakeMiddlewareHTTPRoute) References() []string { return nil } diff --git a/internal/net/gphttp/middleware/test_utils_test.go b/internal/net/gphttp/middleware/test_utils_test.go index dae8fb79..5819aed7 100644 --- a/internal/net/gphttp/middleware/test_utils_test.go +++ b/internal/net/gphttp/middleware/test_utils_test.go @@ -12,6 +12,7 @@ import ( "github.com/bytedance/sonic" "github.com/yusing/godoxy/internal/common" nettypes "github.com/yusing/godoxy/internal/net/types" + "github.com/yusing/godoxy/internal/route/routes" "github.com/yusing/goutils/http/reverseproxy" ) @@ -161,6 +162,10 @@ func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, req := httptest.NewRequest(args.reqMethod, args.reqURL.String(), args.bodyReader()) maps.Copy(req.Header, args.headers) + req = routes.WithRouteContext(req, fakeMiddlewareHTTPRoute{ + name: "test-upstream", + targetURL: args.upstreamURL, + }) w := httptest.NewRecorder() diff --git a/internal/net/gphttp/middleware/vars.go b/internal/net/gphttp/middleware/vars.go index c49af043..1c168f9f 100644 --- a/internal/net/gphttp/middleware/vars.go +++ b/internal/net/gphttp/middleware/vars.go @@ -56,9 +56,15 @@ var staticReqVarSubsMap = map[string]reqVarGetter{ if req.TLS != nil { return "https" } + if req.URL != nil && req.URL.Scheme != "" { + return req.URL.Scheme + } return "http" }, VarRequestHost: func(req *http.Request) string { + if req.Host == "" && req.URL != nil { + return req.URL.Hostname() + } reqHost, _, err := net.SplitHostPort(req.Host) if err != nil { return req.Host @@ -66,10 +72,21 @@ var staticReqVarSubsMap = map[string]reqVarGetter{ return reqHost }, VarRequestPort: func(req *http.Request) string { + if req.Host == "" && req.URL != nil { + return req.URL.Port() + } _, reqPort, _ := net.SplitHostPort(req.Host) return reqPort }, - VarRequestAddr: func(req *http.Request) string { return req.Host }, + VarRequestAddr: func(req *http.Request) string { + if req.Host != "" { + return req.Host + } + if req.URL != nil { + return req.URL.Host + } + return "" + }, VarRequestPath: func(req *http.Request) string { return req.URL.Path }, VarRequestQuery: func(req *http.Request) string { return req.URL.RawQuery }, VarRequestURL: func(req *http.Request) string { return req.URL.String() }, diff --git a/internal/route/rules/do_log_test.go b/internal/route/rules/do_log_test.go index dcd958be..a55e6665 100644 --- a/internal/route/rules/do_log_test.go +++ b/internal/route/rules/do_log_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "reflect" "strings" + "sync" "testing" "time" @@ -70,13 +71,34 @@ default { assert.Equal(t, "POST /api/users 200 application/json\n", logContent) } +type lockedWriteCloser struct { + mu sync.Mutex + buf bytes.Buffer +} + +func (w *lockedWriteCloser) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + return w.buf.Write(p) +} + +func (w *lockedWriteCloser) Close() error { + return nil +} + +func (w *lockedWriteCloser) String() string { + w.mu.Lock() + defer w.mu.Unlock() + return w.buf.String() +} + func TestLogCommand_StdoutAndStderr(t *testing.T) { originalStdout := stdout originalStderr := stderr - var stdoutBuf bytes.Buffer - var stderrBuf bytes.Buffer - stdout = noopWriteCloser{&stdoutBuf} - stderr = noopWriteCloser{&stderrBuf} + stdoutBuf := &lockedWriteCloser{} + stderrBuf := &lockedWriteCloser{} + stdout = stdoutBuf + stderr = stderrBuf defer func() { stdout = originalStdout stderr = originalStderr