restructured the project to comply community guideline, for others check release note

This commit is contained in:
yusing
2024-09-28 09:51:34 +08:00
parent 4120fd8d1c
commit 90487bfde6
134 changed files with 1110 additions and 625 deletions

View File

@@ -0,0 +1,8 @@
package route
import (
"time"
)
const udpBufferSize = 8192
const streamStopListenTimeout = 1 * time.Second

229
internal/route/http.go Executable file
View File

@@ -0,0 +1,229 @@
package route
import (
"encoding/json"
"fmt"
"slices"
"sync"
"net/http"
"net/url"
"strings"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/docker/idlewatcher"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/http"
P "github.com/yusing/go-proxy/internal/proxy"
PT "github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/route/middleware"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
HTTPRoute struct {
Alias PT.Alias `json:"alias"`
TargetURL *URL `json:"target_url"`
PathPatterns PT.PathPatterns `json:"path_patterns"`
entry *P.ReverseProxyEntry
mux *http.ServeMux
handler *ReverseProxy
regIdleWatcher func() E.NestedError
unregIdleWatcher func()
}
URL url.URL
SubdomainKey = PT.Alias
)
var (
findMuxFunc = findMuxAnyDomain
httpRoutes = F.NewMapOf[SubdomainKey, *HTTPRoute]()
httpRoutesMu sync.Mutex
globalMux = http.NewServeMux() // TODO: support regex subdomain matching
)
func SetFindMuxDomains(domains []string) {
if len(domains) == 0 {
findMuxFunc = findMuxAnyDomain
} else {
findMuxFunc = findMuxByDomains(domains)
}
}
func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
var trans *http.Transport
var regIdleWatcher func() E.NestedError
var unregIdleWatcher func()
if entry.NoTLSVerify {
trans = common.DefaultTransportNoTLS.Clone()
} else {
trans = common.DefaultTransport.Clone()
}
rp := NewReverseProxy(entry.URL, trans)
if len(entry.Middlewares) > 0 {
err := middleware.PatchReverseProxy(rp, entry.Middlewares)
if err != nil {
return nil, err
}
}
if entry.UseIdleWatcher() {
// allow time for response header up to `WakeTimeout`
if entry.WakeTimeout > trans.ResponseHeaderTimeout {
trans.ResponseHeaderTimeout = entry.WakeTimeout
}
regIdleWatcher = func() E.NestedError {
watcher, err := idlewatcher.Register(entry)
if err.HasError() {
return err
}
// patch round-tripper
rp.Transport = watcher.PatchRoundTripper(trans)
return nil
}
unregIdleWatcher = func() {
idlewatcher.Unregister(entry.ContainerName)
rp.Transport = trans
}
}
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
_, exists := httpRoutes.Load(entry.Alias)
if exists {
return nil, E.Duplicated("HTTPRoute alias", entry.Alias)
}
r := &HTTPRoute{
Alias: entry.Alias,
TargetURL: (*URL)(entry.URL),
PathPatterns: entry.PathPatterns,
entry: entry,
handler: rp,
regIdleWatcher: regIdleWatcher,
unregIdleWatcher: unregIdleWatcher,
}
return r, nil
}
func (r *HTTPRoute) String() string {
return string(r.Alias)
}
func (r *HTTPRoute) Start() E.NestedError {
if r.mux != nil {
return nil
}
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
if r.regIdleWatcher != nil {
if err := r.regIdleWatcher(); err.HasError() {
r.unregIdleWatcher = nil
return err
}
}
r.mux = http.NewServeMux()
for _, p := range r.PathPatterns {
r.mux.HandleFunc(string(p), r.handler.ServeHTTP)
}
httpRoutes.Store(r.Alias, r)
return nil
}
func (r *HTTPRoute) Stop() E.NestedError {
if r.mux == nil {
return nil
}
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
if r.unregIdleWatcher != nil {
r.unregIdleWatcher()
r.unregIdleWatcher = nil
}
r.mux = nil
httpRoutes.Delete(r.Alias)
return nil
}
func (u *URL) String() string {
return (*url.URL)(u).String()
}
func (u *URL) MarshalText() (text []byte, err error) {
return []byte(u.String()), nil
}
func ProxyHandler(w http.ResponseWriter, r *http.Request) {
mux, err := findMuxFunc(r.Host)
if err != nil {
logrus.Error(E.Failure("request").
Subjectf("%s %s", r.Method, r.URL.String()).
With(err))
acceptTypes := r.Header.Values("Accept")
switch {
case slices.Contains(acceptTypes, "text/html"):
case slices.Contains(acceptTypes, "application/json"):
w.WriteHeader(http.StatusNotFound)
json.NewEncoder(w).Encode(map[string]string{
"error": err.Error(),
})
default:
http.Error(w, err.Error(), http.StatusNotFound)
}
return
}
mux.ServeHTTP(w, r)
}
func findMuxAnyDomain(host string) (*http.ServeMux, error) {
hostSplit := strings.Split(host, ".")
n := len(hostSplit)
if n <= 2 {
return nil, fmt.Errorf("missing subdomain in url")
}
sd := strings.Join(hostSplit[:n-2], ".")
if r, ok := httpRoutes.Load(PT.Alias(sd)); ok {
return r.mux, nil
}
return nil, fmt.Errorf("no such route: %s", sd)
}
func findMuxByDomains(domains []string) func(host string) (*http.ServeMux, error) {
return func(host string) (*http.ServeMux, error) {
var subdomain string
for _, domain := range domains {
if !strings.HasPrefix(domain, ".") {
domain = "." + domain
}
subdomain = strings.TrimSuffix(host, domain)
if len(subdomain) < len(host) {
break
}
}
if len(subdomain) == len(host) { // not matched
return nil, fmt.Errorf("%s does not match any base domain", host)
}
if r, ok := httpRoutes.Load(PT.Alias(subdomain)); ok {
return r.mux, nil
}
return nil, fmt.Errorf("no such route: %s", subdomain)
}
}

View File

@@ -0,0 +1,70 @@
package middleware
import (
"bytes"
"fmt"
"io"
"net/http"
"path/filepath"
"strings"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/api/v1/error_page"
gpHTTP "github.com/yusing/go-proxy/internal/http"
)
const staticFilePathPrefix = "/$gperrorpage/"
var CustomErrorPage = &Middleware{
before: func(next http.Handler, w ResponseWriter, r *Request) {
path := r.URL.Path
if path != "" && path[0] != '/' {
path = "/" + path
}
if strings.HasPrefix(path, staticFilePathPrefix) {
filename := path[len(staticFilePathPrefix):]
file, ok := error_page.GetStaticFile(filename)
if !ok {
http.NotFound(w, r)
errPageLogger.Errorf("unable to load resource %s", filename)
} else {
ext := filepath.Ext(filename)
switch ext {
case ".html":
w.Header().Set("Content-Type", "text/html; charset=utf-8")
case ".js":
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
case ".css":
w.Header().Set("Content-Type", "text/css; charset=utf-8")
default:
errPageLogger.Errorf("unexpected file type %q for %s", ext, filename)
}
w.Write(file)
}
return
}
next.ServeHTTP(w, r)
},
modifyResponse: func(resp *Response) error {
// only handles non-success status code and html/plain content type
contentType := gpHTTP.GetContentType(resp.Header)
if !gpHTTP.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
errorPage, ok := error_page.GetErrorPageByStatus(resp.StatusCode)
if ok {
errPageLogger.Debugf("error page for status %d loaded", resp.StatusCode)
io.Copy(io.Discard, resp.Body) // drain the original body
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
resp.ContentLength = int64(len(errorPage))
resp.Header.Set("Content-Length", fmt.Sprint(len(errorPage)))
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
} else {
errPageLogger.Errorf("unable to load error page for status %d", resp.StatusCode)
}
return nil
}
return nil
},
}
var errPageLogger = logrus.WithField("middleware", "error_page")

View File

@@ -0,0 +1,249 @@
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/auth/forward.go)
// Copyright (c) 2020-2024 Traefik Labs
// Copyright (c) 2024 yusing
package middleware
import (
"io"
"net"
"net/http"
"net/url"
"slices"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
gpHTTP "github.com/yusing/go-proxy/internal/http"
U "github.com/yusing/go-proxy/internal/utils"
)
type (
forwardAuth struct {
*forwardAuthOpts
m *Middleware
client http.Client
}
forwardAuthOpts struct {
Address string
TrustForwardHeader bool
AuthResponseHeaders []string
AddAuthCookiesToResponse []string
}
)
const (
xForwardedFor = "X-Forwarded-For"
xForwardedMethod = "X-Forwarded-Method"
xForwardedHost = "X-Forwarded-Host"
xForwardedProto = "X-Forwarded-Proto"
xForwardedURI = "X-Forwarded-Uri"
xForwardedPort = "X-Forwarded-Port"
)
var ForwardAuth = newForwardAuth()
var faLogger = logrus.WithField("middleware", "ForwardAuth")
func newForwardAuth() (fa *forwardAuth) {
fa = new(forwardAuth)
fa.m = new(Middleware)
fa.m.labelParserMap = D.ValueParserMap{
"trust_forward_header": D.BoolParser,
"auth_response_headers": D.YamlStringListParser,
"add_auth_cookies_to_response": D.YamlStringListParser,
}
fa.m.withOptions = func(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
tr, ok := rp.Transport.(*http.Transport)
if ok {
tr = tr.Clone()
} else {
tr = common.DefaultTransport.Clone()
}
faWithOpts := new(forwardAuth)
faWithOpts.forwardAuthOpts = new(forwardAuthOpts)
faWithOpts.client = http.Client{
CheckRedirect: func(r *Request, via []*Request) error {
return http.ErrUseLastResponse
},
Timeout: 30 * time.Second,
Transport: tr,
}
faWithOpts.m = &Middleware{
impl: faWithOpts,
before: faWithOpts.forward,
}
err := U.Deserialize(optsRaw, faWithOpts.forwardAuthOpts)
if err != nil {
return nil, E.FailWith("set options", err)
}
_, err = E.Check(url.Parse(faWithOpts.Address))
if err != nil {
return nil, E.Invalid("address", faWithOpts.Address)
}
return faWithOpts.m, nil
}
return
}
func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request) {
gpHTTP.RemoveHop(req.Header)
faReq, err := http.NewRequestWithContext(
req.Context(),
http.MethodGet,
fa.Address,
nil,
)
if err != nil {
faLogger.Debugf("new request err to %s: %s", fa.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
gpHTTP.CopyHeader(faReq.Header, req.Header)
gpHTTP.RemoveHop(faReq.Header)
gpHTTP.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
fa.setAuthHeaders(req, faReq)
faResp, err := fa.client.Do(faReq)
if err != nil {
faLogger.Debugf("failed to call %s: %s", fa.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
defer faResp.Body.Close()
body, err := io.ReadAll(faResp.Body)
if err != nil {
faLogger.Debugf("failed to read response body from %s: %s", fa.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
gpHTTP.CopyHeader(w.Header(), faResp.Header)
gpHTTP.RemoveHop(w.Header())
redirectURL, err := faResp.Location()
if err != nil {
faLogger.Debugf("failed to get location from %s: %s", fa.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
} else if redirectURL.String() != "" {
w.Header().Set("Location", redirectURL.String())
}
w.WriteHeader(faResp.StatusCode)
if _, err = w.Write(body); err != nil {
faLogger.Debugf("failed to write response body from %s: %s", fa.Address, err)
}
return
}
for _, key := range fa.AuthResponseHeaders {
key := http.CanonicalHeaderKey(key)
req.Header.Del(key)
if len(faResp.Header[key]) > 0 {
req.Header[key] = append([]string(nil), faResp.Header[key]...)
}
}
req.RequestURI = req.URL.RequestURI()
authCookies := faResp.Cookies()
if len(authCookies) == 0 {
next.ServeHTTP(w, req)
return
}
next.ServeHTTP(gpHTTP.NewModifyResponseWriter(w, req, func(resp *Response) error {
fa.setAuthCookies(resp, authCookies)
return nil
}), req)
}
func (fa *forwardAuth) setAuthCookies(resp *Response, authCookies []*Cookie) {
if len(fa.AddAuthCookiesToResponse) == 0 {
return
}
cookies := resp.Cookies()
resp.Header.Del("Set-Cookie")
for _, cookie := range cookies {
if !slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
// this cookie is not an auth cookie, so add it back
resp.Header.Add("Set-Cookie", cookie.String())
}
}
for _, cookie := range authCookies {
if slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
// this cookie is an auth cookie, so add to resp
resp.Header.Add("Set-Cookie", cookie.String())
}
}
}
func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if fa.TrustForwardHeader {
if prior, ok := req.Header[xForwardedFor]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
}
faReq.Header.Set(xForwardedFor, clientIP)
}
xMethod := req.Header.Get(xForwardedMethod)
switch {
case xMethod != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedMethod, xMethod)
case req.Method != "":
faReq.Header.Set(xForwardedMethod, req.Method)
default:
faReq.Header.Del(xForwardedMethod)
}
xfp := req.Header.Get(xForwardedProto)
switch {
case xfp != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedProto, xfp)
case req.TLS != nil:
faReq.Header.Set(xForwardedProto, "https")
default:
faReq.Header.Set(xForwardedProto, "http")
}
if xfp := req.Header.Get(xForwardedPort); xfp != "" && fa.TrustForwardHeader {
faReq.Header.Set(xForwardedPort, xfp)
}
xfh := req.Header.Get(xForwardedHost)
switch {
case xfh != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedHost, xfh)
case req.Host != "":
faReq.Header.Set(xForwardedHost, req.Host)
default:
faReq.Header.Del(xForwardedHost)
}
xfURI := req.Header.Get(xForwardedURI)
switch {
case xfURI != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedURI, xfURI)
case req.URL.RequestURI() != "":
faReq.Header.Set(xForwardedURI, req.URL.RequestURI())
default:
faReq.Header.Del(xForwardedURI)
}
}

View File

@@ -0,0 +1,145 @@
package middleware
import (
"net/http"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
gpHTTP "github.com/yusing/go-proxy/internal/http"
)
type (
Error = E.NestedError
ReverseProxy = gpHTTP.ReverseProxy
ProxyRequest = gpHTTP.ProxyRequest
Request = http.Request
Response = http.Response
ResponseWriter = http.ResponseWriter
Header = http.Header
Cookie = http.Cookie
BeforeFunc func(next http.Handler, w ResponseWriter, r *Request)
RewriteFunc func(req *ProxyRequest)
ModifyResponseFunc func(resp *Response) error
CloneWithOptFunc func(opts OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError)
OptionsRaw = map[string]any
Options any
Middleware struct {
name string
before BeforeFunc // runs before ReverseProxy.ServeHTTP
rewrite RewriteFunc // runs after ReverseProxy.Rewrite
modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse
transport http.RoundTripper
withOptions CloneWithOptFunc
labelParserMap D.ValueParserMap
impl any
}
)
func (m *Middleware) Name() string {
return m.name
}
func (m *Middleware) String() string {
return m.name
}
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
if len(optsRaw) != 0 && m.withOptions != nil {
if mWithOpt, err := m.withOptions(optsRaw, rp); err != nil {
return nil, err
} else {
return mWithOpt, nil
}
}
// WithOptionsClone is called only once
// set withOptions and labelParser will not be used after that
return &Middleware{m.name, m.before, m.rewrite, m.modifyResponse, m.transport, nil, nil, m.impl}, nil
}
// TODO: check conflict or duplicates
func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res E.NestedError) {
befores := make([]BeforeFunc, 0, len(middlewares))
rewrites := make([]RewriteFunc, 0, len(middlewares))
modResps := make([]ModifyResponseFunc, 0, len(middlewares))
invalidM := E.NewBuilder("invalid middlewares")
invalidOpts := E.NewBuilder("invalid options")
defer func() {
invalidM.Add(invalidOpts.Build())
invalidM.To(&res)
}()
for name, opts := range middlewares {
m, ok := Get(name)
if !ok {
invalidM.Addf("%s", name)
continue
}
m, err := m.WithOptionsClone(opts, rp)
if err != nil {
invalidOpts.Add(err.Subject(name))
continue
}
if m.before != nil {
befores = append(befores, m.before)
}
if m.rewrite != nil {
rewrites = append(rewrites, m.rewrite)
}
if m.modifyResponse != nil {
modResps = append(modResps, m.modifyResponse)
}
}
if invalidM.HasError() {
return
}
origServeHTTP := rp.ServeHTTP
for i, before := range befores {
if i < len(befores)-1 {
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
before(rp.ServeHTTP, w, r)
}
} else {
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
before(origServeHTTP, w, r)
}
}
}
if len(rewrites) > 0 {
if rp.Rewrite != nil {
rewrites = append([]RewriteFunc{rp.Rewrite}, rewrites...)
}
rp.Rewrite = func(req *ProxyRequest) {
for _, rewrite := range rewrites {
rewrite(req)
}
}
}
if len(modResps) > 0 {
if rp.ModifyResponse != nil {
modResps = append([]ModifyResponseFunc{rp.ModifyResponse}, modResps...)
}
rp.ModifyResponse = func(res *Response) error {
b := E.NewBuilder("errors in middleware ModifyResponse")
for _, mr := range modResps {
b.AddE(mr(res))
}
return b.Build().Error()
}
}
return
}

View File

@@ -0,0 +1,45 @@
package middleware
import (
"fmt"
"strings"
D "github.com/yusing/go-proxy/internal/docker"
)
var middlewares map[string]*Middleware
func Get(name string) (middleware *Middleware, ok bool) {
middleware, ok = middlewares[name]
return
}
// initialize middleware names and label parsers
func init() {
middlewares = map[string]*Middleware{
"set_x_forwarded": SetXForwarded,
"add_x_forwarded": AddXForwarded,
"redirect_http": RedirectHTTP,
"forward_auth": ForwardAuth.m,
"modify_response": ModifyResponse.m,
"modify_request": ModifyRequest.m,
"error_page": CustomErrorPage,
"custom_error_page": CustomErrorPage,
}
names := make(map[*Middleware][]string)
for name, m := range middlewares {
names[m] = append(names[m], name)
// register middleware name to docker label parsr
// in order to parse middleware_name.option=value into correct type
if m.labelParserMap != nil {
D.RegisterNamespace(name, m.labelParserMap)
}
}
for m, names := range names {
if len(names) > 1 {
m.name = fmt.Sprintf("%s (a.k.a. %s)", names[0], strings.Join(names[1:], ", "))
} else {
m.name = names[0]
}
}
}

View File

@@ -0,0 +1,58 @@
package middleware
import (
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
U "github.com/yusing/go-proxy/internal/utils"
)
type (
modifyRequest struct {
*modifyRequestOpts
m *Middleware
}
// order: set_headers -> add_headers -> hide_headers
modifyRequestOpts struct {
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
}
)
var ModifyRequest = newModifyRequest()
func newModifyRequest() (mr *modifyRequest) {
mr = new(modifyRequest)
mr.m = new(Middleware)
mr.m.labelParserMap = D.ValueParserMap{
"set_headers": D.YamlLikeMappingParser(true),
"add_headers": D.YamlLikeMappingParser(true),
"hide_headers": D.YamlStringListParser,
}
mr.m.withOptions = func(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
mrWithOpts := new(modifyRequest)
mrWithOpts.m = &Middleware{
impl: mrWithOpts,
rewrite: mrWithOpts.modifyRequest,
}
mrWithOpts.modifyRequestOpts = new(modifyRequestOpts)
err := U.Deserialize(optsRaw, mrWithOpts.modifyRequestOpts)
if err != nil {
return nil, E.FailWith("set options", err)
}
return mrWithOpts.m, nil
}
return
}
func (mr *modifyRequest) modifyRequest(req *ProxyRequest) {
for k, v := range mr.SetHeaders {
req.Out.Header.Set(k, v)
}
for k, v := range mr.AddHeaders {
req.Out.Header.Add(k, v)
}
for _, k := range mr.HideHeaders {
req.Out.Header.Del(k)
}
}

View File

@@ -0,0 +1,34 @@
package middleware
import (
"slices"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetModifyRequest(t *testing.T) {
opts := OptionsRaw{
"set_headers": map[string]string{"User-Agent": "go-proxy/v0.5.0"},
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
"hide_headers": []string{"Accept"},
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.m.WithOptionsClone(opts, nil)
ExpectNoError(t, err.Error())
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string))
})
t.Run("request_headers", func(t *testing.T) {
result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{
middlewareOpt: opts,
})
ExpectNoError(t, err.Error())
ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
})
}

View File

@@ -0,0 +1,61 @@
package middleware
import (
"net/http"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
U "github.com/yusing/go-proxy/internal/utils"
)
type (
modifyResponse struct {
*modifyResponseOpts
m *Middleware
}
// order: set_headers -> add_headers -> hide_headers
modifyResponseOpts struct {
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
}
)
var ModifyResponse = newModifyResponse()
func newModifyResponse() (mr *modifyResponse) {
mr = new(modifyResponse)
mr.m = new(Middleware)
mr.m.labelParserMap = D.ValueParserMap{
"set_headers": D.YamlLikeMappingParser(true),
"add_headers": D.YamlLikeMappingParser(true),
"hide_headers": D.YamlStringListParser,
}
mr.m.withOptions = func(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
mrWithOpts := new(modifyResponse)
mrWithOpts.m = &Middleware{
impl: mrWithOpts,
modifyResponse: mrWithOpts.modifyResponse,
}
mrWithOpts.modifyResponseOpts = new(modifyResponseOpts)
err := U.Deserialize(optsRaw, mrWithOpts.modifyResponseOpts)
if err != nil {
return nil, E.FailWith("set options", err)
}
return mrWithOpts.m, nil
}
return
}
func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
for k, v := range mr.SetHeaders {
resp.Header.Set(k, v)
}
for k, v := range mr.AddHeaders {
resp.Header.Add(k, v)
}
for _, k := range mr.HideHeaders {
resp.Header.Del(k)
}
return nil
}

View File

@@ -0,0 +1,35 @@
package middleware
import (
"slices"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetModifyResponse(t *testing.T) {
opts := OptionsRaw{
"set_headers": map[string]string{"User-Agent": "go-proxy/v0.5.0"},
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
"hide_headers": []string{"Accept"},
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.m.WithOptionsClone(opts, nil)
ExpectNoError(t, err.Error())
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
})
t.Run("request_headers", func(t *testing.T) {
result, err := newMiddlewareTest(ModifyResponse.m, &testArgs{
middlewareOpt: opts,
})
ExpectNoError(t, err.Error())
ExpectEqual(t, result.ResponseHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
t.Log(result.ResponseHeaders.Get("Accept-Encoding"))
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "")
})
}

View File

@@ -0,0 +1,19 @@
package middleware
import (
"net/http"
"github.com/yusing/go-proxy/internal/common"
)
var RedirectHTTP = &Middleware{
before: func(next http.Handler, w ResponseWriter, r *Request) {
if r.TLS == nil {
r.URL.Scheme = "https"
r.URL.Host = r.URL.Hostname() + ":" + common.ProxyHTTPSPort
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
return
}
next.ServeHTTP(w, r)
},
}

View File

@@ -0,0 +1,26 @@
package middleware
import (
"net/http"
"testing"
"github.com/yusing/go-proxy/internal/common"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRedirectToHTTPs(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
scheme: "http",
})
ExpectNoError(t, err.Error())
ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect)
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://"+testHost+":"+common.ProxyHTTPSPort)
}
func TestNoRedirect(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
scheme: "https",
})
ExpectNoError(t, err.Error())
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
}

View File

@@ -0,0 +1,17 @@
{
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"Accept-Encoding": "gzip, deflate, br, zstd",
"Accept-Language": "en,zh-HK;q=0.9,zh-TW;q=0.8,zh-CN;q=0.7,zh;q=0.6",
"Dnt": "1",
"Host": "localhost",
"Priority": "u=0, i",
"Sec-Ch-Ua": "\"Chromium\";v=\"129\", \"Not=A?Brand\";v=\"8\"",
"Sec-Ch-Ua-Mobile": "?0",
"Sec-Ch-Ua-Platform": "\"Windows\"",
"Sec-Fetch-Dest": "document",
"Sec-Fetch-Mode": "navigate",
"Sec-Fetch-Site": "none",
"Sec-Fetch-User": "?1",
"Upgrade-Insecure-Requests": "1",
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36"
}

View File

@@ -0,0 +1,125 @@
package middleware
import (
"bytes"
_ "embed"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
E "github.com/yusing/go-proxy/internal/error"
gpHTTP "github.com/yusing/go-proxy/internal/http"
)
//go:embed test_data/sample_headers.json
var testHeadersRaw []byte
var testHeaders http.Header
const testHost = "example.com"
func init() {
tmp := map[string]string{}
err := json.Unmarshal(testHeadersRaw, &tmp)
if err != nil {
panic(err)
}
testHeaders = http.Header{}
for k, v := range tmp {
testHeaders.Set(k, v)
}
}
type requestHeaderRecorder struct {
parent http.RoundTripper
reqHeaders http.Header
}
func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
rt.reqHeaders = req.Header
if rt.parent != nil {
return rt.parent.RoundTrip(req)
}
return &http.Response{
StatusCode: http.StatusOK,
Header: testHeaders,
Body: io.NopCloser(bytes.NewBufferString("OK")),
Request: req,
}, nil
}
type TestResult struct {
RequestHeaders http.Header
ResponseHeaders http.Header
ResponseStatus int
Data []byte
}
type testArgs struct {
middlewareOpt OptionsRaw
proxyURL string
body []byte
scheme string
}
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) {
var body io.Reader
var rt = new(requestHeaderRecorder)
var proxyURL *url.URL
var requestTarget string
var err error
if args == nil {
args = new(testArgs)
}
if args.body != nil {
body = bytes.NewReader(args.body)
}
if args.scheme == "" || args.scheme == "http" {
requestTarget = "http://" + testHost
} else if args.scheme == "https" {
requestTarget = "https://" + testHost
} else {
panic("typo?")
}
req := httptest.NewRequest(http.MethodGet, requestTarget, body)
w := httptest.NewRecorder()
if args.scheme == "https" && req.TLS == nil {
panic("bug occurred")
}
if args.proxyURL != "" {
proxyURL, err = url.Parse(args.proxyURL)
if err != nil {
return nil, E.From(err)
}
rt.parent = http.DefaultTransport
} else {
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
}
rp := gpHTTP.NewReverseProxy(proxyURL, rt)
setOptErr := PatchReverseProxy(rp, map[string]OptionsRaw{
middleware.name: args.middlewareOpt,
})
if setOptErr != nil {
return nil, setOptErr
}
rp.ServeHTTP(w, req)
resp := w.Result()
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, E.From(err)
}
return &TestResult{
RequestHeaders: rt.reqHeaders,
ResponseHeaders: resp.Header,
ResponseStatus: resp.StatusCode,
Data: data,
}, nil
}

View File

@@ -0,0 +1,9 @@
package middleware
var AddXForwarded = &Middleware{
rewrite: (*ProxyRequest).AddXForwarded,
}
var SetXForwarded = &Middleware{
rewrite: (*ProxyRequest).SetXForwarded,
}

95
internal/route/route.go Executable file
View File

@@ -0,0 +1,95 @@
package route
import (
"fmt"
"net/url"
E "github.com/yusing/go-proxy/internal/error"
M "github.com/yusing/go-proxy/internal/models"
P "github.com/yusing/go-proxy/internal/proxy"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
Route interface {
RouteImpl
Entry() *M.RawEntry
Type() RouteType
URL() *url.URL
}
Routes = F.Map[string, Route]
RouteImpl interface {
Start() E.NestedError
Stop() E.NestedError
String() string
}
RouteType string
route struct {
RouteImpl
type_ RouteType
entry *M.RawEntry
}
)
const (
RouteTypeStream RouteType = "stream"
RouteTypeReverseProxy RouteType = "reverse_proxy"
)
// function alias
var NewRoutes = F.NewMapOf[string, Route]
func NewRoute(en *M.RawEntry) (Route, E.NestedError) {
rt, err := P.ValidateEntry(en)
if err != nil {
return nil, err
}
var t RouteType
switch e := rt.(type) {
case *P.StreamEntry:
rt, err = NewStreamRoute(e)
t = RouteTypeStream
case *P.ReverseProxyEntry:
rt, err = NewHTTPRoute(e)
t = RouteTypeReverseProxy
default:
panic("bug: should not reach here")
}
if err != nil {
return nil, err
}
return &route{RouteImpl: rt.(RouteImpl), entry: en, type_: t}, nil
}
func (rt *route) Entry() *M.RawEntry {
return rt.entry
}
func (rt *route) Type() RouteType {
return rt.type_
}
func (rt *route) URL() *url.URL {
url, _ := url.Parse(fmt.Sprintf("%s://%s", rt.entry.Scheme, rt.entry.Host))
return url
}
func FromEntries(entries M.RawEntries) (Routes, E.NestedError) {
b := E.NewBuilder("errors in routes")
routes := NewRoutes()
entries.RangeAll(func(alias string, entry *M.RawEntry) {
entry.Alias = alias
r, err := NewRoute(entry)
if err.HasError() {
b.Add(err.Subject(alias))
} else {
routes.Store(alias, r)
}
})
return routes, b.Build()
}

141
internal/route/stream.go Executable file
View File

@@ -0,0 +1,141 @@
package route
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error"
P "github.com/yusing/go-proxy/internal/proxy"
)
type StreamRoute struct {
*P.StreamEntry
StreamImpl `json:"-"`
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
connCh chan any
started atomic.Bool
l logrus.FieldLogger
}
type StreamImpl interface {
Setup() error
Accept() (any, error)
Handle(any) error
CloseListeners()
String() string
}
func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) {
// TODO: support non-coherent scheme
if !entry.Scheme.IsCoherent() {
return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme))
}
base := &StreamRoute{
StreamEntry: entry,
connCh: make(chan any, 100),
}
if entry.Scheme.ListeningScheme.IsTCP() {
base.StreamImpl = NewTCPRoute(base)
} else {
base.StreamImpl = NewUDPRoute(base)
}
base.l = logrus.WithField("route", base.StreamImpl)
return base, nil
}
func (r *StreamRoute) String() string {
return fmt.Sprintf("%s stream: %s", r.Scheme, r.Alias)
}
func (r *StreamRoute) Start() E.NestedError {
if r.started.Load() {
return nil
}
r.ctx, r.cancel = context.WithCancel(context.Background())
r.wg.Wait()
if err := r.Setup(); err != nil {
return E.FailWith("setup", err)
}
r.l.Infof("listening on port %d", r.Port.ListeningPort)
r.started.Store(true)
r.wg.Add(2)
go r.grAcceptConnections()
go r.grHandleConnections()
return nil
}
func (r *StreamRoute) Stop() E.NestedError {
if !r.started.Load() {
return nil
}
l := r.l
r.cancel()
r.CloseListeners()
done := make(chan struct{}, 1)
go func() {
r.wg.Wait()
close(done)
}()
timeout := time.After(streamStopListenTimeout)
for {
select {
case <-done:
l.Debug("stopped listening")
return nil
case <-timeout:
return E.FailedWhy("stop", "timed out")
}
}
}
func (r *StreamRoute) grAcceptConnections() {
defer r.wg.Done()
for {
select {
case <-r.ctx.Done():
return
default:
conn, err := r.Accept()
if err != nil {
select {
case <-r.ctx.Done():
return
default:
r.l.Error(err)
continue
}
}
r.connCh <- conn
}
}
}
func (r *StreamRoute) grHandleConnections() {
defer r.wg.Done()
for {
select {
case <-r.ctx.Done():
return
case conn := <-r.connCh:
go func() {
err := r.Handle(conn)
if err != nil && !errors.Is(err, context.Canceled) {
r.l.Error(err)
}
}()
}
}
}

80
internal/route/tcp.go Executable file
View File

@@ -0,0 +1,80 @@
package route
import (
"context"
"fmt"
"net"
"sync"
"time"
T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils"
)
const tcpDialTimeout = 5 * time.Second
type (
Pipes []U.BidirectionalPipe
TCPRoute struct {
*StreamRoute
listener net.Listener
pipe Pipes
mu sync.Mutex
}
)
func NewTCPRoute(base *StreamRoute) StreamImpl {
return &TCPRoute{
StreamRoute: base,
pipe: make(Pipes, 0),
}
}
func (route *TCPRoute) Setup() error {
in, err := net.Listen("tcp", fmt.Sprintf(":%v", route.Port.ListeningPort))
if err != nil {
return err
}
//! this read the allocated port from orginal ':0'
route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port)
route.listener = in
return nil
}
func (route *TCPRoute) Accept() (any, error) {
return route.listener.Accept()
}
func (route *TCPRoute) Handle(c any) error {
clientConn := c.(net.Conn)
defer clientConn.Close()
ctx, cancel := context.WithTimeout(route.ctx, tcpDialTimeout)
defer cancel()
serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
dialer := &net.Dialer{}
serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr)
if err != nil {
return err
}
route.mu.Lock()
pipe := U.NewBidirectionalPipe(route.ctx, clientConn, serverConn)
route.pipe = append(route.pipe, pipe)
route.mu.Unlock()
return pipe.Start()
}
func (route *TCPRoute) CloseListeners() {
if route.listener == nil {
return
}
route.listener.Close()
route.listener = nil
}

129
internal/route/udp.go Executable file
View File

@@ -0,0 +1,129 @@
package route
import (
"fmt"
"io"
"net"
T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
UDPRoute struct {
*StreamRoute
connMap UDPConnMap
listeningConn *net.UDPConn
targetAddr *net.UDPAddr
}
UDPConn struct {
src *net.UDPConn
dst *net.UDPConn
U.BidirectionalPipe
}
UDPConnMap = F.Map[string, *UDPConn]
)
var NewUDPConnMap = F.NewMapOf[string, *UDPConn]
func NewUDPRoute(base *StreamRoute) StreamImpl {
return &UDPRoute{
StreamRoute: base,
connMap: NewUDPConnMap(),
}
}
func (route *UDPRoute) Setup() error {
laddr, err := net.ResolveUDPAddr(string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ListeningPort))
if err != nil {
return err
}
source, err := net.ListenUDP(string(route.Scheme.ListeningScheme), laddr)
if err != nil {
return err
}
raddr, err := net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
if err != nil {
source.Close()
return err
}
//! this read the allocated listeningPort from orginal ':0'
route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port)
route.listeningConn = source
route.targetAddr = raddr
return nil
}
func (route *UDPRoute) Accept() (any, error) {
in := route.listeningConn
buffer := make([]byte, udpBufferSize)
nRead, srcAddr, err := in.ReadFromUDP(buffer)
if err != nil {
return nil, err
}
if nRead == 0 {
return nil, io.ErrShortBuffer
}
key := srcAddr.String()
conn, ok := route.connMap.Load(key)
if !ok {
srcConn, err := net.DialUDP("udp", nil, srcAddr)
if err != nil {
return nil, err
}
dstConn, err := net.DialUDP("udp", nil, route.targetAddr)
if err != nil {
srcConn.Close()
return nil, err
}
conn = &UDPConn{
srcConn,
dstConn,
U.NewBidirectionalPipe(route.ctx, sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}),
}
route.connMap.Store(key, conn)
}
_, err = conn.dst.Write(buffer[:nRead])
return conn, err
}
func (route *UDPRoute) Handle(c any) error {
return c.(*UDPConn).Start()
}
func (route *UDPRoute) CloseListeners() {
if route.listeningConn != nil {
route.listeningConn.Close()
route.listeningConn = nil
}
route.connMap.RangeAll(func(_ string, conn *UDPConn) {
if err := conn.src.Close(); err != nil {
route.l.Errorf("error closing src conn: %s", err)
}
if err := conn.dst.Close(); err != nil {
route.l.Error("error closing dst conn: %s", err)
}
})
route.connMap.Clear()
}
type sourceRWCloser struct {
server *net.UDPConn
*net.UDPConn
}
func (w sourceRWCloser) Write(p []byte) (int, error) {
return w.server.WriteToUDP(p, w.RemoteAddr().(*net.UDPAddr)) // TODO: support non udp
}