Compare commits

...

7 Commits
0.6 ... 0.6.3

87 changed files with 1638 additions and 564 deletions

View File

@@ -11,7 +11,7 @@ env:
jobs:
build:
name: Build multi-platform Docker image
runs-on: self-hosted
runs-on: ubuntu-22.04
permissions:
contents: read
@@ -85,7 +85,7 @@ jobs:
if-no-files-found: error
retention-days: 1
merge:
runs-on: self-hosted
runs-on: ubuntu-22.04
needs:
- build
permissions:

136
.golangci.yml Normal file
View File

@@ -0,0 +1,136 @@
run:
timeout: 10m
linters-settings:
govet:
enable-all: true
disable:
- shadow
- fieldalignment
gocyclo:
min-complexity: 14
goconst:
min-len: 3
min-occurrences: 4
misspell:
locale: US
funlen:
lines: -1
statements: 120
forbidigo:
forbid:
- ^print(ln)?$
godox:
keywords:
- FIXME
tagalign:
align: false
sort: true
order:
- description
- json
- toml
- yaml
- yml
- label
- label-slice-as-struct
- file
- kv
- export
stylecheck:
dot-import-whitelist:
- github.com/yusing/go-proxy/internal/utils/testing # go tests only
- github.com/yusing/go-proxy/internal/api/v1/utils # api only
revive:
rules:
- name: struct-tag
- name: blank-imports
- name: context-as-argument
- name: context-keys-type
- name: error-return
- name: error-strings
- name: error-naming
- name: exported
disabled: true
- name: if-return
- name: increment-decrement
- name: var-naming
- name: var-declaration
- name: package-comments
disabled: true
- name: range
- name: receiver-naming
- name: time-naming
- name: unexported-return
- name: indent-error-flow
- name: errorf
- name: empty-block
- name: superfluous-else
- name: unused-parameter
disabled: true
- name: unreachable-code
- name: redefines-builtin-id
gomoddirectives:
replace-allow-list:
- github.com/abbot/go-http-auth
- github.com/gorilla/mux
- github.com/mailgun/minheap
- github.com/mailgun/multibuf
- github.com/jaguilar/vt100
- github.com/cucumber/godog
- github.com/http-wasm/http-wasm-host-go
testifylint:
disable:
- suite-dont-use-pkg
- require-error
- go-require
staticcheck:
checks:
- all
- -SA1019
errcheck:
exclude-functions:
- fmt.Fprintln
linters:
enable-all: true
disable:
- execinquery # deprecated
- gomnd # deprecated
- sqlclosecheck # not relevant (SQL)
- rowserrcheck # not relevant (SQL)
- cyclop # duplicate of gocyclo
- depguard # Not relevant
- nakedret # Too strict
- lll # Not relevant
- gocyclo # FIXME must be fixed
- gocognit # Too strict
- nestif # Too many false-positive.
- prealloc # Too many false-positive.
- makezero # Not relevant
- dupl # Too strict
- gosec # Too strict
- gochecknoinits
- gochecknoglobals
- wsl # Too strict
- nlreturn # Not relevant
- mnd # Too strict
- testpackage # Too strict
- tparallel # Not relevant
- paralleltest # Not relevant
- exhaustive # Not relevant
- exhaustruct # Not relevant
- err113 # Too strict
- wrapcheck # Too strict
- noctx # Too strict
- bodyclose # too many false-positive
- forcetypeassert # Too strict
- tagliatelle # Too strict
- varnamelen # Not relevant
- nilnil # Not relevant
- ireturn # Not relevant
- contextcheck # too many false-positive
- containedctx # too many false-positive
- maintidx # kind of duplicate of gocyclo
- nonamedreturns # Too strict
- gosmopolitan # not relevant
- exportloopref # Not relevant since go1.22

9
.trunk/.gitignore vendored Normal file
View File

@@ -0,0 +1,9 @@
*out
*logs
*actions
*notifications
*tools
plugins
user_trunk.yaml
user.yaml
tmp

41
.trunk/trunk.yaml Normal file
View File

@@ -0,0 +1,41 @@
# This file controls the behavior of Trunk: https://docs.trunk.io/cli
# To learn more about the format of this file, see https://docs.trunk.io/reference/trunk-yaml
version: 0.1
cli:
version: 1.22.6
# Trunk provides extensibility via plugins. (https://docs.trunk.io/plugins)
plugins:
sources:
- id: trunk
ref: v1.6.3
uri: https://github.com/trunk-io/plugins
# Many linters and tools depend on runtimes - configure them here. (https://docs.trunk.io/runtimes)
runtimes:
enabled:
- node@18.12.1
- python@3.10.8
- go@1.23.2
# This is the section where you manage your linters. (https://docs.trunk.io/check/configuration)
lint:
enabled:
- hadolint@2.12.0
- actionlint@1.7.3
- checkov@3.2.257
- git-diff-check
- gofmt@1.20.4
- golangci-lint@1.61.0
- markdownlint@0.42.0
- osv-scanner@1.9.0
- oxipng@9.1.2
- prettier@3.3.3
- shellcheck@0.10.0
- shfmt@3.6.0
- trufflehog@3.82.7
- yamllint@1.35.1
actions:
disabled:
- trunk-announce
- trunk-check-pre-push
- trunk-fmt-pre-commit
enabled:
- trunk-upgrade-available

View File

@@ -11,7 +11,6 @@ import (
"reflect"
"runtime"
"strings"
"sync"
"syscall"
"time"
@@ -137,13 +136,11 @@ func main() {
signal.Notify(sig, syscall.SIGHUP)
autocert := cfg.GetAutoCertProvider()
if autocert != nil {
ctx, cancel := context.WithCancel(context.Background())
onShutdown.Add(cancel)
if err := autocert.Setup(ctx); err != nil {
l.Fatal(err)
} else {
onShutdown.Add(cancel)
}
} else {
l.Info("autocert not configured")
@@ -179,19 +176,15 @@ func main() {
// grafully shutdown
logrus.Info("shutting down")
done := make(chan struct{}, 1)
currentIdx := 0
var wg sync.WaitGroup
wg.Add(onShutdown.Size())
onShutdown.ForEach(func(f func()) {
go func() {
go func() {
onShutdown.ForEach(func(f func()) {
l.Debugf("waiting for %s to complete...", funcName(f))
f()
currentIdx++
l.Debugf("%s done", funcName(f))
wg.Done()
}()
})
go func() {
wg.Wait()
})
close(done)
}()
@@ -201,9 +194,9 @@ func main() {
logrus.Info("shutdown complete")
case <-timeout:
logrus.Info("timeout waiting for shutdown")
onShutdown.ForEach(func(f func()) {
l.Warnf("%s() is still running", funcName(f))
})
for i := currentIdx; i < onShutdown.Size(); i++ {
l.Warnf("%s() is still running", funcName(onShutdown.Get(i)))
}
}
}

2
go.mod
View File

@@ -19,7 +19,7 @@ require (
require (
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cloudflare/cloudflare-go v0.106.0 // indirect
github.com/cloudflare/cloudflare-go v0.107.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/go-connections v0.5.0 // indirect

4
go.sum
View File

@@ -4,8 +4,8 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cloudflare/cloudflare-go v0.106.0 h1:q41gC5Wc1nfi0D1ZhSHokWcd9mGMbqC7RE7qiP+qE00=
github.com/cloudflare/cloudflare-go v0.106.0/go.mod h1:pfUQ4PIG4ISI0/Mmc21Bp86UnFU0ktmPf3iTgbSL+cM=
github.com/cloudflare/cloudflare-go v0.107.0 h1:cMDIw2tzt6TXCJyMFVyP+BPOVkIfMvcKjhMNSNvuEPc=
github.com/cloudflare/cloudflare-go v0.107.0/go.mod h1:5cYGzVBqNTLxMYSLdVjuSs5LJL517wJDSvMPWUrzHzc=
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=

View File

@@ -5,7 +5,7 @@ import (
"net/http"
v1 "github.com/yusing/go-proxy/internal/api/v1"
"github.com/yusing/go-proxy/internal/api/v1/error_page"
"github.com/yusing/go-proxy/internal/api/v1/errorpage"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
@@ -36,11 +36,11 @@ func NewHandler(cfg *config.Config) http.Handler {
mux.HandleFunc("PUT", "/v1/file/{filename...}", v1.SetFileContent)
mux.HandleFunc("GET", "/v1/stats", wrap(cfg, v1.Stats))
mux.HandleFunc("GET", "/v1/stats/ws", wrap(cfg, v1.StatsWS))
mux.HandleFunc("GET", "/v1/error_page", error_page.GetHandleFunc())
mux.HandleFunc("GET", "/v1/error_page", errorpage.GetHandleFunc())
return mux
}
// allow only requests to API server with host matching common.APIHTTPAddr
// allow only requests to API server with host matching common.APIHTTPAddr.
func checkHost(f http.HandlerFunc) http.HandlerFunc {
if common.IsDebug {
return f
@@ -48,8 +48,7 @@ func checkHost(f http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Host != common.APIHTTPAddr {
Logger.Warnf("invalid request to API server with host: %s, expect %s", r.Host, common.APIHTTPAddr)
w.WriteHeader(http.StatusForbidden)
w.Write([]byte("invalid request"))
http.Error(w, "invalid request", http.StatusForbidden)
return
}
f(w, r)

View File

@@ -5,7 +5,7 @@ import (
"net/http"
"strings"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/config"
R "github.com/yusing/go-proxy/internal/route"
)
@@ -13,7 +13,7 @@ import (
func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
target := r.FormValue("target")
if target == "" {
U.HandleErr(w, r, U.ErrMissingKey("target"), http.StatusBadRequest)
HandleErr(w, r, ErrMissingKey("target"), http.StatusBadRequest)
return
}
@@ -22,7 +22,7 @@ func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
switch {
case route == nil:
U.HandleErr(w, r, U.ErrNotFound("target", target), http.StatusNotFound)
HandleErr(w, r, ErrNotFound("target", target), http.StatusNotFound)
return
case route.Type() == R.RouteTypeReverseProxy:
ok = IsSiteHealthy(route.URL().String())

View File

@@ -1,4 +1,4 @@
package error_page
package errorpage
import (
"context"
@@ -7,7 +7,7 @@ import (
"path"
"sync"
api "github.com/yusing/go-proxy/internal/api/v1/utils"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
@@ -17,6 +17,11 @@ import (
const errPagesBasePath = common.ErrorPagesBasePath
var (
dirWatcher W.Watcher
fileContentMap = F.NewMapOf[string, []byte]()
)
var setup = sync.OnceFunc(func() {
dirWatcher = W.NewDirectoryWatcher(context.Background(), errPagesBasePath)
loadContent()
@@ -27,7 +32,7 @@ func GetStaticFile(filename string) ([]byte, bool) {
return fileContentMap.Load(filename)
}
// try <statusCode>.html -> 404.html -> not ok
// try <statusCode>.html -> 404.html -> not ok.
func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
content, ok = fileContentMap.Load(fmt.Sprintf("%d.html", statusCode))
if !ok && statusCode != 404 {
@@ -39,7 +44,7 @@ func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
func loadContent() {
files, err := U.ListFiles(errPagesBasePath, 0)
if err != nil {
api.Logger.Error(err)
Logger.Error(err)
return
}
for _, file := range files {
@@ -48,11 +53,11 @@ func loadContent() {
}
content, err := os.ReadFile(file)
if err != nil {
api.Logger.Errorf("failed to read error page resource %s: %s", file, err)
Logger.Errorf("failed to read error page resource %s: %s", file, err)
continue
}
file = path.Base(file)
api.Logger.Infof("error page resource %s loaded", file)
Logger.Infof("error page resource %s loaded", file)
fileContentMap.Store(file, content)
}
}
@@ -72,17 +77,14 @@ func watchDir() {
loadContent()
case events.ActionFileDeleted:
fileContentMap.Delete(filename)
api.Logger.Infof("error page resource %s deleted", filename)
Logger.Infof("error page resource %s deleted", filename)
case events.ActionFileRenamed:
api.Logger.Infof("error page resource %s deleted", filename)
Logger.Infof("error page resource %s deleted", filename)
fileContentMap.Delete(filename)
loadContent()
}
case err := <-errCh:
api.Logger.Errorf("error watching error page directory: %s", err)
Logger.Errorf("error watching error page directory: %s", err)
}
}
}
var dirWatcher W.Watcher
var fileContentMap = F.NewMapOf[string, []byte]()

View File

@@ -1,6 +1,10 @@
package error_page
package errorpage
import "net/http"
import (
"net/http"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
)
func GetHandleFunc() http.HandlerFunc {
setup()
@@ -21,5 +25,7 @@ func serveHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, "404 not found", http.StatusNotFound)
return
}
w.Write(content)
if _, err := w.Write(content); err != nil {
HandleErr(w, r, err)
}
}

View File

@@ -24,7 +24,7 @@ func GetFileContent(w http.ResponseWriter, r *http.Request) {
U.HandleErr(w, r, err)
return
}
w.Write(content)
U.WriteBody(w, content)
}
func SetFileContent(w http.ResponseWriter, r *http.Request) {
@@ -47,11 +47,11 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) {
}
if validateErr != nil {
U.RespondJson(w, validateErr.JSONObject(), http.StatusBadRequest)
U.RespondJSON(w, r, validateErr.JSONObject(), http.StatusBadRequest)
return
}
err = os.WriteFile(path.Join(common.ConfigBasePath, filename), content, 0644)
err = os.WriteFile(path.Join(common.ConfigBasePath, filename), content, 0o644)
if err != nil {
U.HandleErr(w, r, err)
return

View File

@@ -1,7 +1,11 @@
package v1
import "net/http"
import (
"net/http"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
)
func Index(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("API ready"))
WriteBody(w, []byte("API ready"))
}

View File

@@ -55,7 +55,7 @@ func listRoutes(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
}
}
U.HandleErr(w, r, U.RespondJson(w, routes))
U.RespondJSON(w, r, routes)
}
func listConfigFiles(w http.ResponseWriter, r *http.Request) {
@@ -67,21 +67,21 @@ func listConfigFiles(w http.ResponseWriter, r *http.Request) {
for i := range files {
files[i] = strings.TrimPrefix(files[i], common.ConfigBasePath+"/")
}
U.HandleErr(w, r, U.RespondJson(w, files))
U.RespondJSON(w, r, files)
}
func listMiddlewareTrace(w http.ResponseWriter, r *http.Request) {
U.HandleErr(w, r, U.RespondJson(w, middleware.GetAllTrace()))
U.RespondJSON(w, r, middleware.GetAllTrace())
}
func listMiddlewares(w http.ResponseWriter, r *http.Request) {
U.HandleErr(w, r, U.RespondJson(w, middleware.All()))
U.RespondJSON(w, r, middleware.All())
}
func listMatchDomains(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
U.HandleErr(w, r, U.RespondJson(w, cfg.Value().MatchDomains))
U.RespondJSON(w, r, cfg.Value().MatchDomains)
}
func listHomepageConfig(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
U.HandleErr(w, r, U.RespondJson(w, cfg.HomepageConfig()))
U.RespondJSON(w, r, cfg.HomepageConfig())
}

View File

@@ -9,7 +9,7 @@ import (
func Reload(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
if err := cfg.Reload(); err != nil {
U.RespondJson(w, err.JSONObject(), http.StatusInternalServerError)
U.RespondJSON(w, r, err.JSONObject(), http.StatusInternalServerError)
} else {
w.WriteHeader(http.StatusOK)
}

View File

@@ -5,18 +5,17 @@ import (
"net/http"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
"github.com/yusing/go-proxy/internal/server"
"github.com/yusing/go-proxy/internal/utils"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
func Stats(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
U.HandleErr(w, r, U.RespondJson(w, getStats(cfg)))
U.RespondJSON(w, r, getStats(cfg))
}
func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
@@ -42,6 +41,7 @@ func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
U.Logger.Errorf("/stats/ws failed to upgrade websocket: %s", err)
return
}
/* trunk-ignore(golangci-lint/errcheck) */
defer conn.CloseNow()
ctx, cancel := context.WithCancel(context.Background())

View File

@@ -8,20 +8,22 @@ import (
"github.com/yusing/go-proxy/internal/common"
)
var HTTPClient = &http.Client{
Timeout: common.ConnectionTimeout,
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DisableKeepAlives: true,
ForceAttemptHTTP2: true,
DialContext: (&net.Dialer{
Timeout: common.DialTimeout,
KeepAlive: common.KeepAlive, // this is different from DisableKeepAlives
}).DialContext,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
var (
HTTPClient = &http.Client{
Timeout: common.ConnectionTimeout,
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DisableKeepAlives: true,
ForceAttemptHTTP2: true,
DialContext: (&net.Dialer{
Timeout: common.DialTimeout,
KeepAlive: common.KeepAlive, // this is different from DisableKeepAlives
}).DialContext,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
var Get = HTTPClient.Get
var Post = HTTPClient.Post
var Head = HTTPClient.Head
Get = HTTPClient.Get
Post = HTTPClient.Post
Head = HTTPClient.Head
)

View File

@@ -5,16 +5,26 @@ import (
"net/http"
)
func RespondJson(w http.ResponseWriter, data any, code ...int) error {
func WriteBody(w http.ResponseWriter, body []byte) {
if _, err := w.Write(body); err != nil {
HandleErr(w, nil, err)
}
}
func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) bool {
if len(code) > 0 {
w.WriteHeader(code[0])
}
w.Header().Set("Content-Type", "application/json")
j, err := json.MarshalIndent(data, "", " ")
if err != nil {
return err
} else {
w.Write(j)
HandleErr(w, r, err)
return false
}
return nil
_, err = w.Write(j)
if err != nil {
HandleErr(w, r, err)
return false
}
return true
}

View File

@@ -3,9 +3,10 @@ package v1
import (
"net/http"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/pkg"
)
func GetVersion(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(pkg.GetVersion()))
WriteBody(w, []byte(pkg.GetVersion()))
}

View File

@@ -16,22 +16,23 @@ import (
"github.com/go-acme/lego/v4/registration"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
U "github.com/yusing/go-proxy/internal/utils"
)
type Provider struct {
cfg *Config
user *User
legoCfg *lego.Config
client *lego.Client
type (
Provider struct {
cfg *Config
user *User
legoCfg *lego.Config
client *lego.Client
tlsCert *tls.Certificate
certExpiries CertExpiries
}
tlsCert *tls.Certificate
certExpiries CertExpiries
}
ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.NestedError)
type ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.NestedError)
type CertExpiries map[string]time.Time
CertExpiries map[string]time.Time
)
func (p *Provider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
if p.tlsCert == nil {
@@ -192,8 +193,8 @@ func (p *Provider) registerACME() E.NestedError {
}
func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError {
//* This should have been done in setup
//* but double check is always a good choice
/* This should have been done in setup
but double check is always a good choice.*/
_, err := os.Stat(path.Dir(p.cfg.CertPath))
if err != nil {
if os.IsNotExist(err) {

View File

@@ -1,8 +1,9 @@
package autocert
import (
"github.com/go-acme/lego/v4/registration"
"crypto"
"github.com/go-acme/lego/v4/registration"
)
type User struct {
@@ -19,4 +20,4 @@ func (u *User) GetRegistration() *registration.Resource {
}
func (u *User) GetPrivateKey() crypto.PrivateKey {
return u.key
}
}

View File

@@ -36,14 +36,12 @@ const (
ErrorPagesBasePath = "error_pages"
)
var (
RequiredDirectories = []string{
ConfigBasePath,
SchemaBasePath,
ErrorPagesBasePath,
MiddlewareComposeBasePath,
}
)
var RequiredDirectories = []string{
ConfigBasePath,
SchemaBasePath,
ErrorPagesBasePath,
MiddlewareComposeBasePath,
}
const DockerHostFromEnv = "$DOCKER_HOST"

View File

@@ -8,7 +8,6 @@ import (
"github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
PR "github.com/yusing/go-proxy/internal/proxy/provider"
R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/types"
@@ -212,7 +211,7 @@ func (cfg *Config) loadProviders(providers *types.ProxyProviders) (res E.NestedE
continue
}
cfg.proxyProviders.Store(p.GetName(), p)
b.Add(p.LoadRoutes().Subject(dockerHost))
b.Add(p.LoadRoutes().Subject(p.GetName()))
}
return
}
@@ -220,7 +219,7 @@ func (cfg *Config) loadProviders(providers *types.ProxyProviders) (res E.NestedE
func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) {
errors := E.NewBuilder("errors in %s these providers", action)
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) {
cfg.proxyProviders.RangeAllParallel(func(name string, p *PR.Provider) {
if err := do(p); err.HasError() {
errors.Add(err.Subject(p))
}

View File

@@ -133,7 +133,7 @@ func ConnectClient(host string) (Client, E.NestedError) {
}
func CloseAllClients() {
clientMap.RangeAll(func(_ string, c Client) {
clientMap.RangeAllParallel(func(_ string, c Client) {
c.Client.Close()
})
clientMap.Clear()

View File

@@ -7,7 +7,6 @@ import (
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
E "github.com/yusing/go-proxy/internal/error"
)

View File

@@ -1,7 +1,6 @@
package docker
import (
"fmt"
"strconv"
"strings"
@@ -39,7 +38,7 @@ func FromDocker(c *types.Container, dockerHost string) (res Container) {
return
}
func FromJson(json types.ContainerJSON, dockerHost string) Container {
func FromJSON(json types.ContainerJSON, dockerHost string) Container {
ports := make([]types.Port, 0)
for k, bindings := range json.NetworkSettings.Ports {
for _, v := range bindings {
@@ -76,9 +75,8 @@ func (c Container) getDeleteLabel(label string) string {
func (c Container) getAliases() []string {
if l := c.getDeleteLabel(LabelAliases); l != "" {
return U.CommaSeperatedList(l)
} else {
return []string{c.getName()}
}
return []string{c.getName()}
}
func (c Container) getName() string {
@@ -97,7 +95,7 @@ func (c Container) getPublicPortMapping() PortMapping {
if v.PublicPort == 0 {
continue
}
res[fmt.Sprint(v.PublicPort)] = v
res[U.PortString(v.PublicPort)] = v
}
return res
}
@@ -105,7 +103,7 @@ func (c Container) getPublicPortMapping() PortMapping {
func (c Container) getPrivatePortMapping() PortMapping {
res := make(PortMapping)
for _, v := range c.Ports {
res[fmt.Sprint(v.PrivatePort)] = v
res[U.PortString(v.PrivatePort)] = v
}
return res
}

View File

@@ -18,9 +18,9 @@ type templateData struct {
var loadingPage []byte
var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(loadingPage)))
const headerCheckRedirect = "X-GoProxy-Check-Redirect"
const headerCheckRedirect = "X-Goproxy-Check-Redirect"
func (w *watcher) makeRespBody(format string, args ...any) []byte {
func (w *Watcher) makeRespBody(format string, args ...any) []byte {
msg := fmt.Sprintf(format, args...)
data := new(templateData)

View File

@@ -2,30 +2,41 @@ package idlewatcher
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"strconv"
"time"
gphttp "github.com/yusing/go-proxy/internal/net/http"
)
type Waker struct {
*watcher
*Watcher
client *http.Client
rp *gphttp.ReverseProxy
}
func NewWaker(w *watcher, rp *gphttp.ReverseProxy) *Waker {
tr := &http.Transport{}
if w.NoTLSVerify {
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
func NewWaker(w *Watcher, rp *gphttp.ReverseProxy) *Waker {
orig := rp.ServeHTTP
// workaround for stopped containers port become zero
rp.ServeHTTP = func(rw http.ResponseWriter, r *http.Request) {
if rp.TargetURL.Port() == "0" {
port, ok := portHistoryMap.Load(w.Alias)
if !ok {
w.l.Errorf("port history not found for %s", w.Alias)
http.Error(rw, "internal server error", http.StatusInternalServerError)
return
}
rp.TargetURL.Host = fmt.Sprintf("%s:%v", rp.TargetURL.Hostname(), port)
}
orig(rw, r)
}
return &Waker{
watcher: w,
Watcher: w,
client: &http.Client{
Timeout: 1 * time.Second,
Transport: tr,
Transport: rp.Transport,
},
rp: rp,
}
@@ -36,6 +47,8 @@ func (w *Waker) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
}
func (w *Waker) wake(next http.HandlerFunc, rw http.ResponseWriter, r *http.Request) {
w.resetIdleTimer()
// pass through if container is ready
if w.ready.Load() {
next(rw, r)
@@ -45,10 +58,21 @@ func (w *Waker) wake(next http.HandlerFunc, rw http.ResponseWriter, r *http.Requ
ctx, cancel := context.WithTimeout(r.Context(), w.WakeTimeout)
defer cancel()
if r.Header.Get(headerCheckRedirect) == "" {
accept := gphttp.GetAccept(r.Header)
acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty())
isCheckRedirect := r.Header.Get(headerCheckRedirect) != ""
if !isCheckRedirect && acceptHTML {
// Send a loading response to the client
body := w.makeRespBody("%s waking up...", w.ContainerName)
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.Write(w.makeRespBody("%s waking up...", w.ContainerName))
rw.Header().Set("Content-Length", strconv.Itoa(len(body)))
rw.Header().Add("Cache-Control", "no-cache")
rw.Header().Add("Cache-Control", "no-store")
rw.Header().Add("Cache-Control", "must-revalidate")
if _, err := rw.Write(body); err != nil {
w.l.Errorf("error writing http response: %s", err)
}
return
}
@@ -63,7 +87,11 @@ func (w *Waker) wake(next http.HandlerFunc, rw http.ResponseWriter, r *http.Requ
// maybe another request came in while we were waiting for the wake
if w.ready.Load() {
next(rw, r)
if isCheckRedirect {
rw.WriteHeader(http.StatusOK)
} else {
next(rw, r)
}
return
}
@@ -87,11 +115,15 @@ func (w *Waker) wake(next http.HandlerFunc, rw http.ResponseWriter, r *http.Requ
return
}
// we don't care about the response
_, err = w.client.Do(wakeReq)
if err == nil {
wakeResp, err := w.client.Do(wakeReq)
if err == nil && wakeResp.StatusCode != http.StatusServiceUnavailable {
w.ready.Store(true)
rw.WriteHeader(http.StatusOK)
w.l.Debug("awaken")
if isCheckRedirect {
rw.WriteHeader(http.StatusOK)
} else {
next(rw, r)
}
return
}

View File

@@ -12,11 +12,12 @@ import (
E "github.com/yusing/go-proxy/internal/error"
P "github.com/yusing/go-proxy/internal/proxy"
PT "github.com/yusing/go-proxy/internal/proxy/fields"
F "github.com/yusing/go-proxy/internal/utils/functional"
W "github.com/yusing/go-proxy/internal/watcher"
)
type (
watcher struct {
Watcher struct {
*P.ReverseProxyEntry
client D.Client
@@ -26,6 +27,7 @@ type (
wakeCh chan struct{}
wakeDone chan E.NestedError
ticker *time.Ticker
ctx context.Context
cancel context.CancelFunc
@@ -44,15 +46,17 @@ var (
mainLoopCancel context.CancelFunc
mainLoopWg sync.WaitGroup
watcherMap = make(map[string]*watcher)
watcherMap = F.NewMapOf[string, *Watcher]()
watcherMapMu sync.Mutex
newWatcherCh = make(chan *watcher)
portHistoryMap = F.NewMapOf[PT.Alias, string]()
newWatcherCh = make(chan *Watcher)
logger = logrus.WithField("module", "idle_watcher")
)
func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) {
func Register(entry *P.ReverseProxyEntry) (*Watcher, E.NestedError) {
failure := E.Failure("idle_watcher register")
if entry.IdleTimeout == 0 {
@@ -64,7 +68,11 @@ func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) {
key := entry.ContainerID
if w, ok := watcherMap[key]; ok {
if entry.URL.Port() != "0" {
portHistoryMap.Store(entry.Alias, entry.URL.Port())
}
if w, ok := watcherMap.Load(key); ok {
w.refCount.Add(1)
w.ReverseProxyEntry = entry
return w, nil
@@ -75,18 +83,19 @@ func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) {
return nil, failure.With(err)
}
w := &watcher{
w := &Watcher{
ReverseProxyEntry: entry,
client: client,
refCount: &sync.WaitGroup{},
wakeCh: make(chan struct{}),
wakeCh: make(chan struct{}, 1),
wakeDone: make(chan E.NestedError),
ticker: time.NewTicker(entry.IdleTimeout),
l: logger.WithField("container", entry.ContainerName),
}
w.refCount.Add(1)
w.stopByMethod = w.getStopCallback()
watcherMap[key] = w
watcherMap.Store(key, w)
go func() {
newWatcherCh <- w
@@ -95,7 +104,7 @@ func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) {
return w, nil
}
func (w *watcher) Unregister() {
func (w *Watcher) Unregister() {
w.refCount.Add(-1)
}
@@ -116,8 +125,7 @@ func Start() {
w.watchUntilCancel()
w.refCount.Wait() // wait for 0 ref count
w.client.Close()
delete(watcherMap, w.ContainerID)
watcherMap.Delete(w.ContainerID)
w.l.Debug("unregistered")
mainLoopWg.Done()
}()
@@ -130,29 +138,30 @@ func Stop() {
mainLoopWg.Wait()
}
func (w *watcher) containerStop() error {
func (w *Watcher) containerStop() error {
return w.client.ContainerStop(w.ctx, w.ContainerID, container.StopOptions{
Signal: string(w.StopSignal),
Timeout: &w.StopTimeout})
Timeout: &w.StopTimeout,
})
}
func (w *watcher) containerPause() error {
func (w *Watcher) containerPause() error {
return w.client.ContainerPause(w.ctx, w.ContainerID)
}
func (w *watcher) containerKill() error {
func (w *Watcher) containerKill() error {
return w.client.ContainerKill(w.ctx, w.ContainerID, string(w.StopSignal))
}
func (w *watcher) containerUnpause() error {
func (w *Watcher) containerUnpause() error {
return w.client.ContainerUnpause(w.ctx, w.ContainerID)
}
func (w *watcher) containerStart() error {
func (w *Watcher) containerStart() error {
return w.client.ContainerStart(w.ctx, w.ContainerID, container.StartOptions{})
}
func (w *watcher) containerStatus() (string, E.NestedError) {
func (w *Watcher) containerStatus() (string, E.NestedError) {
json, err := w.client.ContainerInspect(w.ctx, w.ContainerID)
if err != nil {
return "", E.FailWith("inspect container", err)
@@ -160,7 +169,7 @@ func (w *watcher) containerStatus() (string, E.NestedError) {
return json.State.Status, nil
}
func (w *watcher) wakeIfStopped() E.NestedError {
func (w *Watcher) wakeIfStopped() E.NestedError {
if w.ready.Load() || w.ContainerRunning {
return nil
}
@@ -183,7 +192,7 @@ func (w *watcher) wakeIfStopped() E.NestedError {
}
}
func (w *watcher) getStopCallback() StopCallback {
func (w *Watcher) getStopCallback() StopCallback {
var cb func() error
switch w.StopMethod {
case PT.StopMethodPause:
@@ -207,10 +216,14 @@ func (w *watcher) getStopCallback() StopCallback {
}
}
func (w *watcher) watchUntilCancel() {
func (w *Watcher) resetIdleTimer() {
w.ticker.Reset(w.IdleTimeout)
}
func (w *Watcher) watchUntilCancel() {
defer close(w.wakeCh)
w.ctx, w.cancel = context.WithCancel(context.Background())
w.ctx, w.cancel = context.WithCancel(mainLoopCtx)
dockerWatcher := W.NewDockerWatcherWithClient(w.client)
dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.ctx, W.DockerListOptions{
@@ -225,14 +238,11 @@ func (w *watcher) watchUntilCancel() {
W.DockerFilterUnpause,
),
})
ticker := time.NewTicker(w.IdleTimeout)
defer ticker.Stop()
defer w.ticker.Stop()
defer w.client.Close()
for {
select {
case <-mainLoopCtx.Done():
w.cancel()
case <-w.ctx.Done():
w.l.Debug("stopped")
return
@@ -244,22 +254,24 @@ func (w *watcher) watchUntilCancel() {
switch {
// create / start / unpause
case e.Action.IsContainerWake():
ticker.Reset(w.IdleTimeout)
w.ContainerRunning = true
w.resetIdleTimer()
w.l.Info(e)
default: // stop / pause / kill
ticker.Stop()
default: // stop / pause / kil
w.ContainerRunning = false
w.ticker.Stop()
w.ready.Store(false)
w.l.Info(e)
}
case <-ticker.C:
case <-w.ticker.C:
w.l.Debug("idle timeout")
ticker.Stop()
w.ticker.Stop()
if err := w.stopByMethod(); err != nil && err.IsNot(context.Canceled) {
w.l.Error(E.FailWith("stop", err).Extraf("stop method: %s", w.StopMethod))
}
case <-w.wakeCh:
w.l.Debug("wake signal received")
ticker.Reset(w.IdleTimeout)
w.resetIdleTimer()
err := w.wakeIfStopped()
if err != nil {
w.l.Error(E.FailWith("wake", err))

View File

@@ -15,5 +15,5 @@ func (c Client) Inspect(containerID string) (Container, E.NestedError) {
if err != nil {
return Container{}, E.From(err)
}
return FromJson(json, c.key), nil
return FromJSON(json, c.key), nil
}

View File

@@ -47,7 +47,7 @@ func ApplyLabel[T any](obj *T, l *Label) E.NestedError {
case *Label:
var field reflect.Value
objType := reflect.TypeFor[T]()
for i := 0; i < reflect.TypeFor[T]().NumField(); i++ {
for i := range reflect.TypeFor[T]().NumField() {
if objType.Field(i).Tag.Get("yaml") == l.Attribute {
field = reflect.ValueOf(obj).Elem().Field(i)
break

View File

@@ -8,14 +8,18 @@ import (
. "github.com/yusing/go-proxy/internal/utils/testing"
)
const (
mName = "middleware1"
mAttr = "prop1"
v = "value1"
)
func makeLabel(ns, name, attr string) string {
return fmt.Sprintf("%s.%s.%s", ns, name, attr)
}
func TestNestedLabel(t *testing.T) {
mName := "middleware1"
mAttr := "prop1"
v := "value1"
pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
ExpectNoError(t, err.Error())
sGot := ExpectType[*Label](t, pl.Value)
@@ -28,9 +32,6 @@ func TestApplyNestedLabel(t *testing.T) {
entry := new(struct {
Middlewares NestedLabelMap `yaml:"middlewares"`
})
mName := "middleware1"
mAttr := "prop1"
v := "value1"
pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
ExpectNoError(t, err.Error())
err = ApplyLabel(entry, pl)
@@ -42,10 +43,6 @@ func TestApplyNestedLabel(t *testing.T) {
}
func TestApplyNestedLabelExisting(t *testing.T) {
mName := "middleware1"
mAttr := "prop1"
v := "value1"
checkAttr := "prop2"
checkV := "value2"
entry := new(struct {
@@ -71,9 +68,6 @@ func TestApplyNestedLabelExisting(t *testing.T) {
}
func TestApplyNestedLabelNoAttr(t *testing.T) {
mName := "middleware1"
v := "value1"
entry := new(struct {
Middlewares NestedLabelMap `yaml:"middlewares"`
})

View File

@@ -2,24 +2,26 @@ package docker
import "github.com/docker/docker/api/types"
type PortMapping = map[string]types.Port
type ProxyProperties struct {
DockerHost string `yaml:"-" json:"docker_host"`
ContainerName string `yaml:"-" json:"container_name"`
ContainerID string `yaml:"-" json:"container_id"`
ImageName string `yaml:"-" json:"image_name"`
PublicPortMapping PortMapping `yaml:"-" json:"public_port_mapping"` // non-zero publicPort:types.Port
PrivatePortMapping PortMapping `yaml:"-" json:"private_port_mapping"` // privatePort:types.Port
NetworkMode string `yaml:"-" json:"network_mode"`
type (
PortMapping = map[string]types.Port
ProxyProperties struct {
DockerHost string `json:"docker_host" yaml:"-"`
ContainerName string `json:"container_name" yaml:"-"`
ContainerID string `json:"container_id" yaml:"-"`
ImageName string `json:"image_name" yaml:"-"`
PublicPortMapping PortMapping `json:"public_ports" yaml:"-"` // non-zero publicPort:types.Port
PrivatePortMapping PortMapping `json:"private_ports" yaml:"-"` // privatePort:types.Port
NetworkMode string `json:"network_mode" yaml:"-"`
Aliases []string `yaml:"-" json:"aliases"`
IsExcluded bool `yaml:"-" json:"is_excluded"`
IsExplicit bool `yaml:"-" json:"is_explicit"`
IsDatabase bool `yaml:"-" json:"is_database"`
IdleTimeout string `yaml:"-" json:"idle_timeout"`
WakeTimeout string `yaml:"-" json:"wake_timeout"`
StopMethod string `yaml:"-" json:"stop_method"`
StopTimeout string `yaml:"-" json:"stop_timeout"` // stop_method = "stop" only
StopSignal string `yaml:"-" json:"stop_signal"` // stop_method = "stop" | "kill" only
Running bool `yaml:"-" json:"running"`
}
Aliases []string `json:"aliases" yaml:"-"`
IsExcluded bool `json:"is_excluded" yaml:"-"`
IsExplicit bool `json:"is_explicit" yaml:"-"`
IsDatabase bool `json:"is_database" yaml:"-"`
IdleTimeout string `json:"idle_timeout" yaml:"-"`
WakeTimeout string `json:"wake_timeout" yaml:"-"`
StopMethod string `json:"stop_method" yaml:"-"`
StopTimeout string `json:"stop_timeout" yaml:"-"` // stop_method = "stop" only
StopSignal string `json:"stop_signal" yaml:"-"` // stop_method = "stop" | "kill" only
Running bool `json:"running" yaml:"-"`
}
)

View File

@@ -21,7 +21,7 @@ func NewBuilder(format string, args ...any) Builder {
}
// adding nil / nil is no-op,
// you may safely pass expressions returning error to it
// you may safely pass expressions returning error to it.
func (b Builder) Add(err NestedError) Builder {
if err != nil {
b.Lock()
@@ -39,6 +39,13 @@ func (b Builder) Addf(format string, args ...any) Builder {
return b.Add(errorf(format, args...))
}
func (b Builder) AddRangeE(errs ...error) Builder {
for _, err := range errs {
b.AddE(err)
}
return b
}
// Build builds a NestedError based on the errors collected in the Builder.
//
// If there are no errors in the Builder, it returns a Nil() NestedError.
@@ -56,15 +63,20 @@ func (b Builder) Build() NestedError {
}
func (b Builder) To(ptr *NestedError) {
if ptr == nil {
switch {
case ptr == nil:
return
} else if *ptr == nil {
case *ptr == nil:
*ptr = b.Build()
} else {
(*ptr).With(b.Build())
default:
(*ptr).extras = append((*ptr).extras, *b.Build())
}
}
func (b Builder) String() string {
return b.Build().String()
}
func (b Builder) HasError() bool {
return len(b.errors) > 0
}

View File

@@ -33,15 +33,13 @@ func TestBuilderNested(t *testing.T) {
eb.Add(Failure("Action 2").With(Invalid("Inner", "3")))
got := eb.Build().String()
expected1 :=
(`error occurred:
expected1 := (`error occurred:
- Action 1 failed:
- invalid Inner: 1
- invalid Inner: 2
- Action 2 failed:
- invalid Inner: 3`)
expected2 :=
(`error occurred:
expected2 := (`error occurred:
- Action 1 failed:
- invalid Inner: "1"
- invalid Inner: "2"

View File

@@ -8,16 +8,16 @@ import (
)
type (
NestedError = *nestedError
nestedError struct {
NestedError = *NestedErrorImpl
NestedErrorImpl struct {
subject string
err error
extras []nestedError
extras []NestedErrorImpl
}
jsonNestedError struct {
Subject string
Err string
Extras []jsonNestedError
JSONNestedError struct {
Subject string `json:"subject"`
Err string `json:"error"`
Extras []JSONNestedError `json:"extras,omitempty"`
}
)
@@ -25,18 +25,18 @@ func From(err error) NestedError {
if IsNil(err) {
return nil
}
return &nestedError{err: err}
return &NestedErrorImpl{err: err}
}
func FromJSON(data []byte) (NestedError, bool) {
var j jsonNestedError
var j JSONNestedError
if err := json.Unmarshal(data, &j); err != nil {
return nil, false
}
if j.Err == "" {
return nil, false
}
extras := make([]nestedError, len(j.Extras))
extras := make([]NestedErrorImpl, len(j.Extras))
for i, e := range j.Extras {
extra, ok := fromJSONObject(e)
if !ok {
@@ -44,7 +44,7 @@ func FromJSON(data []byte) (NestedError, bool) {
}
extras[i] = *extra
}
return &nestedError{
return &NestedErrorImpl{
subject: j.Subject,
err: errors.New(j.Err),
extras: extras,
@@ -58,26 +58,26 @@ func Check[T any](obj T, err error) (T, NestedError) {
}
func Join(message string, err ...NestedError) NestedError {
extras := make([]nestedError, len(err))
extras := make([]NestedErrorImpl, len(err))
nErr := 0
for i, e := range err {
if e == nil {
continue
}
extras[i] = *e
nErr += 1
nErr++
}
if nErr == 0 {
return nil
}
return &nestedError{
return &NestedErrorImpl{
err: errors.New(message),
extras: extras,
}
}
func JoinE(message string, err ...error) NestedError {
b := NewBuilder(message)
b := NewBuilder("%s", message)
for _, e := range err {
b.AddE(e)
}
@@ -151,7 +151,7 @@ func (ne NestedError) Extraf(format string, args ...any) NestedError {
return ne.With(errorf(format, args...))
}
func (ne NestedError) Subject(s any) NestedError {
func (ne NestedError) Subject(s any, sep ...string) NestedError {
if ne == nil {
return ne
}
@@ -164,11 +164,12 @@ func (ne NestedError) Subject(s any) NestedError {
default:
subject = fmt.Sprint(s)
}
if ne.subject == "" {
switch {
case ne.subject == "":
ne.subject = subject
} else if !strings.ContainsRune(subject, ' ') || strings.ContainsRune(ne.subject, '.') {
ne.subject = fmt.Sprintf("%s.%s", subject, ne.subject)
} else {
case len(sep) > 0:
ne.subject = fmt.Sprintf("%s%s%s", subject, sep[0], ne.subject)
default:
ne.subject = fmt.Sprintf("%s > %s", subject, ne.subject)
}
return ne
@@ -178,21 +179,15 @@ func (ne NestedError) Subjectf(format string, args ...any) NestedError {
if ne == nil {
return ne
}
if strings.Contains(format, "%q") {
panic("Subjectf format should not contain %q")
}
if strings.Contains(format, "%w") {
panic("Subjectf format should not contain %w")
}
return ne.Subject(fmt.Sprintf(format, args...))
}
func (ne NestedError) JSONObject() jsonNestedError {
extras := make([]jsonNestedError, len(ne.extras))
func (ne NestedError) JSONObject() JSONNestedError {
extras := make([]JSONNestedError, len(ne.extras))
for i, e := range ne.extras {
extras[i] = e.JSONObject()
}
return jsonNestedError{
return JSONNestedError{
Subject: ne.subject,
Err: ne.err.Error(),
Extras: extras,
@@ -200,7 +195,10 @@ func (ne NestedError) JSONObject() jsonNestedError {
}
func (ne NestedError) JSON() []byte {
b, _ := json.MarshalIndent(ne.JSONObject(), "", " ")
b, err := json.MarshalIndent(ne.JSONObject(), "", " ")
if err != nil {
panic(err)
}
return b
}
@@ -216,7 +214,7 @@ func errorf(format string, args ...any) NestedError {
return From(fmt.Errorf(format, args...))
}
func fromJSONObject(obj jsonNestedError) (NestedError, bool) {
func fromJSONObject(obj JSONNestedError) (NestedError, bool) {
data, err := json.Marshal(obj)
if err != nil {
return nil, false
@@ -240,7 +238,7 @@ func (ne NestedError) appendMsg(msg string) NestedError {
}
func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) {
for i := 0; i < level; i++ {
for range level {
sb.WriteString(" ")
}
sb.WriteString(prefix)
@@ -267,7 +265,7 @@ func (ne NestedError) buildError(level int, prefix string) error {
var res error
var sb strings.Builder
for i := 0; i < level; i++ {
for range level {
sb.WriteString(" ")
}
sb.WriteString(prefix)

View File

@@ -1,10 +1,9 @@
package error_test
package error
import (
"errors"
"testing"
. "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@@ -88,8 +87,7 @@ func TestErrorNested(t *testing.T) {
With("baz").
With(inner).
With(inner.With(inner2.With(inner3)))
want :=
`foo failed:
want := `foo failed:
- bar
- baz
- inner failed:

View File

@@ -4,12 +4,11 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"time"
"log"
"github.com/yusing/go-proxy/internal/utils"
)
@@ -21,8 +20,10 @@ type GitHubContents struct { //! keep this, may reuse in future
Size int `json:"size"`
}
const iconsCachePath = "/tmp/icons_cache.json"
const updateInterval = 1 * time.Hour
const (
iconsCachePath = "/tmp/icons_cache.json"
updateInterval = 1 * time.Hour
)
func ListAvailableIcons() ([]string, error) {
owner := "walkxcode"
@@ -30,13 +31,14 @@ func ListAvailableIcons() ([]string, error) {
ref := "main"
var lastUpdate time.Time
var icons = make([]string, 0)
icons := make([]string, 0)
info, err := os.Stat(iconsCachePath)
if err == nil {
lastUpdate = info.ModTime().Local()
}
if time.Since(lastUpdate) < updateInterval {
err := utils.LoadJson(iconsCachePath, &icons)
err := utils.LoadJSON(iconsCachePath, &icons)
if err == nil {
return icons, nil
}
@@ -51,7 +53,7 @@ func ListAvailableIcons() ([]string, error) {
icons = append(icons, content.Path)
}
}
err = utils.SaveJson(iconsCachePath, &icons, 0o644).Error()
err = utils.SaveJSON(iconsCachePath, &icons, 0o644).Error()
if err != nil {
log.Print("error saving cache", err)
}
@@ -59,7 +61,7 @@ func ListAvailableIcons() ([]string, error) {
}
func getRepoContents(client *http.Client, owner string, repo string, ref string, path string) ([]GitHubContents, error) {
req, err := http.NewRequest("GET", fmt.Sprintf("https://api.github.com/repos/%s/%s/contents/%s?ref=%s", owner, repo, path, ref), nil)
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://api.github.com/repos/%s/%s/contents/%s?ref=%s", owner, repo, path, ref), nil)
if err != nil {
return nil, err
}

View File

@@ -6,6 +6,7 @@ import (
)
type ContentType string
type AcceptContentType []ContentType
func GetContentType(h http.Header) ContentType {
ct := h.Get("Content-Type")
@@ -19,6 +20,18 @@ func GetContentType(h http.Header) ContentType {
return ContentType(ct)
}
func GetAccept(h http.Header) AcceptContentType {
var accepts []ContentType
for _, v := range h["Accept"] {
ct, _, err := mime.ParseMediaType(v)
if err != nil {
continue
}
accepts = append(accepts, ContentType(ct))
}
return accepts
}
func (ct ContentType) IsHTML() bool {
return ct == "text/html" || ct == "application/xhtml+xml"
}
@@ -30,3 +43,34 @@ func (ct ContentType) IsJSON() bool {
func (ct ContentType) IsPlainText() bool {
return ct == "text/plain"
}
func (act AcceptContentType) IsEmpty() bool {
return len(act) == 0
}
func (act AcceptContentType) AcceptHTML() bool {
for _, v := range act {
if v.IsHTML() || v == "text/*" || v == "*/*" {
return true
}
}
return false
}
func (act AcceptContentType) AcceptJSON() bool {
for _, v := range act {
if v.IsJSON() || v == "*/*" {
return true
}
}
return false
}
func (act AcceptContentType) AcceptPlainText() bool {
for _, v := range act {
if v.IsPlainText() || v == "text/*" || v == "*/*" {
return true
}
}
return false
}

View File

@@ -0,0 +1,41 @@
package http
import (
"net/http"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestContentTypes(t *testing.T) {
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsHTML())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/html; charset=utf-8"}}).IsHTML())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/xhtml+xml"}}).IsHTML())
ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/plain"}}).IsHTML())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/json"}}).IsJSON())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/json; charset=utf-8"}}).IsJSON())
ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsJSON())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/plain"}}).IsPlainText())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/plain; charset=utf-8"}}).IsPlainText())
ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsPlainText())
}
func TestAcceptContentTypes(t *testing.T) {
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain"}}).AcceptPlainText())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain; charset=utf-8"}}).AcceptPlainText())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain"}}).AcceptHTML())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"application/json"}}).AcceptJSON())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptPlainText())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptHTML())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptJSON())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptPlainText())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptHTML())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/plain"}}).AcceptHTML())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/plain; charset=utf-8"}}).AcceptHTML())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/html"}}).AcceptPlainText())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/html"}}).AcceptJSON())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptJSON())
}

View File

@@ -0,0 +1,33 @@
package loadbalancer
import (
"hash/fnv"
"net"
"net/http"
)
type ipHash struct{ *LoadBalancer }
func (lb *LoadBalancer) newIPHash() impl { return &ipHash{lb} }
func (ipHash) OnAddServer(srv *Server) {}
func (ipHash) OnRemoveServer(srv *Server) {}
func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(rw, "Internal error", http.StatusInternalServerError)
logger.Errorf("invalid remote address %s: %s", r.RemoteAddr, err)
return
}
idx := hashIP(ip) % uint32(len(impl.pool))
if !impl.pool[idx].available.Load() {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
}
impl.pool[idx].handler.ServeHTTP(rw, r)
}
func hashIP(ip string) uint32 {
h := fnv.New32a()
h.Write([]byte(ip))
return h.Sum32()
}

View File

@@ -0,0 +1,53 @@
package loadbalancer
import (
"net/http"
"sync/atomic"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type leastConn struct {
*LoadBalancer
nConn F.Map[*Server, *atomic.Int64]
}
func (lb *LoadBalancer) newLeastConn() impl {
return &leastConn{
LoadBalancer: lb,
nConn: F.NewMapOf[*Server, *atomic.Int64](),
}
}
func (impl *leastConn) OnAddServer(srv *Server) {
impl.nConn.Store(srv, new(atomic.Int64))
}
func (impl *leastConn) OnRemoveServer(srv *Server) {
impl.nConn.Delete(srv)
}
func (impl *leastConn) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
srv := srvs[0]
minConn, ok := impl.nConn.Load(srv)
if !ok {
logger.Errorf("[BUG] server %s not found", srv.Name)
http.Error(rw, "Internal error", http.StatusInternalServerError)
}
for i := 1; i < len(srvs); i++ {
nConn, ok := impl.nConn.Load(srvs[i])
if !ok {
logger.Errorf("[BUG] server %s not found", srv.Name)
http.Error(rw, "Internal error", http.StatusInternalServerError)
}
if nConn.Load() < minConn.Load() {
minConn = nConn
srv = srvs[i]
}
}
minConn.Add(1)
srv.handler.ServeHTTP(rw, r)
minConn.Add(-1)
}

View File

@@ -0,0 +1,238 @@
package loadbalancer
import (
"context"
"net/http"
"sync"
"time"
"github.com/go-acme/lego/v4/log"
E "github.com/yusing/go-proxy/internal/error"
)
// TODO: stats of each server.
// TODO: support weighted mode.
type (
impl interface {
ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request)
OnAddServer(srv *Server)
OnRemoveServer(srv *Server)
}
Config struct {
Link string
Mode Mode
Weight weightType
}
LoadBalancer struct {
impl
Config
pool servers
poolMu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
done chan struct{}
sumWeight weightType
}
weightType uint16
)
const maxWeight weightType = 100
func New(cfg Config) *LoadBalancer {
lb := &LoadBalancer{Config: cfg, pool: servers{}}
mode := cfg.Mode
if !cfg.Mode.ValidateUpdate() {
logger.Warnf("loadbalancer %s: invalid mode %q, fallback to %s", cfg.Link, mode, cfg.Mode)
}
switch mode {
case RoundRobin:
lb.impl = lb.newRoundRobin()
case LeastConn:
lb.impl = lb.newLeastConn()
case IPHash:
lb.impl = lb.newIPHash()
default: // should happen in test only
lb.impl = lb.newRoundRobin()
}
return lb
}
func (lb *LoadBalancer) AddServer(srv *Server) {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
lb.pool = append(lb.pool, srv)
lb.sumWeight += srv.Weight
lb.impl.OnAddServer(srv)
logger.Debugf("[add] loadbalancer %s: %d servers available", lb.Link, len(lb.pool))
}
func (lb *LoadBalancer) RemoveServer(srv *Server) {
lb.poolMu.RLock()
defer lb.poolMu.RUnlock()
lb.impl.OnRemoveServer(srv)
for i, s := range lb.pool {
if s == srv {
lb.pool = append(lb.pool[:i], lb.pool[i+1:]...)
break
}
}
if lb.IsEmpty() {
lb.Stop()
return
}
lb.Rebalance()
logger.Debugf("[remove] loadbalancer %s: %d servers left", lb.Link, len(lb.pool))
}
func (lb *LoadBalancer) IsEmpty() bool {
return len(lb.pool) == 0
}
func (lb *LoadBalancer) Rebalance() {
if lb.sumWeight == maxWeight {
return
}
if lb.sumWeight == 0 { // distribute evenly
weightEach := maxWeight / weightType(len(lb.pool))
remainder := maxWeight % weightType(len(lb.pool))
for _, s := range lb.pool {
s.Weight = weightEach
lb.sumWeight += weightEach
if remainder > 0 {
s.Weight++
remainder--
}
}
return
}
// scale evenly
scaleFactor := float64(maxWeight) / float64(lb.sumWeight)
lb.sumWeight = 0
for _, s := range lb.pool {
s.Weight = weightType(float64(s.Weight) * scaleFactor)
lb.sumWeight += s.Weight
}
delta := maxWeight - lb.sumWeight
if delta == 0 {
return
}
for _, s := range lb.pool {
if delta == 0 {
break
}
if delta > 0 {
s.Weight++
lb.sumWeight++
delta--
} else {
s.Weight--
lb.sumWeight--
delta++
}
}
}
func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
srvs := lb.availServers()
if len(srvs) == 0 {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return
}
lb.impl.ServeHTTP(srvs, rw, r)
}
func (lb *LoadBalancer) Start() {
if lb.sumWeight != 0 && lb.sumWeight != maxWeight {
msg := E.NewBuilder("loadbalancer %s total weight %d != %d", lb.Link, lb.sumWeight, maxWeight)
for _, s := range lb.pool {
msg.Addf("%s: %d", s.Name, s.Weight)
}
lb.Rebalance()
inner := E.NewBuilder("after rebalancing")
for _, s := range lb.pool {
inner.Addf("%s: %d", s.Name, s.Weight)
}
msg.Addf("%s", inner)
logger.Warn(msg)
}
if lb.sumWeight != 0 {
log.Warnf("weighted mode not supported yet")
}
lb.done = make(chan struct{}, 1)
lb.ctx, lb.cancel = context.WithCancel(context.Background())
updateAll := func() {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
var wg sync.WaitGroup
wg.Add(len(lb.pool))
for _, s := range lb.pool {
go func(s *Server) {
defer wg.Done()
s.checkUpdateAvail(lb.ctx)
}(s)
}
wg.Wait()
}
logger.Debugf("loadbalancer %s started", lb.Link)
go func() {
defer lb.cancel()
defer close(lb.done)
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
updateAll()
for {
select {
case <-lb.ctx.Done():
return
case <-ticker.C:
updateAll()
}
}
}()
}
func (lb *LoadBalancer) Stop() {
if lb.cancel == nil {
return
}
lb.cancel()
<-lb.done
lb.pool = nil
logger.Debugf("loadbalancer %s stopped", lb.Link)
}
func (lb *LoadBalancer) availServers() servers {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
avail := servers{}
for _, s := range lb.pool {
if s.available.Load() {
avail = append(avail, s)
}
}
return avail
}

View File

@@ -0,0 +1,43 @@
package loadbalancer
import (
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRebalance(t *testing.T) {
t.Parallel()
t.Run("zero", func(t *testing.T) {
lb := New(Config{})
for range 10 {
lb.AddServer(&Server{})
}
lb.Rebalance()
ExpectEqual(t, lb.sumWeight, maxWeight)
})
t.Run("less", func(t *testing.T) {
lb := New(Config{})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb.Rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight)
})
t.Run("more", func(t *testing.T) {
lb := New(Config{})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .4)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb.Rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight)
})
}

View File

@@ -0,0 +1,5 @@
package loadbalancer
import "github.com/sirupsen/logrus"
var logger = logrus.WithField("module", "load_balancer")

View File

@@ -0,0 +1,29 @@
package loadbalancer
import (
U "github.com/yusing/go-proxy/internal/utils"
)
type Mode string
const (
RoundRobin Mode = "roundrobin"
LeastConn Mode = "leastconn"
IPHash Mode = "iphash"
)
func (mode *Mode) ValidateUpdate() bool {
switch U.ToLowerNoSnake(string(*mode)) {
case "", string(RoundRobin):
*mode = RoundRobin
return true
case string(LeastConn):
*mode = LeastConn
return true
case string(IPHash):
*mode = IPHash
return true
}
*mode = RoundRobin
return false
}

View File

@@ -0,0 +1,22 @@
package loadbalancer
import (
"net/http"
"sync/atomic"
)
type roundRobin struct {
index atomic.Uint32
}
func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} }
func (lb *roundRobin) OnAddServer(srv *Server) {}
func (lb *roundRobin) OnRemoveServer(srv *Server) {}
func (lb *roundRobin) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
index := lb.index.Add(1)
srvs[index%uint32(len(srvs))].handler.ServeHTTP(rw, r)
if lb.index.Load() >= 2*uint32(len(srvs)) {
lb.index.Store(0)
}
}

View File

@@ -0,0 +1,67 @@
package loadbalancer
import (
"context"
"net/http"
"sync/atomic"
"time"
"github.com/yusing/go-proxy/internal/net/types"
)
type (
Server struct {
Name string
URL types.URL
Weight weightType
handler http.Handler
pinger *http.Client
available atomic.Bool
}
servers []*Server
)
func NewServer(name string, url types.URL, weight weightType, handler http.Handler) *Server {
srv := &Server{
Name: name,
URL: url,
Weight: weight,
handler: handler,
pinger: &http.Client{Timeout: 3 * time.Second},
}
srv.available.Store(true)
return srv
}
func (srv *Server) checkUpdateAvail(ctx context.Context) {
req, err := http.NewRequestWithContext(
ctx,
http.MethodHead,
srv.URL.String(),
nil,
)
if err != nil {
logger.Error("failed to create request: ", err)
srv.available.Store(false)
}
resp, err := srv.pinger.Do(req)
if err == nil && resp.StatusCode != http.StatusServiceUnavailable {
if !srv.available.Swap(true) {
logger.Infof("server %s is up", srv.Name)
}
} else if err != nil {
if srv.available.Swap(false) {
logger.Warnf("server %s is down: %s", srv.Name, err)
}
} else {
if srv.available.Swap(false) {
logger.Warnf("server %s is down: status %s", srv.Name, resp.Status)
}
}
}
func (srv *Server) String() string {
return srv.Name
}

View File

@@ -5,7 +5,7 @@ import (
"net/http"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/net/types"
F "github.com/yusing/go-proxy/internal/utils/functional"
)

View File

@@ -13,7 +13,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/net/types"
)
const (

View File

@@ -2,14 +2,14 @@ package middleware
import (
"bytes"
"fmt"
"io"
"net/http"
"path/filepath"
"strconv"
"strings"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/api/v1/error_page"
"github.com/yusing/go-proxy/internal/api/v1/errorpage"
gphttp "github.com/yusing/go-proxy/internal/net/http"
)
@@ -23,14 +23,15 @@ var CustomErrorPage = &Middleware{
// only handles non-success status code and html/plain content type
contentType := gphttp.GetContentType(resp.Header)
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
errorPage, ok := error_page.GetErrorPageByStatus(resp.StatusCode)
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
if ok {
errPageLogger.Debugf("error page for status %d loaded", resp.StatusCode)
/* trunk-ignore(golangci-lint/errcheck) */
io.Copy(io.Discard, resp.Body) // drain the original body
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
resp.ContentLength = int64(len(errorPage))
resp.Header.Set("Content-Length", fmt.Sprint(len(errorPage)))
resp.Header.Set("Content-Length", strconv.Itoa(len(errorPage)))
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
} else {
errPageLogger.Errorf("unable to load error page for status %d", resp.StatusCode)
@@ -48,25 +49,27 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
}
if strings.HasPrefix(path, gphttp.StaticFilePathPrefix) {
filename := path[len(gphttp.StaticFilePathPrefix):]
file, ok := error_page.GetStaticFile(filename)
file, ok := errorpage.GetStaticFile(filename)
if !ok {
errPageLogger.Errorf("unable to load resource %s", filename)
return false
} else {
ext := filepath.Ext(filename)
switch ext {
case ".html":
w.Header().Set("Content-Type", "text/html; charset=utf-8")
case ".js":
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
case ".css":
w.Header().Set("Content-Type", "text/css; charset=utf-8")
default:
errPageLogger.Errorf("unexpected file type %q for %s", ext, filename)
}
w.Write(file)
return true
}
ext := filepath.Ext(filename)
switch ext {
case ".html":
w.Header().Set("Content-Type", "text/html; charset=utf-8")
case ".js":
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
case ".css":
w.Header().Set("Content-Type", "text/css; charset=utf-8")
default:
errPageLogger.Errorf("unexpected file type %q for %s", ext, filename)
}
if _, err := w.Write(file); err != nil {
errPageLogger.WithError(err).Errorf("unable to write resource %s", filename)
http.Error(w, "Error page failure", http.StatusInternalServerError)
}
return true
}
return false
}

View File

@@ -30,7 +30,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
}
middlewares = make(map[string]*Middleware)
for name, defs := range rawMap {
chainErr := E.NewBuilder(name)
chainErr := E.NewBuilder("%s", name)
chain := make([]*Middleware, 0, len(defs))
for i, def := range defs {
if def["use"] == nil || def["use"] == "" {
@@ -64,7 +64,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
return
}
// TODO: check conflict or duplicates
// TODO: check conflict or duplicates.
func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
m := &Middleware{name: name, children: chain}

View File

@@ -29,7 +29,6 @@ func init() {
"setxforwarded": SetXForwarded,
"hidexforwarded": HideXForwarded,
"redirecthttp": RedirectHTTP,
"forwardauth": ForwardAuth.m,
"modifyresponse": ModifyResponse.m,
"modifyrequest": ModifyRequest.m,
"errorpage": CustomErrorPage,
@@ -37,6 +36,10 @@ func init() {
"realip": RealIP.m,
"cloudflarerealip": CloudflareRealIP.m,
"cidrwhitelist": CIDRWhiteList.m,
// !experimental
"forwardauth": ForwardAuth.m,
"oauth2": OAuth2.m,
}
names := make(map[*Middleware][]string)
for name, m := range middlewares {

View File

@@ -0,0 +1,129 @@
package middleware
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"reflect"
E "github.com/yusing/go-proxy/internal/error"
)
type oAuth2 struct {
*oAuth2Opts
m *Middleware
}
type oAuth2Opts struct {
ClientID string
ClientSecret string
AuthURL string // Authorization Endpoint
TokenURL string // Token Endpoint
}
var OAuth2 = &oAuth2{
m: &Middleware{withOptions: NewAuthentikOAuth2},
}
func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.NestedError) {
oauth := new(oAuth2)
oauth.m = &Middleware{
impl: oauth,
before: oauth.handleOAuth2,
}
oauth.oAuth2Opts = &oAuth2Opts{}
err := Deserialize(opts, oauth.oAuth2Opts)
if err != nil {
return nil, err
}
b := E.NewBuilder("missing required fields")
optV := reflect.ValueOf(oauth.oAuth2Opts)
for _, field := range reflect.VisibleFields(reflect.TypeFor[oAuth2Opts]()) {
if optV.FieldByName(field.Name).Len() == 0 {
b.Add(E.Missing(field.Name))
}
}
if b.HasError() {
return nil, b.Build().Subject("oAuth2")
}
return oauth.m, nil
}
func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) {
// Check if the user is authenticated (you may use session, cookie, etc.)
if !userIsAuthenticated(r) {
// TODO: Redirect to OAuth2 auth URL
http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code",
oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound)
return
}
// If you have a token in the query string, process it
if code := r.URL.Query().Get("code"); code != "" {
// Exchange the authorization code for a token here
// Use the TokenURL and authenticate the user
token, err := exchangeCodeForToken(code, oauth.oAuth2Opts, r.RequestURI)
if err != nil {
// handle error
http.Error(rw, "failed to get token", http.StatusUnauthorized)
return
}
// Save token and user info based on your requirements
saveToken(rw, token)
// Redirect to the originally requested URL
http.Redirect(rw, r, "/", http.StatusFound)
return
}
// If user is authenticated, go to the next handler
next(rw, r)
}
func userIsAuthenticated(r *http.Request) bool {
// Example: Check for a session or cookie
session, err := r.Cookie("session_token")
if err != nil || session.Value == "" {
return false
}
// Validate the session_token if necessary
return true
}
func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) {
// Prepare the request body
data := url.Values{
"client_id": {opts.ClientID},
"client_secret": {opts.ClientSecret},
"code": {code},
"grant_type": {"authorization_code"},
"redirect_uri": {requestURI},
}
resp, err := http.PostForm(opts.TokenURL, data)
if err != nil {
return "", fmt.Errorf("failed to request token: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status)
}
// Decode the response
var tokenResp struct {
AccessToken string `json:"access_token"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", fmt.Errorf("failed to decode token response: %w", err)
}
return tokenResp.AccessToken, nil
}
func saveToken(rw ResponseWriter, token string) {
// Example: Save token in cookie
http.SetCookie(rw, &http.Cookie{
Name: "auth_token",
Value: token,
// set other properties as necessary, such as Secure and HttpOnly
})
}

View File

@@ -4,7 +4,7 @@ import (
"net"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/net/types"
)
// https://nginx.org/en/docs/http/ngx_http_realip_module.html

View File

@@ -6,7 +6,7 @@ import (
"strings"
"testing"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)

View File

@@ -12,6 +12,7 @@ import (
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/types"
)
//go:embed test_data/sample_headers.json
@@ -73,7 +74,7 @@ type testArgs struct {
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) {
var body io.Reader
var rr = new(requestRecorder)
var rr requestRecorder
var proxyURL *url.URL
var requestTarget string
var err error
@@ -86,11 +87,14 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
body = bytes.NewReader(args.body)
}
if args.scheme == "" || args.scheme == "http" {
switch args.scheme {
case "":
fallthrough
case "http":
requestTarget = "http://" + testHost
} else if args.scheme == "https" {
case "https":
requestTarget = "https://" + testHost
} else {
default:
panic("typo?")
}
@@ -110,7 +114,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
} else {
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
}
rp := gphttp.NewReverseProxy(proxyURL, rr)
rp := gphttp.NewReverseProxy(types.NewURL(proxyURL), &rr)
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
if setOptErr != nil {
return nil, setOptErr

View File

@@ -24,9 +24,9 @@ import (
"sync"
"github.com/sirupsen/logrus"
"golang.org/x/net/http/httpguts"
"github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils"
"golang.org/x/net/http/httpguts"
)
// A ProxyRequest contains a request to be rewritten by a [ReverseProxy].
@@ -69,36 +69,6 @@ type ProxyRequest struct {
// 1xx responses are forwarded to the client if the underlying
// transport supports ClientTrace.Got1xxResponse.
type ReverseProxy struct {
// Director is a function which modifies
// the request into a new request to be sent
// using Transport. Its response is then copied
// back to the original client unmodified.
// Director must not access the provided Request
// after returning.
//
// By default, the X-Forwarded-For header is set to the
// value of the client IP address. If an X-Forwarded-For
// header already exists, the client IP is appended to the
// existing values. As a special case, if the header
// exists in the Request.Header map but has a nil value
// (such as when set by the Director func), the X-Forwarded-For
// header is not modified.
//
// To prevent IP spoofing, be sure to delete any pre-existing
// X-Forwarded-For header coming from the client or
// an untrusted proxy.
//
// Hop-by-hop headers are removed from the request after
// Director returns, which can remove headers added by
// Director. Use a Rewrite function instead to ensure
// modifications to the request are preserved.
//
// Unparsable query parameters are removed from the outbound
// request if Request.Form is set after Director returns.
//
// At most one of Rewrite or Director may be set.
Director func(*http.Request)
// The transport used to perform proxy requests.
// If nil, http.DefaultTransport is used.
Transport http.RoundTripper
@@ -115,6 +85,8 @@ type ReverseProxy struct {
ModifyResponse func(*http.Response) error
ServeHTTP http.HandlerFunc
TargetURL types.URL
}
func singleJoiningSlash(a, b string) string {
@@ -172,16 +144,11 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
// }
//
func NewReverseProxy(target *url.URL, transport http.RoundTripper) *ReverseProxy {
func NewReverseProxy(target types.URL, transport http.RoundTripper) *ReverseProxy {
if transport == nil {
panic("nil transport")
}
rp := &ReverseProxy{
Director: func(req *http.Request) {
rewriteRequestURL(req, target)
},
Transport: transport,
}
rp := &ReverseProxy{Transport: transport, TargetURL: target}
rp.ServeHTTP = rp.serveHTTP
return rp
}
@@ -254,6 +221,7 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport
ctx := req.Context()
/* trunk-ignore(golangci-lint/revive) */
if ctx.Done() != nil {
// CloseNotifier predates context.Context, and has been
// entirely superseded by it. If the request contains
@@ -296,7 +264,7 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
}
p.Director(outreq)
rewriteRequestURL(outreq, p.TargetURL.URL)
outreq.Close = false
reqUpType := UpgradeType(outreq.Header)
@@ -381,18 +349,16 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
roundTripMutex.Unlock()
if err != nil {
p.errorHandler(rw, outreq, err, false)
errMsg := err.Error()
res = &http.Response{
Status: http.StatusText(http.StatusBadGateway),
StatusCode: http.StatusBadGateway,
Proto: outreq.Proto,
ProtoMajor: outreq.ProtoMajor,
ProtoMinor: outreq.ProtoMinor,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))),
Request: outreq,
ContentLength: int64(len(errMsg)),
TLS: outreq.TLS,
Status: http.StatusText(http.StatusBadGateway),
StatusCode: http.StatusBadGateway,
Proto: outreq.Proto,
ProtoMajor: outreq.ProtoMajor,
ProtoMinor: outreq.ProtoMinor,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))),
Request: outreq,
TLS: outreq.TLS,
}
}
@@ -494,7 +460,7 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok {
p.errorHandler(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"), true)
p.errorHandler(rw, req, errors.New("internal error: 101 switching protocols response with non-writable body"), true)
return
}
@@ -528,21 +494,24 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
res.Header = rw.Header()
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
if err := res.Write(brw); err != nil {
/* trunk-ignore(golangci-lint/errorlint) */
p.errorHandler(rw, req, fmt.Errorf("response write: %s", err), true)
return
}
if err := brw.Flush(); err != nil {
/* trunk-ignore(golangci-lint/errorlint) */
p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true)
return
}
bdp := U.NewBidirectionalPipe(req.Context(), conn, backConn)
/* trunk-ignore(golangci-lint/errcheck) */
bdp.Start()
}
func IsPrint(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] < ' ' || s[i] > '~' {
for _, r := range s {
if r < ' ' || r > '~' {
return false
}
}

24
internal/net/types/url.go Normal file
View File

@@ -0,0 +1,24 @@
package types
import "net/url"
type URL struct{ *url.URL }
func NewURL(url *url.URL) URL {
return URL{url}
}
func (u URL) String() string {
if u.URL == nil {
return "nil"
}
return u.URL.String()
}
func (u URL) MarshalText() (text []byte, err error) {
return []byte(u.String()), nil
}
func (u URL) Equals(other URL) bool {
return u.URL == other.URL || u.String() == other.String()
}

View File

@@ -7,6 +7,8 @@ import (
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
net "github.com/yusing/go-proxy/internal/net/types"
T "github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/types"
)
@@ -15,9 +17,10 @@ type (
ReverseProxyEntry struct { // real model after validation
Alias T.Alias
Scheme T.Scheme
URL *url.URL
URL net.URL
NoTLSVerify bool
PathPatterns T.PathPatterns
LoadBalance loadbalancer.Config
Middlewares D.NestedLabelMap
/* Docker only */
@@ -47,6 +50,10 @@ func (rp *ReverseProxyEntry) IsDocker() bool {
return rp.DockerHost != ""
}
func (rp *ReverseProxyEntry) IsZeroPort() bool {
return rp.URL.Port() == "0"
}
func ValidateEntry(m *types.RawEntry) (any, E.NestedError) {
m.FillMissingFields()
@@ -107,9 +114,10 @@ func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEn
return &ReverseProxyEntry{
Alias: T.NewAlias(m.Alias),
Scheme: s,
URL: url,
URL: net.NewURL(url),
NoTLSVerify: m.NoTLSVerify,
PathPatterns: pathPatterns,
LoadBalance: m.LoadBalance,
Middlewares: m.Middlewares,
IdleTimeout: idleTimeout,
WakeTimeout: wakeTimeout,

View File

@@ -4,8 +4,10 @@ import (
E "github.com/yusing/go-proxy/internal/error"
)
type Host string
type Subdomain = Alias
type (
Host string
Subdomain = Alias
)
func ValidateHost[String ~string](s String) (Host, E.NestedError) {
return Host(s), nil

View File

@@ -6,8 +6,12 @@ import (
E "github.com/yusing/go-proxy/internal/error"
)
type PathPattern string
type PathPatterns = []PathPattern
type (
PathPattern string
PathPatterns = []PathPattern
)
var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`)
func NewPathPattern(s string) (PathPattern, E.NestedError) {
if len(s) == 0 {
@@ -25,13 +29,11 @@ func ValidatePathPatterns(s []string) (PathPatterns, E.NestedError) {
}
pp := make(PathPatterns, len(s))
for i, v := range s {
if pattern, err := NewPathPattern(v); err.HasError() {
pattern, err := NewPathPattern(v)
if err != nil {
return nil, err
} else {
pp[i] = pattern
}
pp[i] = pattern
}
return pp, nil
}
var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`)

View File

@@ -1,7 +1,6 @@
package provider
import (
"fmt"
"regexp"
"strconv"
"strings"
@@ -9,7 +8,6 @@ import (
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/types"
W "github.com/yusing/go-proxy/internal/watcher"
@@ -21,8 +19,10 @@ type DockerProvider struct {
ExplicitOnly bool
}
var AliasRefRegex = regexp.MustCompile(`#\d+`)
var AliasRefRegexOld = regexp.MustCompile(`\$\d+`)
var (
AliasRefRegex = regexp.MustCompile(`#\d+`)
AliasRefRegexOld = regexp.MustCompile(`\$\d+`)
)
func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImpl, E.NestedError) {
hostname, err := D.ParseDockerHostname(dockerHost)
@@ -33,7 +33,7 @@ func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImp
}
func (p *DockerProvider) String() string {
return fmt.Sprintf("docker: %s", p.name)
return "docker: " + p.name
}
func (p *DockerProvider) NewWatcher() W.Watcher {
@@ -49,7 +49,7 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
return routes, E.FailWith("connect to docker", err)
}
errors := E.NewBuilder("errors when parse docker labels")
errors := E.NewBuilder("errors in docker labels")
for _, c := range info.Containers {
container := D.FromDocker(&c, p.dockerHost)
@@ -172,7 +172,7 @@ func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResul
}
// Returns a list of proxy entries for a container.
// Always non-nil
// Always non-nil.
func (p *DockerProvider) entriesFromContainerLabels(container D.Container) (entries types.RawEntries, _ E.NestedError) {
entries = types.NewProxyEntries()
@@ -206,7 +206,7 @@ func (p *DockerProvider) applyLabel(container D.Container, entries types.RawEntr
b := E.NewBuilder("errors in label %s", key)
defer b.To(&res)
refErr := E.NewBuilder("errors parsing alias references")
refErr := E.NewBuilder("errors in alias references")
replaceIndexRef := func(ref string) string {
index, err := strconv.Atoi(ref[1:])
if err != nil {
@@ -231,7 +231,7 @@ func (p *DockerProvider) applyLabel(container D.Container, entries types.RawEntr
// apply label for all aliases
entries.RangeAll(func(a string, e *types.RawEntry) {
if err = D.ApplyLabel(e, lbl); err.HasError() {
b.Add(err.Subjectf("alias %s", lbl.Target))
b.Add(err)
}
})
} else {
@@ -250,7 +250,7 @@ func (p *DockerProvider) applyLabel(container D.Container, entries types.RawEntr
return
}
if err = D.ApplyLabel(config, lbl); err.HasError() {
b.Add(err.Subjectf("alias %s", lbl.Target))
b.Add(err)
}
}
return

View File

@@ -52,12 +52,12 @@ func (p FileProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult)
return
}
routes.RangeAll(func(_ string, v R.Route) {
routes.RangeAllParallel(func(_ string, v R.Route) {
b.Add(v.Stop())
})
routes.Clear()
newRoutes.RangeAll(func(_ string, v R.Route) {
newRoutes.RangeAllParallel(func(_ string, v R.Route) {
b.Add(v.Start())
})

View File

@@ -99,31 +99,21 @@ func (p *Provider) GetType() ProviderType {
return p.t
}
// to work with json marshaller
// to work with json marshaller.
func (p *Provider) MarshalText() ([]byte, error) {
return []byte(p.String()), nil
}
func (p *Provider) StartAllRoutes() (res E.NestedError) {
errors := E.NewBuilder("errors in routes")
errors := E.NewBuilder("errors starting routes")
defer errors.To(&res)
// start watcher no matter load success or not
go p.watchEvents()
nStarted := 0
nFailed := 0
p.routes.RangeAll(func(alias string, r R.Route) {
if err := r.Start(); err.HasError() {
errors.Add(err.Subject(r))
nFailed++
} else {
nStarted++
}
p.routes.RangeAllParallel(func(alias string, r R.Route) {
errors.Add(r.Start().Subject(r))
})
p.l.Debugf("%d routes started, %d failed", nStarted, nFailed)
return
}
@@ -133,20 +123,12 @@ func (p *Provider) StopAllRoutes() (res E.NestedError) {
p.watcherCancel = nil
}
errors := E.NewBuilder("errors stopping routes for provider %q", p.name)
errors := E.NewBuilder("errors stopping routes")
defer errors.To(&res)
nStopped := 0
nFailed := 0
p.routes.RangeAll(func(alias string, r R.Route) {
if err := r.Stop(); err.HasError() {
errors.Add(err.Subject(r))
nFailed++
} else {
nStopped++
}
p.routes.RangeAllParallel(func(alias string, r R.Route) {
errors.Add(r.Stop().Subject(r))
})
p.l.Debugf("%d routes stopped, %d failed", nStopped, nFailed)
return
}
@@ -165,6 +147,9 @@ func (p *Provider) LoadRoutes() E.NestedError {
p.l.Infof("loaded %d routes", p.routes.Size())
return err
}
if err == nil {
return nil
}
return E.FailWith("loading routes", err)
}

View File

@@ -4,5 +4,7 @@ import (
"time"
)
const udpBufferSize = 8192
const streamStopListenTimeout = 1 * time.Second
const (
udpBufferSize = 8192
streamStopListenTimeout = 1 * time.Second
)

View File

@@ -1,18 +1,18 @@
package route
import (
"errors"
"fmt"
"net/http"
"strings"
"sync"
"net/http"
"net/url"
"strings"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/api/v1/error_page"
"github.com/yusing/go-proxy/internal/api/v1/errorpage"
"github.com/yusing/go-proxy/internal/docker/idlewatcher"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/net/http"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
"github.com/yusing/go-proxy/internal/net/http/middleware"
P "github.com/yusing/go-proxy/internal/proxy"
PT "github.com/yusing/go-proxy/internal/proxy/fields"
@@ -21,20 +21,18 @@ import (
type (
HTTPRoute struct {
Alias PT.Alias `json:"alias"`
TargetURL *URL `json:"target_url"`
PathPatterns PT.PathPatterns `json:"path_patterns"`
*P.ReverseProxyEntry
LoadBalancer *loadbalancer.LoadBalancer `json:"load_balancer"`
entry *P.ReverseProxyEntry
server *loadbalancer.Server
handler http.Handler
rp *ReverseProxy
rp *gphttp.ReverseProxy
}
URL url.URL
SubdomainKey = PT.Alias
ReverseProxyHandler struct {
*ReverseProxy
*gphttp.ReverseProxy
}
)
@@ -43,7 +41,7 @@ var (
httpRoutes = F.NewMapOf[string, *HTTPRoute]()
httpRoutesMu sync.Mutex
globalMux = http.NewServeMux() // TODO: support regex subdomain matching
// globalMux = http.NewServeMux() // TODO: support regex subdomain matching.
)
func (rp ReverseProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -62,12 +60,12 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
var trans *http.Transport
if entry.NoTLSVerify {
trans = DefaultTransportNoTLS.Clone()
trans = gphttp.DefaultTransportNoTLS.Clone()
} else {
trans = DefaultTransport.Clone()
trans = gphttp.DefaultTransport.Clone()
}
rp := NewReverseProxy(entry.URL, trans)
rp := gphttp.NewReverseProxy(entry.URL, trans)
if len(entry.Middlewares) > 0 {
err := middleware.PatchReverseProxy(string(entry.Alias), rp, entry.Middlewares)
@@ -80,11 +78,8 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
defer httpRoutesMu.Unlock()
r := &HTTPRoute{
Alias: entry.Alias,
TargetURL: (*URL)(entry.URL),
PathPatterns: entry.PathPatterns,
entry: entry,
rp: rp,
ReverseProxyEntry: entry,
rp: rp,
}
return r, nil
}
@@ -101,18 +96,19 @@ func (r *HTTPRoute) Start() E.NestedError {
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
if r.entry.UseIdleWatcher() {
watcher, err := idlewatcher.Register(r.entry)
switch {
case r.UseIdleWatcher():
watcher, err := idlewatcher.Register(r.ReverseProxyEntry)
if err != nil {
return err
}
r.handler = idlewatcher.NewWaker(watcher, r.rp)
} else if r.entry.URL.Port() == "0" ||
r.entry.IsDocker() && !r.entry.ContainerRunning {
case r.IsZeroPort() ||
r.IsDocker() && !r.ContainerRunning:
return nil
} else if len(r.PathPatterns) == 1 && r.PathPatterns[0] == "/" {
case len(r.PathPatterns) == 1 && r.PathPatterns[0] == "/":
r.handler = ReverseProxyHandler{r.rp}
} else {
default:
mux := http.NewServeMux()
for _, p := range r.PathPatterns {
mux.HandleFunc(string(p), r.rp.ServeHTTP)
@@ -120,7 +116,26 @@ func (r *HTTPRoute) Start() E.NestedError {
r.handler = mux
}
httpRoutes.Store(string(r.Alias), r)
if r.LoadBalance.Link == "" {
httpRoutes.Store(string(r.Alias), r)
return nil
}
var lb *loadbalancer.LoadBalancer
linked, ok := httpRoutes.Load(r.LoadBalance.Link)
if ok {
lb = linked.LoadBalancer
} else {
lb = loadbalancer.New(r.LoadBalance)
lb.Start()
linked = &HTTPRoute{
LoadBalancer: lb,
handler: lb,
}
httpRoutes.Store(r.LoadBalance.Link, linked)
}
r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler)
lb.AddServer(r.server)
return nil
}
@@ -136,8 +151,21 @@ func (r *HTTPRoute) Stop() (_ E.NestedError) {
waker.Unregister()
}
if r.server != nil {
linked, ok := httpRoutes.Load(r.LoadBalance.Link)
if ok {
linked.LoadBalancer.RemoveServer(r.server)
}
if linked.LoadBalancer.IsEmpty() {
httpRoutes.Delete(r.LoadBalance.Link)
}
r.server = nil
} else {
httpRoutes.Delete(string(r.Alias))
}
r.handler = nil
httpRoutes.Delete(string(r.Alias))
return
}
@@ -145,14 +173,6 @@ func (r *HTTPRoute) Started() bool {
return r.handler != nil
}
func (u *URL) String() string {
return (*url.URL)(u).String()
}
func (u *URL) MarshalText() (text []byte, err error) {
return []byte(u.String()), nil
}
func ProxyHandler(w http.ResponseWriter, r *http.Request) {
mux, err := findMuxFunc(r.Host)
if err != nil {
@@ -160,11 +180,13 @@ func ProxyHandler(w http.ResponseWriter, r *http.Request) {
logrus.Error(E.Failure("request").
Subjectf("%s %s", r.Method, r.URL.String()).
With(err))
errorPage, ok := error_page.GetErrorPageByStatus(http.StatusNotFound)
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
if ok {
w.WriteHeader(http.StatusNotFound)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Write(errorPage)
if _, err := w.Write(errorPage); err != nil {
logrus.Errorf("failed to respond error page to %s: %s", r.RemoteAddr, err)
}
} else {
http.Error(w, err.Error(), http.StatusNotFound)
}
@@ -178,7 +200,7 @@ func findMuxAnyDomain(host string) (http.Handler, error) {
hostSplit := strings.Split(host, ".")
n := len(hostSplit)
if n <= 2 {
return nil, fmt.Errorf("missing subdomain in url")
return nil, errors.New("missing subdomain in url")
}
sd := strings.Join(hostSplit[:n-2], ".")
if r, ok := httpRoutes.Load(sd); ok {

View File

@@ -30,7 +30,7 @@ type StreamRoute struct {
type StreamImpl interface {
Setup() error
Accept() (any, error)
Handle(any) error
Handle(conn any) error
CloseListeners()
String() string
}

View File

@@ -36,7 +36,7 @@ func (route *TCPRoute) Setup() error {
if err != nil {
return err
}
//! this read the allocated port from orginal ':0'
//! this read the allocated port from original ':0'
route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port)
route.listener = in
return nil

View File

@@ -51,7 +51,7 @@ func (route *UDPRoute) Setup() error {
return err
}
//! this read the allocated listeningPort from orginal ':0'
//! this read the allocated listeningPort from original ':0'
route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port)
route.listeningConn = source
@@ -65,7 +65,6 @@ func (route *UDPRoute) Accept() (any, error) {
buffer := make([]byte, udpBufferSize)
nRead, srcAddr, err := in.ReadFromUDP(buffer)
if err != nil {
return nil, err
}
@@ -108,7 +107,7 @@ func (route *UDPRoute) CloseListeners() {
route.listeningConn.Close()
route.listeningConn = nil
}
route.connMap.RangeAll(func(_ string, conn *UDPConn) {
route.connMap.RangeAllParallel(func(_ string, conn *UDPConn) {
if err := conn.src.Close(); err != nil {
route.l.Errorf("error closing src conn: %s", err)
}

View File

@@ -1,25 +1,25 @@
package server
var proxyServer, apiServer *server
var proxyServer, apiServer *Server
func InitProxyServer(opt Options) *server {
func InitProxyServer(opt Options) *Server {
if proxyServer == nil {
proxyServer = NewServer(opt)
}
return proxyServer
}
func InitAPIServer(opt Options) *server {
func InitAPIServer(opt Options) *Server {
if apiServer == nil {
apiServer = NewServer(opt)
}
return apiServer
}
func GetProxyServer() *server {
func GetProxyServer() *Server {
return proxyServer
}
func GetAPIServer() *server {
func GetAPIServer() *Server {
return apiServer
}

View File

@@ -2,6 +2,7 @@ package server
import (
"crypto/tls"
"errors"
"log"
"net/http"
"time"
@@ -11,7 +12,7 @@ import (
"golang.org/x/net/context"
)
type server struct {
type Server struct {
Name string
CertProvider *autocert.Provider
http *http.Server
@@ -38,7 +39,7 @@ func (l LogrusWrapper) Write(b []byte) (int, error) {
return l.Logger.WriterLevel(logrus.ErrorLevel).Write(b)
}
func NewServer(opt Options) (s *server) {
func NewServer(opt Options) (s *Server) {
var httpSer, httpsSer *http.Server
var httpHandler http.Handler
@@ -76,7 +77,7 @@ func NewServer(opt Options) (s *server) {
},
}
}
return &server{
return &Server{
Name: opt.Name,
CertProvider: opt.CertProvider,
http: httpSer,
@@ -88,8 +89,8 @@ func NewServer(opt Options) (s *server) {
//
// If both are not set, this does nothing.
//
// Start() is non-blocking
func (s *server) Start() {
// Start() is non-blocking.
func (s *Server) Start() {
if s.http == nil && s.https == nil {
return
}
@@ -112,7 +113,7 @@ func (s *server) Start() {
}
}
func (s *server) Stop() {
func (s *Server) Stop() {
if s.http == nil && s.https == nil {
return
}
@@ -133,13 +134,13 @@ func (s *server) Stop() {
}
}
func (s *server) Uptime() time.Duration {
func (s *Server) Uptime() time.Duration {
return time.Since(s.startTime)
}
func (s *server) handleErr(scheme string, err error) {
switch err {
case nil, http.ErrServerClosed:
func (s *Server) handleErr(scheme string, err error) {
switch {
case err == nil, errors.Is(err, http.ErrServerClosed):
return
default:
logrus.Fatalf("failed to start %s %s server: %s", scheme, s.Name, err)

View File

@@ -1,7 +1,6 @@
package internal
import (
"fmt"
"io"
"log"
"net/http"
@@ -9,16 +8,18 @@ import (
"os"
"path"
. "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/common"
)
var branch = GetEnv("GOPROXY_BRANCH", "v0.5")
var baseUrl = fmt.Sprintf("https://github.com/yusing/go-proxy/raw/%s", branch)
var requiredConfigs = []Config{
{ConfigBasePath, true, false, ""},
{ComposeFileName, false, true, ComposeExampleFileName},
{path.Join(ConfigBasePath, ConfigFileName), false, true, ConfigExampleFileName},
}
var (
branch = common.GetEnv("GOPROXY_BRANCH", "v0.6")
baseURL = "https://github.com/yusing/go-proxy/raw/" + branch
requiredConfigs = []Config{
{common.ConfigBasePath, true, false, ""},
{common.ComposeFileName, false, true, common.ComposeExampleFileName},
{path.Join(common.ConfigBasePath, common.ConfigFileName), false, true, common.ConfigExampleFileName},
}
)
type Config struct {
Pathname string
@@ -31,7 +32,9 @@ func Setup() {
log.Println("setting up go-proxy")
log.Println("branch:", branch)
os.Chdir("/setup")
if err := os.Chdir("/setup"); err != nil {
log.Fatalf("failed: %s\n", err)
}
for _, config := range requiredConfigs {
config.setup()
@@ -83,6 +86,7 @@ func touch(pathname string) {
log.Fatalf("failed: %s\n", err)
}
}
func fetch(remoteFilename string, outFileName string) {
if hasFileOrDir(outFileName) {
if remoteFilename == outFileName {
@@ -94,7 +98,7 @@ func fetch(remoteFilename string, outFileName string) {
}
log.Printf("downloading %q\n", remoteFilename)
url, err := url.JoinPath(baseUrl, remoteFilename)
url, err := url.JoinPath(baseURL, remoteFilename)
if err != nil {
log.Fatalf("unexpected error: %s\n", err)
}
@@ -104,17 +108,19 @@ func fetch(remoteFilename string, outFileName string) {
log.Fatalf("http request failed: %s\n", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
resp.Body.Close()
log.Fatalf("error reading response body: %s\n", err)
}
err = os.WriteFile(outFileName, body, 0o644)
if err != nil {
resp.Body.Close()
log.Fatalf("failed to write to file: %s\n", err)
}
log.Printf("downloaded to %q\n", outFileName)
resp.Body.Close()
}

View File

@@ -2,12 +2,12 @@ package types
type (
AutoCertConfig struct {
Email string `json:"email"`
Domains []string `yaml:",flow" json:"domains"`
CertPath string `yaml:"cert_path" json:"cert_path"`
KeyPath string `yaml:"key_path" json:"key_path"`
Provider string `json:"provider"`
Options AutocertProviderOpt `yaml:",flow" json:"options"`
Email string `json:"email,omitempty" yaml:"email"`
Domains []string `json:"domains,omitempty" yaml:",flow"`
CertPath string `json:"cert_path,omitempty" yaml:"cert_path"`
KeyPath string `json:"key_path,omitempty" yaml:"key_path"`
Provider string `json:"provider,omitempty" yaml:"provider"`
Options AutocertProviderOpt `json:"options,omitempty" yaml:",flow"`
}
AutocertProviderOpt map[string]any
)

View File

@@ -1,12 +1,12 @@
package types
type Config struct {
Providers ProxyProviders `yaml:",flow" json:"providers"`
AutoCert AutoCertConfig `yaml:",flow" json:"autocert"`
ExplicitOnly bool `yaml:"explicit_only" json:"explicit_only"`
MatchDomains []string `yaml:"match_domains" json:"match_domains"`
TimeoutShutdown int `yaml:"timeout_shutdown" json:"timeout_shutdown"`
RedirectToHTTPS bool `yaml:"redirect_to_https" json:"redirect_to_https"`
Providers ProxyProviders `json:"providers" yaml:",flow"`
AutoCert AutoCertConfig `json:"autocert" yaml:",flow"`
ExplicitOnly bool `json:"explicit_only" yaml:"explicit_only"`
MatchDomains []string `json:"match_domains" yaml:"match_domains"`
TimeoutShutdown int `json:"timeout_shutdown" yaml:"timeout_shutdown"`
RedirectToHTTPS bool `json:"redirect_to_https" yaml:"redirect_to_https"`
}
func DefaultConfig() *Config {

View File

@@ -1,6 +1,6 @@
package types
type ProxyProviders struct {
Files []string `yaml:"include" json:"include"` // docker, file
Docker map[string]string `yaml:"docker" json:"docker"`
Files []string `json:"include" yaml:"include"` // docker, file
Docker map[string]string `json:"docker" yaml:"docker"`
}

View File

@@ -1,13 +1,14 @@
package types
import (
"fmt"
"strconv"
"strings"
. "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker"
H "github.com/yusing/go-proxy/internal/homepage"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
@@ -15,17 +16,18 @@ type (
RawEntry struct {
// raw entry object before validation
// loaded from docker labels or yaml file
Alias string `yaml:"-" json:"-"`
Scheme string `yaml:"scheme" json:"scheme"`
Host string `yaml:"host" json:"host"`
Port string `yaml:"port" json:"port"`
NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify,omitempty"` // https proxy only
PathPatterns []string `yaml:"path_patterns" json:"path_patterns,omitempty"` // http(s) proxy only
Middlewares D.NestedLabelMap `yaml:"middlewares" json:"middlewares,omitempty"`
Homepage *H.HomePageItem `yaml:"homepage" json:"homepage"`
Alias string `json:"-" yaml:"-"`
Scheme string `json:"scheme" yaml:"scheme"`
Host string `json:"host" yaml:"host"`
Port string `json:"port" yaml:"port"`
NoTLSVerify bool `json:"no_tls_verify,omitempty" yaml:"no_tls_verify"` // https proxy only
PathPatterns []string `json:"path_patterns,omitempty" yaml:"path_patterns"` // http(s) proxy only
LoadBalance loadbalancer.Config `json:"load_balance" yaml:"load_balance"`
Middlewares D.NestedLabelMap `json:"middlewares,omitempty" yaml:"middlewares"`
Homepage *H.HomePageItem `json:"homepage,omitempty" yaml:"homepage"`
/* Docker only */
*D.ProxyProperties `yaml:"-" json:"proxy_properties"`
*D.ProxyProperties `json:"proxy_properties" yaml:"-"`
}
RawEntries = F.Map[string, *RawEntry]
@@ -41,14 +43,14 @@ func (e *RawEntry) FillMissingFields() {
lp, pp, extra := e.splitPorts()
if port, ok := ServiceNamePortMapTCP[e.ImageName]; ok {
if port, ok := common.ServiceNamePortMapTCP[e.ImageName]; ok {
if pp == "" {
pp = strconv.Itoa(port)
}
if e.Scheme == "" {
e.Scheme = "tcp"
}
} else if port, ok := ImageNamePortMap[e.ImageName]; ok {
} else if port, ok := common.ImageNamePortMap[e.ImageName]; ok {
if pp == "" {
pp = strconv.Itoa(port)
}
@@ -59,7 +61,7 @@ func (e *RawEntry) FillMissingFields() {
pp = "443"
} else if pp == "" {
if p, ok := F.FirstValueOf(e.PrivatePortMapping); ok {
pp = fmt.Sprint(p.PrivatePort)
pp = U.PortString(p.PrivatePort)
} else if !isDocker {
pp = "80"
}
@@ -68,12 +70,12 @@ func (e *RawEntry) FillMissingFields() {
// replace private port with public port (if any)
if isDocker && e.NetworkMode != "host" {
if p, ok := e.PrivatePortMapping[pp]; ok {
pp = fmt.Sprint(p.PublicPort)
pp = U.PortString(p.PublicPort)
}
if _, ok := e.PublicPortMapping[pp]; !ok { // port is not exposed, but specified
// try to fallback to first public port
if p, ok := F.FirstValueOf(e.PublicPortMapping); ok {
pp = fmt.Sprint(p.PublicPort)
pp = U.PortString(p.PublicPort)
}
}
}
@@ -85,13 +87,12 @@ func (e *RawEntry) FillMissingFields() {
}
if e.Scheme == "" {
if lp != "" {
switch {
case lp != "":
e.Scheme = "tcp"
} else if strings.HasSuffix(pp, "443") {
case strings.HasSuffix(pp, "443"):
e.Scheme = "https"
} else if _, ok := WellKnownHTTPPorts[pp]; ok {
e.Scheme = "http"
} else {
default:
// assume its http
e.Scheme = "http"
}
@@ -101,16 +102,16 @@ func (e *RawEntry) FillMissingFields() {
e.Host = "localhost"
}
if e.IdleTimeout == "" {
e.IdleTimeout = IdleTimeoutDefault
e.IdleTimeout = common.IdleTimeoutDefault
}
if e.WakeTimeout == "" {
e.WakeTimeout = WakeTimeoutDefault
e.WakeTimeout = common.WakeTimeoutDefault
}
if e.StopTimeout == "" {
e.StopTimeout = StopTimeoutDefault
e.StopTimeout = common.StopTimeoutDefault
}
if e.StopMethod == "" {
e.StopMethod = StopMethodDefault
e.StopMethod = common.StopMethodDefault
}
e.Port = joinPorts(lp, pp, extra)

View File

@@ -7,7 +7,7 @@ import (
)
// Recursively lists all files in a directory until `maxDepth` is reached
// Returns a slice of file paths relative to `dir`
// Returns a slice of file paths relative to `dir`.
func ListFiles(dir string, maxDepth int) ([]string, error) {
entries, err := os.ReadDir(dir)
if err != nil {

View File

@@ -1,10 +1,11 @@
package functional
import (
"github.com/puzpuzpuz/xsync/v3"
"gopkg.in/yaml.v3"
"sync"
"github.com/puzpuzpuz/xsync/v3"
E "github.com/yusing/go-proxy/internal/error"
"gopkg.in/yaml.v3"
)
type Map[KT comparable, VT any] struct {
@@ -23,6 +24,17 @@ func NewMapFrom[KT comparable, VT any](m map[KT]VT) (res Map[KT, VT]) {
return
}
// MapFind iterates over the map and returns the first value
// that satisfies the given criteria. The iteration is stopped
// once a value is found. If no value satisfies the criteria,
// the function returns the zero value of CT.
//
// The criteria function takes a value of type VT and returns a
// value of type CT and a boolean indicating whether the value
// satisfies the criteria. The boolean value is used to determine
// whether the iteration should be stopped.
//
// The function is safe for concurrent use.
func MapFind[KT comparable, VT, CT any](m Map[KT, VT], criteria func(VT) (CT, bool)) (_ CT) {
result := make(chan CT, 1)
@@ -47,13 +59,15 @@ func MapFind[KT comparable, VT, CT any](m Map[KT, VT], criteria func(VT) (CT, bo
}
}
// MergeFrom add contents from another `Map`, ignore duplicated keys
// MergeFrom merges the contents of another Map into this one, ignoring duplicated keys.
//
// Parameters:
// - other: `Map` of values to add from
//
// Return:
// - Map: a `Map` of duplicated keys-value pairs
// other: Map of values to add from
//
// Returns:
//
// Map of duplicated keys-value pairs
func (m Map[KT, VT]) MergeFrom(other Map[KT, VT]) Map[KT, VT] {
dups := NewMapOf[KT, VT]()
@@ -68,6 +82,15 @@ func (m Map[KT, VT]) MergeFrom(other Map[KT, VT]) Map[KT, VT] {
return dups
}
// RangeAll calls the given function for each key-value pair in the map.
//
// Parameters:
//
// do: function to call for each key-value pair
//
// Returns:
//
// nothing
func (m Map[KT, VT]) RangeAll(do func(k KT, v VT)) {
m.Range(func(k KT, v VT) bool {
do(k, v)
@@ -75,6 +98,39 @@ func (m Map[KT, VT]) RangeAll(do func(k KT, v VT)) {
})
}
// RangeAllParallel calls the given function for each key-value pair in the map,
// in parallel. The map is not safe for modification from within the function.
//
// Parameters:
//
// do: function to call for each key-value pair
//
// Returns:
//
// nothing
func (m Map[KT, VT]) RangeAllParallel(do func(k KT, v VT)) {
var wg sync.WaitGroup
wg.Add(m.Size())
m.Range(func(k KT, v VT) bool {
go func() {
do(k, v)
wg.Done()
}()
return true
})
wg.Wait()
}
// RemoveAll removes all key-value pairs from the map where the value matches the given criteria.
//
// Parameters:
//
// criteria: function to determine whether a value should be removed
//
// Returns:
//
// nothing
func (m Map[KT, VT]) RemoveAll(criteria func(VT) bool) {
m.Range(func(k KT, v VT) bool {
if criteria(v) {
@@ -89,6 +145,17 @@ func (m Map[KT, VT]) Has(k KT) bool {
return ok
}
// UnmarshalFromYAML unmarshals a yaml byte slice into the map.
//
// It overwrites all existing key-value pairs in the map.
//
// Parameters:
//
// data: yaml byte slice to unmarshal
//
// Returns:
//
// error: if the unmarshaling fails
func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) E.NestedError {
if m.Size() != 0 {
return E.FailedWhy("unmarshal from yaml", "map is not empty")

View File

@@ -38,6 +38,10 @@ func (s *Slice[T]) Iterator() []T {
return s.s
}
func (s *Slice[T]) Get(i int) T {
return s.s[i]
}
func (s *Slice[T]) Set(i int, v T) {
s.s[i] = v
}
@@ -76,6 +80,20 @@ func (s *Slice[T]) SafePop() T {
return s.Pop()
}
func (s *Slice[T]) Remove(criteria func(T) bool) {
for i, v2 := range s.s {
if criteria(v2) {
s.s = append(s.s[:i], s.s[i+1:]...)
}
}
}
func (s *Slice[T]) SafeRemove(criteria func(T) bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.Remove(criteria)
}
func (s *Slice[T]) ForEach(do func(T)) {
for _, v := range s.s {
do(v)

View File

@@ -6,12 +6,13 @@ import (
"errors"
"io"
"os"
"sync"
"syscall"
E "github.com/yusing/go-proxy/internal/error"
)
// TODO: move to "utils/io"
// TODO: move to "utils/io".
type (
FileReader struct {
Path string
@@ -28,10 +29,8 @@ type (
}
Pipe struct {
r ContextReader
w ContextWriter
ctx context.Context
cancel context.CancelFunc
r ContextReader
w ContextWriter
}
BidirectionalPipe struct {
@@ -59,12 +58,9 @@ func (w *ContextWriter) Write(p []byte) (int, error) {
}
func NewPipe(ctx context.Context, r io.ReadCloser, w io.WriteCloser) *Pipe {
_, cancel := context.WithCancel(ctx)
return &Pipe{
r: ContextReader{ctx: ctx, Reader: r},
w: ContextWriter{ctx: ctx, Writer: w},
ctx: ctx,
cancel: cancel,
r: ContextReader{ctx: ctx, Reader: r},
w: ContextWriter{ctx: ctx, Writer: w},
}
}
@@ -87,22 +83,20 @@ func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.Re
}
}
func NewBidirectionalPipeIntermediate(ctx context.Context, listener io.ReadCloser, client io.ReadWriteCloser, target io.ReadWriteCloser) *BidirectionalPipe {
return &BidirectionalPipe{
pSrcDst: NewPipe(ctx, listener, client),
pDstSrc: NewPipe(ctx, client, target),
}
}
func (p BidirectionalPipe) Start() error {
errCh := make(chan error, 2)
var wg sync.WaitGroup
wg.Add(2)
b := E.NewBuilder("bidirectional pipe error")
go func() {
errCh <- p.pSrcDst.Start()
b.AddE(p.pSrcDst.Start())
wg.Done()
}()
go func() {
errCh <- p.pDstSrc.Start()
b.AddE(p.pDstSrc.Start())
wg.Done()
}()
return E.JoinE("bidirectional pipe error", <-errCh, <-errCh).Error()
wg.Wait()
return b.Build().Error()
}
func Copy(dst *ContextWriter, src *ContextReader) error {
@@ -114,7 +108,7 @@ func Copy2(ctx context.Context, dst io.Writer, src io.Reader) error {
return Copy(&ContextWriter{ctx: ctx, Writer: dst}, &ContextReader{ctx: ctx, Reader: src})
}
func LoadJson[T any](path string, pointer *T) E.NestedError {
func LoadJSON[T any](path string, pointer *T) E.NestedError {
data, err := E.Check(os.ReadFile(path))
if err.HasError() {
return err
@@ -122,7 +116,7 @@ func LoadJson[T any](path string, pointer *T) E.NestedError {
return E.From(json.Unmarshal(data, pointer))
}
func SaveJson[T any](path string, pointer *T, perm os.FileMode) E.NestedError {
func SaveJSON[T any](path string, pointer *T, perm os.FileMode) E.NestedError {
data, err := E.Check(json.Marshal(pointer))
if err.HasError() {
return err

View File

@@ -3,6 +3,7 @@ package utils
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"reflect"
"strconv"
@@ -14,10 +15,12 @@ import (
"gopkg.in/yaml.v3"
)
type SerializedObject = map[string]any
type Converter interface {
ConvertFrom(value any) (any, E.NestedError)
}
type (
SerializedObject = map[string]any
Converter interface {
ConvertFrom(value any) (any, E.NestedError)
}
)
func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError {
var i any
@@ -37,11 +40,16 @@ func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError {
return nil
}
errors := E.NewBuilder("yaml validation error")
for _, e := range err.(*jsonschema.ValidationError).Causes {
errors.AddE(e)
var valErr *jsonschema.ValidationError
if !errors.As(err, &valErr) {
return E.UnexpectedError(err)
}
return errors.Build()
b := E.NewBuilder("yaml validation error")
for _, e := range valErr.Causes {
b.AddE(e)
}
return b.Build()
}
// Serialize converts the given data into a map[string]any representation.
@@ -80,7 +88,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
result[key.String()] = value.MapIndex(key).Interface()
}
case reflect.Struct:
for i := 0; i < value.NumField(); i++ {
for i := range value.NumField() {
field := value.Type().Field(i)
if !field.IsExported() {
continue
@@ -91,9 +99,10 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
}
// If the json tag is not empty, use it as the key
if jsonTag != "" {
switch {
case jsonTag != "":
result[jsonTag] = value.Field(i).Interface()
} else if field.Anonymous {
case field.Anonymous:
// If the field is an embedded struct, add its fields to the result
fieldMap, err := Serialize(value.Field(i).Interface())
if err != nil {
@@ -102,7 +111,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
for k, v := range fieldMap {
result[k] = v
}
} else {
default:
result[field.Name] = value.Field(i).Interface()
}
}
@@ -147,11 +156,11 @@ func Deserialize(src SerializedObject, dst any) E.NestedError {
// TODO: use E.Builder to collect errors from all fields
if dstV.Kind() == reflect.Struct {
switch dstV.Kind() {
case reflect.Struct:
mapping := make(map[string]reflect.Value)
for i := 0; i < dstV.NumField(); i++ {
field := dstT.Field(i)
mapping[ToLowerNoSnake(field.Name)] = dstV.Field(i)
for _, field := range reflect.VisibleFields(dstT) {
mapping[ToLowerNoSnake(field.Name)] = dstV.FieldByName(field.Name)
}
for k, v := range src {
if field, ok := mapping[ToLowerNoSnake(k)]; ok {
@@ -163,7 +172,7 @@ func Deserialize(src SerializedObject, dst any) E.NestedError {
return E.Unexpected("field", k).Subjectf("%T", dst)
}
}
} else if dstV.Kind() == reflect.Map && dstT.Key().Kind() == reflect.String {
case reflect.Map:
if dstV.IsNil() {
dstV.Set(reflect.MakeMap(dstT))
}
@@ -175,8 +184,7 @@ func Deserialize(src SerializedObject, dst any) E.NestedError {
}
dstV.SetMapIndex(reflect.ValueOf(ToLowerNoSnake(k)), tmp)
}
return nil
} else {
default:
return E.Unsupported("target type", fmt.Sprintf("%T", dst))
}
@@ -322,7 +330,7 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.N
var tmp any
switch dst.Kind() {
case reflect.Slice:
// one liner is comma seperated list
// one liner is comma separated list
if len(lines) == 0 {
dst.Set(reflect.ValueOf(CommaSeperatedList(src)))
return
@@ -363,7 +371,7 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.N
return true, Convert(reflect.ValueOf(tmp), dst)
}
func DeserializeJson(j map[string]string, target any) E.NestedError {
func DeserializeJSON(j map[string]string, target any) E.NestedError {
data, err := E.Check(json.Marshal(j))
if err != nil {
return err

View File

@@ -9,7 +9,7 @@ import (
"golang.org/x/text/language"
)
// TODO: support other languages
// TODO: support other languages.
var titleCaser = cases.Title(language.AmericanEnglish)
func CommaSeperatedList(s string) []string {
@@ -31,3 +31,7 @@ func ExtractPort(fullURL string) (int, error) {
}
return strconv.Atoi(url.Port())
}
func PortString(port uint16) string {
return strconv.FormatUint(uint64(port), 10)
}

View File

@@ -92,7 +92,6 @@ func ExpectType[T any](t *testing.T, got any) (_ T) {
_, ok := got.(T)
if !ok {
t.Fatalf("expected type %s, got %s", tExpect, reflect.TypeOf(got).Elem())
t.FailNow()
return
}
return got.(T)

View File

@@ -7,10 +7,12 @@ import (
"github.com/yusing/go-proxy/internal/common"
)
var configDirWatcher *dirWatcher
var configDirWatcherMu sync.Mutex
var (
configDirWatcher *DirWatcher
configDirWatcherMu sync.Mutex
)
// create a new file watcher for file under ConfigBasePath
// create a new file watcher for file under ConfigBasePath.
func NewConfigFileWatcher(filename string) Watcher {
configDirWatcherMu.Lock()
defer configDirWatcherMu.Unlock()

View File

@@ -13,7 +13,7 @@ import (
"github.com/yusing/go-proxy/internal/watcher/events"
)
type dirWatcher struct {
type DirWatcher struct {
dir string
w *fsnotify.Watcher
@@ -26,7 +26,7 @@ type dirWatcher struct {
ctx context.Context
}
func NewDirectoryWatcher(ctx context.Context, dirPath string) *dirWatcher {
func NewDirectoryWatcher(ctx context.Context, dirPath string) *DirWatcher {
//! subdirectories are not watched
w, err := fsnotify.NewWatcher()
if err != nil {
@@ -35,7 +35,7 @@ func NewDirectoryWatcher(ctx context.Context, dirPath string) *dirWatcher {
if err = w.Add(dirPath); err != nil {
logrus.Panicf("unable to create fs watcher: %s", err)
}
helper := &dirWatcher{
helper := &DirWatcher{
dir: dirPath,
w: w,
fwMap: F.NewMapOf[string, *fileWatcher](),
@@ -47,11 +47,11 @@ func NewDirectoryWatcher(ctx context.Context, dirPath string) *dirWatcher {
return helper
}
func (h *dirWatcher) Events(_ context.Context) (<-chan Event, <-chan E.NestedError) {
func (h *DirWatcher) Events(_ context.Context) (<-chan Event, <-chan E.NestedError) {
return h.eventCh, h.errCh
}
func (h *dirWatcher) Add(relPath string) *fileWatcher {
func (h *DirWatcher) Add(relPath string) Watcher {
h.mu.Lock()
defer h.mu.Unlock()
@@ -85,7 +85,7 @@ func (h *dirWatcher) Add(relPath string) *fileWatcher {
return s
}
func (h *dirWatcher) start() {
func (h *DirWatcher) start() {
defer close(h.eventCh)
defer h.w.Close()