refactor: remove forward auth, move module net/http to net/gphttp

This commit is contained in:
yusing
2025-03-28 07:03:35 +08:00
parent c0c6e21a16
commit 5d2df3550b
69 changed files with 321 additions and 745 deletions

View File

@@ -0,0 +1,82 @@
package middleware
import (
"net"
"net/http"
"github.com/go-playground/validator/v10"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
cidrWhitelist struct {
CIDRWhitelistOpts
Tracer
cachedAddr F.Map[string, bool] // cache for trusted IPs
}
CIDRWhitelistOpts struct {
Allow []*types.CIDR `validate:"min=1"`
StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,status_code"`
Message string
}
)
var (
CIDRWhiteList = NewMiddleware[cidrWhitelist]()
cidrWhitelistDefaults = CIDRWhitelistOpts{
Allow: []*types.CIDR{},
StatusCode: http.StatusForbidden,
Message: "IP not allowed",
}
)
func init() {
utils.MustRegisterValidation("status_code", func(fl validator.FieldLevel) bool {
statusCode := fl.Field().Int()
return gphttp.IsStatusCodeValid(int(statusCode))
})
}
// setup implements MiddlewareWithSetup.
func (wl *cidrWhitelist) setup() {
wl.CIDRWhitelistOpts = cidrWhitelistDefaults
wl.cachedAddr = F.NewMapOf[string, bool]()
}
// before implements RequestModifier.
func (wl *cidrWhitelist) before(w http.ResponseWriter, r *http.Request) bool {
return wl.checkIP(w, r)
}
// checkIP checks if the IP address is allowed.
func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
var allow, ok bool
if allow, ok = wl.cachedAddr.Load(r.RemoteAddr); !ok {
ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
ipStr = r.RemoteAddr
}
ip := net.ParseIP(ipStr)
for _, cidr := range wl.CIDRWhitelistOpts.Allow {
if cidr.Contains(ip) {
wl.cachedAddr.Store(r.RemoteAddr, true)
allow = true
wl.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
break
}
}
if !allow {
wl.cachedAddr.Store(r.RemoteAddr, false)
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.CIDRWhitelistOpts.Allow)
}
}
if !allow {
http.Error(w, wl.Message, wl.StatusCode)
return false
}
return true
}

View File

@@ -0,0 +1,91 @@
package middleware
import (
_ "embed"
"net"
"net/http"
"strings"
"testing"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
//go:embed test_data/cidr_whitelist_test.yml
var testCIDRWhitelistCompose []byte
var deny, accept *Middleware
func TestCIDRWhitelistValidation(t *testing.T) {
const testMessage = "test-message"
t.Run("valid", func(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
"allow": []string{"192.168.2.100/32"},
"message": testMessage,
})
ExpectNoError(t, err)
_, err = CIDRWhiteList.New(OptionsRaw{
"allow": []string{"192.168.2.100/32"},
"message": testMessage,
"status": 403,
})
ExpectNoError(t, err)
_, err = CIDRWhiteList.New(OptionsRaw{
"allow": []string{"192.168.2.100/32"},
"message": testMessage,
"status_code": 403,
})
ExpectNoError(t, err)
})
t.Run("missing allow", func(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
"message": testMessage,
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid cidr", func(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
"allow": []string{"192.168.2.100/123"},
"message": testMessage,
})
ExpectErrorT[*net.ParseError](t, err)
})
t.Run("invalid status code", func(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
"allow": []string{"192.168.2.100/32"},
"status_code": 600,
"message": testMessage,
})
ExpectError(t, utils.ErrValidationError, err)
})
}
func TestCIDRWhitelist(t *testing.T) {
errs := gperr.NewBuilder("")
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
ExpectNoError(t, errs.Error())
deny = mids["deny@file"]
accept = mids["accept@file"]
if deny == nil || accept == nil {
panic("bug occurred")
}
t.Run("deny", func(t *testing.T) {
t.Parallel()
for range 10 {
result, err := newMiddlewareTest(deny, nil)
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults.StatusCode)
ExpectEqual(t, strings.TrimSpace(string(result.Data)), cidrWhitelistDefaults.Message)
}
})
t.Run("accept", func(t *testing.T) {
t.Parallel()
for range 10 {
result, err := newMiddlewareTest(accept, nil)
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
}
})
}

View File

@@ -0,0 +1,131 @@
package middleware
import (
"errors"
"fmt"
"io"
"net"
"net/http"
"sync"
"time"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/atomic"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type cloudflareRealIP struct {
realIP realIP
Recursive bool
}
const (
cfIPv4CIDRsEndpoint = "https://www.cloudflare.com/ips-v4"
cfIPv6CIDRsEndpoint = "https://www.cloudflare.com/ips-v6"
cfCIDRsUpdateInterval = time.Hour
cfCIDRsUpdateRetryInterval = 3 * time.Second
)
var (
cfCIDRsLastUpdate atomic.Value[time.Time]
cfCIDRsMu sync.Mutex
// RFC 1918.
localCIDRs = []*types.CIDR{
{IP: net.IPv4(127, 0, 0, 1), Mask: net.IPv4Mask(255, 255, 255, 255)}, // 127.0.0.1/32
{IP: net.IPv4(10, 0, 0, 0), Mask: net.IPv4Mask(255, 0, 0, 0)}, // 10.0.0.0/8
{IP: net.IPv4(172, 16, 0, 0), Mask: net.IPv4Mask(255, 240, 0, 0)}, // 172.16.0.0/12
{IP: net.IPv4(192, 168, 0, 0), Mask: net.IPv4Mask(255, 255, 0, 0)}, // 192.168.0.0/16
}
)
var CloudflareRealIP = NewMiddleware[cloudflareRealIP]()
// setup implements MiddlewareWithSetup.
func (cri *cloudflareRealIP) setup() {
cri.realIP.RealIPOpts = RealIPOpts{
Header: "CF-Connecting-IP",
Recursive: cri.Recursive,
}
}
// before implements RequestModifier.
func (cri *cloudflareRealIP) before(w http.ResponseWriter, r *http.Request) bool {
cidrs := tryFetchCFCIDR()
if cidrs != nil {
cri.realIP.From = cidrs
}
return cri.realIP.before(w, r)
}
func (cri *cloudflareRealIP) enableTrace() {
cri.realIP.enableTrace()
}
func (cri *cloudflareRealIP) getTracer() *Tracer {
return cri.realIP.getTracer()
}
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
if time.Since(cfCIDRsLastUpdate.Load()) < cfCIDRsUpdateInterval {
return
}
cfCIDRsMu.Lock()
defer cfCIDRsMu.Unlock()
if time.Since(cfCIDRsLastUpdate.Load()) < cfCIDRsUpdateInterval {
return
}
if common.IsTest {
cfCIDRs = localCIDRs
} else {
cfCIDRs = make([]*types.CIDR, 0, 30)
err := errors.Join(
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, &cfCIDRs),
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs),
)
if err != nil {
cfCIDRsLastUpdate.Store(time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval))
logging.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
return nil
}
if len(cfCIDRs) == 0 {
logging.Warn().Msg("cloudflare CIDR range is empty")
}
}
cfCIDRsLastUpdate.Store(time.Now())
logging.Info().Msg("cloudflare CIDR range updated")
return
}
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
resp, err := http.Get(endpoint)
if err != nil {
return err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
for _, line := range strutils.SplitLine(string(body)) {
if line == "" {
continue
}
_, cidr, err := net.ParseCIDR(line)
if err != nil {
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
}
*cfCIDRs = append(*cfCIDRs, (*types.CIDR)(cidr))
}
*cfCIDRs = append(*cfCIDRs, localCIDRs...)
return nil
}

View File

@@ -0,0 +1,80 @@
package middleware
import (
"bytes"
"io"
"net/http"
"path/filepath"
"strconv"
"strings"
"github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/middleware/errorpage"
)
type customErrorPage struct{}
var CustomErrorPage = NewMiddleware[customErrorPage]()
const StaticFilePathPrefix = "/$gperrorpage/"
// before implements RequestModifier.
func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
return !ServeStaticErrorPageFile(w, r)
}
// modifyResponse implements ResponseModifier.
func (customErrorPage) modifyResponse(resp *http.Response) error {
// only handles non-success status code and html/plain content type
contentType := gphttp.GetContentType(resp.Header)
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
if ok {
logging.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
_, _ = io.Copy(io.Discard, resp.Body) // drain the original body
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
resp.ContentLength = int64(len(errorPage))
resp.Header.Set(httpheaders.HeaderContentLength, strconv.Itoa(len(errorPage)))
resp.Header.Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
} else {
logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
}
return nil
}
return nil
}
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bool) {
path := r.URL.Path
if path != "" && path[0] != '/' {
path = "/" + path
}
if strings.HasPrefix(path, StaticFilePathPrefix) {
filename := path[len(StaticFilePathPrefix):]
file, ok := errorpage.GetStaticFile(filename)
if !ok {
logging.Error().Msg("unable to load resource " + filename)
return false
}
ext := filepath.Ext(filename)
switch ext {
case ".html":
w.Header().Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
case ".js":
w.Header().Set(httpheaders.HeaderContentType, "application/javascript; charset=utf-8")
case ".css":
w.Header().Set(httpheaders.HeaderContentType, "text/css; charset=utf-8")
default:
logging.Error().Msgf("unexpected file type %q for %s", ext, filename)
}
if _, err := w.Write(file); err != nil {
logging.Err(err).Msg("unable to write resource " + filename)
http.Error(w, "Error page failure", http.StatusInternalServerError)
}
return true
}
return false
}

View File

@@ -0,0 +1,96 @@
package errorpage
import (
"fmt"
"os"
"path"
"sync"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
W "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events"
)
const errPagesBasePath = common.ErrorPagesBasePath
var (
setupOnce sync.Once
dirWatcher W.Watcher
fileContentMap = F.NewMapOf[string, []byte]()
)
func setup() {
t := task.RootTask("error_page", false)
dirWatcher = W.NewDirectoryWatcher(t, errPagesBasePath)
loadContent()
go watchDir()
}
func GetStaticFile(filename string) ([]byte, bool) {
setupOnce.Do(setup)
return fileContentMap.Load(filename)
}
// try <statusCode>.html -> 404.html -> not ok.
func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
content, ok = GetStaticFile(fmt.Sprintf("%d.html", statusCode))
if !ok && statusCode != 404 {
return fileContentMap.Load("404.html")
}
return
}
func loadContent() {
files, err := U.ListFiles(errPagesBasePath, 0)
if err != nil {
logging.Err(err).Msg("failed to list error page resources")
return
}
for _, file := range files {
if fileContentMap.Has(file) {
continue
}
content, err := os.ReadFile(file)
if err != nil {
logging.Warn().Err(err).Msgf("failed to read error page resource %s", file)
continue
}
file = path.Base(file)
logging.Info().Msgf("error page resource %s loaded", file)
fileContentMap.Store(file, content)
}
}
func watchDir() {
eventCh, errCh := dirWatcher.Events(task.RootContext())
for {
select {
case <-task.RootContextCanceled():
return
case event, ok := <-eventCh:
if !ok {
return
}
filename := event.ActorName
switch event.Action {
case events.ActionFileWritten:
fileContentMap.Delete(filename)
loadContent()
case events.ActionFileDeleted:
fileContentMap.Delete(filename)
logging.Warn().Msgf("error page resource %s deleted", filename)
case events.ActionFileRenamed:
logging.Warn().Msgf("error page resource %s deleted", filename)
fileContentMap.Delete(filename)
loadContent()
}
case err := <-errCh:
gperr.LogError("error watching error page directory", err)
}
}
}

View File

@@ -0,0 +1,44 @@
package metricslogger
import (
"net"
"net/http"
"github.com/yusing/go-proxy/internal/metrics"
)
type MetricsLogger struct {
ServiceName string `json:"service_name"`
}
func NewMetricsLogger(serviceName string) *MetricsLogger {
return &MetricsLogger{serviceName}
}
func (m *MetricsLogger) GetHandler(next http.Handler) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
m.ServeHTTP(rw, req, next.ServeHTTP)
}
}
func (m *MetricsLogger) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
visitorIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
visitorIP = req.RemoteAddr
}
// req.RemoteAddr had been modified by middleware (if any)
lbls := &metrics.HTTPRouteMetricLabels{
Service: m.ServiceName,
Method: req.Method,
Host: req.Host,
Visitor: visitorIP,
Path: req.URL.Path,
}
next.ServeHTTP(newHTTPMetricLogger(rw, lbls), req)
}
func (m *MetricsLogger) ResetMetrics() {
metrics.GetRouteMetrics().UnregisterService(m.ServiceName)
}

View File

@@ -0,0 +1,47 @@
package metricslogger
import (
"net/http"
"time"
"github.com/yusing/go-proxy/internal/metrics"
)
type httpMetricLogger struct {
http.ResponseWriter
timestamp time.Time
labels *metrics.HTTPRouteMetricLabels
}
// WriteHeader implements http.ResponseWriter.
func (l *httpMetricLogger) WriteHeader(status int) {
l.ResponseWriter.WriteHeader(status)
duration := time.Since(l.timestamp)
go func() {
m := metrics.GetRouteMetrics()
m.HTTPReqTotal.Inc()
m.HTTPReqElapsed.With(l.labels).Set(float64(duration.Milliseconds()))
// ignore 1xx
switch {
case status >= 500:
m.HTTP5xx.With(l.labels).Inc()
case status >= 400:
m.HTTP4xx.With(l.labels).Inc()
case status >= 200:
m.HTTP2xx3xx.With(l.labels).Inc()
}
}()
}
func (l *httpMetricLogger) Unwrap() http.ResponseWriter {
return l.ResponseWriter
}
func newHTTPMetricLogger(w http.ResponseWriter, labels *metrics.HTTPRouteMetricLabels) *httpMetricLogger {
return &httpMetricLogger{
ResponseWriter: w,
timestamp: time.Now(),
labels: labels,
}
}

View File

@@ -0,0 +1,237 @@
package middleware
import (
"encoding/json"
"net/http"
"reflect"
"sort"
"strings"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/utils"
)
type (
Error = gperr.Error
ReverseProxy = reverseproxy.ReverseProxy
ProxyRequest = reverseproxy.ProxyRequest
ImplNewFunc = func() any
OptionsRaw = map[string]any
Middleware struct {
name string
construct ImplNewFunc
impl any
// priority is only applied for ReverseProxy.
//
// Middleware compose follows the order of the slice
//
// Default is 10, 0 is the highest
priority int
}
ByPriority []*Middleware
RequestModifier interface {
before(w http.ResponseWriter, r *http.Request) (proceed bool)
}
ResponseModifier interface{ modifyResponse(r *http.Response) error }
MiddlewareWithSetup interface{ setup() }
MiddlewareFinalizer interface{ finalize() }
MiddlewareFinalizerWithError interface {
finalize() error
}
MiddlewareWithTracer interface {
enableTrace()
getTracer() *Tracer
}
)
const DefaultPriority = 10
func (m ByPriority) Len() int { return len(m) }
func (m ByPriority) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
func (m ByPriority) Less(i, j int) bool { return m[i].priority < m[j].priority }
func NewMiddleware[ImplType any]() *Middleware {
// type check
t := any(new(ImplType))
switch t.(type) {
case RequestModifier:
case ResponseModifier:
default:
panic("must implement RequestModifier or ResponseModifier")
}
_, hasFinializer := t.(MiddlewareFinalizer)
_, hasFinializerWithError := t.(MiddlewareFinalizerWithError)
if hasFinializer && hasFinializerWithError {
panic("MiddlewareFinalizer and MiddlewareFinalizerWithError are mutually exclusive")
}
return &Middleware{
name: strings.ToLower(reflect.TypeFor[ImplType]().Name()),
construct: func() any { return new(ImplType) },
}
}
func (m *Middleware) enableTrace() {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
tracer.enableTrace()
logging.Trace().Msgf("middleware %s enabled trace", m.name)
}
}
func (m *Middleware) getTracer() *Tracer {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
return tracer.getTracer()
}
return nil
}
func (m *Middleware) setParent(parent *Middleware) {
if tracer := m.getTracer(); tracer != nil {
tracer.SetParent(parent.getTracer())
}
}
func (m *Middleware) setup() {
if setup, ok := m.impl.(MiddlewareWithSetup); ok {
setup.setup()
}
}
func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error {
if len(optsRaw) == 0 {
return nil
}
priority, ok := optsRaw["priority"].(int)
if ok {
m.priority = priority
// remove priority for deserialization, restore later
delete(optsRaw, "priority")
defer func() {
optsRaw["priority"] = priority
}()
} else {
m.priority = DefaultPriority
}
return utils.Deserialize(optsRaw, m.impl)
}
func (m *Middleware) finalize() error {
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
finalizer.finalize()
return nil
}
if finalizer, ok := m.impl.(MiddlewareFinalizerWithError); ok {
return finalizer.finalize()
}
return nil
}
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, gperr.Error) {
if m.construct == nil { // likely a middleware from compose
if len(optsRaw) != 0 {
return nil, gperr.New("additional options not allowed for middleware ").Subject(m.name)
}
return m, nil
}
mid := &Middleware{name: m.name, impl: m.construct()}
mid.setup()
if err := mid.apply(optsRaw); err != nil {
return nil, err
}
if err := mid.finalize(); err != nil {
return nil, gperr.Wrap(err)
}
return mid, nil
}
func (m *Middleware) Name() string {
return m.name
}
func (m *Middleware) String() string {
return m.name
}
func (m *Middleware) MarshalJSON() ([]byte, error) {
return json.MarshalIndent(map[string]any{
"name": m.name,
"options": m.impl,
"priority": m.priority,
}, "", " ")
}
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
if exec, ok := m.impl.(RequestModifier); ok {
if proceed := exec.before(w, r); !proceed {
return
}
}
next(w, r)
}
func (m *Middleware) ModifyResponse(resp *http.Response) error {
if exec, ok := m.impl.(ResponseModifier); ok {
return exec.modifyResponse(resp)
}
return nil
}
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
if exec, ok := m.impl.(ResponseModifier); ok {
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
return exec.modifyResponse(resp)
})
}
if exec, ok := m.impl.(RequestModifier); ok {
if proceed := exec.before(w, r); !proceed {
return
}
}
next(w, r)
}
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err gperr.Error) {
var middlewares []*Middleware
middlewares, err = compileMiddlewares(middlewaresMap)
if err != nil {
return
}
patchReverseProxy(rp, middlewares)
return
}
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
sort.Sort(ByPriority(middlewares))
middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
mid := NewMiddlewareChain(rp.TargetName, middlewares)
if before, ok := mid.impl.(RequestModifier); ok {
next := rp.HandlerFunc
rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
if proceed := before.before(w, r); proceed {
next(w, r)
}
}
}
if mr, ok := mid.impl.(ResponseModifier); ok {
if rp.ModifyResponse != nil {
ori := rp.ModifyResponse
rp.ModifyResponse = func(res *http.Response) error {
if err := mr.modifyResponse(res); err != nil {
return err
}
return ori(res)
}
} else {
rp.ModifyResponse = mr.modifyResponse
}
}
}

View File

@@ -0,0 +1,107 @@
package middleware
import (
"fmt"
"os"
"path"
"sort"
"github.com/yusing/go-proxy/internal/gperr"
"gopkg.in/yaml.v3"
)
var ErrMissingMiddlewareUse = gperr.New("missing middleware 'use' field")
func BuildMiddlewaresFromComposeFile(filePath string, eb *gperr.Builder) map[string]*Middleware {
fileContent, err := os.ReadFile(filePath)
if err != nil {
eb.Add(err)
return nil
}
return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
}
func BuildMiddlewaresFromYAML(source string, data []byte, eb *gperr.Builder) map[string]*Middleware {
var rawMap map[string][]map[string]any
err := yaml.Unmarshal(data, &rawMap)
if err != nil {
eb.Add(err)
return nil
}
middlewares := make(map[string]*Middleware)
for name, defs := range rawMap {
chain, err := BuildMiddlewareFromChainRaw(name, defs)
if err != nil {
eb.Add(err.Subject(source))
} else {
middlewares[name+"@file"] = chain
}
}
return middlewares
}
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, gperr.Error) {
middlewares := make([]*Middleware, 0, len(middlewaresMap))
errs := gperr.NewBuilder("middlewares compile error")
invalidOpts := gperr.NewBuilder("options compile error")
for name, opts := range middlewaresMap {
m, err := Get(name)
if err != nil {
errs.Add(err)
continue
}
m, err = m.New(opts)
if err != nil {
invalidOpts.Add(err.Subject(name))
continue
}
middlewares = append(middlewares, m)
}
if invalidOpts.HasError() {
errs.Add(invalidOpts.Error())
}
sort.Sort(ByPriority(middlewares))
return middlewares, errs.Error()
}
func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (*Middleware, gperr.Error) {
compiled, err := compileMiddlewares(middlewaresMap)
if err != nil {
return nil, err
}
return NewMiddlewareChain(name, compiled), nil
}
// TODO: check conflict or duplicates.
func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middleware, gperr.Error) {
chainErr := gperr.NewBuilder("")
chain := make([]*Middleware, 0, len(defs))
for i, def := range defs {
if def["use"] == nil || def["use"] == "" {
chainErr.Add(ErrMissingMiddlewareUse.Subjectf("%s[%d]", name, i))
continue
}
baseName := def["use"].(string)
base, err := Get(baseName)
if err != nil {
chainErr.Add(err.Subjectf("%s[%d]", name, i))
continue
}
delete(def, "use")
m, err := base.New(def)
if err != nil {
chainErr.Add(err.Subjectf("%s[%d]", name, i))
continue
}
m.name = fmt.Sprintf("%s[%d]", name, i)
chain = append(chain, m)
}
if chainErr.HasError() {
return nil, chainErr.Error()
}
return NewMiddlewareChain(name, chain), nil
}

View File

@@ -0,0 +1,22 @@
package middleware
import (
_ "embed"
"encoding/json"
"testing"
"github.com/yusing/go-proxy/internal/gperr"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
//go:embed test_data/middleware_compose.yml
var testMiddlewareCompose []byte
func TestBuild(t *testing.T) {
errs := gperr.NewBuilder("")
middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs)
ExpectNoError(t, errs.Error())
Must(json.MarshalIndent(middlewares, "", " "))
// t.Log(string(data))
// TODO: test
}

View File

@@ -0,0 +1,61 @@
package middleware
import (
"net/http"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
)
type middlewareChain struct {
befores []RequestModifier
modResps []ResponseModifier
}
// TODO: check conflict or duplicates.
func NewMiddlewareChain(name string, chain []*Middleware) *Middleware {
chainMid := &middlewareChain{befores: []RequestModifier{}, modResps: []ResponseModifier{}}
m := &Middleware{name: name, impl: chainMid}
for _, comp := range chain {
if before, ok := comp.impl.(RequestModifier); ok {
chainMid.befores = append(chainMid.befores, before)
}
if mr, ok := comp.impl.(ResponseModifier); ok {
chainMid.modResps = append(chainMid.modResps, mr)
}
comp.setParent(m)
}
if common.IsDebug {
for _, child := range chain {
child.enableTrace()
}
m.enableTrace()
}
return m
}
// before implements RequestModifier.
func (m *middlewareChain) before(w http.ResponseWriter, r *http.Request) (proceedNext bool) {
for _, b := range m.befores {
if proceedNext = b.before(w, r); !proceedNext {
return false
}
}
return true
}
// modifyResponse implements ResponseModifier.
func (m *middlewareChain) modifyResponse(resp *http.Response) error {
if len(m.modResps) == 0 {
return nil
}
errs := gperr.NewBuilder("modify response errors")
for i, mr := range m.modResps {
if err := mr.modifyResponse(resp); err != nil {
errs.Add(gperr.Wrap(err).Subjectf("%d", i))
}
}
return errs.Error()
}

View File

@@ -0,0 +1,37 @@
package middleware
import (
"net/http"
"strconv"
"strings"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
type testPriority struct {
Value int `json:"value"`
}
var test = NewMiddleware[testPriority]()
func (t testPriority) before(w http.ResponseWriter, r *http.Request) bool {
w.Header().Add("Test-Value", strconv.Itoa(t.Value))
return true
}
func TestMiddlewarePriority(t *testing.T) {
priorities := []int{4, 7, 9, 0}
chain := make([]*Middleware, len(priorities))
for i, p := range priorities {
mid, err := test.New(OptionsRaw{
"priority": p,
"value": i,
})
ExpectNoError(t, err)
chain[i] = mid
}
res, err := newMiddlewaresTest(chain, nil)
ExpectNoError(t, err)
ExpectEqual(t, strings.Join(res.ResponseHeaders["Test-Value"], ","), "3,0,1,2")
}

View File

@@ -0,0 +1,104 @@
package middleware
import (
"path"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
// snakes and cases will be stripped on `Get`
// so keys are lowercase without snake.
var allMiddlewares = map[string]*Middleware{
"redirecthttp": RedirectHTTP,
"oidc": OIDC,
"request": ModifyRequest,
"modifyrequest": ModifyRequest,
"response": ModifyResponse,
"modifyresponse": ModifyResponse,
"setxforwarded": SetXForwarded,
"hidexforwarded": HideXForwarded,
"errorpage": CustomErrorPage,
"customerrorpage": CustomErrorPage,
"realip": RealIP,
"cloudflarerealip": CloudflareRealIP,
"cidrwhitelist": CIDRWhiteList,
"ratelimit": RateLimiter,
}
var (
ErrUnknownMiddleware = gperr.New("unknown middleware")
ErrDuplicatedMiddleware = gperr.New("duplicated middleware")
)
func Get(name string) (*Middleware, Error) {
middleware, ok := allMiddlewares[strutils.ToLowerNoSnake(name)]
if !ok {
return nil, ErrUnknownMiddleware.
Subject(name).
Withf(strutils.DoYouMean(utils.NearestField(name, allMiddlewares)))
}
return middleware, nil
}
func All() map[string]*Middleware {
return allMiddlewares
}
func LoadComposeFiles() {
errs := gperr.NewBuilder("middleware compile errors")
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
if err != nil {
logging.Err(err).Msg("failed to list middleware definitions")
return
}
for _, defFile := range middlewareDefs {
voidErrs := gperr.NewBuilder("") // ignore these errors, will be added in next step
mws := BuildMiddlewaresFromComposeFile(defFile, voidErrs)
if len(mws) == 0 {
continue
}
for name, m := range mws {
name = strutils.ToLowerNoSnake(name)
if _, ok := allMiddlewares[name]; ok {
errs.Add(ErrDuplicatedMiddleware.Subject(name))
continue
}
allMiddlewares[name] = m
logging.Info().
Str("src", path.Base(defFile)).
Str("name", name).
Msg("middleware loaded")
}
}
// build again to resolve cross references
for _, defFile := range middlewareDefs {
mws := BuildMiddlewaresFromComposeFile(defFile, errs)
if len(mws) == 0 {
continue
}
for name, m := range mws {
name = strutils.ToLowerNoSnake(name)
if _, ok := allMiddlewares[name]; ok {
// already loaded above
continue
}
allMiddlewares[name] = m
logging.Info().
Str("src", path.Base(defFile)).
Str("name", name).
Msg("middleware loaded")
}
}
if errs.HasError() {
gperr.LogError(errs.About(), errs.Error())
}
}

View File

@@ -0,0 +1,91 @@
package middleware
import (
"net/http"
"path/filepath"
"strings"
)
type (
modifyRequest struct {
ModifyRequestOpts
Tracer
}
// order: add_prefix -> set_headers -> add_headers -> hide_headers
ModifyRequestOpts struct {
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
AddPrefix string
needVarSubstitution bool
}
)
var ModifyRequest = NewMiddleware[modifyRequest]()
// finalize implements MiddlewareFinalizer.
func (mr *ModifyRequestOpts) finalize() {
mr.checkVarSubstitution()
}
// before implements RequestModifier.
func (mr *modifyRequest) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
mr.AddTraceRequest("before modify request", r)
mr.addPrefix(r, nil, r.URL.Path)
mr.modifyHeaders(r, nil, r.Header)
mr.AddTraceRequest("after modify request", r)
return true
}
func (mr *ModifyRequestOpts) checkVarSubstitution() {
for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} {
for _, v := range m {
if strings.ContainsRune(v, '$') {
mr.needVarSubstitution = true
return
}
}
}
}
func (mr *ModifyRequestOpts) modifyHeaders(req *http.Request, resp *http.Response, headers http.Header) {
if !mr.needVarSubstitution {
for k, v := range mr.SetHeaders {
if req != nil && strings.EqualFold(k, "host") {
defer func() {
req.Host = v
}()
}
headers[k] = []string{v}
}
for k, v := range mr.AddHeaders {
headers[k] = append(headers[k], v)
}
} else {
for k, v := range mr.SetHeaders {
if req != nil && strings.EqualFold(k, "host") {
defer func() {
req.Host = varReplace(req, resp, v)
}()
}
headers[k] = []string{varReplace(req, resp, v)}
}
for k, v := range mr.AddHeaders {
headers[k] = append(headers[k], varReplace(req, resp, v))
}
}
for _, k := range mr.HideHeaders {
delete(headers, k)
}
}
func (mr *modifyRequest) addPrefix(r *http.Request, _ *http.Response, path string) {
if len(mr.AddPrefix) == 0 {
return
}
r.URL.Path = filepath.Join(mr.AddPrefix, path)
}

View File

@@ -0,0 +1,145 @@
package middleware
import (
"bytes"
"net"
"net/http"
"slices"
"testing"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestModifyRequest(t *testing.T) {
opts := OptionsRaw{
"set_headers": map[string]string{
"User-Agent": "go-proxy/v0.5.0",
"Host": VarUpstreamAddr,
"X-Test-Req-Method": VarRequestMethod,
"X-Test-Req-Scheme": VarRequestScheme,
"X-Test-Req-Host": VarRequestHost,
"X-Test-Req-Port": VarRequestPort,
"X-Test-Req-Addr": VarRequestAddr,
"X-Test-Req-Path": VarRequestPath,
"X-Test-Req-Query": VarRequestQuery,
"X-Test-Req-Url": VarRequestURL,
"X-Test-Req-Uri": VarRequestURI,
"X-Test-Req-Content-Type": VarRequestContentType,
"X-Test-Req-Content-Length": VarRequestContentLen,
"X-Test-Remote-Host": VarRemoteHost,
"X-Test-Remote-Port": VarRemotePort,
"X-Test-Remote-Addr": VarRemoteAddr,
"X-Test-Upstream-Scheme": VarUpstreamScheme,
"X-Test-Upstream-Host": VarUpstreamHost,
"X-Test-Upstream-Port": VarUpstreamPort,
"X-Test-Upstream-Addr": VarUpstreamAddr,
"X-Test-Upstream-Url": VarUpstreamURL,
"X-Test-Header-Content-Type": "$header(Content-Type)",
"X-Test-Arg-Arg_1": "$arg(arg_1)",
},
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
"hide_headers": []string{"Accept"},
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.New(opts)
ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string))
})
t.Run("request_headers", func(t *testing.T) {
reqURL := types.MustParseURL("https://my.app/?arg_1=b")
upstreamURL := types.MustParseURL("http://test.example.com")
result, err := newMiddlewareTest(ModifyRequest, &testArgs{
middlewareOpt: opts,
reqURL: reqURL,
upstreamURL: upstreamURL,
body: bytes.Repeat([]byte("a"), 100),
headers: http.Header{
"Content-Type": []string{"application/json"},
},
})
ExpectNoError(t, err)
ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
ExpectEqual(t, result.RequestHeaders.Get("Host"), "test.example.com")
ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Method"), "GET")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Port"), reqURL.Port())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Path"), reqURL.Path)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Url"), reqURL.String())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Type"), "application/json")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Length"), "100")
remoteHost, remotePort, _ := net.SplitHostPort(result.RemoteAddr)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Host"), remoteHost)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Port"), remotePort)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Header-Content-Type"), "application/json")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Arg-Arg_1"), "b")
})
t.Run("add_prefix", func(t *testing.T) {
tests := []struct {
name string
path string
expectedPath string
upstreamURL string
addPrefix string
}{
{
name: "no prefix",
path: "/foo",
expectedPath: "/foo",
upstreamURL: "http://test.example.com",
},
{
name: "slash only",
path: "/",
expectedPath: "/",
upstreamURL: "http://test.example.com",
addPrefix: "/", // should not change anything
},
{
name: "some prefix",
path: "/test",
expectedPath: "/foo/test",
upstreamURL: "http://test.example.com",
addPrefix: "/foo",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reqURL := types.MustParseURL("https://my.app" + tt.path)
upstreamURL := types.MustParseURL(tt.upstreamURL)
opts["add_prefix"] = tt.addPrefix
result, err := newMiddlewareTest(ModifyRequest, &testArgs{
middlewareOpt: opts,
reqURL: reqURL,
upstreamURL: upstreamURL,
})
ExpectNoError(t, err)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Path"), tt.expectedPath)
})
}
})
}

View File

@@ -0,0 +1,20 @@
package middleware
import (
"net/http"
)
type modifyResponse struct {
ModifyRequestOpts
Tracer
}
var ModifyResponse = NewMiddleware[modifyResponse]()
// modifyResponse implements ResponseModifier.
func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
mr.AddTraceResponse("before modify response", resp)
mr.modifyHeaders(resp.Request, resp, resp.Header)
mr.AddTraceResponse("after modify response", resp)
return nil
}

View File

@@ -0,0 +1,108 @@
package middleware
import (
"bytes"
"net"
"net/http"
"slices"
"testing"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestModifyResponse(t *testing.T) {
opts := OptionsRaw{
"set_headers": map[string]string{
"X-Test-Resp-Status": VarRespStatusCode,
"X-Test-Resp-Content-Type": VarRespContentType,
"X-Test-Resp-Content-Length": VarRespContentLen,
"X-Test-Resp-Header-Content-Type": "$resp_header(Content-Type)",
"X-Test-Req-Method": VarRequestMethod,
"X-Test-Req-Scheme": VarRequestScheme,
"X-Test-Req-Host": VarRequestHost,
"X-Test-Req-Port": VarRequestPort,
"X-Test-Req-Addr": VarRequestAddr,
"X-Test-Req-Path": VarRequestPath,
"X-Test-Req-Query": VarRequestQuery,
"X-Test-Req-Url": VarRequestURL,
"X-Test-Req-Uri": VarRequestURI,
"X-Test-Req-Content-Type": VarRequestContentType,
"X-Test-Req-Content-Length": VarRequestContentLen,
"X-Test-Remote-Host": VarRemoteHost,
"X-Test-Remote-Port": VarRemotePort,
"X-Test-Remote-Addr": VarRemoteAddr,
"X-Test-Upstream-Scheme": VarUpstreamScheme,
"X-Test-Upstream-Host": VarUpstreamHost,
"X-Test-Upstream-Port": VarUpstreamPort,
"X-Test-Upstream-Addr": VarUpstreamAddr,
"X-Test-Upstream-Url": VarUpstreamURL,
"X-Test-Header-Content-Type": "$header(Content-Type)",
"X-Test-Arg-Arg_1": "$arg(arg_1)",
},
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
"hide_headers": []string{"Accept"},
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.New(opts)
ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
})
t.Run("response_headers", func(t *testing.T) {
reqURL := types.MustParseURL("https://my.app/?arg_1=b")
upstreamURL := types.MustParseURL("http://test.example.com")
result, err := newMiddlewareTest(ModifyResponse, &testArgs{
middlewareOpt: opts,
reqURL: reqURL,
upstreamURL: upstreamURL,
body: bytes.Repeat([]byte("a"), 100),
headers: http.Header{
"Content-Type": []string{"application/json"},
},
respHeaders: http.Header{
"Content-Type": []string{"application/json"},
},
respBody: bytes.Repeat([]byte("a"), 50),
respStatus: http.StatusOK,
})
ExpectNoError(t, err)
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Status"), "200")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Type"), "application/json")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Length"), "50")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Header-Content-Type"), "application/json")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Method"), http.MethodGet)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Port"), reqURL.Port())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Path"), reqURL.Path)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Url"), reqURL.String())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Content-Type"), "application/json")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Content-Length"), "100")
remoteHost, remotePort, _ := net.SplitHostPort(result.RemoteAddr)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Host"), remoteHost)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Port"), remotePort)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Header-Content-Type"), "application/json")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Arg-Arg_1"), "b")
})
}

View File

@@ -0,0 +1,93 @@
package middleware
import (
"errors"
"net/http"
"sync"
"sync/atomic"
"github.com/yusing/go-proxy/internal/api/v1/auth"
"github.com/yusing/go-proxy/internal/gperr"
)
type oidcMiddleware struct {
AllowedUsers []string `json:"allowed_users"`
AllowedGroups []string `json:"allowed_groups"`
auth auth.Provider
authMux *http.ServeMux
isInitialized int32
initMu sync.Mutex
}
var OIDC = NewMiddleware[oidcMiddleware]()
func (amw *oidcMiddleware) finalize() error {
if !auth.IsOIDCEnabled() {
return gperr.New("OIDC not enabled but OIDC middleware is used")
}
return nil
}
func (amw *oidcMiddleware) init() error {
if atomic.LoadInt32(&amw.isInitialized) == 1 {
return nil
}
return amw.initSlow()
}
func (amw *oidcMiddleware) initSlow() error {
amw.initMu.Lock()
if amw.isInitialized == 1 {
amw.initMu.Unlock()
return nil
}
defer func() {
amw.isInitialized = 1
amw.initMu.Unlock()
}()
authProvider, err := auth.NewOIDCProviderFromEnv()
if err != nil {
return err
}
authProvider.SetIsMiddleware(true)
if len(amw.AllowedUsers) > 0 {
authProvider.SetAllowedUsers(amw.AllowedUsers)
}
if len(amw.AllowedGroups) > 0 {
authProvider.SetAllowedGroups(amw.AllowedGroups)
}
amw.authMux = http.NewServeMux()
amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage)
amw.auth = authProvider
return nil
}
func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
if err := amw.init(); err != nil {
// no need to log here, main OIDC may already failed and logged
http.Error(w, err.Error(), http.StatusInternalServerError)
return false
}
if r.URL.Path == auth.OIDCLogoutPath {
amw.auth.LogoutCallbackHandler(w, r)
return false
}
if err := amw.auth.CheckToken(r); err != nil {
if errors.Is(err, auth.ErrMissingToken) {
amw.authMux.ServeHTTP(w, r)
} else {
auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath)
}
return false
}
return true
}

View File

@@ -0,0 +1,73 @@
package middleware
import (
"net"
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
)
type (
requestMap = map[string]*rate.Limiter
rateLimiter struct {
RateLimiterOpts
Tracer
requestMap requestMap
mu sync.Mutex
}
RateLimiterOpts struct {
Average int `validate:"min=1,required"`
Burst int `validate:"min=1,required"`
Period time.Duration `validate:"min=1s"`
}
)
var (
RateLimiter = NewMiddleware[rateLimiter]()
rateLimiterOptsDefault = RateLimiterOpts{
Period: time.Second,
}
)
// setup implements MiddlewareWithSetup.
func (rl *rateLimiter) setup() {
rl.RateLimiterOpts = rateLimiterOptsDefault
rl.requestMap = make(requestMap, 0)
}
// before implements RequestModifier.
func (rl *rateLimiter) before(w http.ResponseWriter, r *http.Request) bool {
return rl.limit(w, r)
}
func (rl *rateLimiter) newLimiter() *rate.Limiter {
return rate.NewLimiter(rate.Limit(rl.Average)*rate.Every(rl.Period), rl.Burst)
}
func (rl *rateLimiter) limit(w http.ResponseWriter, r *http.Request) bool {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
rl.AddTracef("unable to parse remote address %s", r.RemoteAddr)
http.Error(w, "Internal error", http.StatusInternalServerError)
return false
}
rl.mu.Lock()
limiter, ok := rl.requestMap[host]
if !ok {
limiter = rl.newLimiter()
rl.requestMap[host] = limiter
}
rl.mu.Unlock()
if limiter.Allow() {
return true
}
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
return false
}

View File

@@ -0,0 +1,27 @@
package middleware
import (
"net/http"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRateLimit(t *testing.T) {
opts := OptionsRaw{
"average": "10",
"burst": "10",
"period": "1s",
}
rl, err := RateLimiter.New(opts)
ExpectNoError(t, err)
for range 10 {
result, err := newMiddlewareTest(rl, nil)
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
}
result, err := newMiddlewareTest(rl, nil)
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusTooManyRequests)
}

View File

@@ -0,0 +1,116 @@
package middleware
import (
"net"
"net/http"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/types"
)
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
type (
realIP struct {
RealIPOpts
Tracer
}
RealIPOpts struct {
// Header is the name of the header to use for the real client IP
Header string `validate:"required"`
// From is a list of Address / CIDRs to trust
From []*types.CIDR `validate:"required,min=1"`
/*
If recursive search is disabled,
the original client address that matches one of the trusted addresses is replaced by
the last address sent in the request header field defined by the Header field.
If recursive search is enabled,
the original client address that matches one of the trusted addresses is replaced by
the last non-trusted address sent in the request header field.
*/
Recursive bool
}
)
var (
RealIP = NewMiddleware[realIP]()
realIPOptsDefault = RealIPOpts{
Header: "X-Real-IP",
From: []*types.CIDR{},
}
)
// setup implements MiddlewareWithSetup.
func (ri *realIP) setup() {
ri.RealIPOpts = realIPOptsDefault
}
// before implements RequestModifier.
func (ri *realIP) before(w http.ResponseWriter, r *http.Request) bool {
ri.setRealIP(r)
return true
}
func (ri *realIP) isInCIDRList(ip net.IP) bool {
for _, CIDR := range ri.From {
if CIDR.Contains(ip) {
return true
}
}
// not in any CIDR
return false
}
func (ri *realIP) setRealIP(req *http.Request) {
clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
clientIPStr = req.RemoteAddr
}
clientIP := net.ParseIP(clientIPStr)
isTrusted := false
for _, CIDR := range ri.From {
if CIDR.Contains(clientIP) {
isTrusted = true
break
}
}
if !isTrusted {
ri.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
return
}
realIPs := req.Header.Values(ri.Header)
lastNonTrustedIP := ""
if len(realIPs) == 0 {
// try non-canonical key
realIPs = req.Header[ri.Header]
}
if len(realIPs) == 0 {
ri.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
return
}
if !ri.Recursive {
lastNonTrustedIP = realIPs[len(realIPs)-1]
} else {
for _, r := range realIPs {
if !ri.isInCIDRList(net.ParseIP(r)) {
lastNonTrustedIP = r
}
}
}
if lastNonTrustedIP == "" {
ri.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
return
}
req.RemoteAddr = lastNonTrustedIP
req.Header.Set(ri.Header, lastNonTrustedIP)
req.Header.Set(httpheaders.HeaderXRealIP, lastNonTrustedIP)
ri.AddTracef("set real ip %s", lastNonTrustedIP)
}

View File

@@ -0,0 +1,77 @@
package middleware
import (
"net"
"net/http"
"strings"
"testing"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetRealIPOpts(t *testing.T) {
opts := OptionsRaw{
"header": httpheaders.HeaderXRealIP,
"from": []string{
"127.0.0.0/8",
"192.168.0.0/16",
"172.16.0.0/12",
},
"recursive": true,
}
optExpected := &RealIPOpts{
Header: httpheaders.HeaderXRealIP,
From: []*types.CIDR{
{
IP: net.ParseIP("127.0.0.0"),
Mask: net.IPv4Mask(255, 0, 0, 0),
},
{
IP: net.ParseIP("192.168.0.0"),
Mask: net.IPv4Mask(255, 255, 0, 0),
},
{
IP: net.ParseIP("172.16.0.0"),
Mask: net.IPv4Mask(255, 240, 0, 0),
},
},
Recursive: true,
}
ri, err := RealIP.New(opts)
ExpectNoError(t, err)
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
for i, CIDR := range ri.impl.(*realIP).From {
ExpectEqual(t, CIDR.String(), optExpected.From[i].String())
}
}
func TestSetRealIP(t *testing.T) {
const (
testHeader = httpheaders.HeaderXRealIP
testRealIP = "192.168.1.1"
)
opts := OptionsRaw{
"header": testHeader,
"from": []string{"0.0.0.0/0"},
}
optsMr := OptionsRaw{
"set_headers": map[string]string{testHeader: testRealIP},
}
realip, err := RealIP.New(opts)
ExpectNoError(t, err)
mr, err := ModifyRequest.New(optsMr)
ExpectNoError(t, err)
mid := NewMiddlewareChain("test", []*Middleware{mr, realip})
result, err := newMiddlewareTest(mid, nil)
ExpectNoError(t, err)
t.Log(traces)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
}

View File

@@ -0,0 +1,51 @@
package middleware
import (
"net"
"net/http"
"strings"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging"
)
type redirectHTTP struct {
Bypass struct {
UserAgents []string
}
}
var RedirectHTTP = NewMiddleware[redirectHTTP]()
// before implements RequestModifier.
func (m *redirectHTTP) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
if r.TLS != nil {
return true
}
if len(m.Bypass.UserAgents) > 0 {
ua := r.UserAgent()
for _, uaBypass := range m.Bypass.UserAgents {
if strings.Contains(ua, uaBypass) {
return true
}
}
}
r.URL.Scheme = "https"
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
if common.ProxyHTTPSPort != "443" {
r.URL.Host = host + ":" + common.ProxyHTTPSPort
} else {
r.URL.Host = host
}
http.Redirect(w, r, r.URL.String(), http.StatusPermanentRedirect)
logging.Debug().Str("url", r.URL.String()).Str("user_agent", r.UserAgent()).Msg("redirect to https")
return false
}

View File

@@ -0,0 +1,26 @@
package middleware
import (
"net/http"
"testing"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRedirectToHTTPs(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
reqURL: types.MustParseURL("http://example.com"),
})
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusPermanentRedirect)
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://example.com")
}
func TestNoRedirect(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
reqURL: types.MustParseURL("https://example.com"),
})
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
}

View File

@@ -0,0 +1,37 @@
package middleware
import (
"net/http"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
)
// internal use only.
type setUpstreamHeaders struct {
Name, Scheme, Host, Port string
}
var suh = NewMiddleware[setUpstreamHeaders]()
func newSetUpstreamHeaders(rp *reverseproxy.ReverseProxy) *Middleware {
m, err := suh.New(OptionsRaw{
"name": rp.TargetName,
"scheme": rp.TargetURL.Scheme,
"host": rp.TargetURL.Hostname(),
"port": rp.TargetURL.Port(),
})
if err != nil {
panic(err)
}
return m
}
// before implements RequestModifier.
func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Set(httpheaders.HeaderUpstreamName, s.Name)
r.Header.Set(httpheaders.HeaderUpstreamScheme, s.Scheme)
r.Header.Set(httpheaders.HeaderUpstreamHost, s.Host)
r.Header.Set(httpheaders.HeaderUpstreamPort, s.Port)
return true
}

View File

@@ -0,0 +1,23 @@
deny:
- use: ModifyRequest
setHeaders:
X-Real-IP: 192.168.1.1:1234
- use: RealIP
header: X-Real-IP
from:
- 0.0.0.0/0
- use: CIDRWhitelist
allow:
- 192.168.0.0/24
accept:
- use: ModifyRequest
setHeaders:
X-Real-IP: 192.168.0.1:1234
- use: RealIP
header: X-Real-IP
from:
- 0.0.0.0/0
- use: CIDRWhitelist
allow:
- 192.168.0.0/24
- 127.0.0.1

View File

@@ -0,0 +1,25 @@
theGreatPretender:
- use: HideXForwarded
- use: ModifyRequest
setHeaders:
X-Real-IP: 6.6.6.6
- use: ModifyResponse
hideHeaders:
- X-Test3
- X-Test4
realIPAuthentik:
- use: RedirectHTTP
- use: RealIP
header: X-Real-IP
from:
- "127.0.0.0/8"
- "192.168.0.0/16"
- "172.16.0.0/12"
recursive: true
testFakeRealIP:
- use: ModifyRequest
setHeaders:
CF-Connecting-IP: 127.0.0.1
- use: CloudflareRealIP

View File

@@ -0,0 +1,17 @@
{
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"Accept-Encoding": "gzip, deflate, br, zstd",
"Accept-Language": "en,zh-HK;q=0.9,zh-TW;q=0.8,zh-CN;q=0.7,zh;q=0.6",
"Dnt": "1",
"Host": "localhost",
"Priority": "u=0, i",
"Sec-Ch-Ua": "\"Chromium\";v=\"129\", \"Not=A?Brand\";v=\"8\"",
"Sec-Ch-Ua-Mobile": "?0",
"Sec-Ch-Ua-Platform": "\"Windows\"",
"Sec-Fetch-Dest": "document",
"Sec-Fetch-Mode": "navigate",
"Sec-Fetch-Site": "none",
"Sec-Fetch-User": "?1",
"Upgrade-Insecure-Requests": "1",
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36"
}

View File

@@ -0,0 +1,176 @@
package middleware
import (
"bytes"
_ "embed"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
//go:embed test_data/sample_headers.json
var testHeadersRaw []byte
var testHeaders http.Header
func init() {
if !common.IsTest {
return
}
tmp := map[string]string{}
err := json.Unmarshal(testHeadersRaw, &tmp)
if err != nil {
panic(err)
}
testHeaders = http.Header{}
for k, v := range tmp {
testHeaders.Set(k, v)
}
}
type requestRecorder struct {
args *testArgs
parent http.RoundTripper
headers http.Header
remoteAddr string
}
func newRequestRecorder(args *testArgs) *requestRecorder {
return &requestRecorder{args: args}
}
func (rt *requestRecorder) RoundTrip(req *http.Request) (resp *http.Response, err error) {
rt.headers = req.Header
rt.remoteAddr = req.RemoteAddr
if rt.parent != nil {
resp, err = rt.parent.RoundTrip(req)
} else {
resp = &http.Response{
Status: http.StatusText(rt.args.respStatus),
StatusCode: rt.args.respStatus,
Header: testHeaders,
Body: io.NopCloser(bytes.NewReader(rt.args.respBody)),
ContentLength: int64(len(rt.args.respBody)),
Request: req,
TLS: req.TLS,
}
}
if err == nil {
for k, v := range rt.args.respHeaders {
resp.Header[k] = v
}
}
return resp, nil
}
type TestResult struct {
RequestHeaders http.Header
ResponseHeaders http.Header
ResponseStatus int
RemoteAddr string
Data []byte
}
type testArgs struct {
middlewareOpt OptionsRaw
upstreamURL *types.URL
realRoundTrip bool
reqURL *types.URL
reqMethod string
headers http.Header
body []byte
respHeaders http.Header
respBody []byte
respStatus int
}
func (args *testArgs) setDefaults() {
if args.reqURL == nil {
args.reqURL = Must(types.ParseURL("https://example.com"))
}
if args.reqMethod == "" {
args.reqMethod = http.MethodGet
}
if args.upstreamURL == nil {
args.upstreamURL = Must(types.ParseURL("https://10.0.0.1:8443")) // dummy url, no actual effect
}
if args.respHeaders == nil {
args.respHeaders = http.Header{}
}
if args.respBody == nil {
args.respBody = []byte("OK")
}
if args.respStatus == 0 {
args.respStatus = http.StatusOK
}
}
func (args *testArgs) bodyReader() io.Reader {
if args.body != nil {
return bytes.NewReader(args.body)
}
return nil
}
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, gperr.Error) {
if args == nil {
args = new(testArgs)
}
args.setDefaults()
mid, setOptErr := middleware.New(args.middlewareOpt)
if setOptErr != nil {
return nil, setOptErr
}
return newMiddlewaresTest([]*Middleware{mid}, args)
}
func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, gperr.Error) {
if args == nil {
args = new(testArgs)
}
args.setDefaults()
req := httptest.NewRequest(args.reqMethod, args.reqURL.String(), args.bodyReader())
for k, v := range args.headers {
req.Header[k] = v
}
w := httptest.NewRecorder()
rr := newRequestRecorder(args)
if args.realRoundTrip {
rr.parent = http.DefaultTransport
}
rp := reverseproxy.NewReverseProxy("test", args.upstreamURL, rr)
patchReverseProxy(rp, middlewares)
rp.ServeHTTP(w, req)
resp := w.Result()
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, gperr.Wrap(err)
}
return &TestResult{
RequestHeaders: rr.headers,
ResponseHeaders: resp.Header,
ResponseStatus: resp.StatusCode,
RemoteAddr: rr.remoteAddr,
Data: data,
}, nil
}

View File

@@ -0,0 +1,87 @@
package middleware
import (
"net/http"
"sync"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
)
type (
Trace struct {
Time string `json:"time,omitempty"`
Caller string `json:"caller,omitempty"`
URL string `json:"url,omitempty"`
Message string `json:"msg"`
ReqHeaders map[string]string `json:"req_headers,omitempty"`
RespHeaders map[string]string `json:"resp_headers,omitempty"`
RespStatus int `json:"resp_status,omitempty"`
Additional map[string]any `json:"additional,omitempty"`
}
Traces []*Trace
)
var (
traces = make(Traces, 0)
tracesMu sync.Mutex
)
const MaxTraceNum = 100
func GetAllTrace() []*Trace {
return traces
}
func (tr *Trace) WithRequest(req *http.Request) *Trace {
if tr == nil {
return nil
}
tr.URL = req.RequestURI
tr.ReqHeaders = httpheaders.HeaderToMap(req.Header)
return tr
}
func (tr *Trace) WithResponse(resp *http.Response) *Trace {
if tr == nil {
return nil
}
tr.URL = resp.Request.RequestURI
tr.ReqHeaders = httpheaders.HeaderToMap(resp.Request.Header)
tr.RespHeaders = httpheaders.HeaderToMap(resp.Header)
tr.RespStatus = resp.StatusCode
return tr
}
func (tr *Trace) With(what string, additional any) *Trace {
if tr == nil {
return nil
}
if tr.Additional == nil {
tr.Additional = map[string]any{}
}
tr.Additional[what] = additional
return tr
}
func (tr *Trace) WithError(err error) *Trace {
if tr == nil {
return nil
}
if tr.Additional == nil {
tr.Additional = map[string]any{}
}
tr.Additional["error"] = err.Error()
return tr
}
func addTrace(t *Trace) *Trace {
tracesMu.Lock()
defer tracesMu.Unlock()
if len(traces) > MaxTraceNum {
traces = traces[1:]
}
traces = append(traces, t)
return t
}

View File

@@ -0,0 +1,62 @@
package middleware
import (
"fmt"
"net/http"
"time"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Tracer struct {
name string
enabled bool
}
func _() {
var _ MiddlewareWithTracer = &Tracer{}
}
func (t *Tracer) enableTrace() {
t.enabled = true
}
func (t *Tracer) getTracer() *Tracer {
return t
}
func (t *Tracer) SetParent(parent *Tracer) {
if parent == nil {
return
}
t.name = parent.name + "." + t.name
}
func (t *Tracer) addTrace(msg string) *Trace {
return addTrace(&Trace{
Time: strutils.FormatTime(time.Now()),
Caller: t.name,
Message: msg,
})
}
func (t *Tracer) AddTracef(msg string, args ...any) *Trace {
if !t.enabled {
return nil
}
return t.addTrace(fmt.Sprintf(msg, args...))
}
func (t *Tracer) AddTraceRequest(msg string, req *http.Request) *Trace {
if !t.enabled {
return nil
}
return t.addTrace(msg).WithRequest(req)
}
func (t *Tracer) AddTraceResponse(msg string, resp *http.Response) *Trace {
if !t.enabled {
return nil
}
return t.addTrace(msg).WithResponse(resp)
}

View File

@@ -0,0 +1,175 @@
package middleware
import (
"net"
"net/http"
"regexp"
"strconv"
"strings"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
)
type (
reqVarGetter func(*http.Request) string
respVarGetter func(*http.Response) string
)
var (
reArg = regexp.MustCompile(`\$arg\([\w-_]+\)`)
reReqHeader = regexp.MustCompile(`\$header\([\w-]+\)`)
reRespHeader = regexp.MustCompile(`\$resp_header\([\w-]+\)`)
reStatic = regexp.MustCompile(`\$[\w_]+`)
)
const (
VarRequestMethod = "$req_method"
VarRequestScheme = "$req_scheme"
VarRequestHost = "$req_host"
VarRequestPort = "$req_port"
VarRequestPath = "$req_path"
VarRequestAddr = "$req_addr"
VarRequestQuery = "$req_query"
VarRequestURL = "$req_url"
VarRequestURI = "$req_uri"
VarRequestContentType = "$req_content_type"
VarRequestContentLen = "$req_content_length"
VarRemoteHost = "$remote_host"
VarRemotePort = "$remote_port"
VarRemoteAddr = "$remote_addr"
VarUpstreamName = "$upstream_name"
VarUpstreamScheme = "$upstream_scheme"
VarUpstreamHost = "$upstream_host"
VarUpstreamPort = "$upstream_port"
VarUpstreamAddr = "$upstream_addr"
VarUpstreamURL = "$upstream_url"
VarRespContentType = "$resp_content_type"
VarRespContentLen = "$resp_content_length"
VarRespStatusCode = "$status_code"
)
var staticReqVarSubsMap = map[string]reqVarGetter{
VarRequestMethod: func(req *http.Request) string { return req.Method },
VarRequestScheme: func(req *http.Request) string {
if req.TLS != nil {
return "https"
}
return "http"
},
VarRequestHost: func(req *http.Request) string {
reqHost, _, err := net.SplitHostPort(req.Host)
if err != nil {
return req.Host
}
return reqHost
},
VarRequestPort: func(req *http.Request) string {
_, reqPort, _ := net.SplitHostPort(req.Host)
return reqPort
},
VarRequestAddr: func(req *http.Request) string { return req.Host },
VarRequestPath: func(req *http.Request) string { return req.URL.Path },
VarRequestQuery: func(req *http.Request) string { return req.URL.RawQuery },
VarRequestURL: func(req *http.Request) string { return req.URL.String() },
VarRequestURI: func(req *http.Request) string { return req.URL.RequestURI() },
VarRequestContentType: func(req *http.Request) string { return req.Header.Get("Content-Type") },
VarRequestContentLen: func(req *http.Request) string { return strconv.FormatInt(req.ContentLength, 10) },
VarRemoteHost: func(req *http.Request) string {
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
return clientIP
}
return ""
},
VarRemotePort: func(req *http.Request) string {
_, clientPort, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
return clientPort
}
return ""
},
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
VarUpstreamName: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamName) },
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamScheme) },
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamHost) },
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamPort) },
VarUpstreamAddr: func(req *http.Request) string {
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
if upPort != "" {
return upHost + ":" + upPort
}
return upHost
},
VarUpstreamURL: func(req *http.Request) string {
upScheme := req.Header.Get(httpheaders.HeaderUpstreamScheme)
if upScheme == "" {
return ""
}
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
upAddr := upHost
if upPort != "" {
upAddr += ":" + upPort
}
return upScheme + "://" + upAddr
},
}
var staticRespVarSubsMap = map[string]respVarGetter{
VarRespContentType: func(resp *http.Response) string { return resp.Header.Get("Content-Type") },
VarRespContentLen: func(resp *http.Response) string { return strconv.FormatInt(resp.ContentLength, 10) },
VarRespStatusCode: func(resp *http.Response) string { return strconv.Itoa(resp.StatusCode) },
}
func varReplace(req *http.Request, resp *http.Response, s string) string {
if req != nil {
// Replace query parameters
s = reArg.ReplaceAllStringFunc(s, func(match string) string {
name := match[5 : len(match)-1]
for k, v := range req.URL.Query() {
if strings.EqualFold(k, name) {
return v[0]
}
}
return ""
})
// Replace request headers
s = reReqHeader.ReplaceAllStringFunc(s, func(match string) string {
header := http.CanonicalHeaderKey(match[8 : len(match)-1])
return req.Header.Get(header)
})
}
if resp != nil {
// Replace response headers
s = reRespHeader.ReplaceAllStringFunc(s, func(match string) string {
header := http.CanonicalHeaderKey(match[13 : len(match)-1])
return resp.Header.Get(header)
})
}
// Replace static variables
if req != nil {
s = reStatic.ReplaceAllStringFunc(s, func(match string) string {
if fn, ok := staticReqVarSubsMap[match]; ok {
return fn(req)
}
return match
})
}
if resp != nil {
s = reStatic.ReplaceAllStringFunc(s, func(match string) string {
if fn, ok := staticRespVarSubsMap[match]; ok {
return fn(resp)
}
return match
})
}
return s
}

View File

@@ -0,0 +1,45 @@
package middleware
import (
"net"
"net/http"
"strings"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
)
type (
setXForwarded struct{}
hideXForwarded struct{}
)
var (
SetXForwarded = NewMiddleware[setXForwarded]()
HideXForwarded = NewMiddleware[hideXForwarded]()
)
// before implements RequestModifier.
func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Del(httpheaders.HeaderXForwardedFor)
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil {
r.Header.Set(httpheaders.HeaderXForwardedFor, clientIP)
}
return true
}
// before implements RequestModifier.
func (hideXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
toDelete := make([]string, 0, len(r.Header))
for k := range r.Header {
if strings.HasPrefix(k, "X-Forwarded-") {
toDelete = append(toDelete, k)
}
}
for _, k := range toDelete {
r.Header.Del(k)
}
return true
}