feat(entrypoint): implement short link #177

- Added ShortLinkMatcher to handle short link routing.
- Integrated short link handling in Entrypoint.
- Introduced tests for short link matching and dispatching.
- Configured default domain suffix for subdomain aliases.
This commit is contained in:
yusing
2026-01-02 15:42:15 +08:00
parent 1f4c30a48e
commit 590743f1ef
9 changed files with 398 additions and 2 deletions

View File

@@ -13,6 +13,8 @@ var (
IsDebug = env.GetEnvBool("DEBUG", IsTest)
IsTrace = env.GetEnvBool("TRACE", false) && IsDebug
ShortLinkPrefix = env.GetEnvString("SHORTLINK_PREFIX", "go")
ProxyHTTPAddr,
ProxyHTTPHost,
ProxyHTTPPort,

View File

@@ -3,6 +3,8 @@ package config
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"iter"
"net/http"
@@ -134,6 +136,10 @@ func (state *state) EntrypointHandler() http.Handler {
return &state.entrypoint
}
func (state *state) Entrypoint() *entrypoint.Entrypoint {
return &state.entrypoint
}
// AutoCertProvider returns the autocert provider.
//
// If the autocert provider is not configured, it returns nil.
@@ -191,18 +197,52 @@ func (state *state) initAccessLogger() error {
}
func (state *state) initEntrypoint() error {
epCfg := state.Entrypoint
epCfg := state.Config.Entrypoint
matchDomains := state.MatchDomains
state.entrypoint.SetFindRouteDomains(matchDomains)
state.entrypoint.SetNotFoundRules(epCfg.Rules.NotFound)
if len(matchDomains) > 0 {
state.entrypoint.ShortLinkMatcher().SetDefaultDomainSuffix(matchDomains[0])
}
if state.autocertProvider != nil {
if domain := getAutoCertDefaultDomain(state.autocertProvider); domain != "" {
state.entrypoint.ShortLinkMatcher().SetDefaultDomainSuffix("." + domain)
}
}
errs := gperr.NewBuilder("entrypoint error")
errs.Add(state.entrypoint.SetMiddlewares(epCfg.Middlewares))
errs.Add(state.entrypoint.SetAccessLogger(state.task, epCfg.AccessLog))
return errs.Error()
}
func getAutoCertDefaultDomain(p *autocert.Provider) string {
if p == nil {
return ""
}
cert, err := tls.LoadX509KeyPair(p.GetCertPath(), p.GetKeyPath())
if err != nil || len(cert.Certificate) == 0 {
return ""
}
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return ""
}
domain := x509Cert.Subject.CommonName
if domain == "" && len(x509Cert.DNSNames) > 0 {
domain = x509Cert.DNSNames[0]
}
domain = strings.TrimSpace(domain)
if after, ok := strings.CutPrefix(domain, "*."); ok {
domain = after
}
return strings.ToLower(domain)
}
func (state *state) initMaxMind() error {
maxmindCfg := state.Providers.MaxMind
if maxmindCfg != nil {

View File

@@ -6,6 +6,7 @@ import (
"iter"
"net/http"
"github.com/yusing/godoxy/internal/entrypoint"
"github.com/yusing/godoxy/internal/types"
"github.com/yusing/goutils/server"
"github.com/yusing/goutils/synk"
@@ -22,6 +23,7 @@ type State interface {
Value() *Config
EntrypointHandler() http.Handler
Entrypoint() *entrypoint.Entrypoint
AutoCertProvider() server.CertProvider
LoadOrStoreProvider(key string, value types.RouteProvider) (actual types.RouteProvider, loaded bool)

View File

@@ -6,6 +6,7 @@ import (
"sync/atomic"
"github.com/rs/zerolog/log"
"github.com/yusing/godoxy/internal/common"
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
"github.com/yusing/godoxy/internal/logging/accesslog"
"github.com/yusing/godoxy/internal/net/gphttp/middleware"
@@ -21,6 +22,7 @@ type Entrypoint struct {
notFoundHandler http.Handler
accessLogger accesslog.AccessLogger
findRouteFunc func(host string) types.HTTPRoute
shortLinkTree *ShortLinkMatcher
}
// nil-safe
@@ -34,9 +36,14 @@ func init() {
func NewEntrypoint() Entrypoint {
return Entrypoint{
findRouteFunc: findRouteAnyDomain,
shortLinkTree: newShortLinkTree(),
}
}
func (ep *Entrypoint) ShortLinkMatcher() *ShortLinkMatcher {
return ep.shortLinkTree
}
func (ep *Entrypoint) SetFindRouteDomains(domains []string) {
if len(domains) == 0 {
ep.findRouteFunc = findRouteAnyDomain
@@ -104,6 +111,8 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} else {
route.ServeHTTP(w, r)
}
case ep.tryHandleShortLink(w, r):
return
case ep.notFoundHandler != nil:
ep.notFoundHandler.ServeHTTP(w, r)
default:
@@ -111,6 +120,22 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
func (ep *Entrypoint) tryHandleShortLink(w http.ResponseWriter, r *http.Request) (handled bool) {
host := r.Host
if before, _, ok := strings.Cut(host, ":"); ok {
host = before
}
if strings.EqualFold(host, common.ShortLinkPrefix) {
if ep.middleware != nil {
ep.middleware.ServeHTTP(ep.shortLinkTree.ServeHTTP, w, r)
} else {
ep.shortLinkTree.ServeHTTP(w, r)
}
return true
}
return false
}
func (ep *Entrypoint) serveNotFound(w http.ResponseWriter, r *http.Request) {
// Why use StatusNotFound instead of StatusBadRequest or StatusBadGateway?
// On nginx, when route for domain does not exist, it returns StatusBadGateway.

View File

@@ -0,0 +1,110 @@
package entrypoint
import (
"net/http"
"strings"
"github.com/puzpuzpuz/xsync/v4"
)
type ShortLinkMatcher struct {
defaultDomainSuffix string // e.g. ".example.com"
fqdnRoutes *xsync.Map[string, string] // "app" -> "app.example.com"
subdomainRoutes *xsync.Map[string, struct{}]
}
func newShortLinkTree() *ShortLinkMatcher {
return &ShortLinkMatcher{
fqdnRoutes: xsync.NewMap[string, string](),
subdomainRoutes: xsync.NewMap[string, struct{}](),
}
}
func (st *ShortLinkMatcher) SetDefaultDomainSuffix(suffix string) {
if !strings.HasPrefix(suffix, ".") {
suffix = "." + suffix
}
st.defaultDomainSuffix = suffix
}
func (st *ShortLinkMatcher) AddRoute(alias string) {
alias = strings.TrimSpace(alias)
if alias == "" {
return
}
if strings.Contains(alias, ".") { // FQDN alias
st.fqdnRoutes.Store(alias, alias)
key, _, _ := strings.Cut(alias, ".")
if key != "" {
if _, ok := st.subdomainRoutes.Load(key); !ok {
if _, ok := st.fqdnRoutes.Load(key); !ok {
st.fqdnRoutes.Store(key, alias)
}
}
}
return
}
// subdomain alias + defaultDomainSuffix
if st.defaultDomainSuffix == "" {
return
}
st.subdomainRoutes.Store(alias, struct{}{})
}
func (st *ShortLinkMatcher) DelRoute(alias string) {
alias = strings.TrimSpace(alias)
if alias == "" {
return
}
if strings.Contains(alias, ".") {
st.fqdnRoutes.Delete(alias)
key, _, _ := strings.Cut(alias, ".")
if key != "" {
if target, ok := st.fqdnRoutes.Load(key); ok && target == alias {
st.fqdnRoutes.Delete(key)
}
}
return
}
st.subdomainRoutes.Delete(alias)
}
func (st *ShortLinkMatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := r.URL.EscapedPath()
trim := strings.TrimPrefix(path, "/")
key, rest, _ := strings.Cut(trim, "/")
if key == "" {
http.Error(w, "short link key is required", http.StatusBadRequest)
return
}
if rest != "" {
rest = "/" + rest
} else {
rest = "/"
}
targetHost := ""
if strings.Contains(key, ".") {
targetHost, _ = st.fqdnRoutes.Load(key)
} else if target, ok := st.fqdnRoutes.Load(key); ok {
targetHost = target
} else if _, ok := st.subdomainRoutes.Load(key); ok && st.defaultDomainSuffix != "" {
targetHost = key + st.defaultDomainSuffix
}
if targetHost == "" {
http.Error(w, "short link not found", http.StatusNotFound)
return
}
targetURL := "https://" + targetHost + rest
if q := r.URL.RawQuery; q != "" {
targetURL += "?" + q
}
http.Redirect(w, r, targetURL, http.StatusTemporaryRedirect)
}

View File

@@ -0,0 +1,194 @@
package entrypoint_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/yusing/godoxy/internal/common"
. "github.com/yusing/godoxy/internal/entrypoint"
)
func TestShortLinkMatcher_FQDNAlias(t *testing.T) {
ep := NewEntrypoint()
matcher := ep.ShortLinkMatcher()
matcher.AddRoute("app.domain.com")
t.Run("exact path", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app.domain.com/", w.Header().Get("Location"))
})
t.Run("with path remainder", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app/foo/bar", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app.domain.com/foo/bar", w.Header().Get("Location"))
})
t.Run("with query", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app/foo?x=y&z=1", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app.domain.com/foo?x=y&z=1", w.Header().Get("Location"))
})
}
func TestShortLinkMatcher_SubdomainAlias(t *testing.T) {
ep := NewEntrypoint()
matcher := ep.ShortLinkMatcher()
matcher.SetDefaultDomainSuffix(".example.com")
matcher.AddRoute("app")
t.Run("exact path", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app.example.com/", w.Header().Get("Location"))
})
t.Run("with path remainder", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app/foo/bar", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app.example.com/foo/bar", w.Header().Get("Location"))
})
}
func TestShortLinkMatcher_NotFound(t *testing.T) {
ep := NewEntrypoint()
matcher := ep.ShortLinkMatcher()
matcher.SetDefaultDomainSuffix(".example.com")
matcher.AddRoute("app")
t.Run("missing key", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
})
t.Run("unknown key", func(t *testing.T) {
req := httptest.NewRequest("GET", "/unknown", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
})
}
func TestShortLinkMatcher_AddDelRoute(t *testing.T) {
ep := NewEntrypoint()
matcher := ep.ShortLinkMatcher()
matcher.SetDefaultDomainSuffix(".example.com")
matcher.AddRoute("app1")
matcher.AddRoute("app2.domain.com")
t.Run("both routes work", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app1", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app1.example.com/", w.Header().Get("Location"))
req = httptest.NewRequest("GET", "/app2.domain.com", nil)
w = httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app2.domain.com/", w.Header().Get("Location"))
})
t.Run("delete route", func(t *testing.T) {
matcher.DelRoute("app1")
req := httptest.NewRequest("GET", "/app1", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
req = httptest.NewRequest("GET", "/app2.domain.com", nil)
w = httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app2.domain.com/", w.Header().Get("Location"))
})
}
func TestShortLinkMatcher_NoDefaultDomainSuffix(t *testing.T) {
ep := NewEntrypoint()
matcher := ep.ShortLinkMatcher()
// no SetDefaultDomainSuffix called
t.Run("subdomain alias ignored", func(t *testing.T) {
matcher.AddRoute("app")
req := httptest.NewRequest("GET", "/app", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
})
t.Run("FQDN alias still works", func(t *testing.T) {
matcher.AddRoute("app.domain.com")
req := httptest.NewRequest("GET", "/app.domain.com", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app.domain.com/", w.Header().Get("Location"))
})
}
func TestEntrypoint_ShortLinkDispatch(t *testing.T) {
ep := NewEntrypoint()
ep.ShortLinkMatcher().SetDefaultDomainSuffix(".example.com")
ep.ShortLinkMatcher().AddRoute("app")
t.Run("shortlink host", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app", nil)
req.Host = common.ShortLinkPrefix
w := httptest.NewRecorder()
ep.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app.example.com/", w.Header().Get("Location"))
})
t.Run("shortlink host with port", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app", nil)
req.Host = common.ShortLinkPrefix + ":8080"
w := httptest.NewRecorder()
ep.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app.example.com/", w.Header().Get("Location"))
})
t.Run("normal host", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app", nil)
req.Host = "app.example.com"
w := httptest.NewRecorder()
ep.ServeHTTP(w, req)
// Should not redirect, should try normal route lookup (which will 404)
assert.NotEqual(t, http.StatusTemporaryRedirect, w.Code)
})
}

View File

@@ -32,6 +32,9 @@ func setup() {
}
func GetStaticFile(filename string) ([]byte, bool) {
if common.IsTest {
return nil, false
}
setupOnce.Do(setup)
return fileContentMap.Load(filename)
}

View File

@@ -6,6 +6,7 @@ import (
"path"
"path/filepath"
config "github.com/yusing/godoxy/internal/config/types"
"github.com/yusing/godoxy/internal/logging/accesslog"
gphttp "github.com/yusing/godoxy/internal/net/gphttp"
"github.com/yusing/godoxy/internal/net/gphttp/middleware"
@@ -124,8 +125,14 @@ func (s *FileServer) Start(parent task.Parent) gperr.Error {
}
routes.HTTP.Add(s)
if state := config.WorkingState.Load(); state != nil {
state.Entrypoint().ShortLinkMatcher().AddRoute(s.Alias)
}
s.task.OnFinished("remove_route_from_http", func() {
routes.HTTP.Del(s)
if state := config.WorkingState.Load(); state != nil {
state.Entrypoint().ShortLinkMatcher().DelRoute(s.Alias)
}
})
return nil
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/yusing/godoxy/agent/pkg/agent"
"github.com/yusing/godoxy/agent/pkg/agentproxy"
config "github.com/yusing/godoxy/internal/config/types"
"github.com/yusing/godoxy/internal/idlewatcher"
"github.com/yusing/godoxy/internal/logging/accesslog"
gphttp "github.com/yusing/godoxy/internal/net/gphttp"
@@ -166,8 +167,14 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error {
r.addToLoadBalancer(parent)
} else {
routes.HTTP.Add(r)
r.task.OnCancel("remove_route_from_http", func() {
if state := config.WorkingState.Load(); state != nil {
state.Entrypoint().ShortLinkMatcher().AddRoute(r.Alias)
}
r.task.OnCancel("remove_route", func() {
routes.HTTP.Del(r)
if state := config.WorkingState.Load(); state != nil {
state.Entrypoint().ShortLinkMatcher().DelRoute(r.Alias)
}
})
}
return nil
@@ -208,8 +215,14 @@ func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent) {
}
linked.SetHealthMonitor(lb)
routes.HTTP.AddKey(cfg.Link, linked)
if state := config.WorkingState.Load(); state != nil {
state.Entrypoint().ShortLinkMatcher().AddRoute(cfg.Link)
}
r.task.OnFinished("remove_loadbalancer_route", func() {
routes.HTTP.DelKey(cfg.Link)
if state := config.WorkingState.Load(); state != nil {
state.Entrypoint().ShortLinkMatcher().DelRoute(cfg.Link)
}
})
lbLock.Unlock()
}