mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-23 08:48:32 +02:00
refactor and organize code
This commit is contained in:
82
internal/net/gphttp/middleware/cidr_whitelist.go
Normal file
82
internal/net/gphttp/middleware/cidr_whitelist.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
type (
|
||||
cidrWhitelist struct {
|
||||
CIDRWhitelistOpts
|
||||
Tracer
|
||||
cachedAddr F.Map[string, bool] // cache for trusted IPs
|
||||
}
|
||||
CIDRWhitelistOpts struct {
|
||||
Allow []*types.CIDR `validate:"min=1"`
|
||||
StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,status_code"`
|
||||
Message string
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
CIDRWhiteList = NewMiddleware[cidrWhitelist]()
|
||||
cidrWhitelistDefaults = CIDRWhitelistOpts{
|
||||
Allow: []*types.CIDR{},
|
||||
StatusCode: http.StatusForbidden,
|
||||
Message: "IP not allowed",
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
utils.MustRegisterValidation("status_code", func(fl validator.FieldLevel) bool {
|
||||
statusCode := fl.Field().Int()
|
||||
return gphttp.IsStatusCodeValid(int(statusCode))
|
||||
})
|
||||
}
|
||||
|
||||
// setup implements MiddlewareWithSetup.
|
||||
func (wl *cidrWhitelist) setup() {
|
||||
wl.CIDRWhitelistOpts = cidrWhitelistDefaults
|
||||
wl.cachedAddr = F.NewMapOf[string, bool]()
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (wl *cidrWhitelist) before(w http.ResponseWriter, r *http.Request) bool {
|
||||
return wl.checkIP(w, r)
|
||||
}
|
||||
|
||||
// checkIP checks if the IP address is allowed.
|
||||
func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
|
||||
var allow, ok bool
|
||||
if allow, ok = wl.cachedAddr.Load(r.RemoteAddr); !ok {
|
||||
ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
ipStr = r.RemoteAddr
|
||||
}
|
||||
ip := net.ParseIP(ipStr)
|
||||
for _, cidr := range wl.CIDRWhitelistOpts.Allow {
|
||||
if cidr.Contains(ip) {
|
||||
wl.cachedAddr.Store(r.RemoteAddr, true)
|
||||
allow = true
|
||||
wl.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allow {
|
||||
wl.cachedAddr.Store(r.RemoteAddr, false)
|
||||
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.CIDRWhitelistOpts.Allow)
|
||||
}
|
||||
}
|
||||
if !allow {
|
||||
http.Error(w, wl.Message, wl.StatusCode)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
91
internal/net/gphttp/middleware/cidr_whitelist_test.go
Normal file
91
internal/net/gphttp/middleware/cidr_whitelist_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
//go:embed test_data/cidr_whitelist_test.yml
|
||||
var testCIDRWhitelistCompose []byte
|
||||
var deny, accept *Middleware
|
||||
|
||||
func TestCIDRWhitelistValidation(t *testing.T) {
|
||||
const testMessage = "test-message"
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||
"allow": []string{"192.168.2.100/32"},
|
||||
"message": testMessage,
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
_, err = CIDRWhiteList.New(OptionsRaw{
|
||||
"allow": []string{"192.168.2.100/32"},
|
||||
"message": testMessage,
|
||||
"status": 403,
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
_, err = CIDRWhiteList.New(OptionsRaw{
|
||||
"allow": []string{"192.168.2.100/32"},
|
||||
"message": testMessage,
|
||||
"status_code": 403,
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
})
|
||||
t.Run("missing allow", func(t *testing.T) {
|
||||
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||
"message": testMessage,
|
||||
})
|
||||
ExpectError(t, utils.ErrValidationError, err)
|
||||
})
|
||||
t.Run("invalid cidr", func(t *testing.T) {
|
||||
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||
"allow": []string{"192.168.2.100/123"},
|
||||
"message": testMessage,
|
||||
})
|
||||
ExpectErrorT[*net.ParseError](t, err)
|
||||
})
|
||||
t.Run("invalid status code", func(t *testing.T) {
|
||||
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||
"allow": []string{"192.168.2.100/32"},
|
||||
"status_code": 600,
|
||||
"message": testMessage,
|
||||
})
|
||||
ExpectError(t, utils.ErrValidationError, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCIDRWhitelist(t *testing.T) {
|
||||
errs := gperr.NewBuilder("")
|
||||
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
|
||||
ExpectNoError(t, errs.Error())
|
||||
deny = mids["deny@file"]
|
||||
accept = mids["accept@file"]
|
||||
if deny == nil || accept == nil {
|
||||
panic("bug occurred")
|
||||
}
|
||||
|
||||
t.Run("deny", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
for range 10 {
|
||||
result, err := newMiddlewareTest(deny, nil)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults.StatusCode)
|
||||
ExpectEqual(t, strings.TrimSpace(string(result.Data)), cidrWhitelistDefaults.Message)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("accept", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
for range 10 {
|
||||
result, err := newMiddlewareTest(accept, nil)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
127
internal/net/gphttp/middleware/cloudflare_real_ip.go
Normal file
127
internal/net/gphttp/middleware/cloudflare_real_ip.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type cloudflareRealIP struct {
|
||||
realIP realIP
|
||||
Recursive bool
|
||||
}
|
||||
|
||||
const (
|
||||
cfIPv4CIDRsEndpoint = "https://www.cloudflare.com/ips-v4"
|
||||
cfIPv6CIDRsEndpoint = "https://www.cloudflare.com/ips-v6"
|
||||
cfCIDRsUpdateInterval = time.Hour
|
||||
cfCIDRsUpdateRetryInterval = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
cfCIDRsLastUpdate time.Time
|
||||
cfCIDRsMu sync.Mutex
|
||||
)
|
||||
|
||||
var CloudflareRealIP = NewMiddleware[cloudflareRealIP]()
|
||||
|
||||
// setup implements MiddlewareWithSetup.
|
||||
func (cri *cloudflareRealIP) setup() {
|
||||
cri.realIP.RealIPOpts = RealIPOpts{
|
||||
Header: "Cf-Connecting-Ip",
|
||||
Recursive: cri.Recursive,
|
||||
}
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (cri *cloudflareRealIP) before(w http.ResponseWriter, r *http.Request) bool {
|
||||
cidrs := tryFetchCFCIDR()
|
||||
if cidrs != nil {
|
||||
cri.realIP.From = cidrs
|
||||
}
|
||||
return cri.realIP.before(w, r)
|
||||
}
|
||||
|
||||
func (cri *cloudflareRealIP) enableTrace() {
|
||||
cri.realIP.enableTrace()
|
||||
}
|
||||
|
||||
func (cri *cloudflareRealIP) getTracer() *Tracer {
|
||||
return cri.realIP.getTracer()
|
||||
}
|
||||
|
||||
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
||||
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
|
||||
return
|
||||
}
|
||||
|
||||
cfCIDRsMu.Lock()
|
||||
defer cfCIDRsMu.Unlock()
|
||||
|
||||
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
|
||||
return
|
||||
}
|
||||
|
||||
if common.IsTest {
|
||||
cfCIDRs = []*types.CIDR{
|
||||
{IP: net.IPv4(127, 0, 0, 1), Mask: net.IPv4Mask(255, 0, 0, 0)},
|
||||
{IP: net.IPv4(10, 0, 0, 0), Mask: net.IPv4Mask(255, 0, 0, 0)},
|
||||
{IP: net.IPv4(172, 16, 0, 0), Mask: net.IPv4Mask(255, 255, 0, 0)},
|
||||
{IP: net.IPv4(192, 168, 0, 0), Mask: net.IPv4Mask(255, 255, 255, 0)},
|
||||
}
|
||||
} else {
|
||||
cfCIDRs = make([]*types.CIDR, 0, 30)
|
||||
err := errors.Join(
|
||||
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, &cfCIDRs),
|
||||
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs),
|
||||
)
|
||||
if err != nil {
|
||||
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)
|
||||
logging.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
|
||||
return nil
|
||||
}
|
||||
if len(cfCIDRs) == 0 {
|
||||
logging.Warn().Msg("cloudflare CIDR range is empty")
|
||||
}
|
||||
}
|
||||
|
||||
cfCIDRsLastUpdate = time.Now()
|
||||
logging.Info().Msg("cloudflare CIDR range updated")
|
||||
return
|
||||
}
|
||||
|
||||
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
|
||||
resp, err := http.Get(endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, line := range strutils.SplitLine(string(body)) {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
_, cidr, err := net.ParseCIDR(line)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
|
||||
}
|
||||
|
||||
*cfCIDRs = append(*cfCIDRs, (*types.CIDR)(cidr))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
80
internal/net/gphttp/middleware/custom_error_page.go
Normal file
80
internal/net/gphttp/middleware/custom_error_page.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/middleware/errorpage"
|
||||
)
|
||||
|
||||
type customErrorPage struct{}
|
||||
|
||||
var CustomErrorPage = NewMiddleware[customErrorPage]()
|
||||
|
||||
const StaticFilePathPrefix = "/$gperrorpage/"
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
return !ServeStaticErrorPageFile(w, r)
|
||||
}
|
||||
|
||||
// modifyResponse implements ResponseModifier.
|
||||
func (customErrorPage) modifyResponse(resp *http.Response) error {
|
||||
// only handles non-success status code and html/plain content type
|
||||
contentType := gphttp.GetContentType(resp.Header)
|
||||
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
|
||||
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
|
||||
if ok {
|
||||
logging.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
|
||||
_, _ = io.Copy(io.Discard, resp.Body) // drain the original body
|
||||
resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
|
||||
resp.ContentLength = int64(len(errorPage))
|
||||
resp.Header.Set(httpheaders.HeaderContentLength, strconv.Itoa(len(errorPage)))
|
||||
resp.Header.Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
|
||||
} else {
|
||||
logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bool) {
|
||||
path := r.URL.Path
|
||||
if path != "" && path[0] != '/' {
|
||||
path = "/" + path
|
||||
}
|
||||
if strings.HasPrefix(path, StaticFilePathPrefix) {
|
||||
filename := path[len(StaticFilePathPrefix):]
|
||||
file, ok := errorpage.GetStaticFile(filename)
|
||||
if !ok {
|
||||
logging.Error().Msg("unable to load resource " + filename)
|
||||
return false
|
||||
}
|
||||
ext := filepath.Ext(filename)
|
||||
switch ext {
|
||||
case ".html":
|
||||
w.Header().Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
|
||||
case ".js":
|
||||
w.Header().Set(httpheaders.HeaderContentType, "application/javascript; charset=utf-8")
|
||||
case ".css":
|
||||
w.Header().Set(httpheaders.HeaderContentType, "text/css; charset=utf-8")
|
||||
default:
|
||||
logging.Error().Msgf("unexpected file type %q for %s", ext, filename)
|
||||
}
|
||||
if _, err := w.Write(file); err != nil {
|
||||
logging.Err(err).Msg("unable to write resource " + filename)
|
||||
http.Error(w, "Error page failure", http.StatusInternalServerError)
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
96
internal/net/gphttp/middleware/errorpage/error_page.go
Normal file
96
internal/net/gphttp/middleware/errorpage/error_page.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package errorpage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"sync"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
W "github.com/yusing/go-proxy/internal/watcher"
|
||||
"github.com/yusing/go-proxy/internal/watcher/events"
|
||||
)
|
||||
|
||||
const errPagesBasePath = common.ErrorPagesBasePath
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
dirWatcher W.Watcher
|
||||
fileContentMap = F.NewMapOf[string, []byte]()
|
||||
)
|
||||
|
||||
func setup() {
|
||||
t := task.RootTask("error_page", false)
|
||||
dirWatcher = W.NewDirectoryWatcher(t, errPagesBasePath)
|
||||
loadContent()
|
||||
go watchDir()
|
||||
}
|
||||
|
||||
func GetStaticFile(filename string) ([]byte, bool) {
|
||||
setupOnce.Do(setup)
|
||||
return fileContentMap.Load(filename)
|
||||
}
|
||||
|
||||
// try <statusCode>.html -> 404.html -> not ok.
|
||||
func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
|
||||
content, ok = GetStaticFile(fmt.Sprintf("%d.html", statusCode))
|
||||
if !ok && statusCode != 404 {
|
||||
return fileContentMap.Load("404.html")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func loadContent() {
|
||||
files, err := U.ListFiles(errPagesBasePath, 0)
|
||||
if err != nil {
|
||||
logging.Err(err).Msg("failed to list error page resources")
|
||||
return
|
||||
}
|
||||
for _, file := range files {
|
||||
if fileContentMap.Has(file) {
|
||||
continue
|
||||
}
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
logging.Warn().Err(err).Msgf("failed to read error page resource %s", file)
|
||||
continue
|
||||
}
|
||||
file = path.Base(file)
|
||||
logging.Info().Msgf("error page resource %s loaded", file)
|
||||
fileContentMap.Store(file, content)
|
||||
}
|
||||
}
|
||||
|
||||
func watchDir() {
|
||||
eventCh, errCh := dirWatcher.Events(task.RootContext())
|
||||
for {
|
||||
select {
|
||||
case <-task.RootContextCanceled():
|
||||
return
|
||||
case event, ok := <-eventCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
filename := event.ActorName
|
||||
switch event.Action {
|
||||
case events.ActionFileWritten:
|
||||
fileContentMap.Delete(filename)
|
||||
loadContent()
|
||||
case events.ActionFileDeleted:
|
||||
fileContentMap.Delete(filename)
|
||||
logging.Warn().Msgf("error page resource %s deleted", filename)
|
||||
case events.ActionFileRenamed:
|
||||
logging.Warn().Msgf("error page resource %s deleted", filename)
|
||||
fileContentMap.Delete(filename)
|
||||
loadContent()
|
||||
}
|
||||
case err := <-errCh:
|
||||
gperr.LogError("error watching error page directory", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package metricslogger
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/metrics"
|
||||
)
|
||||
|
||||
type MetricsLogger struct {
|
||||
ServiceName string `json:"service_name"`
|
||||
}
|
||||
|
||||
func NewMetricsLogger(serviceName string) *MetricsLogger {
|
||||
return &MetricsLogger{serviceName}
|
||||
}
|
||||
|
||||
func (m *MetricsLogger) GetHandler(next http.Handler) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
m.ServeHTTP(rw, req, next.ServeHTTP)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MetricsLogger) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
|
||||
visitorIP, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err != nil {
|
||||
visitorIP = req.RemoteAddr
|
||||
}
|
||||
|
||||
// req.RemoteAddr had been modified by middleware (if any)
|
||||
lbls := &metrics.HTTPRouteMetricLabels{
|
||||
Service: m.ServiceName,
|
||||
Method: req.Method,
|
||||
Host: req.Host,
|
||||
Visitor: visitorIP,
|
||||
Path: req.URL.Path,
|
||||
}
|
||||
|
||||
next.ServeHTTP(newHTTPMetricLogger(rw, lbls), req)
|
||||
}
|
||||
|
||||
func (m *MetricsLogger) ResetMetrics() {
|
||||
metrics.GetRouteMetrics().UnregisterService(m.ServiceName)
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package metricslogger
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/metrics"
|
||||
)
|
||||
|
||||
type httpMetricLogger struct {
|
||||
http.ResponseWriter
|
||||
timestamp time.Time
|
||||
labels *metrics.HTTPRouteMetricLabels
|
||||
}
|
||||
|
||||
// WriteHeader implements http.ResponseWriter.
|
||||
func (l *httpMetricLogger) WriteHeader(status int) {
|
||||
l.ResponseWriter.WriteHeader(status)
|
||||
duration := time.Since(l.timestamp)
|
||||
go func() {
|
||||
m := metrics.GetRouteMetrics()
|
||||
m.HTTPReqTotal.Inc()
|
||||
m.HTTPReqElapsed.With(l.labels).Set(float64(duration.Milliseconds()))
|
||||
|
||||
// ignore 1xx
|
||||
switch {
|
||||
case status >= 500:
|
||||
m.HTTP5xx.With(l.labels).Inc()
|
||||
case status >= 400:
|
||||
m.HTTP4xx.With(l.labels).Inc()
|
||||
case status >= 200:
|
||||
m.HTTP2xx3xx.With(l.labels).Inc()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (l *httpMetricLogger) Unwrap() http.ResponseWriter {
|
||||
return l.ResponseWriter
|
||||
}
|
||||
|
||||
func newHTTPMetricLogger(w http.ResponseWriter, labels *metrics.HTTPRouteMetricLabels) *httpMetricLogger {
|
||||
return &httpMetricLogger{
|
||||
ResponseWriter: w,
|
||||
timestamp: time.Now(),
|
||||
labels: labels,
|
||||
}
|
||||
}
|
||||
237
internal/net/gphttp/middleware/middleware.go
Normal file
237
internal/net/gphttp/middleware/middleware.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type (
|
||||
Error = gperr.Error
|
||||
|
||||
ReverseProxy = reverseproxy.ReverseProxy
|
||||
ProxyRequest = reverseproxy.ProxyRequest
|
||||
|
||||
ImplNewFunc = func() any
|
||||
OptionsRaw = map[string]any
|
||||
|
||||
Middleware struct {
|
||||
name string
|
||||
construct ImplNewFunc
|
||||
impl any
|
||||
// priority is only applied for ReverseProxy.
|
||||
//
|
||||
// Middleware compose follows the order of the slice
|
||||
//
|
||||
// Default is 10, 0 is the highest
|
||||
priority int
|
||||
}
|
||||
ByPriority []*Middleware
|
||||
|
||||
RequestModifier interface {
|
||||
before(w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||
}
|
||||
ResponseModifier interface{ modifyResponse(r *http.Response) error }
|
||||
MiddlewareWithSetup interface{ setup() }
|
||||
MiddlewareFinalizer interface{ finalize() }
|
||||
MiddlewareFinalizerWithError interface {
|
||||
finalize() error
|
||||
}
|
||||
MiddlewareWithTracer interface {
|
||||
enableTrace()
|
||||
getTracer() *Tracer
|
||||
}
|
||||
)
|
||||
|
||||
const DefaultPriority = 10
|
||||
|
||||
func (m ByPriority) Len() int { return len(m) }
|
||||
func (m ByPriority) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
|
||||
func (m ByPriority) Less(i, j int) bool { return m[i].priority < m[j].priority }
|
||||
|
||||
func NewMiddleware[ImplType any]() *Middleware {
|
||||
// type check
|
||||
t := any(new(ImplType))
|
||||
switch t.(type) {
|
||||
case RequestModifier:
|
||||
case ResponseModifier:
|
||||
default:
|
||||
panic("must implement RequestModifier or ResponseModifier")
|
||||
}
|
||||
_, hasFinializer := t.(MiddlewareFinalizer)
|
||||
_, hasFinializerWithError := t.(MiddlewareFinalizerWithError)
|
||||
if hasFinializer && hasFinializerWithError {
|
||||
panic("MiddlewareFinalizer and MiddlewareFinalizerWithError are mutually exclusive")
|
||||
}
|
||||
return &Middleware{
|
||||
name: strings.ToLower(reflect.TypeFor[ImplType]().Name()),
|
||||
construct: func() any { return new(ImplType) },
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) enableTrace() {
|
||||
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
||||
tracer.enableTrace()
|
||||
logging.Debug().Msgf("middleware %s enabled trace", m.name)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) getTracer() *Tracer {
|
||||
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
||||
return tracer.getTracer()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Middleware) setParent(parent *Middleware) {
|
||||
if tracer := m.getTracer(); tracer != nil {
|
||||
tracer.SetParent(parent.getTracer())
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) setup() {
|
||||
if setup, ok := m.impl.(MiddlewareWithSetup); ok {
|
||||
setup.setup()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error {
|
||||
if len(optsRaw) == 0 {
|
||||
return nil
|
||||
}
|
||||
priority, ok := optsRaw["priority"].(int)
|
||||
if ok {
|
||||
m.priority = priority
|
||||
// remove priority for deserialization, restore later
|
||||
delete(optsRaw, "priority")
|
||||
defer func() {
|
||||
optsRaw["priority"] = priority
|
||||
}()
|
||||
} else {
|
||||
m.priority = DefaultPriority
|
||||
}
|
||||
return utils.Deserialize(optsRaw, m.impl)
|
||||
}
|
||||
|
||||
func (m *Middleware) finalize() error {
|
||||
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
|
||||
finalizer.finalize()
|
||||
return nil
|
||||
}
|
||||
if finalizer, ok := m.impl.(MiddlewareFinalizerWithError); ok {
|
||||
return finalizer.finalize()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, gperr.Error) {
|
||||
if m.construct == nil { // likely a middleware from compose
|
||||
if len(optsRaw) != 0 {
|
||||
return nil, gperr.New("additional options not allowed for middleware ").Subject(m.name)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
mid := &Middleware{name: m.name, impl: m.construct()}
|
||||
mid.setup()
|
||||
if err := mid.apply(optsRaw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := mid.finalize(); err != nil {
|
||||
return nil, gperr.Wrap(err)
|
||||
}
|
||||
return mid, nil
|
||||
}
|
||||
|
||||
func (m *Middleware) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *Middleware) String() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *Middleware) MarshalJSON() ([]byte, error) {
|
||||
return json.MarshalIndent(map[string]any{
|
||||
"name": m.name,
|
||||
"options": m.impl,
|
||||
"priority": m.priority,
|
||||
}, "", " ")
|
||||
}
|
||||
|
||||
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
||||
if exec, ok := m.impl.(RequestModifier); ok {
|
||||
if proceed := exec.before(w, r); !proceed {
|
||||
return
|
||||
}
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
|
||||
func (m *Middleware) ModifyResponse(resp *http.Response) error {
|
||||
if exec, ok := m.impl.(ResponseModifier); ok {
|
||||
return exec.modifyResponse(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
||||
if exec, ok := m.impl.(ResponseModifier); ok {
|
||||
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
|
||||
return exec.modifyResponse(resp)
|
||||
})
|
||||
}
|
||||
if exec, ok := m.impl.(RequestModifier); ok {
|
||||
if proceed := exec.before(w, r); !proceed {
|
||||
return
|
||||
}
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
|
||||
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err gperr.Error) {
|
||||
var middlewares []*Middleware
|
||||
middlewares, err = compileMiddlewares(middlewaresMap)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
patchReverseProxy(rp, middlewares)
|
||||
return
|
||||
}
|
||||
|
||||
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
||||
sort.Sort(ByPriority(middlewares))
|
||||
middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
|
||||
|
||||
mid := NewMiddlewareChain(rp.TargetName, middlewares)
|
||||
|
||||
if before, ok := mid.impl.(RequestModifier); ok {
|
||||
next := rp.HandlerFunc
|
||||
rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
|
||||
if proceed := before.before(w, r); proceed {
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if mr, ok := mid.impl.(ResponseModifier); ok {
|
||||
if rp.ModifyResponse != nil {
|
||||
ori := rp.ModifyResponse
|
||||
rp.ModifyResponse = func(res *http.Response) error {
|
||||
if err := mr.modifyResponse(res); err != nil {
|
||||
return err
|
||||
}
|
||||
return ori(res)
|
||||
}
|
||||
} else {
|
||||
rp.ModifyResponse = mr.modifyResponse
|
||||
}
|
||||
}
|
||||
}
|
||||
107
internal/net/gphttp/middleware/middleware_builder.go
Normal file
107
internal/net/gphttp/middleware/middleware_builder.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"sort"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var ErrMissingMiddlewareUse = gperr.New("missing middleware 'use' field")
|
||||
|
||||
func BuildMiddlewaresFromComposeFile(filePath string, eb *gperr.Builder) map[string]*Middleware {
|
||||
fileContent, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
eb.Add(err)
|
||||
return nil
|
||||
}
|
||||
return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
|
||||
}
|
||||
|
||||
func BuildMiddlewaresFromYAML(source string, data []byte, eb *gperr.Builder) map[string]*Middleware {
|
||||
var rawMap map[string][]map[string]any
|
||||
err := yaml.Unmarshal(data, &rawMap)
|
||||
if err != nil {
|
||||
eb.Add(err)
|
||||
return nil
|
||||
}
|
||||
middlewares := make(map[string]*Middleware)
|
||||
for name, defs := range rawMap {
|
||||
chain, err := BuildMiddlewareFromChainRaw(name, defs)
|
||||
if err != nil {
|
||||
eb.Add(err.Subject(source))
|
||||
} else {
|
||||
middlewares[name+"@file"] = chain
|
||||
}
|
||||
}
|
||||
return middlewares
|
||||
}
|
||||
|
||||
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, gperr.Error) {
|
||||
middlewares := make([]*Middleware, 0, len(middlewaresMap))
|
||||
|
||||
errs := gperr.NewBuilder("middlewares compile error")
|
||||
invalidOpts := gperr.NewBuilder("options compile error")
|
||||
|
||||
for name, opts := range middlewaresMap {
|
||||
m, err := Get(name)
|
||||
if err != nil {
|
||||
errs.Add(err)
|
||||
continue
|
||||
}
|
||||
|
||||
m, err = m.New(opts)
|
||||
if err != nil {
|
||||
invalidOpts.Add(err.Subject(name))
|
||||
continue
|
||||
}
|
||||
middlewares = append(middlewares, m)
|
||||
}
|
||||
|
||||
if invalidOpts.HasError() {
|
||||
errs.Add(invalidOpts.Error())
|
||||
}
|
||||
sort.Sort(ByPriority(middlewares))
|
||||
return middlewares, errs.Error()
|
||||
}
|
||||
|
||||
func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (*Middleware, gperr.Error) {
|
||||
compiled, err := compileMiddlewares(middlewaresMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewMiddlewareChain(name, compiled), nil
|
||||
}
|
||||
|
||||
// TODO: check conflict or duplicates.
|
||||
func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middleware, gperr.Error) {
|
||||
chainErr := gperr.NewBuilder("")
|
||||
chain := make([]*Middleware, 0, len(defs))
|
||||
for i, def := range defs {
|
||||
if def["use"] == nil || def["use"] == "" {
|
||||
chainErr.Add(ErrMissingMiddlewareUse.Subjectf("%s[%d]", name, i))
|
||||
continue
|
||||
}
|
||||
baseName := def["use"].(string)
|
||||
base, err := Get(baseName)
|
||||
if err != nil {
|
||||
chainErr.Add(err.Subjectf("%s[%d]", name, i))
|
||||
continue
|
||||
}
|
||||
delete(def, "use")
|
||||
m, err := base.New(def)
|
||||
if err != nil {
|
||||
chainErr.Add(err.Subjectf("%s[%d]", name, i))
|
||||
continue
|
||||
}
|
||||
m.name = fmt.Sprintf("%s[%d]", name, i)
|
||||
chain = append(chain, m)
|
||||
}
|
||||
if chainErr.HasError() {
|
||||
return nil, chainErr.Error()
|
||||
}
|
||||
return NewMiddlewareChain(name, chain), nil
|
||||
}
|
||||
22
internal/net/gphttp/middleware/middleware_builder_test.go
Normal file
22
internal/net/gphttp/middleware/middleware_builder_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
//go:embed test_data/middleware_compose.yml
|
||||
var testMiddlewareCompose []byte
|
||||
|
||||
func TestBuild(t *testing.T) {
|
||||
errs := gperr.NewBuilder("")
|
||||
middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs)
|
||||
ExpectNoError(t, errs.Error())
|
||||
Must(json.MarshalIndent(middlewares, "", " "))
|
||||
// t.Log(string(data))
|
||||
// TODO: test
|
||||
}
|
||||
61
internal/net/gphttp/middleware/middleware_chain.go
Normal file
61
internal/net/gphttp/middleware/middleware_chain.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
)
|
||||
|
||||
type middlewareChain struct {
|
||||
befores []RequestModifier
|
||||
modResps []ResponseModifier
|
||||
}
|
||||
|
||||
// TODO: check conflict or duplicates.
|
||||
func NewMiddlewareChain(name string, chain []*Middleware) *Middleware {
|
||||
chainMid := &middlewareChain{befores: []RequestModifier{}, modResps: []ResponseModifier{}}
|
||||
m := &Middleware{name: name, impl: chainMid}
|
||||
|
||||
for _, comp := range chain {
|
||||
if before, ok := comp.impl.(RequestModifier); ok {
|
||||
chainMid.befores = append(chainMid.befores, before)
|
||||
}
|
||||
if mr, ok := comp.impl.(ResponseModifier); ok {
|
||||
chainMid.modResps = append(chainMid.modResps, mr)
|
||||
}
|
||||
comp.setParent(m)
|
||||
}
|
||||
|
||||
if common.IsDebug {
|
||||
for _, child := range chain {
|
||||
child.enableTrace()
|
||||
}
|
||||
m.enableTrace()
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (m *middlewareChain) before(w http.ResponseWriter, r *http.Request) (proceedNext bool) {
|
||||
for _, b := range m.befores {
|
||||
if proceedNext = b.before(w, r); !proceedNext {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// modifyResponse implements ResponseModifier.
|
||||
func (m *middlewareChain) modifyResponse(resp *http.Response) error {
|
||||
if len(m.modResps) == 0 {
|
||||
return nil
|
||||
}
|
||||
errs := gperr.NewBuilder("modify response errors")
|
||||
for i, mr := range m.modResps {
|
||||
if err := mr.modifyResponse(resp); err != nil {
|
||||
errs.Add(gperr.Wrap(err).Subjectf("%d", i))
|
||||
}
|
||||
}
|
||||
return errs.Error()
|
||||
}
|
||||
37
internal/net/gphttp/middleware/middleware_test.go
Normal file
37
internal/net/gphttp/middleware/middleware_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
type testPriority struct {
|
||||
Value int `json:"value"`
|
||||
}
|
||||
|
||||
var test = NewMiddleware[testPriority]()
|
||||
|
||||
func (t testPriority) before(w http.ResponseWriter, r *http.Request) bool {
|
||||
w.Header().Add("Test-Value", strconv.Itoa(t.Value))
|
||||
return true
|
||||
}
|
||||
|
||||
func TestMiddlewarePriority(t *testing.T) {
|
||||
priorities := []int{4, 7, 9, 0}
|
||||
chain := make([]*Middleware, len(priorities))
|
||||
for i, p := range priorities {
|
||||
mid, err := test.New(OptionsRaw{
|
||||
"priority": p,
|
||||
"value": i,
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
chain[i] = mid
|
||||
}
|
||||
res, err := newMiddlewaresTest(chain, nil)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, strings.Join(res.ResponseHeaders["Test-Value"], ","), "3,0,1,2")
|
||||
}
|
||||
104
internal/net/gphttp/middleware/middlewares.go
Normal file
104
internal/net/gphttp/middleware/middlewares.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"path"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
// snakes and cases will be stripped on `Get`
|
||||
// so keys are lowercase without snake.
|
||||
var allMiddlewares = map[string]*Middleware{
|
||||
"redirecthttp": RedirectHTTP,
|
||||
|
||||
"oidc": OIDC,
|
||||
|
||||
"request": ModifyRequest,
|
||||
"modifyrequest": ModifyRequest,
|
||||
"response": ModifyResponse,
|
||||
"modifyresponse": ModifyResponse,
|
||||
"setxforwarded": SetXForwarded,
|
||||
"hidexforwarded": HideXForwarded,
|
||||
|
||||
"errorpage": CustomErrorPage,
|
||||
"customerrorpage": CustomErrorPage,
|
||||
|
||||
"realip": RealIP,
|
||||
"cloudflarerealip": CloudflareRealIP,
|
||||
|
||||
"cidrwhitelist": CIDRWhiteList,
|
||||
"ratelimit": RateLimiter,
|
||||
}
|
||||
|
||||
var (
|
||||
ErrUnknownMiddleware = gperr.New("unknown middleware")
|
||||
ErrDuplicatedMiddleware = gperr.New("duplicated middleware")
|
||||
)
|
||||
|
||||
func Get(name string) (*Middleware, Error) {
|
||||
middleware, ok := allMiddlewares[strutils.ToLowerNoSnake(name)]
|
||||
if !ok {
|
||||
return nil, ErrUnknownMiddleware.
|
||||
Subject(name).
|
||||
Withf(strutils.DoYouMean(utils.NearestField(name, allMiddlewares)))
|
||||
}
|
||||
return middleware, nil
|
||||
}
|
||||
|
||||
func All() map[string]*Middleware {
|
||||
return allMiddlewares
|
||||
}
|
||||
|
||||
func LoadComposeFiles() {
|
||||
errs := gperr.NewBuilder("middleware compile errors")
|
||||
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
|
||||
if err != nil {
|
||||
logging.Err(err).Msg("failed to list middleware definitions")
|
||||
return
|
||||
}
|
||||
for _, defFile := range middlewareDefs {
|
||||
voidErrs := gperr.NewBuilder("") // ignore these errors, will be added in next step
|
||||
mws := BuildMiddlewaresFromComposeFile(defFile, voidErrs)
|
||||
if len(mws) == 0 {
|
||||
continue
|
||||
}
|
||||
for name, m := range mws {
|
||||
name = strutils.ToLowerNoSnake(name)
|
||||
if _, ok := allMiddlewares[name]; ok {
|
||||
errs.Add(ErrDuplicatedMiddleware.Subject(name))
|
||||
continue
|
||||
}
|
||||
allMiddlewares[name] = m
|
||||
logging.Info().
|
||||
Str("src", path.Base(defFile)).
|
||||
Str("name", name).
|
||||
Msg("middleware loaded")
|
||||
}
|
||||
}
|
||||
// build again to resolve cross references
|
||||
for _, defFile := range middlewareDefs {
|
||||
mws := BuildMiddlewaresFromComposeFile(defFile, errs)
|
||||
if len(mws) == 0 {
|
||||
continue
|
||||
}
|
||||
for name, m := range mws {
|
||||
name = strutils.ToLowerNoSnake(name)
|
||||
if _, ok := allMiddlewares[name]; ok {
|
||||
// already loaded above
|
||||
continue
|
||||
}
|
||||
allMiddlewares[name] = m
|
||||
logging.Info().
|
||||
Str("src", path.Base(defFile)).
|
||||
Str("name", name).
|
||||
Msg("middleware loaded")
|
||||
}
|
||||
}
|
||||
if errs.HasError() {
|
||||
gperr.LogError(errs.About(), errs.Error())
|
||||
}
|
||||
}
|
||||
91
internal/net/gphttp/middleware/modify_request.go
Normal file
91
internal/net/gphttp/middleware/modify_request.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type (
|
||||
modifyRequest struct {
|
||||
ModifyRequestOpts
|
||||
Tracer
|
||||
}
|
||||
// order: add_prefix -> set_headers -> add_headers -> hide_headers
|
||||
ModifyRequestOpts struct {
|
||||
SetHeaders map[string]string
|
||||
AddHeaders map[string]string
|
||||
HideHeaders []string
|
||||
AddPrefix string
|
||||
|
||||
needVarSubstitution bool
|
||||
}
|
||||
)
|
||||
|
||||
var ModifyRequest = NewMiddleware[modifyRequest]()
|
||||
|
||||
// finalize implements MiddlewareFinalizer.
|
||||
func (mr *ModifyRequestOpts) finalize() {
|
||||
mr.checkVarSubstitution()
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (mr *modifyRequest) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
mr.AddTraceRequest("before modify request", r)
|
||||
|
||||
mr.addPrefix(r, nil, r.URL.Path)
|
||||
mr.modifyHeaders(r, nil, r.Header)
|
||||
mr.AddTraceRequest("after modify request", r)
|
||||
return true
|
||||
}
|
||||
|
||||
func (mr *ModifyRequestOpts) checkVarSubstitution() {
|
||||
for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} {
|
||||
for _, v := range m {
|
||||
if strings.ContainsRune(v, '$') {
|
||||
mr.needVarSubstitution = true
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mr *ModifyRequestOpts) modifyHeaders(req *http.Request, resp *http.Response, headers http.Header) {
|
||||
if !mr.needVarSubstitution {
|
||||
for k, v := range mr.SetHeaders {
|
||||
if req != nil && strings.EqualFold(k, "host") {
|
||||
defer func() {
|
||||
req.Host = v
|
||||
}()
|
||||
}
|
||||
headers[k] = []string{v}
|
||||
}
|
||||
for k, v := range mr.AddHeaders {
|
||||
headers[k] = append(headers[k], v)
|
||||
}
|
||||
} else {
|
||||
for k, v := range mr.SetHeaders {
|
||||
if req != nil && strings.EqualFold(k, "host") {
|
||||
defer func() {
|
||||
req.Host = varReplace(req, resp, v)
|
||||
}()
|
||||
}
|
||||
headers[k] = []string{varReplace(req, resp, v)}
|
||||
}
|
||||
for k, v := range mr.AddHeaders {
|
||||
headers[k] = append(headers[k], varReplace(req, resp, v))
|
||||
}
|
||||
}
|
||||
|
||||
for _, k := range mr.HideHeaders {
|
||||
delete(headers, k)
|
||||
}
|
||||
}
|
||||
|
||||
func (mr *modifyRequest) addPrefix(r *http.Request, _ *http.Response, path string) {
|
||||
if len(mr.AddPrefix) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
r.URL.Path = filepath.Join(mr.AddPrefix, path)
|
||||
}
|
||||
145
internal/net/gphttp/middleware/modify_request_test.go
Normal file
145
internal/net/gphttp/middleware/modify_request_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"net/http"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestModifyRequest(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"set_headers": map[string]string{
|
||||
"User-Agent": "go-proxy/v0.5.0",
|
||||
"Host": VarUpstreamAddr,
|
||||
"X-Test-Req-Method": VarRequestMethod,
|
||||
"X-Test-Req-Scheme": VarRequestScheme,
|
||||
"X-Test-Req-Host": VarRequestHost,
|
||||
"X-Test-Req-Port": VarRequestPort,
|
||||
"X-Test-Req-Addr": VarRequestAddr,
|
||||
"X-Test-Req-Path": VarRequestPath,
|
||||
"X-Test-Req-Query": VarRequestQuery,
|
||||
"X-Test-Req-Url": VarRequestURL,
|
||||
"X-Test-Req-Uri": VarRequestURI,
|
||||
"X-Test-Req-Content-Type": VarRequestContentType,
|
||||
"X-Test-Req-Content-Length": VarRequestContentLen,
|
||||
"X-Test-Remote-Host": VarRemoteHost,
|
||||
"X-Test-Remote-Port": VarRemotePort,
|
||||
"X-Test-Remote-Addr": VarRemoteAddr,
|
||||
"X-Test-Upstream-Scheme": VarUpstreamScheme,
|
||||
"X-Test-Upstream-Host": VarUpstreamHost,
|
||||
"X-Test-Upstream-Port": VarUpstreamPort,
|
||||
"X-Test-Upstream-Addr": VarUpstreamAddr,
|
||||
"X-Test-Upstream-Url": VarUpstreamURL,
|
||||
"X-Test-Header-Content-Type": "$header(Content-Type)",
|
||||
"X-Test-Arg-Arg_1": "$arg(arg_1)",
|
||||
},
|
||||
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
|
||||
"hide_headers": []string{"Accept"},
|
||||
}
|
||||
|
||||
t.Run("set_options", func(t *testing.T) {
|
||||
mr, err := ModifyRequest.New(opts)
|
||||
ExpectNoError(t, err)
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string))
|
||||
})
|
||||
|
||||
t.Run("request_headers", func(t *testing.T) {
|
||||
reqURL := types.MustParseURL("https://my.app/?arg_1=b")
|
||||
upstreamURL := types.MustParseURL("http://test.example.com")
|
||||
result, err := newMiddlewareTest(ModifyRequest, &testArgs{
|
||||
middlewareOpt: opts,
|
||||
reqURL: reqURL,
|
||||
upstreamURL: upstreamURL,
|
||||
body: bytes.Repeat([]byte("a"), 100),
|
||||
headers: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
|
||||
ExpectEqual(t, result.RequestHeaders.Get("Host"), "test.example.com")
|
||||
ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
|
||||
ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
|
||||
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Method"), "GET")
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Port"), reqURL.Port())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Path"), reqURL.Path)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Url"), reqURL.String())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Type"), "application/json")
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Length"), "100")
|
||||
|
||||
remoteHost, remotePort, _ := net.SplitHostPort(result.RemoteAddr)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Host"), remoteHost)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Port"), remotePort)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)
|
||||
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
|
||||
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Header-Content-Type"), "application/json")
|
||||
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Arg-Arg_1"), "b")
|
||||
})
|
||||
|
||||
t.Run("add_prefix", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expectedPath string
|
||||
upstreamURL string
|
||||
addPrefix string
|
||||
}{
|
||||
{
|
||||
name: "no prefix",
|
||||
path: "/foo",
|
||||
expectedPath: "/foo",
|
||||
upstreamURL: "http://test.example.com",
|
||||
},
|
||||
{
|
||||
name: "slash only",
|
||||
path: "/",
|
||||
expectedPath: "/",
|
||||
upstreamURL: "http://test.example.com",
|
||||
addPrefix: "/", // should not change anything
|
||||
},
|
||||
{
|
||||
name: "some prefix",
|
||||
path: "/test",
|
||||
expectedPath: "/foo/test",
|
||||
upstreamURL: "http://test.example.com",
|
||||
addPrefix: "/foo",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reqURL := types.MustParseURL("https://my.app" + tt.path)
|
||||
upstreamURL := types.MustParseURL(tt.upstreamURL)
|
||||
|
||||
opts["add_prefix"] = tt.addPrefix
|
||||
result, err := newMiddlewareTest(ModifyRequest, &testArgs{
|
||||
middlewareOpt: opts,
|
||||
reqURL: reqURL,
|
||||
upstreamURL: upstreamURL,
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Path"), tt.expectedPath)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
20
internal/net/gphttp/middleware/modify_response.go
Normal file
20
internal/net/gphttp/middleware/modify_response.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type modifyResponse struct {
|
||||
ModifyRequestOpts
|
||||
Tracer
|
||||
}
|
||||
|
||||
var ModifyResponse = NewMiddleware[modifyResponse]()
|
||||
|
||||
// modifyResponse implements ResponseModifier.
|
||||
func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
|
||||
mr.AddTraceResponse("before modify response", resp)
|
||||
mr.modifyHeaders(resp.Request, resp, resp.Header)
|
||||
mr.AddTraceResponse("after modify response", resp)
|
||||
return nil
|
||||
}
|
||||
108
internal/net/gphttp/middleware/modify_response_test.go
Normal file
108
internal/net/gphttp/middleware/modify_response_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"net/http"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestModifyResponse(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"set_headers": map[string]string{
|
||||
"X-Test-Resp-Status": VarRespStatusCode,
|
||||
"X-Test-Resp-Content-Type": VarRespContentType,
|
||||
"X-Test-Resp-Content-Length": VarRespContentLen,
|
||||
"X-Test-Resp-Header-Content-Type": "$resp_header(Content-Type)",
|
||||
|
||||
"X-Test-Req-Method": VarRequestMethod,
|
||||
"X-Test-Req-Scheme": VarRequestScheme,
|
||||
"X-Test-Req-Host": VarRequestHost,
|
||||
"X-Test-Req-Port": VarRequestPort,
|
||||
"X-Test-Req-Addr": VarRequestAddr,
|
||||
"X-Test-Req-Path": VarRequestPath,
|
||||
"X-Test-Req-Query": VarRequestQuery,
|
||||
"X-Test-Req-Url": VarRequestURL,
|
||||
"X-Test-Req-Uri": VarRequestURI,
|
||||
"X-Test-Req-Content-Type": VarRequestContentType,
|
||||
"X-Test-Req-Content-Length": VarRequestContentLen,
|
||||
"X-Test-Remote-Host": VarRemoteHost,
|
||||
"X-Test-Remote-Port": VarRemotePort,
|
||||
"X-Test-Remote-Addr": VarRemoteAddr,
|
||||
"X-Test-Upstream-Scheme": VarUpstreamScheme,
|
||||
"X-Test-Upstream-Host": VarUpstreamHost,
|
||||
"X-Test-Upstream-Port": VarUpstreamPort,
|
||||
"X-Test-Upstream-Addr": VarUpstreamAddr,
|
||||
"X-Test-Upstream-Url": VarUpstreamURL,
|
||||
"X-Test-Header-Content-Type": "$header(Content-Type)",
|
||||
"X-Test-Arg-Arg_1": "$arg(arg_1)",
|
||||
},
|
||||
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
|
||||
"hide_headers": []string{"Accept"},
|
||||
}
|
||||
|
||||
t.Run("set_options", func(t *testing.T) {
|
||||
mr, err := ModifyResponse.New(opts)
|
||||
ExpectNoError(t, err)
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
|
||||
})
|
||||
|
||||
t.Run("response_headers", func(t *testing.T) {
|
||||
reqURL := types.MustParseURL("https://my.app/?arg_1=b")
|
||||
upstreamURL := types.MustParseURL("http://test.example.com")
|
||||
result, err := newMiddlewareTest(ModifyResponse, &testArgs{
|
||||
middlewareOpt: opts,
|
||||
reqURL: reqURL,
|
||||
upstreamURL: upstreamURL,
|
||||
body: bytes.Repeat([]byte("a"), 100),
|
||||
headers: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
respHeaders: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
respBody: bytes.Repeat([]byte("a"), 50),
|
||||
respStatus: http.StatusOK,
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "")
|
||||
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Status"), "200")
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Type"), "application/json")
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Length"), "50")
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Header-Content-Type"), "application/json")
|
||||
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Method"), http.MethodGet)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Port"), reqURL.Port())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Path"), reqURL.Path)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Url"), reqURL.String())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Content-Type"), "application/json")
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Content-Length"), "100")
|
||||
|
||||
remoteHost, remotePort, _ := net.SplitHostPort(result.RemoteAddr)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Host"), remoteHost)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Port"), remotePort)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)
|
||||
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
|
||||
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Header-Content-Type"), "application/json")
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Arg-Arg_1"), "b")
|
||||
})
|
||||
}
|
||||
91
internal/net/gphttp/middleware/oidc.go
Normal file
91
internal/net/gphttp/middleware/oidc.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/api/v1/auth"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
)
|
||||
|
||||
type oidcMiddleware struct {
|
||||
AllowedUsers []string `json:"allowed_users"`
|
||||
AllowedGroups []string `json:"allowed_groups"`
|
||||
|
||||
auth auth.Provider
|
||||
authMux *http.ServeMux
|
||||
|
||||
isInitialized int32
|
||||
initMu sync.Mutex
|
||||
}
|
||||
|
||||
var OIDC = NewMiddleware[oidcMiddleware]()
|
||||
|
||||
func (amw *oidcMiddleware) finalize() error {
|
||||
if !auth.IsOIDCEnabled() {
|
||||
return gperr.New("OIDC not enabled but OIDC middleware is used")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (amw *oidcMiddleware) init() error {
|
||||
if atomic.LoadInt32(&amw.isInitialized) == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return amw.initSlow()
|
||||
}
|
||||
|
||||
func (amw *oidcMiddleware) initSlow() error {
|
||||
amw.initMu.Lock()
|
||||
if amw.isInitialized == 1 {
|
||||
amw.initMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
defer func() {
|
||||
amw.isInitialized = 1
|
||||
amw.initMu.Unlock()
|
||||
}()
|
||||
|
||||
authProvider, err := auth.NewOIDCProviderFromEnv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
authProvider.SetIsMiddleware(true)
|
||||
if len(amw.AllowedUsers) > 0 {
|
||||
authProvider.SetAllowedUsers(amw.AllowedUsers)
|
||||
}
|
||||
if len(amw.AllowedGroups) > 0 {
|
||||
authProvider.SetAllowedGroups(amw.AllowedGroups)
|
||||
}
|
||||
|
||||
amw.authMux = http.NewServeMux()
|
||||
amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
|
||||
amw.authMux.HandleFunc(auth.OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
})
|
||||
amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage)
|
||||
amw.auth = authProvider
|
||||
return nil
|
||||
}
|
||||
|
||||
func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
if err := amw.init(); err != nil {
|
||||
// no need to log here, main OIDC may already failed and logged
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := amw.auth.CheckToken(r); err != nil {
|
||||
amw.authMux.ServeHTTP(w, r)
|
||||
return false
|
||||
}
|
||||
if r.URL.Path == auth.OIDCLogoutPath {
|
||||
amw.auth.LogoutCallbackHandler(w, r)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
73
internal/net/gphttp/middleware/rate_limit.go
Normal file
73
internal/net/gphttp/middleware/rate_limit.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type (
|
||||
requestMap = map[string]*rate.Limiter
|
||||
rateLimiter struct {
|
||||
RateLimiterOpts
|
||||
Tracer
|
||||
|
||||
requestMap requestMap
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
RateLimiterOpts struct {
|
||||
Average int `validate:"min=1,required"`
|
||||
Burst int `validate:"min=1,required"`
|
||||
Period time.Duration `validate:"min=1s"`
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
RateLimiter = NewMiddleware[rateLimiter]()
|
||||
rateLimiterOptsDefault = RateLimiterOpts{
|
||||
Period: time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
// setup implements MiddlewareWithSetup.
|
||||
func (rl *rateLimiter) setup() {
|
||||
rl.RateLimiterOpts = rateLimiterOptsDefault
|
||||
rl.requestMap = make(requestMap, 0)
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (rl *rateLimiter) before(w http.ResponseWriter, r *http.Request) bool {
|
||||
return rl.limit(w, r)
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) newLimiter() *rate.Limiter {
|
||||
return rate.NewLimiter(rate.Limit(rl.Average)*rate.Every(rl.Period), rl.Burst)
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) limit(w http.ResponseWriter, r *http.Request) bool {
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
rl.AddTracef("unable to parse remote address %s", r.RemoteAddr)
|
||||
http.Error(w, "Internal error", http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
|
||||
rl.mu.Lock()
|
||||
limiter, ok := rl.requestMap[host]
|
||||
if !ok {
|
||||
limiter = rl.newLimiter()
|
||||
rl.requestMap[host] = limiter
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
|
||||
if limiter.Allow() {
|
||||
return true
|
||||
}
|
||||
|
||||
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
|
||||
return false
|
||||
}
|
||||
27
internal/net/gphttp/middleware/rate_limit_test.go
Normal file
27
internal/net/gphttp/middleware/rate_limit_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestRateLimit(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"average": "10",
|
||||
"burst": "10",
|
||||
"period": "1s",
|
||||
}
|
||||
|
||||
rl, err := RateLimiter.New(opts)
|
||||
ExpectNoError(t, err)
|
||||
for range 10 {
|
||||
result, err := newMiddlewareTest(rl, nil)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
}
|
||||
result, err := newMiddlewareTest(rl, nil)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusTooManyRequests)
|
||||
}
|
||||
116
internal/net/gphttp/middleware/real_ip.go
Normal file
116
internal/net/gphttp/middleware/real_ip.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
)
|
||||
|
||||
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
|
||||
|
||||
type (
|
||||
realIP struct {
|
||||
RealIPOpts
|
||||
Tracer
|
||||
}
|
||||
RealIPOpts struct {
|
||||
// Header is the name of the header to use for the real client IP
|
||||
Header string `validate:"required"`
|
||||
// From is a list of Address / CIDRs to trust
|
||||
From []*types.CIDR `validate:"required,min=1"`
|
||||
/*
|
||||
If recursive search is disabled,
|
||||
the original client address that matches one of the trusted addresses is replaced by
|
||||
the last address sent in the request header field defined by the Header field.
|
||||
If recursive search is enabled,
|
||||
the original client address that matches one of the trusted addresses is replaced by
|
||||
the last non-trusted address sent in the request header field.
|
||||
*/
|
||||
Recursive bool
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
RealIP = NewMiddleware[realIP]()
|
||||
realIPOptsDefault = RealIPOpts{
|
||||
Header: "X-Real-IP",
|
||||
From: []*types.CIDR{},
|
||||
}
|
||||
)
|
||||
|
||||
// setup implements MiddlewareWithSetup.
|
||||
func (ri *realIP) setup() {
|
||||
ri.RealIPOpts = realIPOptsDefault
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (ri *realIP) before(w http.ResponseWriter, r *http.Request) bool {
|
||||
ri.setRealIP(r)
|
||||
return true
|
||||
}
|
||||
|
||||
func (ri *realIP) isInCIDRList(ip net.IP) bool {
|
||||
for _, CIDR := range ri.From {
|
||||
if CIDR.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// not in any CIDR
|
||||
return false
|
||||
}
|
||||
|
||||
func (ri *realIP) setRealIP(req *http.Request) {
|
||||
clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err != nil {
|
||||
clientIPStr = req.RemoteAddr
|
||||
}
|
||||
|
||||
clientIP := net.ParseIP(clientIPStr)
|
||||
isTrusted := false
|
||||
|
||||
for _, CIDR := range ri.From {
|
||||
if CIDR.Contains(clientIP) {
|
||||
isTrusted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isTrusted {
|
||||
ri.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
|
||||
return
|
||||
}
|
||||
|
||||
realIPs := req.Header.Values(ri.Header)
|
||||
lastNonTrustedIP := ""
|
||||
|
||||
if len(realIPs) == 0 {
|
||||
// try non-canonical key
|
||||
realIPs = req.Header[ri.Header]
|
||||
}
|
||||
|
||||
if len(realIPs) == 0 {
|
||||
ri.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
|
||||
return
|
||||
}
|
||||
|
||||
if !ri.Recursive {
|
||||
lastNonTrustedIP = realIPs[len(realIPs)-1]
|
||||
} else {
|
||||
for _, r := range realIPs {
|
||||
if !ri.isInCIDRList(net.ParseIP(r)) {
|
||||
lastNonTrustedIP = r
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lastNonTrustedIP == "" {
|
||||
ri.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
|
||||
return
|
||||
}
|
||||
|
||||
req.RemoteAddr = lastNonTrustedIP
|
||||
req.Header.Set(ri.Header, lastNonTrustedIP)
|
||||
req.Header.Set(httpheaders.HeaderXRealIP, lastNonTrustedIP)
|
||||
ri.AddTracef("set real ip %s", lastNonTrustedIP)
|
||||
}
|
||||
77
internal/net/gphttp/middleware/real_ip_test.go
Normal file
77
internal/net/gphttp/middleware/real_ip_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestSetRealIPOpts(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"header": httpheaders.HeaderXRealIP,
|
||||
"from": []string{
|
||||
"127.0.0.0/8",
|
||||
"192.168.0.0/16",
|
||||
"172.16.0.0/12",
|
||||
},
|
||||
"recursive": true,
|
||||
}
|
||||
optExpected := &RealIPOpts{
|
||||
Header: httpheaders.HeaderXRealIP,
|
||||
From: []*types.CIDR{
|
||||
{
|
||||
IP: net.ParseIP("127.0.0.0"),
|
||||
Mask: net.IPv4Mask(255, 0, 0, 0),
|
||||
},
|
||||
{
|
||||
IP: net.ParseIP("192.168.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 0, 0),
|
||||
},
|
||||
{
|
||||
IP: net.ParseIP("172.16.0.0"),
|
||||
Mask: net.IPv4Mask(255, 240, 0, 0),
|
||||
},
|
||||
},
|
||||
Recursive: true,
|
||||
}
|
||||
|
||||
ri, err := RealIP.New(opts)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
|
||||
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
|
||||
for i, CIDR := range ri.impl.(*realIP).From {
|
||||
ExpectEqual(t, CIDR.String(), optExpected.From[i].String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetRealIP(t *testing.T) {
|
||||
const (
|
||||
testHeader = httpheaders.HeaderXRealIP
|
||||
testRealIP = "192.168.1.1"
|
||||
)
|
||||
opts := OptionsRaw{
|
||||
"header": testHeader,
|
||||
"from": []string{"0.0.0.0/0"},
|
||||
}
|
||||
optsMr := OptionsRaw{
|
||||
"set_headers": map[string]string{testHeader: testRealIP},
|
||||
}
|
||||
realip, err := RealIP.New(opts)
|
||||
ExpectNoError(t, err)
|
||||
|
||||
mr, err := ModifyRequest.New(optsMr)
|
||||
ExpectNoError(t, err)
|
||||
|
||||
mid := NewMiddlewareChain("test", []*Middleware{mr, realip})
|
||||
|
||||
result, err := newMiddlewareTest(mid, nil)
|
||||
ExpectNoError(t, err)
|
||||
t.Log(traces)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
|
||||
}
|
||||
29
internal/net/gphttp/middleware/redirect_http.go
Normal file
29
internal/net/gphttp/middleware/redirect_http.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
)
|
||||
|
||||
type redirectHTTP struct{}
|
||||
|
||||
var RedirectHTTP = NewMiddleware[redirectHTTP]()
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (redirectHTTP) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
if r.TLS != nil {
|
||||
return true
|
||||
}
|
||||
r.URL.Scheme = "https"
|
||||
host := r.Host
|
||||
if i := strings.Index(host, ":"); i != -1 {
|
||||
host = host[:i] // strip port number if present
|
||||
}
|
||||
r.URL.Host = host + ":" + common.ProxyHTTPSPort
|
||||
logging.Debug().Str("url", r.URL.String()).Msg("redirect to https")
|
||||
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
|
||||
return true
|
||||
}
|
||||
27
internal/net/gphttp/middleware/redirect_http_test.go
Normal file
27
internal/net/gphttp/middleware/redirect_http_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestRedirectToHTTPs(t *testing.T) {
|
||||
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
|
||||
reqURL: types.MustParseURL("http://example.com"),
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://example.com:"+common.ProxyHTTPSPort)
|
||||
}
|
||||
|
||||
func TestNoRedirect(t *testing.T) {
|
||||
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
|
||||
reqURL: types.MustParseURL("https://example.com"),
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
}
|
||||
37
internal/net/gphttp/middleware/set_upstream_headers.go
Normal file
37
internal/net/gphttp/middleware/set_upstream_headers.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
)
|
||||
|
||||
// internal use only.
|
||||
type setUpstreamHeaders struct {
|
||||
Name, Scheme, Host, Port string
|
||||
}
|
||||
|
||||
var suh = NewMiddleware[setUpstreamHeaders]()
|
||||
|
||||
func newSetUpstreamHeaders(rp *reverseproxy.ReverseProxy) *Middleware {
|
||||
m, err := suh.New(OptionsRaw{
|
||||
"name": rp.TargetName,
|
||||
"scheme": rp.TargetURL.Scheme,
|
||||
"host": rp.TargetURL.Hostname(),
|
||||
"port": rp.TargetURL.Port(),
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
r.Header.Set(httpheaders.HeaderUpstreamName, s.Name)
|
||||
r.Header.Set(httpheaders.HeaderUpstreamScheme, s.Scheme)
|
||||
r.Header.Set(httpheaders.HeaderUpstreamHost, s.Host)
|
||||
r.Header.Set(httpheaders.HeaderUpstreamPort, s.Port)
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
deny:
|
||||
- use: ModifyRequest
|
||||
setHeaders:
|
||||
X-Real-IP: 192.168.1.1:1234
|
||||
- use: RealIP
|
||||
header: X-Real-IP
|
||||
from:
|
||||
- 0.0.0.0/0
|
||||
- use: CIDRWhitelist
|
||||
allow:
|
||||
- 192.168.0.0/24
|
||||
accept:
|
||||
- use: ModifyRequest
|
||||
setHeaders:
|
||||
X-Real-IP: 192.168.0.1:1234
|
||||
- use: RealIP
|
||||
header: X-Real-IP
|
||||
from:
|
||||
- 0.0.0.0/0
|
||||
- use: CIDRWhitelist
|
||||
allow:
|
||||
- 192.168.0.0/24
|
||||
- 127.0.0.1
|
||||
@@ -0,0 +1,41 @@
|
||||
theGreatPretender:
|
||||
- use: HideXForwarded
|
||||
- use: ModifyRequest
|
||||
setHeaders:
|
||||
X-Real-IP: 6.6.6.6
|
||||
- use: ModifyResponse
|
||||
hideHeaders:
|
||||
- X-Test3
|
||||
- X-Test4
|
||||
|
||||
notAuthenticAuthentik:
|
||||
- use: RedirectHTTP
|
||||
- use: ForwardAuth
|
||||
address: https://authentik.company
|
||||
trustForwardHeader: true
|
||||
addAuthCookiesToResponse:
|
||||
- session_id
|
||||
- user_id
|
||||
authResponseHeaders:
|
||||
- X-Auth-SessionID
|
||||
- X-Auth-UserID
|
||||
- use: CustomErrorPage
|
||||
|
||||
realIPAuthentik:
|
||||
- use: RedirectHTTP
|
||||
- use: RealIP
|
||||
header: X-Real-IP
|
||||
from:
|
||||
- "127.0.0.0/8"
|
||||
- "192.168.0.0/16"
|
||||
- "172.16.0.0/12"
|
||||
recursive: true
|
||||
- use: ForwardAuth
|
||||
address: https://authentik.company
|
||||
trustForwardHeader: true
|
||||
|
||||
testFakeRealIP:
|
||||
- use: ModifyRequest
|
||||
setHeaders:
|
||||
CF-Connecting-IP: 127.0.0.1
|
||||
- use: CloudflareRealIP
|
||||
17
internal/net/gphttp/middleware/test_data/sample_headers.json
Normal file
17
internal/net/gphttp/middleware/test_data/sample_headers.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
|
||||
"Accept-Encoding": "gzip, deflate, br, zstd",
|
||||
"Accept-Language": "en,zh-HK;q=0.9,zh-TW;q=0.8,zh-CN;q=0.7,zh;q=0.6",
|
||||
"Dnt": "1",
|
||||
"Host": "localhost",
|
||||
"Priority": "u=0, i",
|
||||
"Sec-Ch-Ua": "\"Chromium\";v=\"129\", \"Not=A?Brand\";v=\"8\"",
|
||||
"Sec-Ch-Ua-Mobile": "?0",
|
||||
"Sec-Ch-Ua-Platform": "\"Windows\"",
|
||||
"Sec-Fetch-Dest": "document",
|
||||
"Sec-Fetch-Mode": "navigate",
|
||||
"Sec-Fetch-Site": "none",
|
||||
"Sec-Fetch-User": "?1",
|
||||
"Upgrade-Insecure-Requests": "1",
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36"
|
||||
}
|
||||
176
internal/net/gphttp/middleware/test_utils.go
Normal file
176
internal/net/gphttp/middleware/test_utils.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
//go:embed test_data/sample_headers.json
|
||||
var testHeadersRaw []byte
|
||||
var testHeaders http.Header
|
||||
|
||||
func init() {
|
||||
if !common.IsTest {
|
||||
return
|
||||
}
|
||||
tmp := map[string]string{}
|
||||
err := json.Unmarshal(testHeadersRaw, &tmp)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
testHeaders = http.Header{}
|
||||
for k, v := range tmp {
|
||||
testHeaders.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
type requestRecorder struct {
|
||||
args *testArgs
|
||||
|
||||
parent http.RoundTripper
|
||||
headers http.Header
|
||||
remoteAddr string
|
||||
}
|
||||
|
||||
func newRequestRecorder(args *testArgs) *requestRecorder {
|
||||
return &requestRecorder{args: args}
|
||||
}
|
||||
|
||||
func (rt *requestRecorder) RoundTrip(req *http.Request) (resp *http.Response, err error) {
|
||||
rt.headers = req.Header
|
||||
rt.remoteAddr = req.RemoteAddr
|
||||
if rt.parent != nil {
|
||||
resp, err = rt.parent.RoundTrip(req)
|
||||
} else {
|
||||
resp = &http.Response{
|
||||
Status: http.StatusText(rt.args.respStatus),
|
||||
StatusCode: rt.args.respStatus,
|
||||
Header: testHeaders,
|
||||
Body: io.NopCloser(bytes.NewReader(rt.args.respBody)),
|
||||
ContentLength: int64(len(rt.args.respBody)),
|
||||
Request: req,
|
||||
TLS: req.TLS,
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
for k, v := range rt.args.respHeaders {
|
||||
resp.Header[k] = v
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
type TestResult struct {
|
||||
RequestHeaders http.Header
|
||||
ResponseHeaders http.Header
|
||||
ResponseStatus int
|
||||
RemoteAddr string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
type testArgs struct {
|
||||
middlewareOpt OptionsRaw
|
||||
upstreamURL *types.URL
|
||||
|
||||
realRoundTrip bool
|
||||
|
||||
reqURL *types.URL
|
||||
reqMethod string
|
||||
headers http.Header
|
||||
body []byte
|
||||
|
||||
respHeaders http.Header
|
||||
respBody []byte
|
||||
respStatus int
|
||||
}
|
||||
|
||||
func (args *testArgs) setDefaults() {
|
||||
if args.reqURL == nil {
|
||||
args.reqURL = Must(types.ParseURL("https://example.com"))
|
||||
}
|
||||
if args.reqMethod == "" {
|
||||
args.reqMethod = http.MethodGet
|
||||
}
|
||||
if args.upstreamURL == nil {
|
||||
args.upstreamURL = Must(types.ParseURL("https://10.0.0.1:8443")) // dummy url, no actual effect
|
||||
}
|
||||
if args.respHeaders == nil {
|
||||
args.respHeaders = http.Header{}
|
||||
}
|
||||
if args.respBody == nil {
|
||||
args.respBody = []byte("OK")
|
||||
}
|
||||
if args.respStatus == 0 {
|
||||
args.respStatus = http.StatusOK
|
||||
}
|
||||
}
|
||||
|
||||
func (args *testArgs) bodyReader() io.Reader {
|
||||
if args.body != nil {
|
||||
return bytes.NewReader(args.body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, gperr.Error) {
|
||||
if args == nil {
|
||||
args = new(testArgs)
|
||||
}
|
||||
args.setDefaults()
|
||||
|
||||
mid, setOptErr := middleware.New(args.middlewareOpt)
|
||||
if setOptErr != nil {
|
||||
return nil, setOptErr
|
||||
}
|
||||
|
||||
return newMiddlewaresTest([]*Middleware{mid}, args)
|
||||
}
|
||||
|
||||
func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, gperr.Error) {
|
||||
if args == nil {
|
||||
args = new(testArgs)
|
||||
}
|
||||
args.setDefaults()
|
||||
|
||||
req := httptest.NewRequest(args.reqMethod, args.reqURL.String(), args.bodyReader())
|
||||
for k, v := range args.headers {
|
||||
req.Header[k] = v
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rr := newRequestRecorder(args)
|
||||
if args.realRoundTrip {
|
||||
rr.parent = http.DefaultTransport
|
||||
}
|
||||
|
||||
rp := reverseproxy.NewReverseProxy("test", args.upstreamURL, rr)
|
||||
patchReverseProxy(rp, middlewares)
|
||||
rp.ServeHTTP(w, req)
|
||||
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, gperr.Wrap(err)
|
||||
}
|
||||
|
||||
return &TestResult{
|
||||
RequestHeaders: rr.headers,
|
||||
ResponseHeaders: resp.Header,
|
||||
ResponseStatus: resp.StatusCode,
|
||||
RemoteAddr: rr.remoteAddr,
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
87
internal/net/gphttp/middleware/trace.go
Normal file
87
internal/net/gphttp/middleware/trace.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
)
|
||||
|
||||
type (
|
||||
Trace struct {
|
||||
Time string `json:"time,omitempty"`
|
||||
Caller string `json:"caller,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Message string `json:"msg"`
|
||||
ReqHeaders map[string]string `json:"req_headers,omitempty"`
|
||||
RespHeaders map[string]string `json:"resp_headers,omitempty"`
|
||||
RespStatus int `json:"resp_status,omitempty"`
|
||||
Additional map[string]any `json:"additional,omitempty"`
|
||||
}
|
||||
Traces []*Trace
|
||||
)
|
||||
|
||||
var (
|
||||
traces = make(Traces, 0)
|
||||
tracesMu sync.Mutex
|
||||
)
|
||||
|
||||
const MaxTraceNum = 100
|
||||
|
||||
func GetAllTrace() []*Trace {
|
||||
return traces
|
||||
}
|
||||
|
||||
func (tr *Trace) WithRequest(req *http.Request) *Trace {
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
tr.URL = req.RequestURI
|
||||
tr.ReqHeaders = httpheaders.HeaderToMap(req.Header)
|
||||
return tr
|
||||
}
|
||||
|
||||
func (tr *Trace) WithResponse(resp *http.Response) *Trace {
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
tr.URL = resp.Request.RequestURI
|
||||
tr.ReqHeaders = httpheaders.HeaderToMap(resp.Request.Header)
|
||||
tr.RespHeaders = httpheaders.HeaderToMap(resp.Header)
|
||||
tr.RespStatus = resp.StatusCode
|
||||
return tr
|
||||
}
|
||||
|
||||
func (tr *Trace) With(what string, additional any) *Trace {
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if tr.Additional == nil {
|
||||
tr.Additional = map[string]any{}
|
||||
}
|
||||
tr.Additional[what] = additional
|
||||
return tr
|
||||
}
|
||||
|
||||
func (tr *Trace) WithError(err error) *Trace {
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if tr.Additional == nil {
|
||||
tr.Additional = map[string]any{}
|
||||
}
|
||||
tr.Additional["error"] = err.Error()
|
||||
return tr
|
||||
}
|
||||
|
||||
func addTrace(t *Trace) *Trace {
|
||||
tracesMu.Lock()
|
||||
defer tracesMu.Unlock()
|
||||
if len(traces) > MaxTraceNum {
|
||||
traces = traces[1:]
|
||||
}
|
||||
traces = append(traces, t)
|
||||
return t
|
||||
}
|
||||
62
internal/net/gphttp/middleware/tracer.go
Normal file
62
internal/net/gphttp/middleware/tracer.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type Tracer struct {
|
||||
name string
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func _() {
|
||||
var _ MiddlewareWithTracer = &Tracer{}
|
||||
}
|
||||
|
||||
func (t *Tracer) enableTrace() {
|
||||
t.enabled = true
|
||||
}
|
||||
|
||||
func (t *Tracer) getTracer() *Tracer {
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Tracer) SetParent(parent *Tracer) {
|
||||
if parent == nil {
|
||||
return
|
||||
}
|
||||
t.name = parent.name + "." + t.name
|
||||
}
|
||||
|
||||
func (t *Tracer) addTrace(msg string) *Trace {
|
||||
return addTrace(&Trace{
|
||||
Time: strutils.FormatTime(time.Now()),
|
||||
Caller: t.name,
|
||||
Message: msg,
|
||||
})
|
||||
}
|
||||
|
||||
func (t *Tracer) AddTracef(msg string, args ...any) *Trace {
|
||||
if !t.enabled {
|
||||
return nil
|
||||
}
|
||||
return t.addTrace(fmt.Sprintf(msg, args...))
|
||||
}
|
||||
|
||||
func (t *Tracer) AddTraceRequest(msg string, req *http.Request) *Trace {
|
||||
if !t.enabled {
|
||||
return nil
|
||||
}
|
||||
return t.addTrace(msg).WithRequest(req)
|
||||
}
|
||||
|
||||
func (t *Tracer) AddTraceResponse(msg string, resp *http.Response) *Trace {
|
||||
if !t.enabled {
|
||||
return nil
|
||||
}
|
||||
return t.addTrace(msg).WithResponse(resp)
|
||||
}
|
||||
175
internal/net/gphttp/middleware/vars.go
Normal file
175
internal/net/gphttp/middleware/vars.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
)
|
||||
|
||||
type (
|
||||
reqVarGetter func(*http.Request) string
|
||||
respVarGetter func(*http.Response) string
|
||||
)
|
||||
|
||||
var (
|
||||
reArg = regexp.MustCompile(`\$arg\([\w-_]+\)`)
|
||||
reReqHeader = regexp.MustCompile(`\$header\([\w-]+\)`)
|
||||
reRespHeader = regexp.MustCompile(`\$resp_header\([\w-]+\)`)
|
||||
reStatic = regexp.MustCompile(`\$[\w_]+`)
|
||||
)
|
||||
|
||||
const (
|
||||
VarRequestMethod = "$req_method"
|
||||
VarRequestScheme = "$req_scheme"
|
||||
VarRequestHost = "$req_host"
|
||||
VarRequestPort = "$req_port"
|
||||
VarRequestPath = "$req_path"
|
||||
VarRequestAddr = "$req_addr"
|
||||
VarRequestQuery = "$req_query"
|
||||
VarRequestURL = "$req_url"
|
||||
VarRequestURI = "$req_uri"
|
||||
VarRequestContentType = "$req_content_type"
|
||||
VarRequestContentLen = "$req_content_length"
|
||||
VarRemoteHost = "$remote_host"
|
||||
VarRemotePort = "$remote_port"
|
||||
VarRemoteAddr = "$remote_addr"
|
||||
|
||||
VarUpstreamName = "$upstream_name"
|
||||
VarUpstreamScheme = "$upstream_scheme"
|
||||
VarUpstreamHost = "$upstream_host"
|
||||
VarUpstreamPort = "$upstream_port"
|
||||
VarUpstreamAddr = "$upstream_addr"
|
||||
VarUpstreamURL = "$upstream_url"
|
||||
|
||||
VarRespContentType = "$resp_content_type"
|
||||
VarRespContentLen = "$resp_content_length"
|
||||
VarRespStatusCode = "$status_code"
|
||||
)
|
||||
|
||||
var staticReqVarSubsMap = map[string]reqVarGetter{
|
||||
VarRequestMethod: func(req *http.Request) string { return req.Method },
|
||||
VarRequestScheme: func(req *http.Request) string {
|
||||
if req.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
return "http"
|
||||
},
|
||||
VarRequestHost: func(req *http.Request) string {
|
||||
reqHost, _, err := net.SplitHostPort(req.Host)
|
||||
if err != nil {
|
||||
return req.Host
|
||||
}
|
||||
return reqHost
|
||||
},
|
||||
VarRequestPort: func(req *http.Request) string {
|
||||
_, reqPort, _ := net.SplitHostPort(req.Host)
|
||||
return reqPort
|
||||
},
|
||||
VarRequestAddr: func(req *http.Request) string { return req.Host },
|
||||
VarRequestPath: func(req *http.Request) string { return req.URL.Path },
|
||||
VarRequestQuery: func(req *http.Request) string { return req.URL.RawQuery },
|
||||
VarRequestURL: func(req *http.Request) string { return req.URL.String() },
|
||||
VarRequestURI: func(req *http.Request) string { return req.URL.RequestURI() },
|
||||
VarRequestContentType: func(req *http.Request) string { return req.Header.Get("Content-Type") },
|
||||
VarRequestContentLen: func(req *http.Request) string { return strconv.FormatInt(req.ContentLength, 10) },
|
||||
VarRemoteHost: func(req *http.Request) string {
|
||||
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err == nil {
|
||||
return clientIP
|
||||
}
|
||||
return ""
|
||||
},
|
||||
VarRemotePort: func(req *http.Request) string {
|
||||
_, clientPort, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err == nil {
|
||||
return clientPort
|
||||
}
|
||||
return ""
|
||||
},
|
||||
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
|
||||
VarUpstreamName: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamName) },
|
||||
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamScheme) },
|
||||
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamHost) },
|
||||
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamPort) },
|
||||
VarUpstreamAddr: func(req *http.Request) string {
|
||||
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
|
||||
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
|
||||
if upPort != "" {
|
||||
return upHost + ":" + upPort
|
||||
}
|
||||
return upHost
|
||||
},
|
||||
VarUpstreamURL: func(req *http.Request) string {
|
||||
upScheme := req.Header.Get(httpheaders.HeaderUpstreamScheme)
|
||||
if upScheme == "" {
|
||||
return ""
|
||||
}
|
||||
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
|
||||
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
|
||||
upAddr := upHost
|
||||
if upPort != "" {
|
||||
upAddr += ":" + upPort
|
||||
}
|
||||
return upScheme + "://" + upAddr
|
||||
},
|
||||
}
|
||||
|
||||
var staticRespVarSubsMap = map[string]respVarGetter{
|
||||
VarRespContentType: func(resp *http.Response) string { return resp.Header.Get("Content-Type") },
|
||||
VarRespContentLen: func(resp *http.Response) string { return strconv.FormatInt(resp.ContentLength, 10) },
|
||||
VarRespStatusCode: func(resp *http.Response) string { return strconv.Itoa(resp.StatusCode) },
|
||||
}
|
||||
|
||||
func varReplace(req *http.Request, resp *http.Response, s string) string {
|
||||
if req != nil {
|
||||
// Replace query parameters
|
||||
s = reArg.ReplaceAllStringFunc(s, func(match string) string {
|
||||
name := match[5 : len(match)-1]
|
||||
for k, v := range req.URL.Query() {
|
||||
if strings.EqualFold(k, name) {
|
||||
return v[0]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
})
|
||||
|
||||
// Replace request headers
|
||||
s = reReqHeader.ReplaceAllStringFunc(s, func(match string) string {
|
||||
header := http.CanonicalHeaderKey(match[8 : len(match)-1])
|
||||
return req.Header.Get(header)
|
||||
})
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
// Replace response headers
|
||||
s = reRespHeader.ReplaceAllStringFunc(s, func(match string) string {
|
||||
header := http.CanonicalHeaderKey(match[13 : len(match)-1])
|
||||
return resp.Header.Get(header)
|
||||
})
|
||||
}
|
||||
|
||||
// Replace static variables
|
||||
if req != nil {
|
||||
s = reStatic.ReplaceAllStringFunc(s, func(match string) string {
|
||||
if fn, ok := staticReqVarSubsMap[match]; ok {
|
||||
return fn(req)
|
||||
}
|
||||
return match
|
||||
})
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
s = reStatic.ReplaceAllStringFunc(s, func(match string) string {
|
||||
if fn, ok := staticRespVarSubsMap[match]; ok {
|
||||
return fn(resp)
|
||||
}
|
||||
return match
|
||||
})
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
45
internal/net/gphttp/middleware/x_forwarded.go
Normal file
45
internal/net/gphttp/middleware/x_forwarded.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
)
|
||||
|
||||
type (
|
||||
setXForwarded struct{}
|
||||
hideXForwarded struct{}
|
||||
)
|
||||
|
||||
var (
|
||||
SetXForwarded = NewMiddleware[setXForwarded]()
|
||||
HideXForwarded = NewMiddleware[hideXForwarded]()
|
||||
)
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
r.Header.Del(httpheaders.HeaderXForwardedFor)
|
||||
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err == nil {
|
||||
r.Header.Set(httpheaders.HeaderXForwardedFor, clientIP)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (hideXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
toDelete := make([]string, 0, len(r.Header))
|
||||
for k := range r.Header {
|
||||
if strings.HasPrefix(k, "X-Forwarded-") {
|
||||
toDelete = append(toDelete, k)
|
||||
}
|
||||
}
|
||||
|
||||
for _, k := range toDelete {
|
||||
r.Header.Del(k)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user