mirror of
https://github.com/yusing/godoxy.git
synced 2026-03-17 23:03:49 +01:00
fix incorrect reload behaviors, further organize code
This commit is contained in:
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/http/accesslog"
|
||||
@@ -17,32 +16,31 @@ import (
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
var findRouteFunc = findRouteAnyDomain
|
||||
|
||||
var (
|
||||
epMiddleware *middleware.Middleware
|
||||
epMiddlewareMu sync.Mutex
|
||||
|
||||
epAccessLogger *accesslog.AccessLogger
|
||||
epAccessLoggerMu sync.Mutex
|
||||
)
|
||||
type Entrypoint struct {
|
||||
middleware *middleware.Middleware
|
||||
accessLogger *accesslog.AccessLogger
|
||||
findRouteFunc func(host string) (route.HTTPRoute, error)
|
||||
}
|
||||
|
||||
var ErrNoSuchRoute = errors.New("no such route")
|
||||
|
||||
func SetFindRouteDomains(domains []string) {
|
||||
if len(domains) == 0 {
|
||||
findRouteFunc = findRouteAnyDomain
|
||||
} else {
|
||||
findRouteFunc = findRouteByDomains(domains)
|
||||
func NewEntrypoint() *Entrypoint {
|
||||
return &Entrypoint{
|
||||
findRouteFunc: findRouteAnyDomain,
|
||||
}
|
||||
}
|
||||
|
||||
func SetMiddlewares(mws []map[string]any) error {
|
||||
epMiddlewareMu.Lock()
|
||||
defer epMiddlewareMu.Unlock()
|
||||
func (ep *Entrypoint) SetFindRouteDomains(domains []string) {
|
||||
if len(domains) == 0 {
|
||||
ep.findRouteFunc = findRouteAnyDomain
|
||||
} else {
|
||||
ep.findRouteFunc = findRouteByDomains(domains)
|
||||
}
|
||||
}
|
||||
|
||||
func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error {
|
||||
if len(mws) == 0 {
|
||||
epMiddleware = nil
|
||||
ep.middleware = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -50,22 +48,19 @@ func SetMiddlewares(mws []map[string]any) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
epMiddleware = mid
|
||||
ep.middleware = mid
|
||||
|
||||
logger.Debug().Msg("entrypoint middleware loaded")
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) {
|
||||
epAccessLoggerMu.Lock()
|
||||
defer epAccessLoggerMu.Unlock()
|
||||
|
||||
func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) {
|
||||
if cfg == nil {
|
||||
epAccessLogger = nil
|
||||
ep.accessLogger = nil
|
||||
return
|
||||
}
|
||||
|
||||
epAccessLogger, err = accesslog.NewFileAccessLogger(parent, cfg)
|
||||
ep.accessLogger, err = accesslog.NewFileAccessLogger(parent, cfg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -73,28 +68,18 @@ func SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func Handler(w http.ResponseWriter, r *http.Request) {
|
||||
mux, err := findRouteFunc(r.Host)
|
||||
func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
mux, err := ep.findRouteFunc(r.Host)
|
||||
if err == nil {
|
||||
if epAccessLogger != nil {
|
||||
epMiddlewareMu.Lock()
|
||||
if epAccessLogger != nil {
|
||||
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
|
||||
epAccessLogger.Log(r, resp)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
epMiddlewareMu.Unlock()
|
||||
if ep.accessLogger != nil {
|
||||
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
|
||||
ep.accessLogger.Log(r, resp)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
if epMiddleware != nil {
|
||||
epMiddlewareMu.Lock()
|
||||
if epMiddleware != nil {
|
||||
mid := epMiddleware
|
||||
epMiddlewareMu.Unlock()
|
||||
mid.ServeHTTP(mux.ServeHTTP, w, r)
|
||||
return
|
||||
}
|
||||
epMiddlewareMu.Unlock()
|
||||
if ep.middleware != nil {
|
||||
ep.middleware.ServeHTTP(mux.ServeHTTP, w, r)
|
||||
return
|
||||
}
|
||||
mux.ServeHTTP(w, r)
|
||||
return
|
||||
|
||||
@@ -8,18 +8,19 @@ import (
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
var r route.HTTPRoute
|
||||
var (
|
||||
r route.HTTPRoute
|
||||
ep = NewEntrypoint()
|
||||
)
|
||||
|
||||
func run(t *testing.T, match []string, noMatch []string) {
|
||||
t.Helper()
|
||||
t.Cleanup(routes.TestClear)
|
||||
t.Cleanup(func() {
|
||||
SetFindRouteDomains(nil)
|
||||
})
|
||||
t.Cleanup(func() { ep.SetFindRouteDomains(nil) })
|
||||
|
||||
for _, test := range match {
|
||||
t.Run(test, func(t *testing.T) {
|
||||
found, err := findRouteFunc(test)
|
||||
found, err := ep.findRouteFunc(test)
|
||||
ExpectNoError(t, err)
|
||||
ExpectTrue(t, found == &r)
|
||||
})
|
||||
@@ -27,7 +28,7 @@ func run(t *testing.T, match []string, noMatch []string) {
|
||||
|
||||
for _, test := range noMatch {
|
||||
t.Run(test, func(t *testing.T) {
|
||||
_, err := findRouteFunc(test)
|
||||
_, err := ep.findRouteFunc(test)
|
||||
ExpectError(t, ErrNoSuchRoute, err)
|
||||
})
|
||||
}
|
||||
@@ -72,7 +73,7 @@ func TestFindRouteExactHostMatch(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFindRouteByDomains(t *testing.T) {
|
||||
SetFindRouteDomains([]string{
|
||||
ep.SetFindRouteDomains([]string{
|
||||
".domain.com",
|
||||
".sub.domain.com",
|
||||
})
|
||||
@@ -97,7 +98,7 @@ func TestFindRouteByDomains(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFindRouteByDomainsExactMatch(t *testing.T) {
|
||||
SetFindRouteDomains([]string{
|
||||
ep.SetFindRouteDomains([]string{
|
||||
".domain.com",
|
||||
".sub.domain.com",
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user