added real_ip and cloudflare_real_ip middlewares, fixed that some middlewares does not work properly

This commit is contained in:
yusing
2024-09-30 04:03:48 +08:00
parent ac3af49aa7
commit 860e914b90
27 changed files with 463 additions and 179 deletions

View File

@@ -0,0 +1,32 @@
package http
import (
"mime"
"net/http"
)
type ContentType string
func GetContentType(h http.Header) ContentType {
ct := h.Get("Content-Type")
if ct == "" {
return ""
}
ct, _, err := mime.ParseMediaType(ct)
if err != nil {
return ""
}
return ContentType(ct)
}
func (ct ContentType) IsHTML() bool {
return ct == "text/html" || ct == "application/xhtml+xml"
}
func (ct ContentType) IsJSON() bool {
return ct == "application/json"
}
func (ct ContentType) IsPlainText() bool {
return ct == "text/plain"
}

View File

@@ -0,0 +1,42 @@
package http
import (
"net/http"
"slices"
)
func RemoveHop(h http.Header) {
reqUpType := UpgradeType(h)
RemoveHopByHopHeaders(h)
if reqUpType != "" {
h.Set("Connection", "Upgrade")
h.Set("Upgrade", reqUpType)
} else {
h.Del("Connection")
}
}
func CopyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func FilterHeaders(h http.Header, allowed []string) {
if allowed == nil {
return
}
for i := range allowed {
allowed[i] = http.CanonicalHeaderKey(allowed[i])
}
for key := range h {
if !slices.Contains(allowed, key) {
h.Del(key)
}
}
}

View File

@@ -0,0 +1,118 @@
package middleware
import (
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
)
const (
cfIPv4CIDRsEndpoint = "https://www.cloudflare.com/ips-v4"
cfIPv6CIDRsEndpoint = "https://www.cloudflare.com/ips-v6"
cfCIDRsUpdateInterval = time.Hour
cfCIDRsUpdateRetryInterval = 3 * time.Second
)
var (
cfCIDRsLastUpdate time.Time
cfCIDRsMu sync.Mutex
cfCIDRsLogger = logrus.WithField("middleware", "CloudflareRealIP")
)
var CloudflareRealIP = &realIP{
m: &Middleware{
withOptions: NewCloudflareRealIP,
},
}
func NewCloudflareRealIP(_ OptionsRaw, _ *ReverseProxy) (*Middleware, E.NestedError) {
cri := new(realIP)
cri.m = &Middleware{
impl: cri,
rewrite: func(r *Request) {
cidrs := tryFetchCFCIDR()
if cidrs != nil {
cri.From = cidrs
}
cri.setRealIP(r)
},
}
cri.realIPOpts = &realIPOpts{
Header: "CF-Connecting-IP",
Recursive: true,
}
return cri.m, nil
}
func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) {
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
return
}
cfCIDRsMu.Lock()
defer cfCIDRsMu.Unlock()
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
return
}
if common.IsTest {
cfCIDRs = []*net.IPNet{
{IP: net.IPv4(127, 0, 0, 1), Mask: net.IPv4Mask(255, 0, 0, 0)},
{IP: net.IPv4(10, 0, 0, 0), Mask: net.IPv4Mask(255, 0, 0, 0)},
{IP: net.IPv4(172, 16, 0, 0), Mask: net.IPv4Mask(255, 255, 0, 0)},
{IP: net.IPv4(192, 168, 0, 0), Mask: net.IPv4Mask(255, 255, 255, 0)},
}
} else {
cfCIDRs = make([]*net.IPNet, 0, 30)
err := errors.Join(
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, cfCIDRs),
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, cfCIDRs),
)
if err != nil {
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)
cfCIDRsLogger.Errorf("failed to update cloudflare range: %s, retry in %s", err, cfCIDRsUpdateRetryInterval)
return nil
}
}
cfCIDRsLastUpdate = time.Now()
cfCIDRsLogger.Debugf("cloudflare CIDR range updated")
return
}
func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*net.IPNet) error {
resp, err := http.Get(endpoint)
if err != nil {
return err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
for _, line := range strings.Split(string(body), "\n") {
if line == "" {
continue
}
_, cidr, err := net.ParseCIDR(line)
if err != nil {
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
} else {
cfCIDRs = append(cfCIDRs, cidr)
}
}
return nil
}

View File

@@ -0,0 +1,75 @@
package middleware
import (
"bytes"
"fmt"
"io"
"net/http"
"path/filepath"
"strings"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/api/v1/error_page"
"github.com/yusing/go-proxy/internal/common"
gpHTTP "github.com/yusing/go-proxy/internal/net/http"
)
var CustomErrorPage = &Middleware{
before: func(next http.Handler, w ResponseWriter, r *Request) {
if !ServeStaticErrorPageFile(w, r) {
next.ServeHTTP(w, r)
}
},
modifyResponse: func(resp *Response) error {
// only handles non-success status code and html/plain content type
contentType := gpHTTP.GetContentType(resp.Header)
if !gpHTTP.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
errorPage, ok := error_page.GetErrorPageByStatus(resp.StatusCode)
if ok {
errPageLogger.Debugf("error page for status %d loaded", resp.StatusCode)
io.Copy(io.Discard, resp.Body) // drain the original body
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
resp.ContentLength = int64(len(errorPage))
resp.Header.Set("Content-Length", fmt.Sprint(len(errorPage)))
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
} else {
errPageLogger.Errorf("unable to load error page for status %d", resp.StatusCode)
}
return nil
}
return nil
},
}
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
path := r.URL.Path
if path != "" && path[0] != '/' {
path = "/" + path
}
if strings.HasPrefix(path, common.StaticFilePathPrefix) {
filename := path[len(common.StaticFilePathPrefix):]
file, ok := error_page.GetStaticFile(filename)
if !ok {
errPageLogger.Errorf("unable to load resource %s", filename)
return false
} else {
ext := filepath.Ext(filename)
switch ext {
case ".html":
w.Header().Set("Content-Type", "text/html; charset=utf-8")
case ".js":
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
case ".css":
w.Header().Set("Content-Type", "text/css; charset=utf-8")
default:
errPageLogger.Errorf("unexpected file type %q for %s", ext, filename)
}
w.Write(file)
return true
}
}
return false
}
var errPageLogger = logrus.WithField("middleware", "error_page")

View File

@@ -0,0 +1,248 @@
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/auth/forward.go)
// Copyright (c) 2020-2024 Traefik Labs
// Copyright (c) 2024 yusing
package middleware
import (
"io"
"net"
"net/http"
"net/url"
"slices"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
gpHTTP "github.com/yusing/go-proxy/internal/net/http"
)
type (
forwardAuth struct {
*forwardAuthOpts
m *Middleware
client http.Client
}
forwardAuthOpts struct {
Address string
TrustForwardHeader bool
AuthResponseHeaders []string
AddAuthCookiesToResponse []string
}
)
const (
xForwardedFor = "X-Forwarded-For"
xForwardedMethod = "X-Forwarded-Method"
xForwardedHost = "X-Forwarded-Host"
xForwardedProto = "X-Forwarded-Proto"
xForwardedURI = "X-Forwarded-Uri"
xForwardedPort = "X-Forwarded-Port"
)
var ForwardAuth = func() *forwardAuth {
fa := new(forwardAuth)
fa.m = new(Middleware)
fa.m.labelParserMap = D.ValueParserMap{
"trust_forward_header": D.BoolParser,
"auth_response_headers": D.YamlStringListParser,
"add_auth_cookies_to_response": D.YamlStringListParser,
}
fa.m.withOptions = NewForwardAuthfunc
return fa
}()
var faLogger = logrus.WithField("middleware", "ForwardAuth")
func NewForwardAuthfunc(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
tr, ok := rp.Transport.(*http.Transport)
if ok {
tr = tr.Clone()
} else {
tr = common.DefaultTransport.Clone()
}
faWithOpts := new(forwardAuth)
faWithOpts.forwardAuthOpts = new(forwardAuthOpts)
faWithOpts.client = http.Client{
CheckRedirect: func(r *Request, via []*Request) error {
return http.ErrUseLastResponse
},
Timeout: 30 * time.Second,
Transport: tr,
}
faWithOpts.m = &Middleware{
impl: faWithOpts,
before: faWithOpts.forward,
}
err := Deserialize(optsRaw, faWithOpts.forwardAuthOpts)
if err != nil {
return nil, E.FailWith("set options", err)
}
_, err = E.Check(url.Parse(faWithOpts.Address))
if err != nil {
return nil, E.Invalid("address", faWithOpts.Address)
}
return faWithOpts.m, nil
}
func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request) {
gpHTTP.RemoveHop(req.Header)
faReq, err := http.NewRequestWithContext(
req.Context(),
http.MethodGet,
fa.Address,
nil,
)
if err != nil {
faLogger.Debugf("new request err to %s: %s", fa.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
gpHTTP.CopyHeader(faReq.Header, req.Header)
gpHTTP.RemoveHop(faReq.Header)
gpHTTP.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
fa.setAuthHeaders(req, faReq)
faResp, err := fa.client.Do(faReq)
if err != nil {
faLogger.Debugf("failed to call %s: %s", fa.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
defer faResp.Body.Close()
body, err := io.ReadAll(faResp.Body)
if err != nil {
faLogger.Debugf("failed to read response body from %s: %s", fa.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
gpHTTP.CopyHeader(w.Header(), faResp.Header)
gpHTTP.RemoveHop(w.Header())
redirectURL, err := faResp.Location()
if err != nil {
faLogger.Debugf("failed to get location from %s: %s", fa.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
} else if redirectURL.String() != "" {
w.Header().Set("Location", redirectURL.String())
}
w.WriteHeader(faResp.StatusCode)
if _, err = w.Write(body); err != nil {
faLogger.Debugf("failed to write response body from %s: %s", fa.Address, err)
}
return
}
for _, key := range fa.AuthResponseHeaders {
key := http.CanonicalHeaderKey(key)
req.Header.Del(key)
if len(faResp.Header[key]) > 0 {
req.Header[key] = append([]string(nil), faResp.Header[key]...)
}
}
req.RequestURI = req.URL.RequestURI()
authCookies := faResp.Cookies()
if len(authCookies) == 0 {
next.ServeHTTP(w, req)
return
}
next.ServeHTTP(gpHTTP.NewModifyResponseWriter(w, req, func(resp *Response) error {
fa.setAuthCookies(resp, authCookies)
return nil
}), req)
}
func (fa *forwardAuth) setAuthCookies(resp *Response, authCookies []*Cookie) {
if len(fa.AddAuthCookiesToResponse) == 0 {
return
}
cookies := resp.Cookies()
resp.Header.Del("Set-Cookie")
for _, cookie := range cookies {
if !slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
// this cookie is not an auth cookie, so add it back
resp.Header.Add("Set-Cookie", cookie.String())
}
}
for _, cookie := range authCookies {
if slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
// this cookie is an auth cookie, so add to resp
resp.Header.Add("Set-Cookie", cookie.String())
}
}
}
func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if fa.TrustForwardHeader {
if prior, ok := req.Header[xForwardedFor]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
}
faReq.Header.Set(xForwardedFor, clientIP)
}
xMethod := req.Header.Get(xForwardedMethod)
switch {
case xMethod != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedMethod, xMethod)
case req.Method != "":
faReq.Header.Set(xForwardedMethod, req.Method)
default:
faReq.Header.Del(xForwardedMethod)
}
xfp := req.Header.Get(xForwardedProto)
switch {
case xfp != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedProto, xfp)
case req.TLS != nil:
faReq.Header.Set(xForwardedProto, "https")
default:
faReq.Header.Set(xForwardedProto, "http")
}
if xfp := req.Header.Get(xForwardedPort); xfp != "" && fa.TrustForwardHeader {
faReq.Header.Set(xForwardedPort, xfp)
}
xfh := req.Header.Get(xForwardedHost)
switch {
case xfh != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedHost, xfh)
case req.Host != "":
faReq.Header.Set(xForwardedHost, req.Host)
default:
faReq.Header.Del(xForwardedHost)
}
xfURI := req.Header.Get(xForwardedURI)
switch {
case xfURI != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedURI, xfURI)
case req.URL.RequestURI() != "":
faReq.Header.Set(xForwardedURI, req.URL.RequestURI())
default:
faReq.Header.Del(xForwardedURI)
}
}

View File

@@ -0,0 +1,147 @@
package middleware
import (
"net/http"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
gpHTTP "github.com/yusing/go-proxy/internal/net/http"
U "github.com/yusing/go-proxy/internal/utils"
)
type (
Error = E.NestedError
ReverseProxy = gpHTTP.ReverseProxy
ProxyRequest = gpHTTP.ProxyRequest
Request = http.Request
Response = http.Response
ResponseWriter = http.ResponseWriter
Header = http.Header
Cookie = http.Cookie
BeforeFunc func(next http.Handler, w ResponseWriter, r *Request)
RewriteFunc func(req *Request)
ModifyResponseFunc func(resp *Response) error
CloneWithOptFunc func(opts OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError)
OptionsRaw = map[string]any
Options any
Middleware struct {
name string
before BeforeFunc // runs before ReverseProxy.ServeHTTP
rewrite RewriteFunc // runs after ReverseProxy.Rewrite
modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse
transport http.RoundTripper
withOptions CloneWithOptFunc
labelParserMap D.ValueParserMap
impl any
}
)
var Deserialize = U.Deserialize
func (m *Middleware) Name() string {
return m.name
}
func (m *Middleware) String() string {
return m.name
}
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
if len(optsRaw) != 0 && m.withOptions != nil {
if mWithOpt, err := m.withOptions(optsRaw, rp); err != nil {
return nil, err
} else {
return mWithOpt, nil
}
}
// WithOptionsClone is called only once
// set withOptions and labelParser will not be used after that
return &Middleware{m.name, m.before, m.rewrite, m.modifyResponse, m.transport, nil, nil, m.impl}, nil
}
// TODO: check conflict or duplicates
func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res E.NestedError) {
befores := make([]BeforeFunc, 0, len(middlewares))
rewrites := make([]RewriteFunc, 0, len(middlewares))
modResps := make([]ModifyResponseFunc, 0, len(middlewares))
invalidM := E.NewBuilder("invalid middlewares")
invalidOpts := E.NewBuilder("invalid options")
defer func() {
invalidM.Add(invalidOpts.Build())
invalidM.To(&res)
}()
for name, opts := range middlewares {
m, ok := Get(name)
if !ok {
invalidM.Add(E.NotExist("middleware", name))
continue
}
m, err := m.WithOptionsClone(opts, rp)
if err != nil {
invalidOpts.Add(err.Subject(name))
continue
}
if m.before != nil {
befores = append(befores, m.before)
}
if m.rewrite != nil {
rewrites = append(rewrites, m.rewrite)
}
if m.modifyResponse != nil {
modResps = append(modResps, m.modifyResponse)
}
}
if invalidM.HasError() {
return
}
origServeHTTP := rp.ServeHTTP
for i, before := range befores {
if i < len(befores)-1 {
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
before(rp.ServeHTTP, w, r)
}
} else {
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
before(origServeHTTP, w, r)
}
}
}
if len(rewrites) > 0 {
origServeHTTP = rp.ServeHTTP
rp.ServeHTTP = func(w http.ResponseWriter, r *http.Request) {
for _, rewrite := range rewrites {
rewrite(r)
}
origServeHTTP(w, r)
}
}
if len(modResps) > 0 {
if rp.ModifyResponse != nil {
modResps = append([]ModifyResponseFunc{rp.ModifyResponse}, modResps...)
}
rp.ModifyResponse = func(res *Response) error {
b := E.NewBuilder("errors in middleware ModifyResponse")
for _, mr := range modResps {
b.AddE(mr(res))
}
return b.Build().Error()
}
}
return
}

View File

@@ -0,0 +1,47 @@
package middleware
import (
"fmt"
"strings"
D "github.com/yusing/go-proxy/internal/docker"
)
var middlewares map[string]*Middleware
func Get(name string) (middleware *Middleware, ok bool) {
middleware, ok = middlewares[name]
return
}
// initialize middleware names and label parsers
func init() {
middlewares = map[string]*Middleware{
"set_x_forwarded": SetXForwarded,
"add_x_forwarded": AddXForwarded,
"redirect_http": RedirectHTTP,
"forward_auth": ForwardAuth.m,
"modify_response": ModifyResponse.m,
"modify_request": ModifyRequest.m,
"error_page": CustomErrorPage,
"custom_error_page": CustomErrorPage,
"real_ip": RealIP.m,
"cloudflare_real_ip": CloudflareRealIP.m,
}
names := make(map[*Middleware][]string)
for name, m := range middlewares {
names[m] = append(names[m], name)
// register middleware name to docker label parsr
// in order to parse middleware_name.option=value into correct type
if m.labelParserMap != nil {
D.RegisterNamespace(name, m.labelParserMap)
}
}
for m, names := range names {
if len(names) > 1 {
m.name = fmt.Sprintf("%s (a.k.a. %s)", names[0], strings.Join(names[1:], ", "))
} else {
m.name = names[0]
}
}
}

View File

@@ -0,0 +1,57 @@
package middleware
import (
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
)
type (
modifyRequest struct {
*modifyRequestOpts
m *Middleware
}
// order: set_headers -> add_headers -> hide_headers
modifyRequestOpts struct {
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
}
)
var ModifyRequest = func() *modifyRequest {
mr := new(modifyRequest)
mr.m = new(Middleware)
mr.m.labelParserMap = D.ValueParserMap{
"set_headers": D.YamlLikeMappingParser(true),
"add_headers": D.YamlLikeMappingParser(true),
"hide_headers": D.YamlStringListParser,
}
mr.m.withOptions = NewModifyRequest
return mr
}()
func NewModifyRequest(optsRaw OptionsRaw, _ *ReverseProxy) (*Middleware, E.NestedError) {
mr := new(modifyRequest)
mr.m = &Middleware{
impl: mr,
rewrite: mr.modifyRequest,
}
mr.modifyRequestOpts = new(modifyRequestOpts)
err := Deserialize(optsRaw, mr.modifyRequestOpts)
if err != nil {
return nil, E.FailWith("set options", err)
}
return mr.m, nil
}
func (mr *modifyRequest) modifyRequest(req *Request) {
for k, v := range mr.SetHeaders {
req.Header.Set(k, v)
}
for k, v := range mr.AddHeaders {
req.Header.Add(k, v)
}
for _, k := range mr.HideHeaders {
req.Header.Del(k)
}
}

View File

@@ -0,0 +1,34 @@
package middleware
import (
"slices"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetModifyRequest(t *testing.T) {
opts := OptionsRaw{
"set_headers": map[string]string{"User-Agent": "go-proxy/v0.5.0"},
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
"hide_headers": []string{"Accept"},
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.m.WithOptionsClone(opts, nil)
ExpectNoError(t, err.Error())
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string))
})
t.Run("request_headers", func(t *testing.T) {
result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{
middlewareOpt: opts,
})
ExpectNoError(t, err.Error())
ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
})
}

View File

@@ -0,0 +1,60 @@
package middleware
import (
"net/http"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
)
type (
modifyResponse struct {
*modifyResponseOpts
m *Middleware
}
// order: set_headers -> add_headers -> hide_headers
modifyResponseOpts struct {
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
}
)
var ModifyResponse = func() (mr *modifyResponse) {
mr = new(modifyResponse)
mr.m = new(Middleware)
mr.m.labelParserMap = D.ValueParserMap{
"set_headers": D.YamlLikeMappingParser(true),
"add_headers": D.YamlLikeMappingParser(true),
"hide_headers": D.YamlStringListParser,
}
mr.m.withOptions = NewModifyResponse
return
}()
func NewModifyResponse(optsRaw OptionsRaw, _ *ReverseProxy) (*Middleware, E.NestedError) {
mr := new(modifyResponse)
mr.m = &Middleware{
impl: mr,
modifyResponse: mr.modifyResponse,
}
mr.modifyResponseOpts = new(modifyResponseOpts)
err := Deserialize(optsRaw, mr.modifyResponseOpts)
if err != nil {
return nil, E.FailWith("set options", err)
}
return mr.m, nil
}
func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
for k, v := range mr.SetHeaders {
resp.Header.Set(k, v)
}
for k, v := range mr.AddHeaders {
resp.Header.Add(k, v)
}
for _, k := range mr.HideHeaders {
resp.Header.Del(k)
}
return nil
}

View File

@@ -0,0 +1,35 @@
package middleware
import (
"slices"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetModifyResponse(t *testing.T) {
opts := OptionsRaw{
"set_headers": map[string]string{"User-Agent": "go-proxy/v0.5.0"},
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
"hide_headers": []string{"Accept"},
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.m.WithOptionsClone(opts, nil)
ExpectNoError(t, err.Error())
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
})
t.Run("request_headers", func(t *testing.T) {
result, err := newMiddlewareTest(ModifyResponse.m, &testArgs{
middlewareOpt: opts,
})
ExpectNoError(t, err.Error())
ExpectEqual(t, result.ResponseHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
t.Log(result.ResponseHeaders.Get("Accept-Encoding"))
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "")
})
}

View File

@@ -0,0 +1,157 @@
package middleware
import (
"net"
"strings"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
)
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
type realIP struct {
*realIPOpts
m *Middleware
}
type realIPOpts struct {
// Header is the name of the header to use for the real client IP
Header string
// From is a list of Address / CIDRs to trust
From []*net.IPNet
/*
If recursive search is disabled,
the original client address that matches one of the trusted addresses is replaced by
the last address sent in the request header field defined by the Header field.
If recursive search is enabled,
the original client address that matches one of the trusted addresses is replaced by
the last non-trusted address sent in the request header field.
*/
Recursive bool
}
var RealIP = &realIP{
m: &Middleware{
labelParserMap: D.ValueParserMap{
"from": CIDRListParser,
"recursive": D.BoolParser,
},
withOptions: NewRealIP,
},
}
var realIPOptsDefault = func() *realIPOpts {
return &realIPOpts{
Header: "X-Real-IP",
From: []*net.IPNet{
{IP: net.IPv4(127, 0, 0, 1), Mask: net.CIDRMask(8, 32)},
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)},
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)},
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)},
{IP: net.ParseIP("fc00::"), Mask: net.CIDRMask(7, 128)},
{IP: net.ParseIP("fe80::"), Mask: net.CIDRMask(10, 128)},
},
}
}
var realIPLogger = logrus.WithField("middleware", "RealIP")
func NewRealIP(opts OptionsRaw, _ *ReverseProxy) (*Middleware, E.NestedError) {
riWithOpts := new(realIP)
riWithOpts.m = &Middleware{
impl: riWithOpts,
rewrite: riWithOpts.setRealIP,
}
riWithOpts.realIPOpts = realIPOptsDefault()
err := Deserialize(opts, riWithOpts.realIPOpts)
if err != nil {
return nil, E.FailWith("set options", err)
}
return riWithOpts.m, nil
}
func CIDRListParser(s string) (any, E.NestedError) {
sl, err := D.YamlStringListParser(s)
if err != nil {
return nil, err
}
b := E.NewBuilder("invalid CIDR(s)")
CIDRs := sl.([]string)
res := make([]*net.IPNet, 0, len(CIDRs))
for _, cidr := range CIDRs {
if !strings.Contains(cidr, "/") {
cidr += "/32" // single IP
}
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
b.Add(E.Invalid("CIDR", cidr))
continue
}
res = append(res, ipnet)
}
return res, b.Build()
}
func (ri *realIP) isInCIDRList(ip net.IP) bool {
for _, CIDR := range ri.From {
if CIDR.Contains(ip) {
return true
}
}
// not in any CIDR
return false
}
func (ri *realIP) setRealIP(req *Request) {
clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
realIPLogger.Debugf("failed to split host port from %s: %s", req.RemoteAddr, err)
}
clientIP := net.ParseIP(clientIPStr)
var isTrusted = false
for _, CIDR := range ri.From {
if CIDR.Contains(clientIP) {
isTrusted = true
break
}
}
if !isTrusted {
realIPLogger.Debugf("client ip %s is not trusted", clientIP)
return
}
var realIPs = req.Header.Values(ri.Header)
var lastNonTrustedIP string
if len(realIPs) == 0 {
realIPLogger.Debugf("no real ip found in header %q", ri.Header)
return
}
if !ri.Recursive {
lastNonTrustedIP = realIPs[len(realIPs)-1]
} else {
for _, r := range realIPs {
if !ri.isInCIDRList(net.ParseIP(r)) {
lastNonTrustedIP = r
}
}
if lastNonTrustedIP == "" {
realIPLogger.Debugf("no non-trusted ip found in header %q", ri.Header)
return
}
}
req.RemoteAddr = lastNonTrustedIP
req.Header.Set(ri.Header, lastNonTrustedIP)
req.Header.Set("X-Real-IP", lastNonTrustedIP)
req.Header.Set("X-Forwarded-For", lastNonTrustedIP)
realIPLogger.Debugf("real ip %s", lastNonTrustedIP)
}

View File

@@ -0,0 +1,19 @@
package middleware
import (
"net/http"
"github.com/yusing/go-proxy/internal/common"
)
var RedirectHTTP = &Middleware{
before: func(next http.Handler, w ResponseWriter, r *Request) {
if r.TLS == nil {
r.URL.Scheme = "https"
r.URL.Host = r.URL.Hostname() + ":" + common.ProxyHTTPSPort
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
return
}
next.ServeHTTP(w, r)
},
}

View File

@@ -0,0 +1,26 @@
package middleware
import (
"net/http"
"testing"
"github.com/yusing/go-proxy/internal/common"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRedirectToHTTPs(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
scheme: "http",
})
ExpectNoError(t, err.Error())
ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect)
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://"+testHost+":"+common.ProxyHTTPSPort)
}
func TestNoRedirect(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
scheme: "https",
})
ExpectNoError(t, err.Error())
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
}

View File

@@ -0,0 +1,17 @@
{
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"Accept-Encoding": "gzip, deflate, br, zstd",
"Accept-Language": "en,zh-HK;q=0.9,zh-TW;q=0.8,zh-CN;q=0.7,zh;q=0.6",
"Dnt": "1",
"Host": "localhost",
"Priority": "u=0, i",
"Sec-Ch-Ua": "\"Chromium\";v=\"129\", \"Not=A?Brand\";v=\"8\"",
"Sec-Ch-Ua-Mobile": "?0",
"Sec-Ch-Ua-Platform": "\"Windows\"",
"Sec-Fetch-Dest": "document",
"Sec-Fetch-Mode": "navigate",
"Sec-Fetch-Site": "none",
"Sec-Fetch-User": "?1",
"Upgrade-Insecure-Requests": "1",
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36"
}

View File

@@ -0,0 +1,125 @@
package middleware
import (
"bytes"
_ "embed"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
E "github.com/yusing/go-proxy/internal/error"
gpHTTP "github.com/yusing/go-proxy/internal/net/http"
)
//go:embed test_data/sample_headers.json
var testHeadersRaw []byte
var testHeaders http.Header
const testHost = "example.com"
func init() {
tmp := map[string]string{}
err := json.Unmarshal(testHeadersRaw, &tmp)
if err != nil {
panic(err)
}
testHeaders = http.Header{}
for k, v := range tmp {
testHeaders.Set(k, v)
}
}
type requestHeaderRecorder struct {
parent http.RoundTripper
reqHeaders http.Header
}
func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
rt.reqHeaders = req.Header
if rt.parent != nil {
return rt.parent.RoundTrip(req)
}
return &http.Response{
StatusCode: http.StatusOK,
Header: testHeaders,
Body: io.NopCloser(bytes.NewBufferString("OK")),
Request: req,
}, nil
}
type TestResult struct {
RequestHeaders http.Header
ResponseHeaders http.Header
ResponseStatus int
Data []byte
}
type testArgs struct {
middlewareOpt OptionsRaw
proxyURL string
body []byte
scheme string
}
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) {
var body io.Reader
var rt = new(requestHeaderRecorder)
var proxyURL *url.URL
var requestTarget string
var err error
if args == nil {
args = new(testArgs)
}
if args.body != nil {
body = bytes.NewReader(args.body)
}
if args.scheme == "" || args.scheme == "http" {
requestTarget = "http://" + testHost
} else if args.scheme == "https" {
requestTarget = "https://" + testHost
} else {
panic("typo?")
}
req := httptest.NewRequest(http.MethodGet, requestTarget, body)
w := httptest.NewRecorder()
if args.scheme == "https" && req.TLS == nil {
panic("bug occurred")
}
if args.proxyURL != "" {
proxyURL, err = url.Parse(args.proxyURL)
if err != nil {
return nil, E.From(err)
}
rt.parent = http.DefaultTransport
} else {
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
}
rp := gpHTTP.NewReverseProxy(proxyURL, rt)
setOptErr := PatchReverseProxy(rp, map[string]OptionsRaw{
middleware.name: args.middlewareOpt,
})
if setOptErr != nil {
return nil, setOptErr
}
rp.ServeHTTP(w, req)
resp := w.Result()
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, E.From(err)
}
return &TestResult{
RequestHeaders: rt.reqHeaders,
ResponseHeaders: resp.Header,
ResponseStatus: resp.StatusCode,
Data: data,
}, nil
}

View File

@@ -0,0 +1,44 @@
package middleware
import (
"net"
"strings"
)
var AddXForwarded = &Middleware{
rewrite: func(req *Request) {
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
req.Header.Set("X-Forwarded-For", clientIP)
} else {
req.Header.Del("X-Forwarded-For")
}
req.Header.Set("X-Forwarded-Host", req.Host)
if req.TLS == nil {
req.Header.Set("X-Forwarded-Proto", "http")
} else {
req.Header.Set("X-Forwarded-Proto", "https")
}
},
}
var SetXForwarded = &Middleware{
rewrite: func(req *Request) {
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
prior := req.Header["X-Forwarded-For"]
if len(prior) > 0 {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
req.Header.Set("X-Forwarded-For", clientIP)
} else {
req.Header.Del("X-Forwarded-For")
}
req.Header.Set("X-Forwarded-Host", req.Host)
if req.TLS == nil {
req.Header.Set("X-Forwarded-Proto", "http")
} else {
req.Header.Set("X-Forwarded-Proto", "https")
}
},
}

View File

@@ -0,0 +1,96 @@
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/response_modifier.go)
// Copyright (c) 2020-2024 Traefik Labs
package http
import (
"bufio"
"fmt"
"net"
"net/http"
)
type ModifyResponseFunc func(*http.Response) error
type ModifyResponseWriter struct {
w http.ResponseWriter
r *http.Request
headerSent bool
code int
modifier ModifyResponseFunc
modified bool
modifierErr error
}
func NewModifyResponseWriter(w http.ResponseWriter, r *http.Request, f ModifyResponseFunc) *ModifyResponseWriter {
return &ModifyResponseWriter{
w: w,
r: r,
modifier: f,
code: http.StatusOK,
}
}
func (w *ModifyResponseWriter) WriteHeader(code int) {
if w.headerSent {
return
}
if code >= http.StatusContinue && code < http.StatusOK {
w.w.WriteHeader(code)
}
defer func() {
w.headerSent = true
w.code = code
}()
if w.modifier == nil || w.modified {
w.w.WriteHeader(code)
return
}
resp := http.Response{
Header: w.w.Header(),
Request: w.r,
}
if err := w.modifier(&resp); err != nil {
w.modifierErr = err
logger.Errorf("error modifying response: %s", err)
w.w.WriteHeader(http.StatusInternalServerError)
return
}
w.modified = true
w.w.WriteHeader(code)
}
func (w *ModifyResponseWriter) Header() http.Header {
return w.w.Header()
}
func (w *ModifyResponseWriter) Write(b []byte) (int, error) {
w.WriteHeader(w.code)
if w.modifierErr != nil {
return 0, w.modifierErr
}
return w.w.Write(b)
}
// Hijack hijacks the connection.
func (w *ModifyResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if h, ok := w.w.(http.Hijacker); ok {
return h.Hijack()
}
return nil, nil, fmt.Errorf("not a hijacker: %T", w.w)
}
// Flush sends any buffered data to the client.
func (w *ModifyResponseWriter) Flush() {
if flusher, ok := w.w.(http.Flusher); ok {
flusher.Flush()
}
}

View File

@@ -0,0 +1,522 @@
// Copyright 2011 The Go Authors.
// Modified from the Go project under the a BSD-style License (https://cs.opensource.google/go/go/+/refs/tags/go1.23.1:src/net/http/httputil/reverseproxy.go)
// https://cs.opensource.google/go/go/+/master:LICENSE
package http
// This is a small mod on net/http/httputil/reverseproxy.go
// that boosts performance in some cases
// and compatible to other modules of this project
// Copyright (c) 2024 yusing
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptrace"
"net/textproto"
"net/url"
"strings"
"github.com/sirupsen/logrus"
"golang.org/x/net/http/httpguts"
)
// A ProxyRequest contains a request to be rewritten by a [ReverseProxy].
type ProxyRequest struct {
// In is the request received by the proxy.
// The Rewrite function must not modify In.
In *http.Request
// Out is the request which will be sent by the proxy.
// The Rewrite function may modify or replace this request.
// Hop-by-hop headers are removed from this request
// before Rewrite is called.
Out *http.Request
}
// SetXForwarded sets the X-Forwarded-For, X-Forwarded-Host, and
// X-Forwarded-Proto headers of the outbound request.
//
// - The X-Forwarded-For header is set to the client IP address.
// - The X-Forwarded-Host header is set to the host name requested
// by the client.
// - The X-Forwarded-Proto header is set to "http" or "https", depending
// on whether the inbound request was made on a TLS-enabled connection.
//
// If the outbound request contains an existing X-Forwarded-For header,
// SetXForwarded appends the client IP address to it. To append to the
// inbound request's X-Forwarded-For header (the default behavior of
// [ReverseProxy] when using a Director function), copy the header
// from the inbound request before calling SetXForwarded:
//
// rewriteFunc := func(r *httputil.ProxyRequest) {
// r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
// r.SetXForwarded()
// }
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
//
// 1xx responses are forwarded to the client if the underlying
// transport supports ClientTrace.Got1xxResponse.
type ReverseProxy struct {
// Rewrite must be a function which modifies
// the request into a new request to be sent
// using Transport. Its response is then copied
// back to the original client unmodified.
// Rewrite must not access the provided ProxyRequest
// or its contents after returning.
//
// The Forwarded, X-Forwarded, X-Forwarded-Host,
// and X-Forwarded-Proto headers are removed from the
// outbound request before Rewrite is called. See also
// the ProxyRequest.SetXForwarded method.
//
// Unparsable query parameters are removed from the
// outbound request before Rewrite is called.
// The Rewrite function may copy the inbound URL's
// RawQuery to the outbound URL to preserve the original
// parameter string. Note that this can lead to security
// issues if the proxy's interpretation of query parameters
// does not match that of the downstream server.
//
// At most one of Rewrite or Director may be set.
Rewrite func(*ProxyRequest)
// The transport used to perform proxy requests.
// If nil, http.DefaultTransport is used.
Transport http.RoundTripper
// ModifyResponse is an optional function that modifies the
// Response from the backend. It is called if the backend
// returns a response at all, with any HTTP status code.
// If the backend is unreachable, the optional ErrorHandler is
// called before ModifyResponse.
//
// If ModifyResponse returns an error, ErrorHandler is called
// with its error value. If ErrorHandler is nil, its default
// implementation is used.
ModifyResponse func(*http.Response) error
ServeHTTP http.HandlerFunc
}
// A BufferPool is an interface for getting and returning temporary
// byte slices for use by [io.CopyBuffer].
type BufferPool interface {
Get() []byte
Put([]byte)
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
func joinURLPath(a, b *url.URL) (path, rawpath string) {
if a.RawPath == "" && b.RawPath == "" {
return singleJoiningSlash(a.Path, b.Path), ""
}
// Same as singleJoiningSlash, but uses EscapedPath to determine
// whether a slash should be added
apath := a.EscapedPath()
bpath := b.EscapedPath()
aslash := strings.HasSuffix(apath, "/")
bslash := strings.HasPrefix(bpath, "/")
switch {
case aslash && bslash:
return a.Path + b.Path[1:], apath + bpath[1:]
case !aslash && !bslash:
return a.Path + "/" + b.Path, apath + "/" + bpath
}
return a.Path + b.Path, apath + bpath
}
// NewReverseProxy returns a new [ReverseProxy] that routes
// URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir.
//
// NewReverseProxy does not rewrite the Host header.
//
// To customize the ReverseProxy behavior beyond what
// NewReverseProxy provides, use ReverseProxy directly
// with a Rewrite function. The ProxyRequest SetURL method
// may be used to route the outbound request. (Note that SetURL,
// unlike NewReverseProxy, rewrites the Host header
// of the outbound request by default.)
//
// proxy := &ReverseProxy{
// Rewrite: func(r *ProxyRequest) {
// r.SetURL(target)
// r.Out.Host = r.In.Host // if desired
// },
// }
//
func NewReverseProxy(target *url.URL, transport http.RoundTripper) *ReverseProxy {
rp := &ReverseProxy{
Rewrite: func(pr *ProxyRequest) {
rewriteRequestURL(pr.Out, target)
}, Transport: transport,
}
rp.ServeHTTP = rp.serveHTTP
return rp
}
func rewriteRequestURL(req *http.Request, target *url.URL) {
targetQuery := target.RawQuery
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
}
// Hop-by-hop headers. These are removed when sent to the backend.
// As of RFC 7230, hop-by-hop headers are required to appear in the
// Connection header field. These are the headers defined by the
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
// compatibility.
var hopHeaders = []string{
"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
"Upgrade",
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err error, writeHeader bool) {
logger.Errorf("http proxy to %s failed: %s", r.URL.String(), err)
if writeHeader {
rw.WriteHeader(http.StatusBadGateway)
}
}
// modifyResponse conditionally runs the optional ModifyResponse hook
// and reports whether the request should proceed.
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
if p.ModifyResponse == nil {
return true
}
if err := p.ModifyResponse(res); err != nil {
res.Body.Close()
p.errorHandler(rw, req, err, true)
return false
}
return true
}
func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport
ctx := req.Context()
if ctx.Done() != nil {
// CloseNotifier predates context.Context, and has been
// entirely superseded by it. If the request contains
// a Context that carries a cancellation signal, don't
// bother spinning up a goroutine to watch the CloseNotify
// channel (if any).
//
// If the request Context has a nil Done channel (which
// means it is either context.Background, or a custom
// Context implementation with no cancellation signal),
// then consult the CloseNotifier if available.
} else if cn, ok := rw.(http.CloseNotifier); ok {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
defer cancel()
notifyChan := cn.CloseNotify()
go func() {
select {
case <-notifyChan:
cancel()
case <-ctx.Done():
}
}()
}
outreq := req.Clone(ctx)
if req.ContentLength == 0 {
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
}
if outreq.Body != nil {
// Reading from the request body after returning from a handler is not
// allowed, and the RoundTrip goroutine that reads the Body can outlive
// this handler. This can lead to a crash if the handler panics (see
// Issue 46866). Although calling Close doesn't guarantee there isn't
// any Read in flight after the handle returns, in practice it's safe to
// read after closing it.
defer outreq.Body.Close()
}
if outreq.Header == nil {
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
}
outreq.Close = false
reqUpType := UpgradeType(outreq.Header)
if !IsPrint(reqUpType) {
p.errorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), true)
return
}
RemoveHopByHopHeaders(outreq.Header)
// Issue 21096: tell backend applications that care about trailer support
// that we support trailers. (We do, but we don't go out of our way to
// advertise that unless the incoming client request thought it was worth
// mentioning.) Note that we look at req.Header, not outreq.Header, since
// the latter has passed through removeHopByHopHeaders.
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
outreq.Header.Set("Te", "trailers")
}
// After stripping all the hop-by-hop connection headers above, add back any
// necessary for protocol upgrades, such as for websockets.
if reqUpType != "" {
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
}
outreq.Header.Del("Forwarded")
// outreq.Header.Del("X-Forwarded-For")
// outreq.Header.Del("X-Forwarded-Host")
// outreq.Header.Del("X-Forwarded-Proto")
pr := &ProxyRequest{
In: req,
Out: outreq,
}
p.Rewrite(pr)
outreq = pr.Out
if _, ok := outreq.Header["User-Agent"]; !ok {
// If the outbound request doesn't have a User-Agent header set,
// don't send the default Go HTTP client User-Agent.
outreq.Header.Set("User-Agent", "")
}
trace := &httptrace.ClientTrace{
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
h := rw.Header()
// copyHeader(h, http.Header(header))
for k, vv := range header {
for _, v := range vv {
h.Add(k, v)
}
}
rw.WriteHeader(code)
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
clear(h)
return nil
},
}
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
res, err := transport.RoundTrip(outreq)
if err != nil {
p.errorHandler(rw, outreq, err, false)
errMsg := err.Error()
res = &http.Response{
Status: http.StatusText(http.StatusBadGateway),
StatusCode: http.StatusBadGateway,
Proto: outreq.Proto,
ProtoMajor: outreq.ProtoMajor,
ProtoMinor: outreq.ProtoMinor,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader([]byte(errMsg))),
Request: outreq,
ContentLength: int64(len(errMsg)),
TLS: outreq.TLS,
}
}
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode == http.StatusSwitchingProtocols {
if !p.modifyResponse(rw, res, outreq) {
return
}
p.handleUpgradeResponse(rw, outreq, res)
return
}
if !p.modifyResponse(rw, res, outreq) {
return
}
copyHeader(rw.Header(), res.Header)
// The "Trailer" header isn't included in the Transport's response,
// at least for *http.Transport. Build it up from Trailer.
announcedTrailers := len(res.Trailer)
if announcedTrailers > 0 {
trailerKeys := make([]string, 0, len(res.Trailer))
for k := range res.Trailer {
trailerKeys = append(trailerKeys, k)
}
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
}
rw.WriteHeader(res.StatusCode)
_, err = io.Copy(rw, res.Body)
if err != nil {
p.errorHandler(rw, req, err, true)
res.Body.Close()
return
}
res.Body.Close() // close now, instead of defer, to populate res.Trailer
if len(res.Trailer) > 0 {
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
// bodies and adding a Content-Length.
http.NewResponseController(rw).Flush()
}
if len(res.Trailer) == announcedTrailers {
copyHeader(rw.Header(), res.Trailer)
return
}
for k, vv := range res.Trailer {
k = http.TrailerPrefix + k
for _, v := range vv {
rw.Header().Add(k, v)
}
}
}
func UpgradeType(h http.Header) string {
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
return ""
}
return h.Get("Upgrade")
}
// RemoveHopByHopHeaders removes hop-by-hop headers.
func RemoveHopByHopHeaders(h http.Header) {
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
for _, f := range h["Connection"] {
for _, sf := range strings.Split(f, ",") {
if sf = textproto.TrimString(sf); sf != "" {
h.Del(sf)
}
}
}
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
// This behavior is superseded by the RFC 7230 Connection header, but
// preserve it for backwards compatibility.
for _, f := range hopHeaders {
h.Del(f)
}
}
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
reqUpType := UpgradeType(req.Header)
resUpType := UpgradeType(res.Header)
if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), true)
}
if !strings.EqualFold(reqUpType, resUpType) {
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType), true)
return
}
backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok {
p.errorHandler(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"), true)
return
}
rc := http.NewResponseController(rw)
conn, brw, hijackErr := rc.Hijack()
if errors.Is(hijackErr, http.ErrNotSupported) {
p.errorHandler(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw), true)
return
}
backConnCloseCh := make(chan bool)
go func() {
// Ensure that the cancellation of a request closes the backend.
// See issue https://golang.org/issue/35559.
select {
case <-req.Context().Done():
case <-backConnCloseCh:
}
backConn.Close()
}()
defer close(backConnCloseCh)
if hijackErr != nil {
p.errorHandler(rw, req, fmt.Errorf("hijack failed on protocol switch: %w", hijackErr), true)
return
}
defer conn.Close()
copyHeader(rw.Header(), res.Header)
res.Header = rw.Header()
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
if err := res.Write(brw); err != nil {
p.errorHandler(rw, req, fmt.Errorf("response write: %s", err), true)
return
}
if err := brw.Flush(); err != nil {
p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true)
return
}
errc := make(chan error, 1)
go func() {
_, err := io.Copy(conn, backConn)
errc <- err
}()
go func() {
_, err := io.Copy(backConn, conn)
errc <- err
}()
<-errc
}
func IsPrint(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] < ' ' || s[i] > '~' {
return false
}
}
return true
}
var logger = logrus.WithField("module", "http")

View File

@@ -0,0 +1,7 @@
package http
import "net/http"
func IsSuccess(status int) bool {
return status >= http.StatusOK && status < http.StatusMultipleChoices
}