mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-17 14:09:44 +02:00
v0.26.0
This commit is contained in:
@@ -102,7 +102,7 @@ type LoadBalancer struct {
|
||||
func New(cfg *types.LoadBalancerConfig) *LoadBalancer
|
||||
|
||||
// Start the load balancer as a background task
|
||||
func (lb *LoadBalancer) Start(parent task.Parent) gperr.Error
|
||||
func (lb *LoadBalancer) Start(parent task.Parent) error
|
||||
|
||||
// Update configuration dynamically
|
||||
func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *types.LoadBalancerConfig)
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/bytedance/gopkg/util/xxhash3"
|
||||
"github.com/yusing/godoxy/internal/net/gphttp/middleware"
|
||||
"github.com/yusing/godoxy/internal/types"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
type ipHash struct {
|
||||
@@ -28,10 +27,10 @@ func (lb *LoadBalancer) newIPHash() impl {
|
||||
if len(lb.Options) == 0 {
|
||||
return impl
|
||||
}
|
||||
var err gperr.Error
|
||||
var err error
|
||||
impl.realIP, err = middleware.RealIP.New(lb.Options)
|
||||
if err != nil {
|
||||
gperr.LogError("invalid real_ip options, ignoring", err, &impl.l)
|
||||
impl.l.Err(err).Msg("invalid real_ip options, ignoring")
|
||||
}
|
||||
return impl
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ const maxWeight int = 100
|
||||
func New(cfg *types.LoadBalancerConfig) *LoadBalancer {
|
||||
lb := &LoadBalancer{
|
||||
LoadBalancerConfig: cfg,
|
||||
pool: pool.New[types.LoadBalancerServer]("loadbalancer." + cfg.Link),
|
||||
pool: pool.New[types.LoadBalancerServer]("loadbalancer."+cfg.Link, "loadbalancers"),
|
||||
l: log.With().Str("name", cfg.Link).Logger(),
|
||||
}
|
||||
lb.UpdateConfigIfNeeded(cfg)
|
||||
@@ -56,7 +56,7 @@ func New(cfg *types.LoadBalancerConfig) *LoadBalancer {
|
||||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (lb *LoadBalancer) Start(parent task.Parent) gperr.Error {
|
||||
func (lb *LoadBalancer) Start(parent task.Parent) error {
|
||||
lb.startTime = time.Now()
|
||||
lb.task = parent.Subtask("loadbalancer."+lb.Link, true)
|
||||
lb.task.OnCancel("cleanup", func() {
|
||||
@@ -234,7 +234,7 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
if err := errs.Wait().Error(); err != nil {
|
||||
gperr.LogWarn("failed to wake some servers", err, &lb.l)
|
||||
lb.l.Warn().Err(err).Msg("failed to wake some servers")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/yusing/godoxy/internal/route"
|
||||
routeTypes "github.com/yusing/godoxy/internal/route/types"
|
||||
"github.com/yusing/goutils/http/reverseproxy"
|
||||
"github.com/yusing/goutils/task"
|
||||
expect "github.com/yusing/goutils/testing"
|
||||
)
|
||||
|
||||
@@ -40,7 +39,7 @@ func TestBypassCIDR(t *testing.T) {
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
req.RemoteAddr = test.remoteAddr
|
||||
recorder := httptest.NewRecorder()
|
||||
mr.ModifyRequest(noOpHandler, recorder, req)
|
||||
@@ -76,7 +75,7 @@ func TestBypassPath(t *testing.T) {
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com"+test.path, nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com"+test.path, nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
mr.ModifyRequest(noOpHandler, recorder, req)
|
||||
expect.NoError(t, err)
|
||||
@@ -126,7 +125,7 @@ func TestReverseProxyBypass(t *testing.T) {
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com"+test.path, nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com"+test.path, nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
rp.ServeHTTP(recorder, req)
|
||||
if test.expectBypass {
|
||||
@@ -160,7 +159,7 @@ func TestBypassResponse(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com"+test.path, nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com"+test.path, nil)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader("test")),
|
||||
@@ -201,7 +200,7 @@ func TestBypassResponse(t *testing.T) {
|
||||
StatusCode: test.statusCode,
|
||||
Body: io.NopCloser(strings.NewReader("test")),
|
||||
Header: make(http.Header),
|
||||
Request: httptest.NewRequest("GET", "http://example.com", nil),
|
||||
Request: httptest.NewRequest(http.MethodGet, "http://example.com", nil),
|
||||
}
|
||||
mErr := mr.ModifyResponse(resp)
|
||||
expect.NoError(t, mErr)
|
||||
@@ -230,15 +229,17 @@ func TestEntrypointBypassRoute(t *testing.T) {
|
||||
portInt, err := strconv.Atoi(port)
|
||||
expect.NoError(t, err)
|
||||
|
||||
expect.NoError(t, err)
|
||||
entry := entrypoint.NewEntrypoint()
|
||||
r := &route.Route{
|
||||
Alias: "test-route",
|
||||
Host: host,
|
||||
entry := entrypoint.NewTestEntrypoint(t, nil)
|
||||
_, err = route.NewStartedTestRoute(t, &route.Route{
|
||||
Alias: "test-route",
|
||||
Scheme: routeTypes.SchemeHTTP,
|
||||
Host: host,
|
||||
Port: routeTypes.Port{
|
||||
Proxy: portInt,
|
||||
Listening: 1000,
|
||||
Proxy: portInt,
|
||||
},
|
||||
}
|
||||
})
|
||||
expect.NoError(t, err)
|
||||
|
||||
err = entry.SetMiddlewares([]map[string]any{
|
||||
{
|
||||
@@ -254,13 +255,13 @@ func TestEntrypointBypassRoute(t *testing.T) {
|
||||
})
|
||||
expect.NoError(t, err)
|
||||
|
||||
err = r.Validate()
|
||||
expect.NoError(t, err)
|
||||
r.Start(task.RootTask("test", false))
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "http://test-route.example.com", nil)
|
||||
entry.ServeHTTP(recorder, req)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://test-route.example.com", nil)
|
||||
server, ok := entry.GetServer(":1000")
|
||||
if !ok {
|
||||
t.Fatal("server not found")
|
||||
}
|
||||
server.ServeHTTP(recorder, req)
|
||||
expect.Equal(t, recorder.Code, http.StatusOK, "should bypass http redirect")
|
||||
expect.Equal(t, recorder.Body.String(), "test")
|
||||
expect.Equal(t, recorder.Header().Get("Test-Header"), "test-value")
|
||||
|
||||
@@ -249,7 +249,7 @@ The package includes an embedded HTML template (`captcha.html`) that renders the
|
||||
## Error Handling
|
||||
|
||||
```go
|
||||
var ErrCaptchaVerificationFailed = gperr.New("captcha verification failed")
|
||||
var ErrCaptchaVerificationFailed = errors.New("captcha verification failed")
|
||||
|
||||
// Verification errors are logged with request details
|
||||
log.Warn().Err(err).Str("url", r.URL.String()).Str("remote_addr", r.RemoteAddr).Msg("failed to verify captcha")
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
package captcha
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
type Provider interface {
|
||||
@@ -16,4 +15,4 @@ type Provider interface {
|
||||
FormHTML() string
|
||||
}
|
||||
|
||||
var ErrCaptchaVerificationFailed = gperr.New("captcha verification failed")
|
||||
var ErrCaptchaVerificationFailed = errors.New("captcha verification failed")
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
httpevents "github.com/yusing/goutils/events/http"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
)
|
||||
|
||||
@@ -71,6 +72,7 @@ func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
|
||||
}
|
||||
}
|
||||
if !allow {
|
||||
defer httpevents.Blocked(r, "CIDRWhitelist", "IP not allowed")
|
||||
http.Error(w, wl.Message, wl.StatusCode)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*nettypes.CIDR) error {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req) //nolint:gosec
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -11,7 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/godoxy/internal/route/routes"
|
||||
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
ioutils "github.com/yusing/goutils/io"
|
||||
)
|
||||
@@ -48,7 +49,7 @@ func (m *crowdsecMiddleware) setup() {
|
||||
|
||||
func (m *crowdsecMiddleware) finalize() error {
|
||||
if !strings.HasPrefix(m.Endpoint, "/") {
|
||||
return fmt.Errorf("endpoint must start with /")
|
||||
return errors.New("endpoint must start with /")
|
||||
}
|
||||
if m.Timeout == 0 {
|
||||
m.Timeout = 5 * time.Second
|
||||
@@ -66,7 +67,7 @@ func (m *crowdsecMiddleware) finalize() error {
|
||||
// before implements RequestModifier.
|
||||
func (m *crowdsecMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
// Build CrowdSec URL
|
||||
crowdsecURL, err := m.buildCrowdSecURL()
|
||||
crowdsecURL, err := m.buildCrowdSecURL(r.Context())
|
||||
if err != nil {
|
||||
Crowdsec.LogError(r).Err(err).Msg("failed to build CrowdSec URL")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
@@ -167,10 +168,10 @@ func (m *crowdsecMiddleware) before(w http.ResponseWriter, r *http.Request) (pro
|
||||
}
|
||||
|
||||
// buildCrowdSecURL constructs the CrowdSec server URL based on route or IP configuration
|
||||
func (m *crowdsecMiddleware) buildCrowdSecURL() (string, error) {
|
||||
func (m *crowdsecMiddleware) buildCrowdSecURL(ctx context.Context) (string, error) {
|
||||
// Try to get route first
|
||||
if m.Route != "" {
|
||||
if route, ok := routes.HTTP.Get(m.Route); ok {
|
||||
if route, ok := entrypoint.FromCtx(ctx).GetRoute(m.Route); ok {
|
||||
// Using route name
|
||||
targetURL := *route.TargetURL()
|
||||
targetURL.Path = m.Endpoint
|
||||
@@ -179,12 +180,12 @@ func (m *crowdsecMiddleware) buildCrowdSecURL() (string, error) {
|
||||
|
||||
// If not found in routes, assume it's an IP address
|
||||
if m.Port == 0 {
|
||||
return "", fmt.Errorf("port must be specified when using IP address")
|
||||
return "", errors.New("port must be specified when using IP address")
|
||||
}
|
||||
return fmt.Sprintf("http://%s%s", net.JoinHostPort(m.Route, strconv.Itoa(m.Port)), m.Endpoint), nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("route or IP address must be specified")
|
||||
return "", errors.New("route or IP address must be specified")
|
||||
}
|
||||
|
||||
func (m *crowdsecMiddleware) getHTTPVersion(r *http.Request) string {
|
||||
|
||||
@@ -10,8 +10,7 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/godoxy/internal/common"
|
||||
"github.com/yusing/godoxy/internal/watcher"
|
||||
"github.com/yusing/godoxy/internal/watcher/events"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
watcherEvents "github.com/yusing/godoxy/internal/watcher/events"
|
||||
"github.com/yusing/goutils/fs"
|
||||
"github.com/yusing/goutils/task"
|
||||
)
|
||||
@@ -81,19 +80,19 @@ func watchDir() {
|
||||
}
|
||||
filename := event.ActorName
|
||||
switch event.Action {
|
||||
case events.ActionFileWritten:
|
||||
case watcherEvents.ActionFileWritten:
|
||||
fileContentMap.Delete(filename)
|
||||
loadContent()
|
||||
case events.ActionFileDeleted:
|
||||
case watcherEvents.ActionFileDeleted:
|
||||
fileContentMap.Delete(filename)
|
||||
log.Warn().Msgf("error page resource %s deleted", filename)
|
||||
case events.ActionFileRenamed:
|
||||
case watcherEvents.ActionFileRenamed:
|
||||
log.Warn().Msgf("error page resource %s deleted", filename)
|
||||
fileContentMap.Delete(filename)
|
||||
loadContent()
|
||||
}
|
||||
case err := <-errCh:
|
||||
gperr.LogError("error watching error page directory", err)
|
||||
log.Err(err).Msg("error watching error page directory")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,12 +3,14 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/godoxy/internal/route/routes"
|
||||
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
|
||||
httpevents "github.com/yusing/goutils/events/http"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
"github.com/yusing/goutils/http/httpheaders"
|
||||
)
|
||||
@@ -46,7 +48,7 @@ func (m *forwardAuthMiddleware) setup() {
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (m *forwardAuthMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
route, ok := routes.HTTP.Get(m.Route)
|
||||
route, ok := entrypoint.FromCtx(r.Context()).HTTPRoutes().Get(m.Route)
|
||||
if !ok {
|
||||
ForwardAuth.LogWarn(r).Str("route", m.Route).Msg("forwardauth route not found")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
@@ -92,6 +94,8 @@ func (m *forwardAuthMiddleware) before(w http.ResponseWriter, r *http.Request) (
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
defer httpevents.Blocked(r, "ForwardAuth", fmt.Sprintf("HTTP %d", resp.StatusCode))
|
||||
|
||||
body, release, err := httputils.ReadAllBody(resp)
|
||||
defer release(body)
|
||||
|
||||
@@ -100,10 +104,23 @@ func (m *forwardAuthMiddleware) before(w http.ResponseWriter, r *http.Request) (
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
|
||||
httpheaders.CopyHeader(w.Header(), resp.Header)
|
||||
httpheaders.RemoveHopByHopHeaders(w.Header())
|
||||
|
||||
isGet := r.Method == http.MethodGet
|
||||
isWS := httpheaders.IsWebsocket(r.Header)
|
||||
if !isGet || isWS {
|
||||
reqType := r.Method
|
||||
if isWS {
|
||||
reqType = "WebSocket"
|
||||
}
|
||||
ForwardAuth.LogWarn(r).Msgf(
|
||||
"[ForwardAuth] %s request rejected by auth upstream (HTTP %d).\nConsider adding bypass rule for this path if needed",
|
||||
reqType,
|
||||
resp.StatusCode,
|
||||
)
|
||||
}
|
||||
|
||||
loc, err := resp.Location()
|
||||
if err != nil {
|
||||
if !errors.Is(err, http.ErrNoLocation) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"reflect"
|
||||
@@ -10,15 +11,12 @@ import (
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
"github.com/yusing/goutils/http/httpheaders"
|
||||
"github.com/yusing/goutils/http/reverseproxy"
|
||||
)
|
||||
|
||||
type (
|
||||
Error = gperr.Error
|
||||
|
||||
ReverseProxy = reverseproxy.ReverseProxy
|
||||
ProxyRequest = reverseproxy.ProxyRequest
|
||||
|
||||
@@ -87,7 +85,7 @@ func (m *Middleware) setup() {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error {
|
||||
func (m *Middleware) apply(optsRaw OptionsRaw) error {
|
||||
if len(optsRaw) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -120,10 +118,10 @@ func (m *Middleware) finalize() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, gperr.Error) {
|
||||
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, error) {
|
||||
if m.construct == nil { // likely a middleware from compose
|
||||
if len(optsRaw) != 0 {
|
||||
return nil, gperr.New("additional options not allowed for middleware").Subject(m.name)
|
||||
return nil, fmt.Errorf("additional options not allowed for middleware %s", m.name)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
@@ -133,7 +131,7 @@ func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, gperr.Error) {
|
||||
return nil, err
|
||||
}
|
||||
if err := mid.finalize(); err != nil {
|
||||
return nil, gperr.Wrap(err)
|
||||
return nil, err
|
||||
}
|
||||
mid.impl = mid.withCheckBypass()
|
||||
return mid, nil
|
||||
@@ -252,14 +250,13 @@ func (m *Middleware) LogError(req *http.Request) *zerolog.Event {
|
||||
Str("path", req.URL.Path)
|
||||
}
|
||||
|
||||
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err gperr.Error) {
|
||||
var middlewares []*Middleware
|
||||
middlewares, err = compileMiddlewares(middlewaresMap)
|
||||
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) error {
|
||||
middlewares, err := compileMiddlewares(middlewaresMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
patchReverseProxy(rp, middlewares)
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"path"
|
||||
"sort"
|
||||
@@ -10,7 +12,7 @@ import (
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
var ErrMissingMiddlewareUse = gperr.New("missing middleware 'use' field")
|
||||
var ErrMissingMiddlewareUse = errors.New("missing middleware 'use' field")
|
||||
|
||||
func BuildMiddlewaresFromComposeFile(filePath string, eb *gperr.Builder) map[string]*Middleware {
|
||||
fileContent, err := os.ReadFile(filePath)
|
||||
@@ -32,7 +34,7 @@ func BuildMiddlewaresFromYAML(source string, data []byte, eb *gperr.Builder) map
|
||||
for name, defs := range rawMap {
|
||||
chain, err := BuildMiddlewareFromChainRaw(name, defs)
|
||||
if err != nil {
|
||||
eb.Add(err.Subject(source))
|
||||
eb.AddSubject(err, source)
|
||||
} else {
|
||||
middlewares[name+"@file"] = chain
|
||||
}
|
||||
@@ -40,7 +42,7 @@ func BuildMiddlewaresFromYAML(source string, data []byte, eb *gperr.Builder) map
|
||||
return middlewares
|
||||
}
|
||||
|
||||
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, gperr.Error) {
|
||||
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, error) {
|
||||
middlewares := make([]*Middleware, 0, len(middlewaresMap))
|
||||
|
||||
var errs gperr.Builder
|
||||
@@ -68,7 +70,7 @@ func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, gp
|
||||
return middlewares, errs.Error()
|
||||
}
|
||||
|
||||
func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (*Middleware, gperr.Error) {
|
||||
func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (*Middleware, error) {
|
||||
compiled, err := compileMiddlewares(middlewaresMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -77,7 +79,7 @@ func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (
|
||||
}
|
||||
|
||||
// TODO: check conflict or duplicates.
|
||||
func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middleware, gperr.Error) {
|
||||
func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middleware, error) {
|
||||
var chainErr gperr.Builder
|
||||
chain := make([]*Middleware, 0, len(defs))
|
||||
for i, def := range defs {
|
||||
@@ -91,6 +93,7 @@ func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middlewar
|
||||
chainErr.AddSubjectf(err, "%s[%d]", name, i)
|
||||
continue
|
||||
}
|
||||
def = maps.Clone(def)
|
||||
delete(def, "use")
|
||||
m, err := base.New(def)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
@@ -47,7 +48,7 @@ func (m *middlewareChain) modifyResponse(resp *http.Response) error {
|
||||
}
|
||||
for i, mr := range m.modResps {
|
||||
if err := mr.modifyResponse(resp); err != nil {
|
||||
return gperr.Wrap(err).Subjectf("%d", i)
|
||||
return gperr.PrependSubject(err, strconv.Itoa(i))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -44,15 +44,14 @@ var allMiddlewares = map[string]*Middleware{
|
||||
}
|
||||
|
||||
var (
|
||||
ErrUnknownMiddleware = gperr.New("unknown middleware")
|
||||
ErrMiddlewareAlreadyExists = gperr.New("middleware with the same name already exists")
|
||||
ErrUnknownMiddleware = errors.New("unknown middleware")
|
||||
ErrMiddlewareAlreadyExists = errors.New("middleware with the same name already exists")
|
||||
)
|
||||
|
||||
func Get(name string) (*Middleware, Error) {
|
||||
func Get(name string) (*Middleware, error) {
|
||||
middleware, ok := allMiddlewares[strutils.ToLowerNoSnake(name)]
|
||||
if !ok {
|
||||
return nil, ErrUnknownMiddleware.
|
||||
Subject(name).
|
||||
return nil, gperr.PrependSubject(ErrUnknownMiddleware, name).
|
||||
With(gperr.DoYouMeanField(name, allMiddlewares))
|
||||
}
|
||||
return middleware, nil
|
||||
@@ -63,7 +62,7 @@ func All() map[string]*Middleware {
|
||||
}
|
||||
|
||||
func LoadComposeFiles() {
|
||||
errs := gperr.NewBuilder("middleware compile errors")
|
||||
var errs gperr.Builder
|
||||
middlewareDefs, err := fsutils.ListFiles(common.MiddlewareComposeBasePath, 0)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
@@ -81,7 +80,7 @@ func LoadComposeFiles() {
|
||||
for name, m := range mws {
|
||||
name = strutils.ToLowerNoSnake(name)
|
||||
if _, ok := allMiddlewares[name]; ok {
|
||||
errs.Add(ErrMiddlewareAlreadyExists.Subject(name))
|
||||
errs.AddSubject(ErrMiddlewareAlreadyExists, name)
|
||||
continue
|
||||
}
|
||||
allMiddlewares[name] = m
|
||||
@@ -111,6 +110,6 @@ func LoadComposeFiles() {
|
||||
}
|
||||
}
|
||||
if errs.HasError() {
|
||||
gperr.LogError(errs.About(), errs.Error())
|
||||
log.Err(errs.Error()).Msg("middleware compile errors")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,8 +7,10 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/godoxy/internal/auth"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
httpevents "github.com/yusing/goutils/events/http"
|
||||
"github.com/yusing/goutils/http/httpheaders"
|
||||
)
|
||||
|
||||
type oidcMiddleware struct {
|
||||
@@ -28,7 +30,7 @@ var OIDC = NewMiddleware[oidcMiddleware]()
|
||||
|
||||
func (amw *oidcMiddleware) finalize() error {
|
||||
if !auth.IsOIDCEnabled() {
|
||||
return gperr.New("OIDC not enabled but OIDC middleware is used")
|
||||
log.Error().Msg("OIDC not enabled but OIDC middleware is used")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -97,6 +99,10 @@ func (amw *oidcMiddleware) initSlow() error {
|
||||
}
|
||||
|
||||
func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
if !auth.IsOIDCEnabled() {
|
||||
return true
|
||||
}
|
||||
|
||||
if err := amw.init(); err != nil {
|
||||
// no need to log here, main OIDC should've already failed and logged
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
@@ -105,7 +111,7 @@ func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proce
|
||||
|
||||
if r.URL.Path == auth.OIDCLogoutPath {
|
||||
amw.auth.LogoutHandler(w, r)
|
||||
return true
|
||||
return false
|
||||
}
|
||||
|
||||
err := amw.auth.CheckToken(r)
|
||||
@@ -113,11 +119,31 @@ func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proce
|
||||
return true
|
||||
}
|
||||
|
||||
emitBlockedEvent := func() {
|
||||
if r.Method != http.MethodHead {
|
||||
httpevents.Blocked(r, "OIDC", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
isGet := r.Method == http.MethodGet
|
||||
isWS := httpheaders.IsWebsocket(r.Header)
|
||||
switch {
|
||||
case r.Method == http.MethodHead:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
case !isGet, isWS:
|
||||
http.Error(w, err.Error(), http.StatusForbidden)
|
||||
reqType := r.Method
|
||||
if isWS {
|
||||
reqType = "WebSocket"
|
||||
}
|
||||
OIDC.LogWarn(r).Msgf("[OIDC] %s request blocked.\nConsider adding bypass rule for this path if needed", reqType)
|
||||
emitBlockedEvent()
|
||||
return false
|
||||
case errors.Is(err, auth.ErrMissingOAuthToken):
|
||||
amw.auth.HandleAuth(w, r)
|
||||
default:
|
||||
auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), "Logout", auth.OIDCLogoutPath)
|
||||
emitBlockedEvent()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
|
||||
"github.com/yusing/godoxy/internal/common"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
"github.com/yusing/goutils/http/reverseproxy"
|
||||
)
|
||||
|
||||
@@ -121,7 +120,7 @@ func (args *testArgs) bodyReader() io.Reader {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, gperr.Error) {
|
||||
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, error) {
|
||||
if args == nil {
|
||||
args = new(testArgs)
|
||||
}
|
||||
@@ -135,7 +134,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, gpe
|
||||
return newMiddlewaresTest([]*Middleware{mid}, args)
|
||||
}
|
||||
|
||||
func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, gperr.Error) {
|
||||
func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, error) {
|
||||
if args == nil {
|
||||
args = new(testArgs)
|
||||
}
|
||||
@@ -160,7 +159,7 @@ func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult,
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, gperr.Wrap(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TestResult{
|
||||
|
||||
@@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -66,7 +67,7 @@ func (m *themed) finalize() error {
|
||||
m.m.HTML += buf.String()
|
||||
}
|
||||
if m.CSS != "" && m.Theme != "" {
|
||||
return gperr.New("css and theme are mutually exclusive")
|
||||
return errors.New("css and theme are mutually exclusive")
|
||||
}
|
||||
// credit: https://hackcss.egoist.dev
|
||||
if m.Theme != "" {
|
||||
@@ -78,7 +79,7 @@ func (m *themed) finalize() error {
|
||||
case SolarizedDarkTheme:
|
||||
m.m.HTML += wrapStyleTag(solarizedDarkModeCSS)
|
||||
default:
|
||||
return gperr.New("invalid theme").Subject(string(m.Theme))
|
||||
return gperr.PrependSubject(errors.New("invalid theme"), m.Theme)
|
||||
}
|
||||
}
|
||||
if m.CSS != "" {
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
type Stream interface {
|
||||
ListenAndServe(ctx context.Context, preDial, onRead HookFunc)
|
||||
ListenAndServe(ctx context.Context, preDial, onRead HookFunc) error
|
||||
LocalAddr() net.Addr
|
||||
Close() error
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user