refactor(config): restructured with better concurrency and error handling, reduced cross referencing

This commit is contained in:
yusing
2025-10-09 01:02:24 +08:00
parent d08be872a0
commit cab68807ee
25 changed files with 720 additions and 623 deletions

View File

@@ -1,10 +1,9 @@
package entrypoint
import (
"errors"
"fmt"
"net/http"
"strings"
"sync/atomic"
"github.com/rs/zerolog/log"
"github.com/yusing/godoxy/internal/logging/accesslog"
@@ -16,15 +15,21 @@ import (
)
type Entrypoint struct {
middleware *middleware.Middleware
accessLogger *accesslog.AccessLogger
findRouteFunc func(host string) (types.HTTPRoute, error)
middleware *middleware.Middleware
accessLogger *accesslog.AccessLogger
findRouteFunc func(host string) types.HTTPRoute
}
var ErrNoSuchRoute = errors.New("no such route")
// nil-safe
var ActiveConfig atomic.Pointer[entrypoint.Config]
func NewEntrypoint() *Entrypoint {
return &Entrypoint{
func init() {
// make sure it's not nil
ActiveConfig.Store(&entrypoint.Config{})
}
func NewEntrypoint() Entrypoint {
return Entrypoint{
findRouteFunc: findRouteAnyDomain,
}
}
@@ -72,8 +77,10 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w = accesslog.NewResponseRecorder(w)
defer ep.accessLogger.Log(r, w.(*accesslog.ResponseRecorder).Response())
}
route, err := ep.findRouteFunc(r.Host)
if err == nil {
route := ep.findRouteFunc(r.Host)
switch {
case route != nil:
r = routes.WithRouteContext(r, route)
if ep.middleware != nil {
ep.middleware.ServeHTTP(route.ServeHTTP, w, r)
@@ -87,11 +94,11 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Then scraper / scanners will know the subdomain is invalid.
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
if served := middleware.ServeStaticErrorPageFile(w, r); !served {
log.Err(err).
log.Error().
Str("method", r.Method).
Str("url", r.URL.String()).
Str("remote", r.RemoteAddr).
Msg("request")
Msgf("not found: %s", r.Host)
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
if ok {
w.WriteHeader(http.StatusNotFound)
@@ -100,39 +107,39 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Err(err).Msg("failed to write error page")
}
} else {
http.Error(w, err.Error(), http.StatusNotFound)
http.NotFound(w, r)
}
}
}
func findRouteAnyDomain(host string) (types.HTTPRoute, error) {
func findRouteAnyDomain(host string) types.HTTPRoute {
idx := strings.IndexByte(host, '.')
if idx != -1 {
target := host[:idx]
if r, ok := routes.HTTP.Get(target); ok {
return r, nil
return r
}
}
if r, ok := routes.HTTP.Get(host); ok {
return r, nil
return r
}
return nil, fmt.Errorf("%w: %s", ErrNoSuchRoute, host)
return nil
}
func findRouteByDomains(domains []string) func(host string) (types.HTTPRoute, error) {
return func(host string) (types.HTTPRoute, error) {
func findRouteByDomains(domains []string) func(host string) types.HTTPRoute {
return func(host string) types.HTTPRoute {
for _, domain := range domains {
if target, ok := strings.CutSuffix(host, domain); ok {
if r, ok := routes.HTTP.Get(target); ok {
return r, nil
return r
}
}
}
// fallback to exact match
if r, ok := routes.HTTP.Get(host); ok {
return r, nil
return r
}
return nil, fmt.Errorf("%w: %s", ErrNoSuchRoute, host)
return nil
}
}

View File

@@ -29,16 +29,15 @@ func run(t *testing.T, match []string, noMatch []string) {
for _, test := range match {
t.Run(test, func(t *testing.T) {
found, err := ep.findRouteFunc(test)
expect.NoError(t, err)
found := ep.findRouteFunc(test)
expect.NotNil(t, found)
})
}
for _, test := range noMatch {
t.Run(test, func(t *testing.T) {
_, err := ep.findRouteFunc(test)
expect.ErrorIs(t, ErrNoSuchRoute, err)
found := ep.findRouteFunc(test)
expect.Nil(t, found)
})
}
}

View File

@@ -0,0 +1,16 @@
package entrypoint
import (
"github.com/yusing/godoxy/internal/logging/accesslog"
"github.com/yusing/godoxy/internal/route/rules"
)
type Config struct {
SupportProxyProtocol bool `json:"support_proxy_protocol"`
Rules struct {
CatchAll rules.Rules `json:"catch_all"`
NotFound rules.Rules `json:"not_found"`
} `json:"rules"`
Middlewares []map[string]any `json:"middlewares"`
AccessLog *accesslog.RequestLoggerConfig `json:"access_log" validate:"omitempty"`
}