mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-20 23:41:23 +02:00
refactor(config): restructured with better concurrency and error handling, reduced cross referencing
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
package entrypoint
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/godoxy/internal/logging/accesslog"
|
||||
@@ -16,15 +15,21 @@ import (
|
||||
)
|
||||
|
||||
type Entrypoint struct {
|
||||
middleware *middleware.Middleware
|
||||
accessLogger *accesslog.AccessLogger
|
||||
findRouteFunc func(host string) (types.HTTPRoute, error)
|
||||
middleware *middleware.Middleware
|
||||
accessLogger *accesslog.AccessLogger
|
||||
findRouteFunc func(host string) types.HTTPRoute
|
||||
}
|
||||
|
||||
var ErrNoSuchRoute = errors.New("no such route")
|
||||
// nil-safe
|
||||
var ActiveConfig atomic.Pointer[entrypoint.Config]
|
||||
|
||||
func NewEntrypoint() *Entrypoint {
|
||||
return &Entrypoint{
|
||||
func init() {
|
||||
// make sure it's not nil
|
||||
ActiveConfig.Store(&entrypoint.Config{})
|
||||
}
|
||||
|
||||
func NewEntrypoint() Entrypoint {
|
||||
return Entrypoint{
|
||||
findRouteFunc: findRouteAnyDomain,
|
||||
}
|
||||
}
|
||||
@@ -72,8 +77,10 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w = accesslog.NewResponseRecorder(w)
|
||||
defer ep.accessLogger.Log(r, w.(*accesslog.ResponseRecorder).Response())
|
||||
}
|
||||
route, err := ep.findRouteFunc(r.Host)
|
||||
if err == nil {
|
||||
|
||||
route := ep.findRouteFunc(r.Host)
|
||||
switch {
|
||||
case route != nil:
|
||||
r = routes.WithRouteContext(r, route)
|
||||
if ep.middleware != nil {
|
||||
ep.middleware.ServeHTTP(route.ServeHTTP, w, r)
|
||||
@@ -87,11 +94,11 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Then scraper / scanners will know the subdomain is invalid.
|
||||
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
|
||||
if served := middleware.ServeStaticErrorPageFile(w, r); !served {
|
||||
log.Err(err).
|
||||
log.Error().
|
||||
Str("method", r.Method).
|
||||
Str("url", r.URL.String()).
|
||||
Str("remote", r.RemoteAddr).
|
||||
Msg("request")
|
||||
Msgf("not found: %s", r.Host)
|
||||
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
|
||||
if ok {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
@@ -100,39 +107,39 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
log.Err(err).Msg("failed to write error page")
|
||||
}
|
||||
} else {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func findRouteAnyDomain(host string) (types.HTTPRoute, error) {
|
||||
func findRouteAnyDomain(host string) types.HTTPRoute {
|
||||
idx := strings.IndexByte(host, '.')
|
||||
if idx != -1 {
|
||||
target := host[:idx]
|
||||
if r, ok := routes.HTTP.Get(target); ok {
|
||||
return r, nil
|
||||
return r
|
||||
}
|
||||
}
|
||||
if r, ok := routes.HTTP.Get(host); ok {
|
||||
return r, nil
|
||||
return r
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %s", ErrNoSuchRoute, host)
|
||||
return nil
|
||||
}
|
||||
|
||||
func findRouteByDomains(domains []string) func(host string) (types.HTTPRoute, error) {
|
||||
return func(host string) (types.HTTPRoute, error) {
|
||||
func findRouteByDomains(domains []string) func(host string) types.HTTPRoute {
|
||||
return func(host string) types.HTTPRoute {
|
||||
for _, domain := range domains {
|
||||
if target, ok := strings.CutSuffix(host, domain); ok {
|
||||
if r, ok := routes.HTTP.Get(target); ok {
|
||||
return r, nil
|
||||
return r
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fallback to exact match
|
||||
if r, ok := routes.HTTP.Get(host); ok {
|
||||
return r, nil
|
||||
return r
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %s", ErrNoSuchRoute, host)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,16 +29,15 @@ func run(t *testing.T, match []string, noMatch []string) {
|
||||
|
||||
for _, test := range match {
|
||||
t.Run(test, func(t *testing.T) {
|
||||
found, err := ep.findRouteFunc(test)
|
||||
expect.NoError(t, err)
|
||||
found := ep.findRouteFunc(test)
|
||||
expect.NotNil(t, found)
|
||||
})
|
||||
}
|
||||
|
||||
for _, test := range noMatch {
|
||||
t.Run(test, func(t *testing.T) {
|
||||
_, err := ep.findRouteFunc(test)
|
||||
expect.ErrorIs(t, ErrNoSuchRoute, err)
|
||||
found := ep.findRouteFunc(test)
|
||||
expect.Nil(t, found)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
16
internal/entrypoint/types/config.go
Normal file
16
internal/entrypoint/types/config.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package entrypoint
|
||||
|
||||
import (
|
||||
"github.com/yusing/godoxy/internal/logging/accesslog"
|
||||
"github.com/yusing/godoxy/internal/route/rules"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
SupportProxyProtocol bool `json:"support_proxy_protocol"`
|
||||
Rules struct {
|
||||
CatchAll rules.Rules `json:"catch_all"`
|
||||
NotFound rules.Rules `json:"not_found"`
|
||||
} `json:"rules"`
|
||||
Middlewares []map[string]any `json:"middlewares"`
|
||||
AccessLog *accesslog.RequestLoggerConfig `json:"access_log" validate:"omitempty"`
|
||||
}
|
||||
Reference in New Issue
Block a user