mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-24 01:38:50 +02:00
refactor(http,rules): move SharedData and ResponseModifier to httputils
- implemented dependency injection for rule auth handler
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/yusing/godoxy/internal/metrics/systeminfo"
|
"github.com/yusing/godoxy/internal/metrics/systeminfo"
|
||||||
"github.com/yusing/godoxy/internal/metrics/uptime"
|
"github.com/yusing/godoxy/internal/metrics/uptime"
|
||||||
"github.com/yusing/godoxy/internal/net/gphttp/middleware"
|
"github.com/yusing/godoxy/internal/net/gphttp/middleware"
|
||||||
|
"github.com/yusing/godoxy/internal/route/rules"
|
||||||
gperr "github.com/yusing/goutils/errs"
|
gperr "github.com/yusing/goutils/errs"
|
||||||
"github.com/yusing/goutils/server"
|
"github.com/yusing/goutils/server"
|
||||||
"github.com/yusing/goutils/task"
|
"github.com/yusing/goutils/task"
|
||||||
@@ -58,9 +59,12 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
config.StartProxyServers()
|
config.StartProxyServers()
|
||||||
|
|
||||||
if err := auth.Initialize(); err != nil {
|
if err := auth.Initialize(); err != nil {
|
||||||
log.Fatal().Err(err).Msg("failed to initialize authentication")
|
log.Fatal().Err(err).Msg("failed to initialize authentication")
|
||||||
}
|
}
|
||||||
|
rules.InitAuthHandler(auth.AuthOrProceed)
|
||||||
|
|
||||||
// API Handler needs to start after auth is initialized.
|
// API Handler needs to start after auth is initialized.
|
||||||
server.StartServer(task.RootTask("api_server", false), server.Options{
|
server.StartServer(task.RootTask("api_server", false), server.Options{
|
||||||
Name: "api",
|
Name: "api",
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/yusing/godoxy/internal/route/rules"
|
"github.com/yusing/godoxy/internal/route/rules"
|
||||||
apitypes "github.com/yusing/goutils/apitypes"
|
apitypes "github.com/yusing/goutils/apitypes"
|
||||||
gperr "github.com/yusing/goutils/errs"
|
gperr "github.com/yusing/goutils/errs"
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RawRule struct {
|
type RawRule struct {
|
||||||
@@ -348,7 +349,7 @@ func checkMatchedRules(rulesList rules.Rules, w http.ResponseWriter, r *http.Req
|
|||||||
var matched []string
|
var matched []string
|
||||||
|
|
||||||
// Create a ResponseModifier to properly check rules
|
// Create a ResponseModifier to properly check rules
|
||||||
rm := rules.NewResponseModifier(w)
|
rm := httputils.NewResponseModifier(w)
|
||||||
|
|
||||||
for _, rule := range rulesList {
|
for _, rule := range rulesList {
|
||||||
// Check if rule matches
|
// Check if rule matches
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/godoxy/internal/auth"
|
"github.com/yusing/godoxy/internal/auth"
|
||||||
"github.com/yusing/godoxy/internal/route/rules"
|
"github.com/yusing/godoxy/internal/route/rules"
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Bypass []rules.RuleOn
|
type Bypass []rules.RuleOn
|
||||||
@@ -50,7 +51,7 @@ func (c *checkBypass) before(w http.ResponseWriter, r *http.Request) (proceedNex
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *checkBypass) modifyResponse(resp *http.Response) error {
|
func (c *checkBypass) modifyResponse(resp *http.Response) error {
|
||||||
if c.modRes == nil || (!c.isEnforced(resp.Request) && c.bypass.ShouldBypass(rules.ResponseAsRW(resp), resp.Request)) {
|
if c.modRes == nil || (!c.isEnforced(resp.Request) && c.bypass.ShouldBypass(httputils.ResponseAsRW(resp), resp.Request)) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
log.Debug().Str("middleware", c.name).Str("url", resp.Request.Host+resp.Request.URL.Path).Msg("modifying response")
|
log.Debug().Str("middleware", c.name).Str("url", resp.Request.Host+resp.Request.URL.Path).Msg("modifying response")
|
||||||
|
|||||||
@@ -9,9 +9,9 @@ import (
|
|||||||
"github.com/bytedance/sonic"
|
"github.com/bytedance/sonic"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/godoxy/internal/route/rules"
|
|
||||||
"github.com/yusing/godoxy/internal/serialization"
|
"github.com/yusing/godoxy/internal/serialization"
|
||||||
gperr "github.com/yusing/goutils/errs"
|
gperr "github.com/yusing/goutils/errs"
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
"github.com/yusing/goutils/http/httpheaders"
|
"github.com/yusing/goutils/http/httpheaders"
|
||||||
"github.com/yusing/goutils/http/reverseproxy"
|
"github.com/yusing/goutils/http/reverseproxy"
|
||||||
)
|
)
|
||||||
@@ -197,7 +197,7 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *
|
|||||||
}
|
}
|
||||||
|
|
||||||
if exec, ok := m.impl.(ResponseModifier); ok {
|
if exec, ok := m.impl.(ResponseModifier); ok {
|
||||||
rm := rules.NewResponseModifier(w)
|
rm := httputils.NewResponseModifier(w)
|
||||||
defer rm.FlushRelease()
|
defer rm.FlushRelease()
|
||||||
next(rm, r)
|
next(rm, r)
|
||||||
|
|
||||||
|
|||||||
@@ -1,108 +0,0 @@
|
|||||||
package rules
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Cache is a map of cached values for a request.
|
|
||||||
// It prevents the same value from being parsed multiple times.
|
|
||||||
type (
|
|
||||||
Cache map[string]any
|
|
||||||
UpdateFunc[T any] func(T) T
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
cacheKeyQueries = "queries"
|
|
||||||
cacheKeyCookies = "cookies"
|
|
||||||
cacheKeyRemoteIP = "remote_ip"
|
|
||||||
cacheKeyBasicAuth = "basic_auth"
|
|
||||||
)
|
|
||||||
|
|
||||||
var cachePool = sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
return make(Cache)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewCache returns a new Cached.
|
|
||||||
func NewCache() Cache {
|
|
||||||
return cachePool.Get().(Cache)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release clear the contents of the Cached and returns it to the pool.
|
|
||||||
func (c Cache) Release() {
|
|
||||||
clear(c)
|
|
||||||
cachePool.Put(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetQueries returns the queries.
|
|
||||||
// If r does not have queries, an empty map is returned.
|
|
||||||
func (c Cache) GetQueries(r *http.Request) url.Values {
|
|
||||||
v, ok := c[cacheKeyQueries]
|
|
||||||
if !ok {
|
|
||||||
v = r.URL.Query()
|
|
||||||
c[cacheKeyQueries] = v
|
|
||||||
}
|
|
||||||
return v.(url.Values)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c Cache) UpdateQueries(r *http.Request, update func(url.Values)) {
|
|
||||||
queries := c.GetQueries(r)
|
|
||||||
update(queries)
|
|
||||||
r.URL.RawQuery = queries.Encode()
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCookies returns the cookies.
|
|
||||||
// If r does not have cookies, an empty slice is returned.
|
|
||||||
func (c Cache) GetCookies(r *http.Request) []*http.Cookie {
|
|
||||||
v, ok := c[cacheKeyCookies]
|
|
||||||
if !ok {
|
|
||||||
v = r.Cookies()
|
|
||||||
c[cacheKeyCookies] = v
|
|
||||||
}
|
|
||||||
return v.([]*http.Cookie)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c Cache) UpdateCookies(r *http.Request, update UpdateFunc[[]*http.Cookie]) {
|
|
||||||
cookies := update(c.GetCookies(r))
|
|
||||||
c[cacheKeyCookies] = cookies
|
|
||||||
r.Header.Del("Cookie")
|
|
||||||
for _, cookie := range cookies {
|
|
||||||
r.AddCookie(cookie)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRemoteIP returns the remote ip address.
|
|
||||||
// If r.RemoteAddr is not a valid ip address, nil is returned.
|
|
||||||
func (c Cache) GetRemoteIP(r *http.Request) net.IP {
|
|
||||||
v, ok := c[cacheKeyRemoteIP]
|
|
||||||
if !ok {
|
|
||||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
host = r.RemoteAddr
|
|
||||||
}
|
|
||||||
v = net.ParseIP(host)
|
|
||||||
c[cacheKeyRemoteIP] = v
|
|
||||||
}
|
|
||||||
return v.(net.IP)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetBasicAuth returns *Credentials the basic auth username and password.
|
|
||||||
// If r does not have basic auth, nil is returned.
|
|
||||||
func (c Cache) GetBasicAuth(r *http.Request) *Credentials {
|
|
||||||
v, ok := c[cacheKeyBasicAuth]
|
|
||||||
if !ok {
|
|
||||||
u, p, ok := r.BasicAuth()
|
|
||||||
if ok {
|
|
||||||
v = &Credentials{u, []byte(p)}
|
|
||||||
c[cacheKeyBasicAuth] = v
|
|
||||||
} else {
|
|
||||||
c[cacheKeyBasicAuth] = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return v.(*Credentials)
|
|
||||||
}
|
|
||||||
@@ -1,16 +1,15 @@
|
|||||||
package rules
|
package rules
|
||||||
|
|
||||||
import "golang.org/x/crypto/bcrypt"
|
import (
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
HashedCrendentials struct {
|
HashedCrendentials struct {
|
||||||
Username string
|
Username string
|
||||||
CheckMatch func(inputPwd []byte) bool
|
CheckMatch func(inputPwd []byte) bool
|
||||||
}
|
}
|
||||||
Credentials struct {
|
|
||||||
Username string
|
|
||||||
Password []byte
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func BCryptCrendentials(username string, hashedPassword []byte) *HashedCrendentials {
|
func BCryptCrendentials(username string, hashedPassword []byte) *HashedCrendentials {
|
||||||
@@ -19,7 +18,7 @@ func BCryptCrendentials(username string, hashedPassword []byte) *HashedCrendenti
|
|||||||
}}
|
}}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hc *HashedCrendentials) Match(cred *Credentials) bool {
|
func (hc *HashedCrendentials) Match(cred *httputils.Credentials) bool {
|
||||||
if cred == nil {
|
if cred == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/yusing/godoxy/internal/auth"
|
|
||||||
"github.com/yusing/godoxy/internal/logging"
|
"github.com/yusing/godoxy/internal/logging"
|
||||||
gphttp "github.com/yusing/godoxy/internal/net/gphttp"
|
gphttp "github.com/yusing/godoxy/internal/net/gphttp"
|
||||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||||
@@ -50,6 +49,14 @@ const (
|
|||||||
CommandPassAlt = "bypass"
|
CommandPassAlt = "bypass"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type AuthHandler func(w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||||
|
|
||||||
|
var authHandler AuthHandler
|
||||||
|
|
||||||
|
func InitAuthHandler(handler AuthHandler) {
|
||||||
|
authHandler = handler
|
||||||
|
}
|
||||||
|
|
||||||
var commands = map[string]struct {
|
var commands = map[string]struct {
|
||||||
help Help
|
help Help
|
||||||
validate ValidateFunc
|
validate ValidateFunc
|
||||||
@@ -70,7 +77,7 @@ var commands = map[string]struct {
|
|||||||
},
|
},
|
||||||
build: func(args any) CommandHandler {
|
build: func(args any) CommandHandler {
|
||||||
return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
if !auth.AuthOrProceed(w, r) {
|
if !authHandler(w, r) {
|
||||||
return errTerminated
|
return errTerminated
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -198,7 +205,7 @@ var commands = map[string]struct {
|
|||||||
code, textTmpl := args.(*Tuple[int, templateString]).Unpack()
|
code, textTmpl := args.(*Tuple[int, templateString]).Unpack()
|
||||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
// error command should overwrite the response body
|
// error command should overwrite the response body
|
||||||
GetInitResponseModifier(w).ResetBody()
|
httputils.GetInitResponseModifier(w).ResetBody()
|
||||||
w.WriteHeader(code)
|
w.WriteHeader(code)
|
||||||
err := textTmpl.ExpandVars(w, r, w)
|
err := textTmpl.ExpandVars(w, r, w)
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
gperr "github.com/yusing/goutils/errs"
|
gperr "github.com/yusing/goutils/errs"
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
ioutils "github.com/yusing/goutils/io"
|
ioutils "github.com/yusing/goutils/io"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -128,7 +129,7 @@ var modFields = map[string]struct {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
|
httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
|
||||||
queries.Set(k, v)
|
queries.Set(k, v)
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
@@ -138,13 +139,13 @@ var modFields = map[string]struct {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
|
httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
|
||||||
queries.Add(k, v)
|
queries.Add(k, v)
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}),
|
}),
|
||||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
|
httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
|
||||||
queries.Del(k)
|
queries.Del(k)
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
@@ -169,7 +170,7 @@ var modFields = map[string]struct {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||||
for i, c := range cookies {
|
for i, c := range cookies {
|
||||||
if c.Name == k {
|
if c.Name == k {
|
||||||
cookies[i].Value = v
|
cookies[i].Value = v
|
||||||
@@ -185,13 +186,13 @@ var modFields = map[string]struct {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||||
return append(cookies, &http.Cookie{Name: k, Value: v})
|
return append(cookies, &http.Cookie{Name: k, Value: v})
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}),
|
}),
|
||||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||||
index := -1
|
index := -1
|
||||||
for i, c := range cookies {
|
for i, c := range cookies {
|
||||||
if c.Name == k {
|
if c.Name == k {
|
||||||
@@ -242,7 +243,7 @@ var modFields = map[string]struct {
|
|||||||
r.Body = nil
|
r.Body = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
bufPool := GetInitResponseModifier(w).BufPool()
|
bufPool := httputils.GetInitResponseModifier(w).BufPool()
|
||||||
b := bufPool.GetBuffer()
|
b := bufPool.GetBuffer()
|
||||||
err := tmpl.ExpandVars(w, r, b)
|
err := tmpl.ExpandVars(w, r, b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -282,7 +283,7 @@ var modFields = map[string]struct {
|
|||||||
tmpl := args.(templateString)
|
tmpl := args.(templateString)
|
||||||
return &FieldHandler{
|
return &FieldHandler{
|
||||||
set: OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
|
set: OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
rm := GetInitResponseModifier(w)
|
rm := httputils.GetInitResponseModifier(w)
|
||||||
rm.ResetBody()
|
rm.ResetBody()
|
||||||
return tmpl.ExpandVars(w, r, rm)
|
return tmpl.ExpandVars(w, r, rm)
|
||||||
}),
|
}),
|
||||||
@@ -317,7 +318,7 @@ var modFields = map[string]struct {
|
|||||||
status := args.(int)
|
status := args.(int)
|
||||||
return &FieldHandler{
|
return &FieldHandler{
|
||||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
GetInitResponseModifier(w).WriteHeader(status)
|
httputils.GetInitResponseModifier(w).WriteHeader(status)
|
||||||
return nil
|
return nil
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestFieldHandler_Header(t *testing.T) {
|
func TestFieldHandler_Header(t *testing.T) {
|
||||||
@@ -420,7 +421,7 @@ func TestFieldHandler_ResponseBody(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
template string
|
template string
|
||||||
setup func(*http.Request)
|
setup func(*http.Request)
|
||||||
verify func(*ResponseModifier)
|
verify func(*httputils.ResponseModifier)
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "set response body with template",
|
name: "set response body with template",
|
||||||
@@ -429,8 +430,8 @@ func TestFieldHandler_ResponseBody(t *testing.T) {
|
|||||||
r.Method = "GET"
|
r.Method = "GET"
|
||||||
r.URL.Path = "/api/test"
|
r.URL.Path = "/api/test"
|
||||||
},
|
},
|
||||||
verify: func(rm *ResponseModifier) {
|
verify: func(rm *httputils.ResponseModifier) {
|
||||||
content := rm.buf.String()
|
content := string(rm.Content())
|
||||||
expected := "Response: GET /api/test"
|
expected := "Response: GET /api/test"
|
||||||
assert.Equal(t, expected, content, "Expected response body")
|
assert.Equal(t, expected, content, "Expected response body")
|
||||||
},
|
},
|
||||||
@@ -444,7 +445,7 @@ func TestFieldHandler_ResponseBody(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
// Create ResponseModifier wrapper
|
// Create ResponseModifier wrapper
|
||||||
rm := NewResponseModifier(w)
|
rm := httputils.NewResponseModifier(w)
|
||||||
|
|
||||||
tmpl, tErr := validateTemplate(tt.template, false)
|
tmpl, tErr := validateTemplate(tt.template, false)
|
||||||
if tErr != nil {
|
if tErr != nil {
|
||||||
@@ -495,7 +496,7 @@ func TestFieldHandler_StatusCode(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
rm := NewResponseModifier(w)
|
rm := httputils.NewResponseModifier(w)
|
||||||
var cmd Command
|
var cmd Command
|
||||||
err := cmd.Parse(fmt.Sprintf("set %s %d", FieldStatusCode, tt.status))
|
err := cmd.Parse(fmt.Sprintf("set %s %d", FieldStatusCode, tt.status))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/yusing/godoxy/internal/route/routes"
|
"github.com/yusing/godoxy/internal/route/routes"
|
||||||
gperr "github.com/yusing/goutils/errs"
|
gperr "github.com/yusing/goutils/errs"
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RuleOn struct {
|
type RuleOn struct {
|
||||||
@@ -95,11 +96,11 @@ var checkers = map[string]struct {
|
|||||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||||
if matcher == nil {
|
if matcher == nil {
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||||
return len(GetInitResponseModifier(w).Header()[k]) > 0
|
return len(httputils.GetInitResponseModifier(w).Header()[k]) > 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||||
return slices.ContainsFunc(GetInitResponseModifier(w).Header()[k], matcher)
|
return slices.ContainsFunc(httputils.GetInitResponseModifier(w).Header()[k], matcher)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -122,11 +123,11 @@ var checkers = map[string]struct {
|
|||||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||||
if matcher == nil {
|
if matcher == nil {
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||||
return len(GetSharedData(w).GetQueries(r)[k]) > 0
|
return len(httputils.GetSharedData(w).GetQueries(r)[k]) > 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||||
return slices.ContainsFunc(GetSharedData(w).GetQueries(r)[k], matcher)
|
return slices.ContainsFunc(httputils.GetSharedData(w).GetQueries(r)[k], matcher)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -149,7 +150,7 @@ var checkers = map[string]struct {
|
|||||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||||
if matcher == nil {
|
if matcher == nil {
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||||
cookies := GetSharedData(w).GetCookies(r)
|
cookies := httputils.GetSharedData(w).GetCookies(r)
|
||||||
for _, cookie := range cookies {
|
for _, cookie := range cookies {
|
||||||
if cookie.Name == k {
|
if cookie.Name == k {
|
||||||
return true
|
return true
|
||||||
@@ -159,7 +160,7 @@ var checkers = map[string]struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||||
cookies := GetSharedData(w).GetCookies(r)
|
cookies := httputils.GetSharedData(w).GetCookies(r)
|
||||||
for _, cookie := range cookies {
|
for _, cookie := range cookies {
|
||||||
if cookie.Name == k {
|
if cookie.Name == k {
|
||||||
if matcher(cookie.Value) {
|
if matcher(cookie.Value) {
|
||||||
@@ -302,7 +303,7 @@ var checkers = map[string]struct {
|
|||||||
if ones, bits := ipnet.Mask.Size(); ones == bits {
|
if ones, bits := ipnet.Mask.Size(); ones == bits {
|
||||||
wantIP := ipnet.IP
|
wantIP := ipnet.IP
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||||
ip := GetSharedData(w).GetRemoteIP(r)
|
ip := httputils.GetSharedData(w).GetRemoteIP(r)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -310,7 +311,7 @@ var checkers = map[string]struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||||
ip := GetSharedData(w).GetRemoteIP(r)
|
ip := httputils.GetSharedData(w).GetRemoteIP(r)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -330,7 +331,7 @@ var checkers = map[string]struct {
|
|||||||
builder: func(args any) CheckFunc {
|
builder: func(args any) CheckFunc {
|
||||||
cred := args.(*HashedCrendentials)
|
cred := args.(*HashedCrendentials)
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||||
return cred.Match(GetSharedData(w).GetBasicAuth(r))
|
return cred.Match(httputils.GetSharedData(w).GetBasicAuth(r))
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -378,11 +379,11 @@ var checkers = map[string]struct {
|
|||||||
beg, end := args.(*IntTuple).Unpack()
|
beg, end := args.(*IntTuple).Unpack()
|
||||||
if beg == end {
|
if beg == end {
|
||||||
return func(w http.ResponseWriter, _ *http.Request) bool {
|
return func(w http.ResponseWriter, _ *http.Request) bool {
|
||||||
return GetInitResponseModifier(w).StatusCode() == beg
|
return httputils.GetInitResponseModifier(w).StatusCode() == beg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return func(w http.ResponseWriter, _ *http.Request) bool {
|
return func(w http.ResponseWriter, _ *http.Request) bool {
|
||||||
statusCode := GetInitResponseModifier(w).StatusCode()
|
statusCode := httputils.GetInitResponseModifier(w).StatusCode()
|
||||||
return statusCode >= beg && statusCode <= end
|
return statusCode >= beg && statusCode <= end
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,267 +0,0 @@
|
|||||||
package rules
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
gperr "github.com/yusing/goutils/errs"
|
|
||||||
"github.com/yusing/goutils/synk"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ResponseModifier struct {
|
|
||||||
bufPool synk.UnsizedBytesPool
|
|
||||||
|
|
||||||
w http.ResponseWriter
|
|
||||||
buf *bytes.Buffer
|
|
||||||
statusCode int
|
|
||||||
shared Cache
|
|
||||||
|
|
||||||
origContentLength int64 // from http.Response in ResponseAsRW, -1 if not set
|
|
||||||
bodyModified bool
|
|
||||||
|
|
||||||
hijacked bool
|
|
||||||
|
|
||||||
errs gperr.Builder
|
|
||||||
}
|
|
||||||
|
|
||||||
type Response struct {
|
|
||||||
StatusCode int
|
|
||||||
Header http.Header
|
|
||||||
}
|
|
||||||
|
|
||||||
func unwrapResponseModifier(w http.ResponseWriter) *ResponseModifier {
|
|
||||||
for {
|
|
||||||
switch ww := w.(type) {
|
|
||||||
case *ResponseModifier:
|
|
||||||
return ww
|
|
||||||
case interface{ Unwrap() http.ResponseWriter }:
|
|
||||||
w = ww.Unwrap()
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type responseAsRW struct {
|
|
||||||
resp *http.Response
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r responseAsRW) WriteHeader(code int) {
|
|
||||||
log.Error().Msg("write header after response has been created")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r responseAsRW) Write(b []byte) (int, error) {
|
|
||||||
return 0, io.ErrClosedPipe
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r responseAsRW) Header() http.Header {
|
|
||||||
return r.resp.Header
|
|
||||||
}
|
|
||||||
|
|
||||||
func ResponseAsRW(resp *http.Response) *ResponseModifier {
|
|
||||||
return &ResponseModifier{
|
|
||||||
statusCode: resp.StatusCode,
|
|
||||||
w: responseAsRW{resp},
|
|
||||||
origContentLength: resp.ContentLength,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetInitResponseModifier returns the response modifier for the given response writer.
|
|
||||||
// If the response writer is already wrapped, it will return the wrapped response modifier.
|
|
||||||
// Otherwise, it will return a new response modifier.
|
|
||||||
func GetInitResponseModifier(w http.ResponseWriter) *ResponseModifier {
|
|
||||||
if rm := unwrapResponseModifier(w); rm != nil {
|
|
||||||
return rm
|
|
||||||
}
|
|
||||||
return NewResponseModifier(w)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSharedData returns the shared data for the given response writer.
|
|
||||||
// It will initialize the shared data if not initialized.
|
|
||||||
func GetSharedData(w http.ResponseWriter) Cache {
|
|
||||||
rm := GetInitResponseModifier(w)
|
|
||||||
if rm.shared == nil {
|
|
||||||
rm.shared = NewCache()
|
|
||||||
}
|
|
||||||
return rm.shared
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewResponseModifier returns a new response modifier for the given response writer.
|
|
||||||
//
|
|
||||||
// It should only be called once, at the very beginning of the request.
|
|
||||||
func NewResponseModifier(w http.ResponseWriter) *ResponseModifier {
|
|
||||||
return &ResponseModifier{
|
|
||||||
bufPool: synk.GetUnsizedBytesPool(),
|
|
||||||
w: w,
|
|
||||||
origContentLength: -1,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *ResponseModifier) BufPool() synk.UnsizedBytesPool {
|
|
||||||
return rm.bufPool
|
|
||||||
}
|
|
||||||
|
|
||||||
// func (rm *ResponseModifier) Unwrap() http.ResponseWriter {
|
|
||||||
// return rm.w
|
|
||||||
// }
|
|
||||||
|
|
||||||
func (rm *ResponseModifier) WriteHeader(code int) {
|
|
||||||
rm.statusCode = code
|
|
||||||
}
|
|
||||||
|
|
||||||
// BodyReader returns a reader for the response body.
|
|
||||||
// Every call to this function will return a new reader that starts from the beginning of the buffer.
|
|
||||||
func (rm *ResponseModifier) BodyReader() io.ReadCloser {
|
|
||||||
if rm.buf == nil {
|
|
||||||
return io.NopCloser(bytes.NewReader(nil))
|
|
||||||
}
|
|
||||||
return io.NopCloser(bytes.NewReader(rm.buf.Bytes()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *ResponseModifier) ResetBody() {
|
|
||||||
if !rm.bodyModified {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if rm.buf == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
rm.buf.Reset()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *ResponseModifier) SetBody(r io.ReadCloser) error {
|
|
||||||
if rm.buf == nil {
|
|
||||||
rm.buf = rm.bufPool.GetBuffer()
|
|
||||||
} else {
|
|
||||||
rm.buf.Reset()
|
|
||||||
}
|
|
||||||
|
|
||||||
rm.bodyModified = true
|
|
||||||
|
|
||||||
_, err := io.Copy(rm.buf, r)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to copy body: %w", err)
|
|
||||||
}
|
|
||||||
r.Close()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *ResponseModifier) ContentLength() int {
|
|
||||||
if !rm.bodyModified {
|
|
||||||
if rm.origContentLength >= 0 {
|
|
||||||
return int(rm.origContentLength)
|
|
||||||
}
|
|
||||||
contentLength, _ := strconv.Atoi(rm.ContentLengthStr())
|
|
||||||
return contentLength
|
|
||||||
}
|
|
||||||
return rm.buf.Len()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *ResponseModifier) ContentLengthStr() string {
|
|
||||||
if !rm.bodyModified {
|
|
||||||
if rm.origContentLength >= 0 {
|
|
||||||
return strconv.FormatInt(rm.origContentLength, 10)
|
|
||||||
}
|
|
||||||
return rm.w.Header().Get("Content-Length")
|
|
||||||
}
|
|
||||||
return strconv.Itoa(rm.buf.Len())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *ResponseModifier) Content() []byte {
|
|
||||||
if rm.buf == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return rm.buf.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *ResponseModifier) StatusCode() int {
|
|
||||||
if rm.statusCode == 0 {
|
|
||||||
return http.StatusOK
|
|
||||||
}
|
|
||||||
return rm.statusCode
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *ResponseModifier) Header() http.Header {
|
|
||||||
return rm.w.Header()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *ResponseModifier) Response() Response {
|
|
||||||
return Response{StatusCode: rm.StatusCode(), Header: rm.Header()}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *ResponseModifier) Write(b []byte) (int, error) {
|
|
||||||
if len(b) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
rm.bodyModified = true
|
|
||||||
if rm.buf == nil {
|
|
||||||
rm.buf = rm.bufPool.GetBuffer()
|
|
||||||
}
|
|
||||||
return rm.buf.Write(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AppendError appends an error to the response modifier
|
|
||||||
// the error will be formatted as "rule <rule.Name> error: <err>"
|
|
||||||
//
|
|
||||||
// It will be aggregated and returned in FlushRelease.
|
|
||||||
func (rm *ResponseModifier) AppendError(rule Rule, err error) {
|
|
||||||
rm.errs.Addf("rule %q error: %w", rule.Name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *ResponseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
||||||
if hijacker, ok := rm.w.(http.Hijacker); ok {
|
|
||||||
rm.hijacked = true
|
|
||||||
return hijacker.Hijack()
|
|
||||||
}
|
|
||||||
return nil, nil, errors.New("hijack not supported")
|
|
||||||
}
|
|
||||||
|
|
||||||
// FlushRelease flushes the response modifier and releases the resources
|
|
||||||
// it returns the number of bytes written and the aggregated error
|
|
||||||
// if there is any error (rule errors or write error), it will be returned
|
|
||||||
func (rm *ResponseModifier) FlushRelease() (int, error) {
|
|
||||||
n := 0
|
|
||||||
if !rm.hijacked {
|
|
||||||
if rm.bodyModified {
|
|
||||||
h := rm.w.Header()
|
|
||||||
h.Set("Content-Length", rm.ContentLengthStr())
|
|
||||||
h.Del("Transfer-Encoding")
|
|
||||||
h.Del("Trailer")
|
|
||||||
}
|
|
||||||
rm.w.WriteHeader(rm.StatusCode())
|
|
||||||
|
|
||||||
if rm.bodyModified {
|
|
||||||
if content := rm.Content(); len(content) > 0 {
|
|
||||||
nn, werr := rm.w.Write(content)
|
|
||||||
n += nn
|
|
||||||
if werr != nil {
|
|
||||||
rm.errs.Addf("write error: %w", werr)
|
|
||||||
}
|
|
||||||
if err := http.NewResponseController(rm.w).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
|
||||||
rm.errs.Addf("flush error: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// release the buffer and reset the pointers
|
|
||||||
if rm.buf != nil {
|
|
||||||
rm.bufPool.PutBuffer(rm.buf)
|
|
||||||
rm.buf = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// release the shared data
|
|
||||||
if rm.shared != nil {
|
|
||||||
rm.shared.Release()
|
|
||||||
rm.shared = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return n, rm.errs.Error()
|
|
||||||
}
|
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/quic-go/quic-go/http3"
|
"github.com/quic-go/quic-go/http3"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
|
|
||||||
_ "unsafe"
|
_ "unsafe"
|
||||||
@@ -91,7 +92,7 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
if defaultRule.IsResponseRule() {
|
if defaultRule.IsResponseRule() {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
rm := NewResponseModifier(w)
|
rm := httputils.NewResponseModifier(w)
|
||||||
defer func() {
|
defer func() {
|
||||||
if _, err := rm.FlushRelease(); err != nil {
|
if _, err := rm.FlushRelease(); err != nil {
|
||||||
logError(err, r)
|
logError(err, r)
|
||||||
@@ -101,12 +102,12 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
|||||||
up(w, r)
|
up(w, r)
|
||||||
err := defaultRule.Do.exec.Handle(w, r)
|
err := defaultRule.Do.exec.Handle(w, r)
|
||||||
if err != nil && !errors.Is(err, errTerminated) {
|
if err != nil && !errors.Is(err, errTerminated) {
|
||||||
rm.AppendError(defaultRule, err)
|
appendRuleError(rm, &defaultRule, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
rm := NewResponseModifier(w)
|
rm := httputils.NewResponseModifier(w)
|
||||||
defer func() {
|
defer func() {
|
||||||
if _, err := rm.FlushRelease(); err != nil {
|
if _, err := rm.FlushRelease(); err != nil {
|
||||||
logError(err, r)
|
logError(err, r)
|
||||||
@@ -119,7 +120,7 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !errors.Is(err, errTerminated) {
|
if !errors.Is(err, errTerminated) {
|
||||||
rm.AppendError(defaultRule, err)
|
appendRuleError(rm, &defaultRule, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -138,7 +139,7 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
|||||||
defaultTerminates := isTerminatingHandler(defaultRule.Do.exec)
|
defaultTerminates := isTerminatingHandler(defaultRule.Do.exec)
|
||||||
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
rm := NewResponseModifier(w)
|
rm := httputils.NewResponseModifier(w)
|
||||||
defer func() {
|
defer func() {
|
||||||
if _, err := rm.FlushRelease(); err != nil {
|
if _, err := rm.FlushRelease(); err != nil {
|
||||||
logError(err, r)
|
logError(err, r)
|
||||||
@@ -157,7 +158,7 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
|||||||
err := defaultRule.Handle(w, r)
|
err := defaultRule.Handle(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, errTerminated) {
|
if !errors.Is(err, errTerminated) {
|
||||||
rm.AppendError(defaultRule, err)
|
appendRuleError(rm, &defaultRule, err)
|
||||||
}
|
}
|
||||||
shouldCallUpstream = false
|
shouldCallUpstream = false
|
||||||
}
|
}
|
||||||
@@ -174,7 +175,7 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
|||||||
err := rule.Handle(w, r)
|
err := rule.Handle(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, errTerminated) {
|
if !errors.Is(err, errTerminated) {
|
||||||
rm.AppendError(rule, err)
|
appendRuleError(rm, &rule, err)
|
||||||
}
|
}
|
||||||
shouldCallUpstream = false
|
shouldCallUpstream = false
|
||||||
break
|
break
|
||||||
@@ -190,7 +191,7 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
|||||||
err := defaultRule.Handle(w, r)
|
err := defaultRule.Handle(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, errTerminated) {
|
if !errors.Is(err, errTerminated) {
|
||||||
rm.AppendError(defaultRule, err)
|
appendRuleError(rm, &defaultRule, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
shouldCallUpstream = false
|
shouldCallUpstream = false
|
||||||
@@ -212,7 +213,7 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
|||||||
err := rule.Handle(w, r)
|
err := rule.Handle(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, errTerminated) {
|
if !errors.Is(err, errTerminated) {
|
||||||
rm.AppendError(rule, err)
|
appendRuleError(rm, &rule, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -222,12 +223,16 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
|||||||
if isDefaultRulePost {
|
if isDefaultRulePost {
|
||||||
err := defaultRule.Handle(w, r)
|
err := defaultRule.Handle(w, r)
|
||||||
if err != nil && !errors.Is(err, errTerminated) {
|
if err != nil && !errors.Is(err, errTerminated) {
|
||||||
rm.AppendError(defaultRule, err)
|
appendRuleError(rm, &defaultRule, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func appendRuleError(rm *httputils.ResponseModifier, rule *Rule, err error) {
|
||||||
|
rm.AppendError("rule: %s, error: %w", rule.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
func isTerminatingHandler(handler CommandHandler) bool {
|
func isTerminatingHandler(handler CommandHandler) bool {
|
||||||
switch h := handler.(type) {
|
switch h := handler.(type) {
|
||||||
case TerminatingCommand:
|
case TerminatingCommand:
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
type templateString struct {
|
type templateString struct {
|
||||||
@@ -27,7 +29,7 @@ func (tmpl *templateString) ExpandVars(w http.ResponseWriter, req *http.Request,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return ExpandVars(GetInitResponseModifier(w), req, tmpl.string, dstW)
|
return ExpandVars(httputils.GetInitResponseModifier(w), req, tmpl.string, dstW)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tmpl *templateString) ExpandVarsToString(w http.ResponseWriter, req *http.Request) (string, error) {
|
func (tmpl *templateString) ExpandVarsToString(w http.ResponseWriter, req *http.Request) (string, error) {
|
||||||
@@ -36,7 +38,7 @@ func (tmpl *templateString) ExpandVarsToString(w http.ResponseWriter, req *http.
|
|||||||
}
|
}
|
||||||
|
|
||||||
var buf strings.Builder
|
var buf strings.Builder
|
||||||
err := ExpandVars(GetInitResponseModifier(w), req, tmpl.string, &buf)
|
err := ExpandVars(httputils.GetInitResponseModifier(w), req, tmpl.string, &buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,10 +5,12 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BenchmarkExpandVars(b *testing.B) {
|
func BenchmarkExpandVars(b *testing.B) {
|
||||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||||
testResponseModifier.WriteHeader(200)
|
testResponseModifier.WriteHeader(200)
|
||||||
testResponseModifier.Write([]byte("Hello, world!"))
|
testResponseModifier.Write([]byte("Hello, world!"))
|
||||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
ioutils "github.com/yusing/goutils/io"
|
ioutils "github.com/yusing/goutils/io"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -15,7 +16,7 @@ import (
|
|||||||
|
|
||||||
type (
|
type (
|
||||||
reqVarGetter func(*http.Request) string
|
reqVarGetter func(*http.Request) string
|
||||||
respVarGetter func(*ResponseModifier) string
|
respVarGetter func(*httputils.ResponseModifier) string
|
||||||
)
|
)
|
||||||
|
|
||||||
var reVar = regexp.MustCompile(`\$[\w_]+`)
|
var reVar = regexp.MustCompile(`\$[\w_]+`)
|
||||||
@@ -36,7 +37,7 @@ func NeedExpandVars(s string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
voidResponseModifier = NewResponseModifier(httptest.NewRecorder())
|
voidResponseModifier = httputils.NewResponseModifier(httptest.NewRecorder())
|
||||||
dummyRequest = http.Request{
|
dummyRequest = http.Request{
|
||||||
Method: "GET",
|
Method: "GET",
|
||||||
URL: &url.URL{Path: "/"},
|
URL: &url.URL{Path: "/"},
|
||||||
@@ -50,7 +51,7 @@ func ValidateVars(s string) error {
|
|||||||
return ExpandVars(voidResponseModifier, &dummyRequest, s, io.Discard)
|
return ExpandVars(voidResponseModifier, &dummyRequest, s, io.Discard)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExpandVars(w *ResponseModifier, req *http.Request, src string, dstW io.Writer) error {
|
func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, dstW io.Writer) error {
|
||||||
dst := ioutils.NewBufferedWriter(dstW, 1024)
|
dst := ioutils.NewBufferedWriter(dstW, 1024)
|
||||||
defer dst.Close()
|
defer dst.Close()
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -14,31 +16,31 @@ var (
|
|||||||
VarPostForm = "postform"
|
VarPostForm = "postform"
|
||||||
)
|
)
|
||||||
|
|
||||||
type dynamicVarGetter func(args []string, w *ResponseModifier, req *http.Request) (string, error)
|
type dynamicVarGetter func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error)
|
||||||
|
|
||||||
var dynamicVarSubsMap = map[string]dynamicVarGetter{
|
var dynamicVarSubsMap = map[string]dynamicVarGetter{
|
||||||
VarHeader: func(args []string, w *ResponseModifier, req *http.Request) (string, error) {
|
VarHeader: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||||
key, index, err := getKeyAndIndex(args)
|
key, index, err := getKeyAndIndex(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return getValueByKeyAtIndex(req.Header, key, index)
|
return getValueByKeyAtIndex(req.Header, key, index)
|
||||||
},
|
},
|
||||||
VarResponseHeader: func(args []string, w *ResponseModifier, req *http.Request) (string, error) {
|
VarResponseHeader: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||||
key, index, err := getKeyAndIndex(args)
|
key, index, err := getKeyAndIndex(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return getValueByKeyAtIndex(w.Header(), key, index)
|
return getValueByKeyAtIndex(w.Header(), key, index)
|
||||||
},
|
},
|
||||||
VarQuery: func(args []string, w *ResponseModifier, req *http.Request) (string, error) {
|
VarQuery: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||||
key, index, err := getKeyAndIndex(args)
|
key, index, err := getKeyAndIndex(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return getValueByKeyAtIndex(GetSharedData(w).GetQueries(req), key, index)
|
return getValueByKeyAtIndex(httputils.GetSharedData(w).GetQueries(req), key, index)
|
||||||
},
|
},
|
||||||
VarForm: func(args []string, w *ResponseModifier, req *http.Request) (string, error) {
|
VarForm: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||||
key, index, err := getKeyAndIndex(args)
|
key, index, err := getKeyAndIndex(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -50,7 +52,7 @@ var dynamicVarSubsMap = map[string]dynamicVarGetter{
|
|||||||
}
|
}
|
||||||
return getValueByKeyAtIndex(req.Form, key, index)
|
return getValueByKeyAtIndex(req.Form, key, index)
|
||||||
},
|
},
|
||||||
VarPostForm: func(args []string, w *ResponseModifier, req *http.Request) (string, error) {
|
VarPostForm: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||||
key, index, err := getKeyAndIndex(args)
|
key, index, err := getKeyAndIndex(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/yusing/godoxy/internal/route/routes"
|
"github.com/yusing/godoxy/internal/route/routes"
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -87,9 +88,9 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
|
|||||||
}
|
}
|
||||||
|
|
||||||
var staticRespVarSubsMap = map[string]respVarGetter{
|
var staticRespVarSubsMap = map[string]respVarGetter{
|
||||||
VarRespContentType: func(resp *ResponseModifier) string { return resp.Header().Get("Content-Type") },
|
VarRespContentType: func(resp *httputils.ResponseModifier) string { return resp.Header().Get("Content-Type") },
|
||||||
VarRespContentLen: func(resp *ResponseModifier) string { return resp.ContentLengthStr() },
|
VarRespContentLen: func(resp *httputils.ResponseModifier) string { return resp.ContentLengthStr() },
|
||||||
VarRespStatusCode: func(resp *ResponseModifier) string { return strconv.Itoa(resp.StatusCode()) },
|
VarRespStatusCode: func(resp *httputils.ResponseModifier) string { return strconv.Itoa(resp.StatusCode()) },
|
||||||
}
|
}
|
||||||
|
|
||||||
func stripFragment(s string) string {
|
func stripFragment(s string) string {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestExtractArgs(t *testing.T) {
|
func TestExtractArgs(t *testing.T) {
|
||||||
@@ -214,7 +215,7 @@ func TestExpandVars(t *testing.T) {
|
|||||||
testRequest.PostForm = postFormData
|
testRequest.PostForm = postFormData
|
||||||
|
|
||||||
// Create response modifier with headers
|
// Create response modifier with headers
|
||||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||||
testResponseModifier.Header().Set("Content-Type", "text/html")
|
testResponseModifier.Header().Set("Content-Type", "text/html")
|
||||||
testResponseModifier.Header().Set("X-Custom-Resp", "resp-value")
|
testResponseModifier.Header().Set("X-Custom-Resp", "resp-value")
|
||||||
testResponseModifier.WriteHeader(200)
|
testResponseModifier.WriteHeader(200)
|
||||||
@@ -483,7 +484,7 @@ func TestExpandVars(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
err := ExpandVars(testResponseModifier, testRequest, tt.input, &out)
|
err := ExpandVars(httputils.NewResponseModifier(httptest.NewRecorder()), testRequest, tt.input, &out)
|
||||||
|
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
@@ -501,11 +502,11 @@ func TestExpandVars_Integration(t *testing.T) {
|
|||||||
testRequest.Header.Set("User-Agent", "curl/7.68.0")
|
testRequest.Header.Set("User-Agent", "curl/7.68.0")
|
||||||
testRequest.RemoteAddr = "10.0.0.1:54321"
|
testRequest.RemoteAddr = "10.0.0.1:54321"
|
||||||
|
|
||||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||||
testResponseModifier.WriteHeader(200)
|
testResponseModifier.WriteHeader(200)
|
||||||
|
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
err := ExpandVars(testResponseModifier, testRequest,
|
err := ExpandVars(httputils.NewResponseModifier(httptest.NewRecorder()), testRequest,
|
||||||
"$req_method $req_url $status_code User-Agent=$header(User-Agent)",
|
"$req_method $req_url $status_code User-Agent=$header(User-Agent)",
|
||||||
&out)
|
&out)
|
||||||
|
|
||||||
@@ -516,7 +517,7 @@ func TestExpandVars_Integration(t *testing.T) {
|
|||||||
t.Run("with query parameters", func(t *testing.T) {
|
t.Run("with query parameters", func(t *testing.T) {
|
||||||
testRequest := httptest.NewRequest("GET", "http://example.com/search?q=test&page=1", nil)
|
testRequest := httptest.NewRequest("GET", "http://example.com/search?q=test&page=1", nil)
|
||||||
|
|
||||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||||
|
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
err := ExpandVars(testResponseModifier, testRequest,
|
err := ExpandVars(testResponseModifier, testRequest,
|
||||||
@@ -530,13 +531,13 @@ func TestExpandVars_Integration(t *testing.T) {
|
|||||||
t.Run("response headers", func(t *testing.T) {
|
t.Run("response headers", func(t *testing.T) {
|
||||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||||
|
|
||||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||||
testResponseModifier.Header().Set("Cache-Control", "no-cache")
|
testResponseModifier.Header().Set("Cache-Control", "no-cache")
|
||||||
testResponseModifier.Header().Set("X-Rate-Limit", "100")
|
testResponseModifier.Header().Set("X-Rate-Limit", "100")
|
||||||
testResponseModifier.WriteHeader(200)
|
testResponseModifier.WriteHeader(200)
|
||||||
|
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
err := ExpandVars(testResponseModifier, testRequest,
|
err := ExpandVars(httputils.NewResponseModifier(httptest.NewRecorder()), testRequest,
|
||||||
"Status: $status_code, Cache: $resp_header(Cache-Control), Limit: $resp_header(X-Rate-Limit)",
|
"Status: $status_code, Cache: $resp_header(Cache-Control), Limit: $resp_header(X-Rate-Limit)",
|
||||||
&out)
|
&out)
|
||||||
|
|
||||||
@@ -569,7 +570,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
err := ExpandVars(testResponseModifier, tt.request, "$req_scheme", &out)
|
err := ExpandVars(testResponseModifier, tt.request, "$req_scheme", &out)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -582,7 +583,7 @@ func TestExpandVars_UpstreamVariables(t *testing.T) {
|
|||||||
// Upstream variables require context from routes package
|
// Upstream variables require context from routes package
|
||||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||||
|
|
||||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||||
|
|
||||||
// Test that upstream variables don't cause errors even when not set
|
// Test that upstream variables don't cause errors even when not set
|
||||||
upstreamVars := []string{
|
upstreamVars := []string{
|
||||||
@@ -609,7 +610,7 @@ func TestExpandVars_NoHostPort(t *testing.T) {
|
|||||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||||
testRequest.Host = "example.com" // No port
|
testRequest.Host = "example.com" // No port
|
||||||
|
|
||||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||||
|
|
||||||
t.Run("req_host without port", func(t *testing.T) {
|
t.Run("req_host without port", func(t *testing.T) {
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
@@ -631,7 +632,7 @@ func TestExpandVars_NoRemotePort(t *testing.T) {
|
|||||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||||
testRequest.RemoteAddr = "192.168.1.1" // No port
|
testRequest.RemoteAddr = "192.168.1.1" // No port
|
||||||
|
|
||||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||||
|
|
||||||
t.Run("remote_host without port", func(t *testing.T) {
|
t.Run("remote_host without port", func(t *testing.T) {
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
@@ -650,7 +651,7 @@ func TestExpandVars_NoRemotePort(t *testing.T) {
|
|||||||
|
|
||||||
func TestExpandVars_WhitespaceHandling(t *testing.T) {
|
func TestExpandVars_WhitespaceHandling(t *testing.T) {
|
||||||
testRequest := httptest.NewRequest("GET", "/test", nil)
|
testRequest := httptest.NewRequest("GET", "/test", nil)
|
||||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||||
|
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path", &out)
|
err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path", &out)
|
||||||
|
|||||||
Reference in New Issue
Block a user