mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-10 18:56:55 +02:00
added golangci-linting, refactor, simplified error msgs and fixed some error handling
This commit is contained in:
@@ -10,8 +10,8 @@ import (
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
// TODO: stats of each server
|
||||
// TODO: support weighted mode
|
||||
// TODO: stats of each server.
|
||||
// TODO: support weighted mode.
|
||||
type (
|
||||
impl interface {
|
||||
ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request)
|
||||
|
||||
@@ -2,14 +2,14 @@ package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/error_page"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/errorpage"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
)
|
||||
|
||||
@@ -23,14 +23,15 @@ var CustomErrorPage = &Middleware{
|
||||
// only handles non-success status code and html/plain content type
|
||||
contentType := gphttp.GetContentType(resp.Header)
|
||||
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
|
||||
errorPage, ok := error_page.GetErrorPageByStatus(resp.StatusCode)
|
||||
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
|
||||
if ok {
|
||||
errPageLogger.Debugf("error page for status %d loaded", resp.StatusCode)
|
||||
/* trunk-ignore(golangci-lint/errcheck) */
|
||||
io.Copy(io.Discard, resp.Body) // drain the original body
|
||||
resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
|
||||
resp.ContentLength = int64(len(errorPage))
|
||||
resp.Header.Set("Content-Length", fmt.Sprint(len(errorPage)))
|
||||
resp.Header.Set("Content-Length", strconv.Itoa(len(errorPage)))
|
||||
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
|
||||
} else {
|
||||
errPageLogger.Errorf("unable to load error page for status %d", resp.StatusCode)
|
||||
@@ -48,25 +49,27 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
|
||||
}
|
||||
if strings.HasPrefix(path, gphttp.StaticFilePathPrefix) {
|
||||
filename := path[len(gphttp.StaticFilePathPrefix):]
|
||||
file, ok := error_page.GetStaticFile(filename)
|
||||
file, ok := errorpage.GetStaticFile(filename)
|
||||
if !ok {
|
||||
errPageLogger.Errorf("unable to load resource %s", filename)
|
||||
return false
|
||||
} else {
|
||||
ext := filepath.Ext(filename)
|
||||
switch ext {
|
||||
case ".html":
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
case ".js":
|
||||
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
|
||||
case ".css":
|
||||
w.Header().Set("Content-Type", "text/css; charset=utf-8")
|
||||
default:
|
||||
errPageLogger.Errorf("unexpected file type %q for %s", ext, filename)
|
||||
}
|
||||
w.Write(file)
|
||||
return true
|
||||
}
|
||||
ext := filepath.Ext(filename)
|
||||
switch ext {
|
||||
case ".html":
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
case ".js":
|
||||
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
|
||||
case ".css":
|
||||
w.Header().Set("Content-Type", "text/css; charset=utf-8")
|
||||
default:
|
||||
errPageLogger.Errorf("unexpected file type %q for %s", ext, filename)
|
||||
}
|
||||
if _, err := w.Write(file); err != nil {
|
||||
errPageLogger.WithError(err).Errorf("unable to write resource %s", filename)
|
||||
http.Error(w, "Error page failure", http.StatusInternalServerError)
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
|
||||
}
|
||||
middlewares = make(map[string]*Middleware)
|
||||
for name, defs := range rawMap {
|
||||
chainErr := E.NewBuilder(name)
|
||||
chainErr := E.NewBuilder("%s", name)
|
||||
chain := make([]*Middleware, 0, len(defs))
|
||||
for i, def := range defs {
|
||||
if def["use"] == nil || def["use"] == "" {
|
||||
@@ -64,7 +64,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: check conflict or duplicates
|
||||
// TODO: check conflict or duplicates.
|
||||
func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
|
||||
m := &Middleware{name: name, children: chain}
|
||||
|
||||
|
||||
@@ -92,18 +92,18 @@ func userIsAuthenticated(r *http.Request) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func exchangeCodeForToken(code string, opts *oAuth2Opts, requestUri string) (string, error) {
|
||||
func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) {
|
||||
// Prepare the request body
|
||||
data := url.Values{
|
||||
"client_id": {opts.ClientID},
|
||||
"client_secret": {opts.ClientSecret},
|
||||
"code": {code},
|
||||
"grant_type": {"authorization_code"},
|
||||
"redirect_uri": {requestUri},
|
||||
"redirect_uri": {requestURI},
|
||||
}
|
||||
resp, err := http.PostForm(opts.TokenURL, data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to request token: %v", err)
|
||||
return "", fmt.Errorf("failed to request token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
@@ -114,7 +114,7 @@ func exchangeCodeForToken(code string, opts *oAuth2Opts, requestUri string) (str
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return "", fmt.Errorf("failed to decode token response: %v", err)
|
||||
return "", fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
return tokenResp.AccessToken, nil
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ type testArgs struct {
|
||||
|
||||
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) {
|
||||
var body io.Reader
|
||||
var rr = new(requestRecorder)
|
||||
var rr requestRecorder
|
||||
var proxyURL *url.URL
|
||||
var requestTarget string
|
||||
var err error
|
||||
@@ -87,11 +87,14 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
|
||||
body = bytes.NewReader(args.body)
|
||||
}
|
||||
|
||||
if args.scheme == "" || args.scheme == "http" {
|
||||
switch args.scheme {
|
||||
case "":
|
||||
fallthrough
|
||||
case "http":
|
||||
requestTarget = "http://" + testHost
|
||||
} else if args.scheme == "https" {
|
||||
case "https":
|
||||
requestTarget = "https://" + testHost
|
||||
} else {
|
||||
default:
|
||||
panic("typo?")
|
||||
}
|
||||
|
||||
@@ -111,7 +114,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
|
||||
} else {
|
||||
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
|
||||
}
|
||||
rp := gphttp.NewReverseProxy(types.NewURL(proxyURL), rr)
|
||||
rp := gphttp.NewReverseProxy(types.NewURL(proxyURL), &rr)
|
||||
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
|
||||
if setOptErr != nil {
|
||||
return nil, setOptErr
|
||||
|
||||
@@ -24,10 +24,9 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
// A ProxyRequest contains a request to be rewritten by a [ReverseProxy].
|
||||
@@ -222,6 +221,7 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
transport := p.Transport
|
||||
|
||||
ctx := req.Context()
|
||||
/* trunk-ignore(golangci-lint/revive) */
|
||||
if ctx.Done() != nil {
|
||||
// CloseNotifier predates context.Context, and has been
|
||||
// entirely superseded by it. If the request contains
|
||||
@@ -460,7 +460,7 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
|
||||
|
||||
backConn, ok := res.Body.(io.ReadWriteCloser)
|
||||
if !ok {
|
||||
p.errorHandler(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"), true)
|
||||
p.errorHandler(rw, req, errors.New("internal error: 101 switching protocols response with non-writable body"), true)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -494,21 +494,24 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
|
||||
res.Header = rw.Header()
|
||||
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
|
||||
if err := res.Write(brw); err != nil {
|
||||
/* trunk-ignore(golangci-lint/errorlint) */
|
||||
p.errorHandler(rw, req, fmt.Errorf("response write: %s", err), true)
|
||||
return
|
||||
}
|
||||
if err := brw.Flush(); err != nil {
|
||||
/* trunk-ignore(golangci-lint/errorlint) */
|
||||
p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true)
|
||||
return
|
||||
}
|
||||
|
||||
bdp := U.NewBidirectionalPipe(req.Context(), conn, backConn)
|
||||
/* trunk-ignore(golangci-lint/errcheck) */
|
||||
bdp.Start()
|
||||
}
|
||||
|
||||
func IsPrint(s string) bool {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] < ' ' || s[i] > '~' {
|
||||
for _, r := range s {
|
||||
if r < ' ' || r > '~' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user