migrated from logrus to zerolog, improved error formatting, fixed concurrent map write, fixed crash on rapid page refresh for idle containers, fixed infinite recursion on gotfiy error, fixed websocket connection problem when using idlewatcher

This commit is contained in:
yusing
2024-10-29 11:34:58 +08:00
parent cfa74d69ae
commit e5bbb18414
137 changed files with 2640 additions and 2348 deletions

View File

@@ -15,9 +15,9 @@ type cidrWhitelist struct {
}
type cidrWhitelistOpts struct {
Allow []*types.CIDR
StatusCode int
Message string
Allow []*types.CIDR `json:"allow"`
StatusCode int `json:"statusCode"`
Message string `json:"message"`
cachedAddr F.Map[string, bool] // cache for trusted IPs
}
@@ -47,7 +47,7 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
return nil, err
}
if len(wl.cidrWhitelistOpts.Allow) == 0 {
return nil, E.Missing("allow range")
return nil, E.New("no allowed CIDRs")
}
return wl.m, nil
}

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"testing"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@@ -13,10 +14,9 @@ var testCIDRWhitelistCompose []byte
var deny, accept *Middleware
func TestCIDRWhitelist(t *testing.T) {
mids, err := BuildMiddlewaresFromYAML(testCIDRWhitelistCompose)
if err != nil {
panic(err)
}
errs := E.NewBuilder("")
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
ExpectNoError(t, errs.Error())
deny = mids["deny@file"]
accept = mids["accept@file"]
if deny == nil || accept == nil {
@@ -26,7 +26,7 @@ func TestCIDRWhitelist(t *testing.T) {
t.Run("deny", func(t *testing.T) {
for range 10 {
result, err := newMiddlewareTest(deny, nil)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults().StatusCode)
ExpectEqual(t, string(result.Data), cidrWhitelistDefaults().Message)
}
@@ -35,7 +35,7 @@ func TestCIDRWhitelist(t *testing.T) {
t.Run("accept", func(t *testing.T) {
for range 10 {
result, err := newMiddlewareTest(accept, nil)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
}
})

View File

@@ -10,10 +10,10 @@ import (
"sync"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
const (
@@ -26,7 +26,7 @@ const (
var (
cfCIDRsLastUpdate time.Time
cfCIDRsMu sync.Mutex
cfCIDRsLogger = logrus.WithField("middleware", "CloudflareRealIP")
cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger()
)
var CloudflareRealIP = &realIP{
@@ -80,13 +80,13 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
)
if err != nil {
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)
cfCIDRsLogger.Errorf("failed to update cloudflare range: %s, retry in %s", err, cfCIDRsUpdateRetryInterval)
cfCIDRsLogger.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
return nil
}
}
cfCIDRsLastUpdate = time.Now()
cfCIDRsLogger.Debugf("cloudflare CIDR range updated")
cfCIDRsLogger.Info().Msg("cloudflare CIDR range updated")
return
}

View File

@@ -8,24 +8,31 @@ import (
"strconv"
"strings"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/api/v1/errorpage"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
)
var CustomErrorPage = &Middleware{
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
if !ServeStaticErrorPageFile(w, r) {
next(w, r)
}
},
modifyResponse: func(resp *Response) error {
var CustomErrorPage *Middleware
func init() {
CustomErrorPage = customErrorPage()
}
func customErrorPage() *Middleware {
m := &Middleware{
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
if !ServeStaticErrorPageFile(w, r) {
next(w, r)
}
},
}
m.modifyResponse = func(resp *Response) error {
// only handles non-success status code and html/plain content type
contentType := gphttp.GetContentType(resp.Header)
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
if ok {
errPageLogger.Debugf("error page for status %d loaded", resp.StatusCode)
CustomErrorPage.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
/* trunk-ignore(golangci-lint/errcheck) */
io.Copy(io.Discard, resp.Body) // drain the original body
resp.Body.Close()
@@ -34,12 +41,13 @@ var CustomErrorPage = &Middleware{
resp.Header.Set("Content-Length", strconv.Itoa(len(errorPage)))
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
} else {
errPageLogger.Errorf("unable to load error page for status %d", resp.StatusCode)
CustomErrorPage.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
}
return nil
}
return nil
},
}
return m
}
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
@@ -51,7 +59,7 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
filename := path[len(gphttp.StaticFilePathPrefix):]
file, ok := errorpage.GetStaticFile(filename)
if !ok {
errPageLogger.Errorf("unable to load resource %s", filename)
logger.Error().Msg("unable to load resource " + filename)
return false
}
ext := filepath.Ext(filename)
@@ -63,15 +71,13 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
case ".css":
w.Header().Set("Content-Type", "text/css; charset=utf-8")
default:
errPageLogger.Errorf("unexpected file type %q for %s", ext, filename)
logger.Error().Msgf("unexpected file type %q for %s", ext, filename)
}
if _, err := w.Write(file); err != nil {
errPageLogger.WithError(err).Errorf("unable to write resource %s", filename)
logger.Err(err).Msg("unable to write resource " + filename)
http.Error(w, "Error page failure", http.StatusInternalServerError)
}
return true
}
return false
}
var errPageLogger = logrus.WithField("middleware", "error_page")

View File

@@ -0,0 +1,94 @@
package errorpage
import (
"fmt"
"os"
"path"
"sync"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
W "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events"
)
const errPagesBasePath = common.ErrorPagesBasePath
var (
dirWatcher W.Watcher
fileContentMap = F.NewMapOf[string, []byte]()
)
var setup = sync.OnceFunc(func() {
task := task.GlobalTask("error page")
dirWatcher = W.NewDirectoryWatcher(task.Subtask("dir watcher"), errPagesBasePath)
loadContent()
go watchDir(task)
})
func GetStaticFile(filename string) ([]byte, bool) {
return fileContentMap.Load(filename)
}
// try <statusCode>.html -> 404.html -> not ok.
func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
content, ok = fileContentMap.Load(fmt.Sprintf("%d.html", statusCode))
if !ok && statusCode != 404 {
return fileContentMap.Load("404.html")
}
return
}
func loadContent() {
files, err := U.ListFiles(errPagesBasePath, 0)
if err != nil {
logger.Err(err).Msg("failed to list error page resources")
return
}
for _, file := range files {
if fileContentMap.Has(file) {
continue
}
content, err := os.ReadFile(file)
if err != nil {
logger.Warn().Err(err).Msgf("failed to read error page resource %s", file)
continue
}
file = path.Base(file)
logging.Info().Msgf("error page resource %s loaded", file)
fileContentMap.Store(file, content)
}
}
func watchDir(task task.Task) {
eventCh, errCh := dirWatcher.Events(task.Context())
for {
select {
case <-task.Context().Done():
return
case event, ok := <-eventCh:
if !ok {
return
}
filename := event.ActorName
switch event.Action {
case events.ActionFileWritten:
fileContentMap.Delete(filename)
loadContent()
case events.ActionFileDeleted:
fileContentMap.Delete(filename)
logger.Warn().Msgf("error page resource %s deleted", filename)
case events.ActionFileRenamed:
logger.Warn().Msgf("error page resource %s deleted", filename)
fileContentMap.Delete(filename)
loadContent()
}
case err := <-errCh:
E.LogError("error watching error page directory", err, &logger)
}
}
}

View File

@@ -0,0 +1,5 @@
package errorpage
import "github.com/yusing/go-proxy/internal/logging"
var logger = logging.With().Str("module", "errorpage").Logger()

View File

@@ -24,11 +24,12 @@ type (
client http.Client
}
forwardAuthOpts struct {
Address string
TrustForwardHeader bool
AuthResponseHeaders []string
AddAuthCookiesToResponse []string
transport http.RoundTripper
Address string `json:"address"`
TrustForwardHeader bool `json:"trustForwardHeader"`
AuthResponseHeaders []string `json:"authResponseHeaders"`
AddAuthCookiesToResponse []string `json:"addAuthCookiesToResponse"`
transport http.RoundTripper
}
)
@@ -39,13 +40,11 @@ var ForwardAuth = &forwardAuth{
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) {
fa := new(forwardAuth)
fa.forwardAuthOpts = new(forwardAuthOpts)
err := Deserialize(optsRaw, fa.forwardAuthOpts)
if err != nil {
if err := Deserialize(optsRaw, fa.forwardAuthOpts); err != nil {
return nil, err
}
_, err = E.Check(url.Parse(fa.Address))
if err != nil {
return nil, E.Invalid("address", fa.Address)
if _, err := url.Parse(fa.Address); err != nil {
return nil, E.From(err)
}
fa.m = &Middleware{

View File

@@ -0,0 +1,5 @@
package middleware
import "github.com/yusing/go-proxy/internal/logging"
var logger = logging.With().Str("module", "middleware").Logger()

View File

@@ -5,6 +5,7 @@ import (
"errors"
"net/http"
"github.com/rs/zerolog"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
U "github.com/yusing/go-proxy/internal/utils"
@@ -32,6 +33,8 @@ type (
Middleware struct {
_ U.NoCopy
zerolog.Logger
name string
before BeforeFunc // runs before ReverseProxy.ServeHTTP
@@ -78,13 +81,19 @@ func (m *Middleware) MarshalJSON() ([]byte, error) {
}
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) {
if len(optsRaw) != 0 && m.withOptions != nil {
return m.withOptions(optsRaw)
if m.withOptions != nil {
m, err := m.withOptions(optsRaw)
if err != nil {
return nil, err
}
m.Logger = logger.With().Str("name", m.name).Logger()
return m, nil
}
// WithOptionsClone is called only once
// set withOptions and labelParser will not be used after that
return &Middleware{
Logger: logger.With().Str("name", m.name).Logger(),
name: m.name,
before: m.before,
modifyResponse: m.modifyResponse,
@@ -108,24 +117,20 @@ func (m *Middleware) ModifyResponse(resp *Response) error {
}
// TODO: check conflict or duplicates.
func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Middleware, res E.Error) {
middlewares = make([]*Middleware, 0, len(middlewaresMap))
func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
middlewares := make([]*Middleware, 0, len(middlewaresMap))
invalidM := E.NewBuilder("invalid middlewares")
invalidOpts := E.NewBuilder("invalid options")
defer func() {
invalidM.Add(invalidOpts.Build())
invalidM.To(&res)
}()
errs := E.NewBuilder("middlewares compile error")
invalidOpts := E.NewBuilder("options compile error")
for name, opts := range middlewaresMap {
m, ok := Get(name)
if !ok {
invalidM.Add(E.NotExist("middleware", name))
m, err := Get(name)
if err != nil {
errs.Add(err)
continue
}
m, err := m.WithOptionsClone(opts)
m, err = m.WithOptionsClone(opts)
if err != nil {
invalidOpts.Add(err.Subject(name))
continue
@@ -133,7 +138,10 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Mid
middlewares = append(middlewares, m)
}
return
if invalidOpts.HasError() {
errs.Add(invalidOpts.Error())
}
return middlewares, errs.Error()
}
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {

View File

@@ -4,64 +4,60 @@ import (
"fmt"
"net/http"
"os"
"path"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"gopkg.in/yaml.v3"
)
func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.Error) {
func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]*Middleware {
fileContent, err := os.ReadFile(filePath)
if err != nil {
return nil, E.FailWith("read middleware compose file", err)
eb.Add(err)
return nil
}
return BuildMiddlewaresFromYAML(fileContent)
return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
}
func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.Error) {
b := E.NewBuilder("middlewares compile errors")
defer b.To(&outErr)
func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[string]*Middleware {
var rawMap map[string][]map[string]any
err := yaml.Unmarshal(data, &rawMap)
if err != nil {
b.Add(E.FailWith("yaml unmarshal", err))
return
eb.Add(err)
return nil
}
middlewares = make(map[string]*Middleware)
middlewares := make(map[string]*Middleware)
for name, defs := range rawMap {
chainErr := E.NewBuilder("%s", name)
chainErr := E.NewBuilder("")
chain := make([]*Middleware, 0, len(defs))
for i, def := range defs {
if def["use"] == nil || def["use"] == "" {
chainErr.Add(E.Missing("use").Subjectf(".%d", i))
chainErr.Addf("item %d: missing field 'use'", i)
continue
}
baseName := def["use"].(string)
base, ok := Get(baseName)
if !ok {
base, ok = middlewares[baseName]
if !ok {
chainErr.Add(E.NotExist("middleware", baseName).Subjectf(".%d", i))
continue
}
base, err := Get(baseName)
if err != nil {
chainErr.Add(err.Subjectf("%s[%d]", name, i))
continue
}
delete(def, "use")
m, err := base.WithOptionsClone(def)
if err != nil {
chainErr.Add(err.Subjectf("item%d", i))
chainErr.Add(err.Subjectf("%s[%d]", name, i))
continue
}
m.name = fmt.Sprintf("%s[%d]", name, i)
chain = append(chain, m)
}
if chainErr.HasError() {
b.Add(chainErr.Build())
eb.Add(chainErr.Error().Subject(source))
} else {
middlewares[name+"@file"] = BuildMiddlewareFromChain(name, chain)
}
}
return
return middlewares
}
// TODO: check conflict or duplicates.
@@ -86,11 +82,13 @@ func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
}
if len(modResps) > 0 {
m.modifyResponse = func(res *Response) error {
b := E.NewBuilder("errors in middleware")
errs := E.NewBuilder("modify response errors")
for _, mr := range modResps {
b.Add(E.From(mr.modifyResponse(res)).Subject(mr.name))
if err := mr.modifyResponse(res); err != nil {
errs.Add(E.From(err).Subject(mr.name))
}
}
return b.Build().Error()
return errs.Error()
}
}

View File

@@ -13,10 +13,10 @@ import (
var testMiddlewareCompose []byte
func TestBuild(t *testing.T) {
middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose)
ExpectNoError(t, err.Error())
_, err = E.Check(json.MarshalIndent(middlewares, "", " "))
ExpectNoError(t, err.Error())
errs := E.NewBuilder("")
middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs)
ExpectNoError(t, errs.Error())
E.Must(json.MarshalIndent(middlewares, "", " "))
// t.Log(string(data))
// TODO: test
}

View File

@@ -6,26 +6,37 @@ import (
"path"
"strings"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
var middlewares map[string]*Middleware
var allMiddlewares map[string]*Middleware
func Get(name string) (middleware *Middleware, ok bool) {
middleware, ok = middlewares[U.ToLowerNoSnake(name)]
return
var (
ErrUnknownMiddleware = E.New("unknown middleware")
ErrDuplicatedMiddleware = E.New("duplicated middleware")
)
func Get(name string) (*Middleware, Error) {
middleware, ok := allMiddlewares[U.ToLowerNoSnake(name)]
if !ok {
return nil, ErrUnknownMiddleware.
Subject(name).
Withf(strutils.DoYouMean(utils.NearestField(name, allMiddlewares)))
}
return middleware, nil
}
func All() map[string]*Middleware {
return middlewares
return allMiddlewares
}
// initialize middleware names and label parsers
func init() {
middlewares = map[string]*Middleware{
allMiddlewares = map[string]*Middleware{
"setxforwarded": SetXForwarded,
"hidexforwarded": HideXForwarded,
"redirecthttp": RedirectHTTP,
@@ -39,10 +50,10 @@ func init() {
// !experimental
"forwardauth": ForwardAuth.m,
"oauth2": OAuth2.m,
// "oauth2": OAuth2.m,
}
names := make(map[*Middleware][]string)
for name, m := range middlewares {
for name, m := range allMiddlewares {
names[m] = append(names[m], http.CanonicalHeaderKey(name))
}
for m, names := range names {
@@ -55,27 +66,30 @@ func init() {
}
func LoadComposeFiles() {
b := E.NewBuilder("failed to load middlewares")
errs := E.NewBuilder("middleware compile errors")
middlewareDefs, err := U.ListFiles(common.MiddlewareComposeBasePath, 0)
if err != nil {
logrus.Errorf("failed to list middleware definitions: %s", err)
logger.Err(err).Msg("failed to list middleware definitions")
return
}
for _, defFile := range middlewareDefs {
mws, err := BuildMiddlewaresFromComposeFile(defFile)
mws := BuildMiddlewaresFromComposeFile(defFile, errs)
if len(mws) == 0 {
continue
}
for name, m := range mws {
if _, ok := middlewares[name]; ok {
b.Add(E.Duplicated("middleware", name))
if _, ok := allMiddlewares[name]; ok {
errs.Add(ErrDuplicatedMiddleware.Subject(name))
continue
}
middlewares[U.ToLowerNoSnake(name)] = m
logger.Infof("middleware %s loaded from %s", name, path.Base(defFile))
allMiddlewares[U.ToLowerNoSnake(name)] = m
logger.Info().
Str("name", name).
Str("src", path.Base(defFile)).
Msg("middleware loaded")
}
b.Add(err.Subject(path.Base(defFile)))
}
if b.HasError() {
logger.Error(b.Build())
if errs.HasError() {
E.LogError(errs.About(), errs.Error(), &logger)
}
}
var logger = logrus.WithField("module", "middlewares")

View File

@@ -12,9 +12,9 @@ type (
}
// order: set_headers -> add_headers -> hide_headers
modifyRequestOpts struct {
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
SetHeaders map[string]string `json:"setHeaders"`
AddHeaders map[string]string `json:"addHeaders"`
HideHeaders []string `json:"hideHeaders"`
}
)

View File

@@ -16,7 +16,7 @@ func TestSetModifyRequest(t *testing.T) {
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string))
@@ -26,7 +26,7 @@ func TestSetModifyRequest(t *testing.T) {
result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{
middlewareOpt: opts,
})
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")

View File

@@ -13,11 +13,7 @@ type (
m *Middleware
}
// order: set_headers -> add_headers -> hide_headers
modifyResponseOpts struct {
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
}
modifyResponseOpts = modifyRequestOpts
)
var ModifyResponse = &modifyResponse{

View File

@@ -16,7 +16,7 @@ func TestSetModifyResponse(t *testing.T) {
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
@@ -26,7 +26,7 @@ func TestSetModifyResponse(t *testing.T) {
result, err := newMiddlewareTest(ModifyResponse.m, &testArgs{
middlewareOpt: opts,
})
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
t.Log(result.ResponseHeaders.Get("Accept-Encoding"))
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))

View File

@@ -1,129 +1,117 @@
package middleware
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"reflect"
// import (
// "encoding/json"
// "fmt"
// "net/http"
// "net/url"
E "github.com/yusing/go-proxy/internal/error"
)
// E "github.com/yusing/go-proxy/internal/error"
// )
type oAuth2 struct {
*oAuth2Opts
m *Middleware
}
// type oAuth2 struct {
// oAuth2Opts
// m *Middleware
// }
type oAuth2Opts struct {
ClientID string
ClientSecret string
AuthURL string // Authorization Endpoint
TokenURL string // Token Endpoint
}
// type oAuth2Opts struct {
// ClientID string `validate:"required"`
// ClientSecret string `validate:"required"`
// AuthURL string `validate:"required"` // Authorization Endpoint
// TokenURL string `validate:"required"` // Token Endpoint
// }
var OAuth2 = &oAuth2{
m: &Middleware{withOptions: NewAuthentikOAuth2},
}
// var OAuth2 = &oAuth2{
// m: &Middleware{withOptions: NewAuthentikOAuth2},
// }
func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) {
oauth := new(oAuth2)
oauth.m = &Middleware{
impl: oauth,
before: oauth.handleOAuth2,
}
oauth.oAuth2Opts = &oAuth2Opts{}
err := Deserialize(opts, oauth.oAuth2Opts)
if err != nil {
return nil, err
}
b := E.NewBuilder("missing required fields")
optV := reflect.ValueOf(oauth.oAuth2Opts)
for _, field := range reflect.VisibleFields(reflect.TypeFor[oAuth2Opts]()) {
if optV.FieldByName(field.Name).Len() == 0 {
b.Add(E.Missing(field.Name))
}
}
if b.HasError() {
return nil, b.Build().Subject("oAuth2")
}
return oauth.m, nil
}
// func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) {
// oauth := new(oAuth2)
// oauth.m = &Middleware{
// impl: oauth,
// before: oauth.handleOAuth2,
// }
// err := Deserialize(opts, &oauth.oAuth2Opts)
// if err != nil {
// return nil, err
// }
// return oauth.m, nil
// }
func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) {
// Check if the user is authenticated (you may use session, cookie, etc.)
if !userIsAuthenticated(r) {
// TODO: Redirect to OAuth2 auth URL
http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code",
oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound)
return
}
// func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) {
// // Check if the user is authenticated (you may use session, cookie, etc.)
// if !userIsAuthenticated(r) {
// // TODO: Redirect to OAuth2 auth URL
// http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code",
// oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound)
// return
// }
// If you have a token in the query string, process it
if code := r.URL.Query().Get("code"); code != "" {
// Exchange the authorization code for a token here
// Use the TokenURL and authenticate the user
token, err := exchangeCodeForToken(code, oauth.oAuth2Opts, r.RequestURI)
if err != nil {
// handle error
http.Error(rw, "failed to get token", http.StatusUnauthorized)
return
}
// // If you have a token in the query string, process it
// if code := r.URL.Query().Get("code"); code != "" {
// // Exchange the authorization code for a token here
// // Use the TokenURL and authenticate the user
// token, err := exchangeCodeForToken(code, &oauth.oAuth2Opts, r.RequestURI)
// if err != nil {
// // handle error
// http.Error(rw, "failed to get token", http.StatusUnauthorized)
// return
// }
// Save token and user info based on your requirements
saveToken(rw, token)
// // Save token and user info based on your requirements
// saveToken(rw, token)
// Redirect to the originally requested URL
http.Redirect(rw, r, "/", http.StatusFound)
return
}
// // Redirect to the originally requested URL
// http.Redirect(rw, r, "/", http.StatusFound)
// return
// }
// If user is authenticated, go to the next handler
next(rw, r)
}
// // If user is authenticated, go to the next handler
// next(rw, r)
// }
func userIsAuthenticated(r *http.Request) bool {
// Example: Check for a session or cookie
session, err := r.Cookie("session_token")
if err != nil || session.Value == "" {
return false
}
// Validate the session_token if necessary
return true
}
// func userIsAuthenticated(r *http.Request) bool {
// // Example: Check for a session or cookie
// session, err := r.Cookie("session_token")
// if err != nil || session.Value == "" {
// return false
// }
// // Validate the session_token if necessary
// return true
// }
func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) {
// Prepare the request body
data := url.Values{
"client_id": {opts.ClientID},
"client_secret": {opts.ClientSecret},
"code": {code},
"grant_type": {"authorization_code"},
"redirect_uri": {requestURI},
}
resp, err := http.PostForm(opts.TokenURL, data)
if err != nil {
return "", fmt.Errorf("failed to request token: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status)
}
// Decode the response
var tokenResp struct {
AccessToken string `json:"access_token"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", fmt.Errorf("failed to decode token response: %w", err)
}
return tokenResp.AccessToken, nil
}
// func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) {
// // Prepare the request body
// data := url.Values{
// "client_id": {opts.ClientID},
// "client_secret": {opts.ClientSecret},
// "code": {code},
// "grant_type": {"authorization_code"},
// "redirect_uri": {requestURI},
// }
// resp, err := http.PostForm(opts.TokenURL, data)
// if err != nil {
// return "", fmt.Errorf("failed to request token: %w", err)
// }
// defer resp.Body.Close()
// if resp.StatusCode != http.StatusOK {
// return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status)
// }
// // Decode the response
// var tokenResp struct {
// AccessToken string `json:"access_token"`
// }
// if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
// return "", fmt.Errorf("failed to decode token response: %w", err)
// }
// return tokenResp.AccessToken, nil
// }
func saveToken(rw ResponseWriter, token string) {
// Example: Save token in cookie
http.SetCookie(rw, &http.Cookie{
Name: "auth_token",
Value: token,
// set other properties as necessary, such as Secure and HttpOnly
})
}
// func saveToken(rw ResponseWriter, token string) {
// // Example: Save token in cookie
// http.SetCookie(rw, &http.Cookie{
// Name: "auth_token",
// Value: token,
// // set other properties as necessary, such as Secure and HttpOnly
// })
// }

View File

@@ -16,9 +16,9 @@ type realIP struct {
type realIPOpts struct {
// Header is the name of the header to use for the real client IP
Header string
Header string `json:"header"`
// From is a list of Address / CIDRs to trust
From []*types.CIDR
From []*types.CIDR `json:"from"`
/*
If recursive search is disabled,
the original client address that matches one of the trusted addresses is replaced by
@@ -27,7 +27,7 @@ type realIPOpts struct {
the original client address that matches one of the trusted addresses is replaced by
the last non-trusted address sent in the request header field.
*/
Recursive bool
Recursive bool `json:"recursive"`
}
var RealIP = &realIP{

View File

@@ -40,7 +40,7 @@ func TestSetRealIPOpts(t *testing.T) {
}
ri, err := NewRealIP(opts)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
for i, CIDR := range ri.impl.(*realIP).From {
@@ -61,15 +61,15 @@ func TestSetRealIP(t *testing.T) {
"set_headers": map[string]string{testHeader: testRealIP},
}
realip, err := NewRealIP(opts)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
mr, err := NewModifyRequest(optsMr)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip})
result, err := newMiddlewareTest(mid, nil)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
t.Log(traces)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)

View File

@@ -12,7 +12,7 @@ func TestRedirectToHTTPs(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
scheme: "http",
})
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect)
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://"+testHost+":"+common.ProxyHTTPSPort)
}
@@ -21,6 +21,6 @@ func TestNoRedirect(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
scheme: "https",
})
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
}

View File

@@ -6,7 +6,7 @@ import (
"time"
gphttp "github.com/yusing/go-proxy/internal/net/http"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Trace struct {
@@ -88,7 +88,7 @@ func (m *Middleware) AddTracef(msg string, args ...any) *Trace {
return nil
}
return addTrace(&Trace{
Time: U.FormatTime(time.Now()),
Time: strutils.FormatTime(time.Now()),
Caller: m.Fullname(),
Message: fmt.Sprintf(msg, args...),
})