diff --git a/internal/common/env.go b/internal/common/env.go index 121aa16a..2121612f 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -13,6 +13,8 @@ var ( IsDebug = env.GetEnvBool("DEBUG", IsTest) IsTrace = env.GetEnvBool("TRACE", false) && IsDebug + ShortLinkPrefix = env.GetEnvString("SHORTLINK_PREFIX", "go") + ProxyHTTPAddr, ProxyHTTPHost, ProxyHTTPPort, diff --git a/internal/config/state.go b/internal/config/state.go index 52fa2c71..ffc92dfd 100644 --- a/internal/config/state.go +++ b/internal/config/state.go @@ -3,6 +3,8 @@ package config import ( "bytes" "context" + "crypto/tls" + "crypto/x509" "fmt" "iter" "net/http" @@ -134,6 +136,10 @@ func (state *state) EntrypointHandler() http.Handler { return &state.entrypoint } +func (state *state) Entrypoint() *entrypoint.Entrypoint { + return &state.entrypoint +} + // AutoCertProvider returns the autocert provider. // // If the autocert provider is not configured, it returns nil. @@ -191,18 +197,52 @@ func (state *state) initAccessLogger() error { } func (state *state) initEntrypoint() error { - epCfg := state.Entrypoint + epCfg := state.Config.Entrypoint matchDomains := state.MatchDomains state.entrypoint.SetFindRouteDomains(matchDomains) state.entrypoint.SetNotFoundRules(epCfg.Rules.NotFound) + if len(matchDomains) > 0 { + state.entrypoint.ShortLinkMatcher().SetDefaultDomainSuffix(matchDomains[0]) + } + + if state.autocertProvider != nil { + if domain := getAutoCertDefaultDomain(state.autocertProvider); domain != "" { + state.entrypoint.ShortLinkMatcher().SetDefaultDomainSuffix("." + domain) + } + } + errs := gperr.NewBuilder("entrypoint error") errs.Add(state.entrypoint.SetMiddlewares(epCfg.Middlewares)) errs.Add(state.entrypoint.SetAccessLogger(state.task, epCfg.AccessLog)) return errs.Error() } +func getAutoCertDefaultDomain(p *autocert.Provider) string { + if p == nil { + return "" + } + cert, err := tls.LoadX509KeyPair(p.GetCertPath(), p.GetKeyPath()) + if err != nil || len(cert.Certificate) == 0 { + return "" + } + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return "" + } + + domain := x509Cert.Subject.CommonName + if domain == "" && len(x509Cert.DNSNames) > 0 { + domain = x509Cert.DNSNames[0] + } + domain = strings.TrimSpace(domain) + if after, ok := strings.CutPrefix(domain, "*."); ok { + domain = after + } + return strings.ToLower(domain) +} + func (state *state) initMaxMind() error { maxmindCfg := state.Providers.MaxMind if maxmindCfg != nil { diff --git a/internal/config/types/state.go b/internal/config/types/state.go index 83632350..9be9ad62 100644 --- a/internal/config/types/state.go +++ b/internal/config/types/state.go @@ -6,6 +6,7 @@ import ( "iter" "net/http" + "github.com/yusing/godoxy/internal/entrypoint" "github.com/yusing/godoxy/internal/types" "github.com/yusing/goutils/server" "github.com/yusing/goutils/synk" @@ -22,6 +23,7 @@ type State interface { Value() *Config EntrypointHandler() http.Handler + Entrypoint() *entrypoint.Entrypoint AutoCertProvider() server.CertProvider LoadOrStoreProvider(key string, value types.RouteProvider) (actual types.RouteProvider, loaded bool) diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index e7a69426..1931db49 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -6,6 +6,7 @@ import ( "sync/atomic" "github.com/rs/zerolog/log" + "github.com/yusing/godoxy/internal/common" entrypoint "github.com/yusing/godoxy/internal/entrypoint/types" "github.com/yusing/godoxy/internal/logging/accesslog" "github.com/yusing/godoxy/internal/net/gphttp/middleware" @@ -21,6 +22,7 @@ type Entrypoint struct { notFoundHandler http.Handler accessLogger accesslog.AccessLogger findRouteFunc func(host string) types.HTTPRoute + shortLinkTree *ShortLinkMatcher } // nil-safe @@ -34,9 +36,14 @@ func init() { func NewEntrypoint() Entrypoint { return Entrypoint{ findRouteFunc: findRouteAnyDomain, + shortLinkTree: newShortLinkTree(), } } +func (ep *Entrypoint) ShortLinkMatcher() *ShortLinkMatcher { + return ep.shortLinkTree +} + func (ep *Entrypoint) SetFindRouteDomains(domains []string) { if len(domains) == 0 { ep.findRouteFunc = findRouteAnyDomain @@ -104,6 +111,8 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { } else { route.ServeHTTP(w, r) } + case ep.tryHandleShortLink(w, r): + return case ep.notFoundHandler != nil: ep.notFoundHandler.ServeHTTP(w, r) default: @@ -111,6 +120,22 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } +func (ep *Entrypoint) tryHandleShortLink(w http.ResponseWriter, r *http.Request) (handled bool) { + host := r.Host + if before, _, ok := strings.Cut(host, ":"); ok { + host = before + } + if strings.EqualFold(host, common.ShortLinkPrefix) { + if ep.middleware != nil { + ep.middleware.ServeHTTP(ep.shortLinkTree.ServeHTTP, w, r) + } else { + ep.shortLinkTree.ServeHTTP(w, r) + } + return true + } + return false +} + func (ep *Entrypoint) serveNotFound(w http.ResponseWriter, r *http.Request) { // Why use StatusNotFound instead of StatusBadRequest or StatusBadGateway? // On nginx, when route for domain does not exist, it returns StatusBadGateway. diff --git a/internal/entrypoint/shortlink.go b/internal/entrypoint/shortlink.go new file mode 100644 index 00000000..e3bf20c1 --- /dev/null +++ b/internal/entrypoint/shortlink.go @@ -0,0 +1,110 @@ +package entrypoint + +import ( + "net/http" + "strings" + + "github.com/puzpuzpuz/xsync/v4" +) + +type ShortLinkMatcher struct { + defaultDomainSuffix string // e.g. ".example.com" + + fqdnRoutes *xsync.Map[string, string] // "app" -> "app.example.com" + subdomainRoutes *xsync.Map[string, struct{}] +} + +func newShortLinkTree() *ShortLinkMatcher { + return &ShortLinkMatcher{ + fqdnRoutes: xsync.NewMap[string, string](), + subdomainRoutes: xsync.NewMap[string, struct{}](), + } +} + +func (st *ShortLinkMatcher) SetDefaultDomainSuffix(suffix string) { + if !strings.HasPrefix(suffix, ".") { + suffix = "." + suffix + } + st.defaultDomainSuffix = suffix +} + +func (st *ShortLinkMatcher) AddRoute(alias string) { + alias = strings.TrimSpace(alias) + if alias == "" { + return + } + + if strings.Contains(alias, ".") { // FQDN alias + st.fqdnRoutes.Store(alias, alias) + key, _, _ := strings.Cut(alias, ".") + if key != "" { + if _, ok := st.subdomainRoutes.Load(key); !ok { + if _, ok := st.fqdnRoutes.Load(key); !ok { + st.fqdnRoutes.Store(key, alias) + } + } + } + return + } + + // subdomain alias + defaultDomainSuffix + if st.defaultDomainSuffix == "" { + return + } + st.subdomainRoutes.Store(alias, struct{}{}) +} + +func (st *ShortLinkMatcher) DelRoute(alias string) { + alias = strings.TrimSpace(alias) + if alias == "" { + return + } + + if strings.Contains(alias, ".") { + st.fqdnRoutes.Delete(alias) + key, _, _ := strings.Cut(alias, ".") + if key != "" { + if target, ok := st.fqdnRoutes.Load(key); ok && target == alias { + st.fqdnRoutes.Delete(key) + } + } + return + } + + st.subdomainRoutes.Delete(alias) +} + +func (st *ShortLinkMatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) { + path := r.URL.EscapedPath() + trim := strings.TrimPrefix(path, "/") + key, rest, _ := strings.Cut(trim, "/") + if key == "" { + http.Error(w, "short link key is required", http.StatusBadRequest) + return + } + if rest != "" { + rest = "/" + rest + } else { + rest = "/" + } + + targetHost := "" + if strings.Contains(key, ".") { + targetHost, _ = st.fqdnRoutes.Load(key) + } else if target, ok := st.fqdnRoutes.Load(key); ok { + targetHost = target + } else if _, ok := st.subdomainRoutes.Load(key); ok && st.defaultDomainSuffix != "" { + targetHost = key + st.defaultDomainSuffix + } + + if targetHost == "" { + http.Error(w, "short link not found", http.StatusNotFound) + return + } + + targetURL := "https://" + targetHost + rest + if q := r.URL.RawQuery; q != "" { + targetURL += "?" + q + } + http.Redirect(w, r, targetURL, http.StatusTemporaryRedirect) +} diff --git a/internal/entrypoint/shortlink_test.go b/internal/entrypoint/shortlink_test.go new file mode 100644 index 00000000..6e28a8b0 --- /dev/null +++ b/internal/entrypoint/shortlink_test.go @@ -0,0 +1,194 @@ +package entrypoint_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/yusing/godoxy/internal/common" + . "github.com/yusing/godoxy/internal/entrypoint" +) + +func TestShortLinkMatcher_FQDNAlias(t *testing.T) { + ep := NewEntrypoint() + matcher := ep.ShortLinkMatcher() + matcher.AddRoute("app.domain.com") + + t.Run("exact path", func(t *testing.T) { + req := httptest.NewRequest("GET", "/app", nil) + w := httptest.NewRecorder() + matcher.ServeHTTP(w, req) + + assert.Equal(t, http.StatusTemporaryRedirect, w.Code) + assert.Equal(t, "https://app.domain.com/", w.Header().Get("Location")) + }) + + t.Run("with path remainder", func(t *testing.T) { + req := httptest.NewRequest("GET", "/app/foo/bar", nil) + w := httptest.NewRecorder() + matcher.ServeHTTP(w, req) + + assert.Equal(t, http.StatusTemporaryRedirect, w.Code) + assert.Equal(t, "https://app.domain.com/foo/bar", w.Header().Get("Location")) + }) + + t.Run("with query", func(t *testing.T) { + req := httptest.NewRequest("GET", "/app/foo?x=y&z=1", nil) + w := httptest.NewRecorder() + matcher.ServeHTTP(w, req) + + assert.Equal(t, http.StatusTemporaryRedirect, w.Code) + assert.Equal(t, "https://app.domain.com/foo?x=y&z=1", w.Header().Get("Location")) + }) +} + +func TestShortLinkMatcher_SubdomainAlias(t *testing.T) { + ep := NewEntrypoint() + matcher := ep.ShortLinkMatcher() + matcher.SetDefaultDomainSuffix(".example.com") + matcher.AddRoute("app") + + t.Run("exact path", func(t *testing.T) { + req := httptest.NewRequest("GET", "/app", nil) + w := httptest.NewRecorder() + matcher.ServeHTTP(w, req) + + assert.Equal(t, http.StatusTemporaryRedirect, w.Code) + assert.Equal(t, "https://app.example.com/", w.Header().Get("Location")) + }) + + t.Run("with path remainder", func(t *testing.T) { + req := httptest.NewRequest("GET", "/app/foo/bar", nil) + w := httptest.NewRecorder() + matcher.ServeHTTP(w, req) + + assert.Equal(t, http.StatusTemporaryRedirect, w.Code) + assert.Equal(t, "https://app.example.com/foo/bar", w.Header().Get("Location")) + }) +} + +func TestShortLinkMatcher_NotFound(t *testing.T) { + ep := NewEntrypoint() + matcher := ep.ShortLinkMatcher() + matcher.SetDefaultDomainSuffix(".example.com") + matcher.AddRoute("app") + + t.Run("missing key", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + matcher.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) + }) + + t.Run("unknown key", func(t *testing.T) { + req := httptest.NewRequest("GET", "/unknown", nil) + w := httptest.NewRecorder() + matcher.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + }) +} + +func TestShortLinkMatcher_AddDelRoute(t *testing.T) { + ep := NewEntrypoint() + matcher := ep.ShortLinkMatcher() + matcher.SetDefaultDomainSuffix(".example.com") + + matcher.AddRoute("app1") + matcher.AddRoute("app2.domain.com") + + t.Run("both routes work", func(t *testing.T) { + req := httptest.NewRequest("GET", "/app1", nil) + w := httptest.NewRecorder() + matcher.ServeHTTP(w, req) + assert.Equal(t, http.StatusTemporaryRedirect, w.Code) + assert.Equal(t, "https://app1.example.com/", w.Header().Get("Location")) + + req = httptest.NewRequest("GET", "/app2.domain.com", nil) + w = httptest.NewRecorder() + matcher.ServeHTTP(w, req) + assert.Equal(t, http.StatusTemporaryRedirect, w.Code) + assert.Equal(t, "https://app2.domain.com/", w.Header().Get("Location")) + }) + + t.Run("delete route", func(t *testing.T) { + matcher.DelRoute("app1") + + req := httptest.NewRequest("GET", "/app1", nil) + w := httptest.NewRecorder() + matcher.ServeHTTP(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) + + req = httptest.NewRequest("GET", "/app2.domain.com", nil) + w = httptest.NewRecorder() + matcher.ServeHTTP(w, req) + assert.Equal(t, http.StatusTemporaryRedirect, w.Code) + assert.Equal(t, "https://app2.domain.com/", w.Header().Get("Location")) + }) +} + +func TestShortLinkMatcher_NoDefaultDomainSuffix(t *testing.T) { + ep := NewEntrypoint() + matcher := ep.ShortLinkMatcher() + // no SetDefaultDomainSuffix called + + t.Run("subdomain alias ignored", func(t *testing.T) { + matcher.AddRoute("app") + + req := httptest.NewRequest("GET", "/app", nil) + w := httptest.NewRecorder() + matcher.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + }) + + t.Run("FQDN alias still works", func(t *testing.T) { + matcher.AddRoute("app.domain.com") + + req := httptest.NewRequest("GET", "/app.domain.com", nil) + w := httptest.NewRecorder() + matcher.ServeHTTP(w, req) + + assert.Equal(t, http.StatusTemporaryRedirect, w.Code) + assert.Equal(t, "https://app.domain.com/", w.Header().Get("Location")) + }) +} + +func TestEntrypoint_ShortLinkDispatch(t *testing.T) { + ep := NewEntrypoint() + ep.ShortLinkMatcher().SetDefaultDomainSuffix(".example.com") + ep.ShortLinkMatcher().AddRoute("app") + + t.Run("shortlink host", func(t *testing.T) { + req := httptest.NewRequest("GET", "/app", nil) + req.Host = common.ShortLinkPrefix + w := httptest.NewRecorder() + ep.ServeHTTP(w, req) + + assert.Equal(t, http.StatusTemporaryRedirect, w.Code) + assert.Equal(t, "https://app.example.com/", w.Header().Get("Location")) + }) + + t.Run("shortlink host with port", func(t *testing.T) { + req := httptest.NewRequest("GET", "/app", nil) + req.Host = common.ShortLinkPrefix + ":8080" + w := httptest.NewRecorder() + ep.ServeHTTP(w, req) + + assert.Equal(t, http.StatusTemporaryRedirect, w.Code) + assert.Equal(t, "https://app.example.com/", w.Header().Get("Location")) + }) + + t.Run("normal host", func(t *testing.T) { + req := httptest.NewRequest("GET", "/app", nil) + req.Host = "app.example.com" + w := httptest.NewRecorder() + ep.ServeHTTP(w, req) + + // Should not redirect, should try normal route lookup (which will 404) + assert.NotEqual(t, http.StatusTemporaryRedirect, w.Code) + }) +} diff --git a/internal/net/gphttp/middleware/errorpage/error_page.go b/internal/net/gphttp/middleware/errorpage/error_page.go index ed4377eb..c9590729 100644 --- a/internal/net/gphttp/middleware/errorpage/error_page.go +++ b/internal/net/gphttp/middleware/errorpage/error_page.go @@ -32,6 +32,9 @@ func setup() { } func GetStaticFile(filename string) ([]byte, bool) { + if common.IsTest { + return nil, false + } setupOnce.Do(setup) return fileContentMap.Load(filename) } diff --git a/internal/route/fileserver.go b/internal/route/fileserver.go index e5b06281..c5dd4087 100644 --- a/internal/route/fileserver.go +++ b/internal/route/fileserver.go @@ -6,6 +6,7 @@ import ( "path" "path/filepath" + config "github.com/yusing/godoxy/internal/config/types" "github.com/yusing/godoxy/internal/logging/accesslog" gphttp "github.com/yusing/godoxy/internal/net/gphttp" "github.com/yusing/godoxy/internal/net/gphttp/middleware" @@ -124,8 +125,14 @@ func (s *FileServer) Start(parent task.Parent) gperr.Error { } routes.HTTP.Add(s) + if state := config.WorkingState.Load(); state != nil { + state.Entrypoint().ShortLinkMatcher().AddRoute(s.Alias) + } s.task.OnFinished("remove_route_from_http", func() { routes.HTTP.Del(s) + if state := config.WorkingState.Load(); state != nil { + state.Entrypoint().ShortLinkMatcher().DelRoute(s.Alias) + } }) return nil } diff --git a/internal/route/reverse_proxy.go b/internal/route/reverse_proxy.go index caebc79d..b99687c5 100755 --- a/internal/route/reverse_proxy.go +++ b/internal/route/reverse_proxy.go @@ -6,6 +6,7 @@ import ( "github.com/yusing/godoxy/agent/pkg/agent" "github.com/yusing/godoxy/agent/pkg/agentproxy" + config "github.com/yusing/godoxy/internal/config/types" "github.com/yusing/godoxy/internal/idlewatcher" "github.com/yusing/godoxy/internal/logging/accesslog" gphttp "github.com/yusing/godoxy/internal/net/gphttp" @@ -166,8 +167,14 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error { r.addToLoadBalancer(parent) } else { routes.HTTP.Add(r) - r.task.OnCancel("remove_route_from_http", func() { + if state := config.WorkingState.Load(); state != nil { + state.Entrypoint().ShortLinkMatcher().AddRoute(r.Alias) + } + r.task.OnCancel("remove_route", func() { routes.HTTP.Del(r) + if state := config.WorkingState.Load(); state != nil { + state.Entrypoint().ShortLinkMatcher().DelRoute(r.Alias) + } }) } return nil @@ -208,8 +215,14 @@ func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent) { } linked.SetHealthMonitor(lb) routes.HTTP.AddKey(cfg.Link, linked) + if state := config.WorkingState.Load(); state != nil { + state.Entrypoint().ShortLinkMatcher().AddRoute(cfg.Link) + } r.task.OnFinished("remove_loadbalancer_route", func() { routes.HTTP.DelKey(cfg.Link) + if state := config.WorkingState.Load(); state != nil { + state.Entrypoint().ShortLinkMatcher().DelRoute(cfg.Link) + } }) lbLock.Unlock() }