fix incorrect reload behaviors, further organize code

This commit is contained in:
yusing
2025-01-09 04:26:00 +08:00
parent 8bbb5d2e09
commit b3c47e759f
26 changed files with 418 additions and 336 deletions

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"net/http"
"strings"
"sync"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/accesslog"
@@ -17,32 +16,31 @@ import (
"github.com/yusing/go-proxy/internal/utils/strutils"
)
var findRouteFunc = findRouteAnyDomain
var (
epMiddleware *middleware.Middleware
epMiddlewareMu sync.Mutex
epAccessLogger *accesslog.AccessLogger
epAccessLoggerMu sync.Mutex
)
type Entrypoint struct {
middleware *middleware.Middleware
accessLogger *accesslog.AccessLogger
findRouteFunc func(host string) (route.HTTPRoute, error)
}
var ErrNoSuchRoute = errors.New("no such route")
func SetFindRouteDomains(domains []string) {
if len(domains) == 0 {
findRouteFunc = findRouteAnyDomain
} else {
findRouteFunc = findRouteByDomains(domains)
func NewEntrypoint() *Entrypoint {
return &Entrypoint{
findRouteFunc: findRouteAnyDomain,
}
}
func SetMiddlewares(mws []map[string]any) error {
epMiddlewareMu.Lock()
defer epMiddlewareMu.Unlock()
func (ep *Entrypoint) SetFindRouteDomains(domains []string) {
if len(domains) == 0 {
ep.findRouteFunc = findRouteAnyDomain
} else {
ep.findRouteFunc = findRouteByDomains(domains)
}
}
func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error {
if len(mws) == 0 {
epMiddleware = nil
ep.middleware = nil
return nil
}
@@ -50,22 +48,19 @@ func SetMiddlewares(mws []map[string]any) error {
if err != nil {
return err
}
epMiddleware = mid
ep.middleware = mid
logger.Debug().Msg("entrypoint middleware loaded")
return nil
}
func SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) {
epAccessLoggerMu.Lock()
defer epAccessLoggerMu.Unlock()
func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) {
if cfg == nil {
epAccessLogger = nil
ep.accessLogger = nil
return
}
epAccessLogger, err = accesslog.NewFileAccessLogger(parent, cfg)
ep.accessLogger, err = accesslog.NewFileAccessLogger(parent, cfg)
if err != nil {
return
}
@@ -73,28 +68,18 @@ func SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) {
return
}
func Handler(w http.ResponseWriter, r *http.Request) {
mux, err := findRouteFunc(r.Host)
func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
mux, err := ep.findRouteFunc(r.Host)
if err == nil {
if epAccessLogger != nil {
epMiddlewareMu.Lock()
if epAccessLogger != nil {
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
epAccessLogger.Log(r, resp)
return nil
})
}
epMiddlewareMu.Unlock()
if ep.accessLogger != nil {
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
ep.accessLogger.Log(r, resp)
return nil
})
}
if epMiddleware != nil {
epMiddlewareMu.Lock()
if epMiddleware != nil {
mid := epMiddleware
epMiddlewareMu.Unlock()
mid.ServeHTTP(mux.ServeHTTP, w, r)
return
}
epMiddlewareMu.Unlock()
if ep.middleware != nil {
ep.middleware.ServeHTTP(mux.ServeHTTP, w, r)
return
}
mux.ServeHTTP(w, r)
return

View File

@@ -8,18 +8,19 @@ import (
. "github.com/yusing/go-proxy/internal/utils/testing"
)
var r route.HTTPRoute
var (
r route.HTTPRoute
ep = NewEntrypoint()
)
func run(t *testing.T, match []string, noMatch []string) {
t.Helper()
t.Cleanup(routes.TestClear)
t.Cleanup(func() {
SetFindRouteDomains(nil)
})
t.Cleanup(func() { ep.SetFindRouteDomains(nil) })
for _, test := range match {
t.Run(test, func(t *testing.T) {
found, err := findRouteFunc(test)
found, err := ep.findRouteFunc(test)
ExpectNoError(t, err)
ExpectTrue(t, found == &r)
})
@@ -27,7 +28,7 @@ func run(t *testing.T, match []string, noMatch []string) {
for _, test := range noMatch {
t.Run(test, func(t *testing.T) {
_, err := findRouteFunc(test)
_, err := ep.findRouteFunc(test)
ExpectError(t, ErrNoSuchRoute, err)
})
}
@@ -72,7 +73,7 @@ func TestFindRouteExactHostMatch(t *testing.T) {
}
func TestFindRouteByDomains(t *testing.T) {
SetFindRouteDomains([]string{
ep.SetFindRouteDomains([]string{
".domain.com",
".sub.domain.com",
})
@@ -97,7 +98,7 @@ func TestFindRouteByDomains(t *testing.T) {
}
func TestFindRouteByDomainsExactMatch(t *testing.T) {
SetFindRouteDomains([]string{
ep.SetFindRouteDomains([]string{
".domain.com",
".sub.domain.com",
})