restructured the project to comply community guideline, for others check release note

This commit is contained in:
yusing
2024-09-28 09:51:34 +08:00
parent 4120fd8d1c
commit 90487bfde6
134 changed files with 1110 additions and 625 deletions

57
internal/api/handler.go Normal file
View File

@@ -0,0 +1,57 @@
package api
import (
"fmt"
"net/http"
v1 "github.com/yusing/go-proxy/internal/api/v1"
"github.com/yusing/go-proxy/internal/api/v1/error_page"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
)
type ServeMux struct{ *http.ServeMux }
func NewServeMux() ServeMux {
return ServeMux{http.NewServeMux()}
}
func (mux ServeMux) HandleFunc(method, endpoint string, handler http.HandlerFunc) {
mux.ServeMux.HandleFunc(fmt.Sprintf("%s %s", method, endpoint), checkHost(handler))
}
func NewHandler(cfg *config.Config) http.Handler {
mux := NewServeMux()
mux.HandleFunc("GET", "/v1", v1.Index)
mux.HandleFunc("GET", "/v1/checkhealth", wrap(cfg, v1.CheckHealth))
mux.HandleFunc("HEAD", "/v1/checkhealth", wrap(cfg, v1.CheckHealth))
mux.HandleFunc("POST", "/v1/reload", wrap(cfg, v1.Reload))
mux.HandleFunc("GET", "/v1/list", wrap(cfg, v1.List))
mux.HandleFunc("GET", "/v1/list/{what}", wrap(cfg, v1.List))
mux.HandleFunc("GET", "/v1/file", v1.GetFileContent)
mux.HandleFunc("GET", "/v1/file/{filename}", v1.GetFileContent)
mux.HandleFunc("POST", "/v1/file/{filename}", v1.SetFileContent)
mux.HandleFunc("PUT", "/v1/file/{filename}", v1.SetFileContent)
mux.HandleFunc("GET", "/v1/stats", wrap(cfg, v1.Stats))
mux.HandleFunc("GET", "/v1/error_page", error_page.GetHandleFunc())
return mux
}
// allow only requests to API server with host matching common.APIHTTPAddr
func checkHost(f http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Host != common.APIHTTPAddr {
Logger.Warnf("invalid request to API server with host: %s, expected: %s", r.Host, common.APIHTTPAddr)
w.WriteHeader(http.StatusNotFound)
return
}
f(w, r)
}
}
func wrap(cfg *config.Config, f func(cfg *config.Config, w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
f(cfg, w, r)
}
}

View File

@@ -0,0 +1,42 @@
package v1
import (
"fmt"
"net/http"
"strings"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/config"
R "github.com/yusing/go-proxy/internal/route"
)
func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
target := r.FormValue("target")
if target == "" {
U.HandleErr(w, r, U.ErrMissingKey("target"), http.StatusBadRequest)
return
}
var ok bool
route := cfg.FindRoute(target)
switch {
case route == nil:
U.HandleErr(w, r, U.ErrNotFound("target", target), http.StatusNotFound)
return
case route.Type() == R.RouteTypeReverseProxy:
ok = U.IsSiteHealthy(route.URL().String())
case route.Type() == R.RouteTypeStream:
entry := route.Entry()
ok = U.IsStreamHealthy(
strings.Split(entry.Scheme, ":")[1], // target scheme
fmt.Sprintf("%s:%v", entry.Host, strings.Split(entry.Port, ":")[1]),
)
}
if ok {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusRequestTimeout)
}
}

View File

@@ -0,0 +1,88 @@
package error_page
import (
"context"
"fmt"
"os"
"path"
"sync"
api "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
W "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events"
)
const errPagesBasePath = common.ErrorPagesBasePath
var setup = sync.OnceFunc(func() {
dirWatcher = W.NewDirectoryWatcher(context.Background(), errPagesBasePath)
loadContent()
go watchDir()
})
func GetStaticFile(filename string) ([]byte, bool) {
return fileContentMap.Load(filename)
}
// try <statusCode>.html -> 404.html -> not ok
func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
content, ok = fileContentMap.Load(fmt.Sprintf("%d.html", statusCode))
if !ok && statusCode != 404 {
return fileContentMap.Load("404.html")
}
return
}
func loadContent() {
files, err := U.ListFiles(errPagesBasePath, 0)
if err != nil {
api.Logger.Error(err)
return
}
for _, file := range files {
if fileContentMap.Has(file) {
continue
}
content, err := os.ReadFile(file)
if err != nil {
api.Logger.Errorf("failed to read error page resource %s: %s", file, err)
continue
}
file = path.Base(file)
api.Logger.Infof("error page resource %s loaded", file)
fileContentMap.Store(file, content)
}
}
func watchDir() {
eventCh, errCh := dirWatcher.Events(context.Background())
for {
select {
case event, ok := <-eventCh:
if !ok {
return
}
filename := event.ActorName
switch event.Action {
case events.ActionFileWritten:
fileContentMap.Delete(filename)
loadContent()
case events.ActionFileDeleted:
fileContentMap.Delete(filename)
api.Logger.Infof("error page resource %s deleted", filename)
case events.ActionFileRenamed:
api.Logger.Infof("error page resource %s deleted", filename)
fileContentMap.Delete(filename)
loadContent()
}
case err := <-errCh:
api.Logger.Errorf("error watching error page directory: %s", err)
}
}
}
var dirWatcher W.Watcher
var fileContentMap = F.NewMapOf[string, []byte]()

View File

@@ -0,0 +1,25 @@
package error_page
import "net/http"
func GetHandleFunc() http.HandlerFunc {
setup()
return serveHTTP
}
func serveHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if r.URL.Path == "/" {
http.Error(w, "invalid path", http.StatusNotFound)
return
}
content, ok := fileContentMap.Load(r.URL.Path)
if !ok {
http.Error(w, "404 not found", http.StatusNotFound)
return
}
w.Write(content)
}

59
internal/api/v1/file.go Normal file
View File

@@ -0,0 +1,59 @@
package v1
import (
"io"
"net/http"
"os"
"path"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/proxy/provider"
)
func GetFileContent(w http.ResponseWriter, r *http.Request) {
filename := r.PathValue("filename")
if filename == "" {
filename = common.ConfigFileName
}
content, err := os.ReadFile(path.Join(common.ConfigBasePath, filename))
if err != nil {
U.HandleErr(w, r, err)
return
}
w.Write(content)
}
func SetFileContent(w http.ResponseWriter, r *http.Request) {
filename := r.PathValue("filename")
if filename == "" {
U.HandleErr(w, r, U.ErrMissingKey("filename"), http.StatusBadRequest)
return
}
content, err := io.ReadAll(r.Body)
if err != nil {
U.HandleErr(w, r, err)
return
}
var validateErr E.NestedError
if filename == common.ConfigFileName {
validateErr = config.Validate(content)
} else {
validateErr = provider.Validate(content)
}
if validateErr != nil {
U.RespondJson(w, validateErr.JSONObject(), http.StatusBadRequest)
return
}
err = os.WriteFile(path.Join(common.ConfigBasePath, filename), content, 0644)
if err != nil {
U.HandleErr(w, r, err)
return
}
w.WriteHeader(http.StatusOK)
}

7
internal/api/v1/index.go Normal file
View File

@@ -0,0 +1,7 @@
package v1
import "net/http"
func Index(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("API ready"))
}

61
internal/api/v1/list.go Normal file
View File

@@ -0,0 +1,61 @@
package v1
import (
"encoding/json"
"net/http"
"os"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
)
func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
what := r.PathValue("what")
if what == "" {
what = "routes"
}
switch what {
case "routes":
listRoutes(cfg, w, r)
case "config_files":
listConfigFiles(w, r)
default:
U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest)
}
}
func listRoutes(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
routes := cfg.RoutesByAlias()
typeFilter := r.FormValue("type")
if typeFilter != "" {
for k, v := range routes {
if v["type"] != typeFilter {
delete(routes, k)
}
}
}
if err := U.RespondJson(w, routes); err != nil {
U.HandleErr(w, r, err)
}
}
func listConfigFiles(w http.ResponseWriter, r *http.Request) {
files, err := os.ReadDir(common.ConfigBasePath)
if err != nil {
U.HandleErr(w, r, err)
return
}
filenames := make([]string, len(files))
for i, f := range files {
filenames[i] = f.Name()
}
resp, err := json.Marshal(filenames)
if err != nil {
U.HandleErr(w, r, err)
return
}
w.Write(resp)
}

16
internal/api/v1/reload.go Normal file
View File

@@ -0,0 +1,16 @@
package v1
import (
"net/http"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/config"
)
func Reload(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
if err := cfg.Reload(); err != nil {
U.RespondJson(w, err.JSONObject(), http.StatusInternalServerError)
} else {
w.WriteHeader(http.StatusOK)
}
}

20
internal/api/v1/stats.go Normal file
View File

@@ -0,0 +1,20 @@
package v1
import (
"net/http"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/config"
"github.com/yusing/go-proxy/internal/server"
"github.com/yusing/go-proxy/internal/utils"
)
func Stats(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
stats := map[string]any{
"proxies": cfg.Statistics(),
"uptime": utils.FormatDuration(server.GetProxyServer().Uptime()),
}
if err := U.RespondJson(w, stats); err != nil {
U.HandleErr(w, r, err)
}
}

View File

@@ -0,0 +1,34 @@
package utils
import (
"errors"
"fmt"
"net/http"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error"
)
var Logger = logrus.WithField("module", "api")
func HandleErr(w http.ResponseWriter, r *http.Request, origErr error, code ...int) {
err := E.From(origErr).Subjectf("%s %s", r.Method, r.URL)
Logger.Error(err)
if len(code) > 0 {
http.Error(w, err.String(), code[0])
return
}
http.Error(w, err.String(), http.StatusInternalServerError)
}
func ErrMissingKey(k string) error {
return errors.New("missing key '" + k + "' in query or request body")
}
func ErrInvalidKey(k string) error {
return errors.New("invalid key '" + k + "' in query or request body")
}
func ErrNotFound(k, v string) error {
return fmt.Errorf("key %q with value %q not found", k, v)
}

View File

@@ -0,0 +1,33 @@
package utils
import (
"net"
"net/http"
"github.com/yusing/go-proxy/internal/common"
)
func IsSiteHealthy(url string) bool {
// try HEAD first
// if HEAD is not allowed, try GET
resp, err := httpClient.Head(url)
if resp != nil {
resp.Body.Close()
}
if err != nil && resp != nil && resp.StatusCode == http.StatusMethodNotAllowed {
_, err = httpClient.Get(url)
}
if resp != nil {
resp.Body.Close()
}
return err == nil
}
func IsStreamHealthy(scheme, address string) bool {
conn, err := net.DialTimeout(scheme, address, common.DialTimeout)
if err != nil {
return false
}
conn.Close()
return true
}

View File

@@ -0,0 +1,23 @@
package utils
import (
"crypto/tls"
"net"
"net/http"
"github.com/yusing/go-proxy/internal/common"
)
var httpClient = &http.Client{
Timeout: common.ConnectionTimeout,
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DisableKeepAlives: true,
ForceAttemptHTTP2: true,
DialContext: (&net.Dialer{
Timeout: common.DialTimeout,
KeepAlive: common.KeepAlive, // this is different from DisableKeepAlives
}).DialContext,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}

View File

@@ -0,0 +1,31 @@
package utils
import (
"fmt"
"io"
"net/http"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
)
func ReloadServer() E.NestedError {
resp, err := httpClient.Post(fmt.Sprintf("%s/v1/reload", common.APIHTTPURL), "", nil)
if err != nil {
return E.From(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
failure := E.Failure("server reload").Extraf("status code: %v", resp.StatusCode)
b, err := io.ReadAll(resp.Body)
if err != nil {
return failure.Extraf("unable to read response body: %s", err)
}
reloadErr, ok := E.FromJSON(b)
if ok {
return E.Join("reload success, but server returned error", reloadErr)
}
return failure.Extraf("unable to read response body")
}
return nil
}

View File

@@ -0,0 +1,20 @@
package utils
import (
"encoding/json"
"net/http"
)
func RespondJson(w http.ResponseWriter, data any, code ...int) error {
if len(code) > 0 {
w.WriteHeader(code[0])
}
w.Header().Set("Content-Type", "application/json")
j, err := json.MarshalIndent(data, "", " ")
if err != nil {
return err
} else {
w.Write(j)
}
return nil
}

View File

@@ -0,0 +1,75 @@
package autocert
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/lego"
E "github.com/yusing/go-proxy/internal/error"
M "github.com/yusing/go-proxy/internal/models"
)
type Config M.AutoCertConfig
func NewConfig(cfg *M.AutoCertConfig) *Config {
if cfg.CertPath == "" {
cfg.CertPath = CertFileDefault
}
if cfg.KeyPath == "" {
cfg.KeyPath = KeyFileDefault
}
if cfg.Provider == "" {
cfg.Provider = ProviderLocal
}
return (*Config)(cfg)
}
func (cfg *Config) GetProvider() (provider *Provider, res E.NestedError) {
b := E.NewBuilder("unable to initialize autocert")
defer b.To(&res)
if cfg.Provider != ProviderLocal {
if len(cfg.Domains) == 0 {
b.Addf("%s", "no domains specified")
}
if cfg.Provider == "" {
b.Addf("%s", "no provider specified")
}
if cfg.Email == "" {
b.Addf("%s", "no email specified")
}
// check if provider is implemented
_, ok := providersGenMap[cfg.Provider]
if !ok {
b.Addf("unknown provider: %q", cfg.Provider)
}
}
if b.HasError() {
return
}
privKey, err := E.Check(ecdsa.GenerateKey(elliptic.P256(), rand.Reader))
if err.HasError() {
b.Add(E.FailWith("generate private key", err))
return
}
user := &User{
Email: cfg.Email,
key: privKey,
}
legoCfg := lego.NewConfig(user)
legoCfg.Certificate.KeyType = certcrypto.RSA2048
provider = &Provider{
cfg: cfg,
user: user,
legoCfg: legoCfg,
}
return
}

View File

@@ -0,0 +1,40 @@
package autocert
import (
"errors"
"github.com/go-acme/lego/v4/providers/dns/clouddns"
"github.com/go-acme/lego/v4/providers/dns/cloudflare"
"github.com/go-acme/lego/v4/providers/dns/duckdns"
"github.com/go-acme/lego/v4/providers/dns/ovh"
"github.com/sirupsen/logrus"
)
const (
certBasePath = "certs/"
CertFileDefault = certBasePath + "cert.crt"
KeyFileDefault = certBasePath + "priv.key"
RegistrationFile = certBasePath + "registration.json"
)
const (
ProviderLocal = "local"
ProviderCloudflare = "cloudflare"
ProviderClouddns = "clouddns"
ProviderDuckdns = "duckdns"
ProviderOVH = "ovh"
)
var providersGenMap = map[string]ProviderGenerator{
ProviderLocal: providerGenerator(NewDummyDefaultConfig, NewDummyDNSProviderConfig),
ProviderCloudflare: providerGenerator(cloudflare.NewDefaultConfig, cloudflare.NewDNSProviderConfig),
ProviderClouddns: providerGenerator(clouddns.NewDefaultConfig, clouddns.NewDNSProviderConfig),
ProviderDuckdns: providerGenerator(duckdns.NewDefaultConfig, duckdns.NewDNSProviderConfig),
ProviderOVH: providerGenerator(ovh.NewDefaultConfig, ovh.NewDNSProviderConfig),
}
var (
ErrGetCertFailure = errors.New("get certificate failed")
)
var logger = logrus.WithField("module", "autocert")

View File

@@ -0,0 +1,20 @@
package autocert
type DummyConfig struct{}
type DummyProvider struct{}
func NewDummyDefaultConfig() *DummyConfig {
return &DummyConfig{}
}
func NewDummyDNSProviderConfig(*DummyConfig) (*DummyProvider, error) {
return &DummyProvider{}, nil
}
func (DummyProvider) Present(domain, token, keyAuth string) error {
return nil
}
func (DummyProvider) CleanUp(domain, token, keyAuth string) error {
return nil
}

View File

@@ -0,0 +1,295 @@
package autocert
import (
"context"
"crypto/tls"
"crypto/x509"
"os"
"path"
"reflect"
"sort"
"time"
"github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/registration"
E "github.com/yusing/go-proxy/internal/error"
M "github.com/yusing/go-proxy/internal/models"
U "github.com/yusing/go-proxy/internal/utils"
)
type Provider struct {
cfg *Config
user *User
legoCfg *lego.Config
client *lego.Client
tlsCert *tls.Certificate
certExpiries CertExpiries
}
type ProviderGenerator func(M.AutocertProviderOpt) (challenge.Provider, E.NestedError)
type CertExpiries map[string]time.Time
func (p *Provider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
if p.tlsCert == nil {
return nil, ErrGetCertFailure
}
return p.tlsCert, nil
}
func (p *Provider) GetName() string {
return p.cfg.Provider
}
func (p *Provider) GetCertPath() string {
return p.cfg.CertPath
}
func (p *Provider) GetKeyPath() string {
return p.cfg.KeyPath
}
func (p *Provider) GetExpiries() CertExpiries {
return p.certExpiries
}
func (p *Provider) ObtainCert() (res E.NestedError) {
b := E.NewBuilder("failed to obtain certificate")
defer b.To(&res)
if p.cfg.Provider == ProviderLocal {
return nil
}
if p.client == nil {
if err := p.initClient(); err.HasError() {
b.Add(E.FailWith("init autocert client", err))
return
}
}
if p.user.Registration == nil {
if err := p.registerACME(); err.HasError() {
b.Add(E.FailWith("register ACME", err))
return
}
}
client := p.client
req := certificate.ObtainRequest{
Domains: p.cfg.Domains,
Bundle: true,
}
cert, err := E.Check(client.Certificate.Obtain(req))
if err.HasError() {
b.Add(err)
return
}
if err = p.saveCert(cert); err.HasError() {
b.Add(E.FailWith("save certificate", err))
return
}
tlsCert, err := E.Check(tls.X509KeyPair(cert.Certificate, cert.PrivateKey))
if err.HasError() {
b.Add(E.FailWith("parse obtained certificate", err))
return
}
expiries, err := getCertExpiries(&tlsCert)
if err.HasError() {
b.Add(E.FailWith("get certificate expiry", err))
return
}
p.tlsCert = &tlsCert
p.certExpiries = expiries
return nil
}
func (p *Provider) LoadCert() E.NestedError {
cert, err := E.Check(tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath))
if err.HasError() {
return err
}
expiries, err := getCertExpiries(&cert)
if err.HasError() {
return err
}
p.tlsCert = &cert
p.certExpiries = expiries
logger.Infof("next renewal in %v", U.FormatDuration(time.Until(p.ShouldRenewOn())))
return p.renewIfNeeded()
}
func (p *Provider) ShouldRenewOn() time.Time {
for _, expiry := range p.certExpiries {
return expiry.AddDate(0, -1, 0) // 1 month before
}
// this line should never be reached
panic("no certificate available")
}
func (p *Provider) ScheduleRenewal(ctx context.Context) {
if p.GetName() == ProviderLocal {
return
}
logger.Debug("started renewal scheduler")
defer logger.Debug("renewal scheduler stopped")
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C: // check every 5 seconds
if err := p.renewIfNeeded(); err.HasError() {
logger.Warn(err)
}
}
}
}
func (p *Provider) initClient() E.NestedError {
legoClient, err := E.Check(lego.NewClient(p.legoCfg))
if err.HasError() {
return E.FailWith("create lego client", err)
}
legoProvider, err := providersGenMap[p.cfg.Provider](p.cfg.Options)
if err.HasError() {
return E.FailWith("create lego provider", err)
}
err = E.From(legoClient.Challenge.SetDNS01Provider(legoProvider))
if err.HasError() {
return E.FailWith("set challenge provider", err)
}
p.client = legoClient
return nil
}
func (p *Provider) registerACME() E.NestedError {
if p.user.Registration != nil {
return nil
}
reg, err := E.Check(p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}))
if err.HasError() {
return err
}
p.user.Registration = reg
return nil
}
func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError {
//* This should have been done in setup
//* but double check is always a good choice
_, err := os.Stat(path.Dir(p.cfg.CertPath))
if err != nil {
if os.IsNotExist(err) {
if err = os.MkdirAll(path.Dir(p.cfg.CertPath), 0o755); err != nil {
return E.FailWith("create cert directory", err)
}
} else {
return E.FailWith("stat cert directory", err)
}
}
err = os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw-------
if err != nil {
return E.FailWith("write key file", err)
}
err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0o644) // -rw-r--r--
if err != nil {
return E.FailWith("write cert file", err)
}
return nil
}
func (p *Provider) certState() CertState {
if time.Now().After(p.ShouldRenewOn()) {
return CertStateExpired
}
certDomains := make([]string, len(p.certExpiries))
wantedDomains := make([]string, len(p.cfg.Domains))
i := 0
for domain := range p.certExpiries {
certDomains[i] = domain
i++
}
copy(wantedDomains, p.cfg.Domains)
sort.Strings(wantedDomains)
sort.Strings(certDomains)
if !reflect.DeepEqual(certDomains, wantedDomains) {
logger.Debugf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains)
return CertStateMismatch
}
return CertStateValid
}
func (p *Provider) renewIfNeeded() E.NestedError {
if p.cfg.Provider == ProviderLocal {
return nil
}
switch p.certState() {
case CertStateExpired:
logger.Info("certs expired, renewing")
case CertStateMismatch:
logger.Info("cert domains mismatch with config, renewing")
default:
return nil
}
if err := p.ObtainCert(); err.HasError() {
return E.FailWith("renew certificate", err)
}
return nil
}
func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) {
r := make(CertExpiries, len(cert.Certificate))
for _, cert := range cert.Certificate {
x509Cert, err := E.Check(x509.ParseCertificate(cert))
if err.HasError() {
return nil, E.FailWith("parse certificate", err)
}
if x509Cert.IsCA {
continue
}
r[x509Cert.Subject.CommonName] = x509Cert.NotAfter
for i := range x509Cert.DNSNames {
r[x509Cert.DNSNames[i]] = x509Cert.NotAfter
}
}
return r, nil
}
func providerGenerator[CT any, PT challenge.Provider](
defaultCfg func() *CT,
newProvider func(*CT) (PT, error),
) ProviderGenerator {
return func(opt M.AutocertProviderOpt) (challenge.Provider, E.NestedError) {
cfg := defaultCfg()
err := U.Deserialize(opt, cfg)
if err.HasError() {
return nil, err
}
p, err := E.Check(newProvider(cfg))
if err.HasError() {
return nil, err
}
return p, nil
}
}

View File

@@ -0,0 +1,50 @@
package provider_test
import (
"testing"
"github.com/go-acme/lego/v4/providers/dns/ovh"
U "github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
"gopkg.in/yaml.v3"
)
// type Config struct {
// APIEndpoint string
// ApplicationKey string
// ApplicationSecret string
// ConsumerKey string
// OAuth2Config *OAuth2Config
// PropagationTimeout time.Duration
// PollingInterval time.Duration
// TTL int
// HTTPClient *http.Client
// }
func TestOVH(t *testing.T) {
cfg := &ovh.Config{}
testYaml := `
api_endpoint: https://eu.api.ovh.com
application_key: <application_key>
application_secret: <application_secret>
consumer_key: <consumer_key>
oauth2_config:
client_id: <client_id>
client_secret: <client_secret>
`
cfgExpected := &ovh.Config{
APIEndpoint: "https://eu.api.ovh.com",
ApplicationKey: "<application_key>",
ApplicationSecret: "<application_secret>",
ConsumerKey: "<consumer_key>",
OAuth2Config: &ovh.OAuth2Config{ClientID: "<client_id>", ClientSecret: "<client_secret>"},
}
testYaml = testYaml[1:] // remove first \n
opt := make(map[string]any)
ExpectNoError(t, yaml.Unmarshal([]byte(testYaml), opt))
ExpectTrue(t, U.Deserialize(opt, cfg).NoError())
ExpectDeepEqual(t, cfg, cfgExpected)
}

View File

@@ -0,0 +1,29 @@
package autocert
import (
"context"
"os"
E "github.com/yusing/go-proxy/internal/error"
)
func (p *Provider) Setup(ctx context.Context) (err E.NestedError) {
if err = p.LoadCert(); err != nil {
if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist
return err
}
logger.Debug("obtaining cert due to error loading cert")
if err = p.ObtainCert(); err != nil {
return err
}
}
go p.ScheduleRenewal(ctx)
for _, expiry := range p.GetExpiries() {
logger.Infof("certificate expire on %s", expiry)
break
}
return nil
}

View File

@@ -0,0 +1,9 @@
package autocert
type CertState int
const (
CertStateValid CertState = iota
CertStateExpired
CertStateMismatch
)

22
internal/autocert/user.go Normal file
View File

@@ -0,0 +1,22 @@
package autocert
import (
"github.com/go-acme/lego/v4/registration"
"crypto"
)
type User struct {
Email string
Registration *registration.Resource
key crypto.PrivateKey
}
func (u *User) GetEmail() string {
return u.Email
}
func (u *User) GetRegistration() *registration.Resource {
return u.Registration
}
func (u *User) GetPrivateKey() crypto.PrivateKey {
return u.key
}

53
internal/common/args.go Normal file
View File

@@ -0,0 +1,53 @@
package common
import (
"flag"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error"
)
type Args struct {
Command string
}
const (
CommandStart = ""
CommandSetup = "setup"
CommandValidate = "validate"
CommandListConfigs = "ls-config"
CommandListRoutes = "ls-routes"
CommandReload = "reload"
CommandDebugListEntries = "debug-ls-entries"
CommandDebugListProviders = "debug-ls-providers"
)
var ValidCommands = []string{
CommandStart,
CommandSetup,
CommandValidate,
CommandListConfigs,
CommandListRoutes,
CommandReload,
CommandDebugListEntries,
CommandDebugListProviders,
}
func GetArgs() Args {
var args Args
flag.Parse()
args.Command = flag.Arg(0)
if err := validateArg(args.Command); err.HasError() {
logrus.Fatal(err)
}
return args
}
func validateArg(arg string) E.NestedError {
for _, v := range ValidCommands {
if arg == v {
return nil
}
}
return E.Invalid("argument", arg)
}

View File

@@ -0,0 +1,44 @@
package common
import (
"time"
)
const (
ConnectionTimeout = 5 * time.Second
DialTimeout = 3 * time.Second
KeepAlive = 60 * time.Second
)
// file, folder structure
const (
ConfigBasePath = "config"
ConfigFileName = "config.yml"
ConfigExampleFileName = "config.example.yml"
ConfigPath = ConfigBasePath + "/" + ConfigFileName
)
const (
SchemaBasePath = "schema"
ConfigSchemaPath = SchemaBasePath + "/config.schema.json"
FileProviderSchemaPath = SchemaBasePath + "/providers.schema.json"
)
const (
ComposeFileName = "compose.yml"
ComposeExampleFileName = "compose.example.yml"
)
const (
ErrorPagesBasePath = "error_pages"
)
const DockerHostFromEnv = "$DOCKER_HOST"
const (
IdleTimeoutDefault = "0"
WakeTimeoutDefault = "30s"
StopTimeoutDefault = "10s"
StopMethodDefault = "stop"
)

55
internal/common/env.go Normal file
View File

@@ -0,0 +1,55 @@
package common
import (
"fmt"
"net"
"os"
"github.com/sirupsen/logrus"
U "github.com/yusing/go-proxy/internal/utils"
)
var (
NoSchemaValidation = GetEnvBool("GOPROXY_NO_SCHEMA_VALIDATION")
IsDebug = GetEnvBool("GOPROXY_DEBUG")
ProxyHTTPAddr,
ProxyHTTPHost,
ProxyHTTPPort,
ProxyHTTPURL = GetAddrEnv("GOPROXY_HTTP_ADDR", ":80", "http")
ProxyHTTPSAddr,
ProxyHTTPSHost,
ProxyHTTPSPort,
ProxyHTTPSURL = GetAddrEnv("GOPROXY_HTTPS_ADDR", ":443", "https")
APIHTTPAddr,
APIHTTPHost,
APIHTTPPort,
APIHTTPURL = GetAddrEnv("GOPROXY_API_ADDR", "127.0.0.1:8888", "http")
)
func GetEnvBool(key string) bool {
return U.ParseBool(os.Getenv(key))
}
func GetEnv(key, defaultValue string) string {
value, ok := os.LookupEnv(key)
if !ok {
value = defaultValue
}
return value
}
func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL string) {
addr = GetEnv(key, defaultValue)
host, port, err := net.SplitHostPort(addr)
if err != nil {
logrus.Fatalf("Invalid address: %s", addr)
}
if host == "" {
host = "localhost"
}
fullURL = fmt.Sprintf("%s://%s:%s", scheme, host, port)
return
}

25
internal/common/http.go Normal file
View File

@@ -0,0 +1,25 @@
package common
import (
"crypto/tls"
"net"
"net/http"
"time"
)
var (
defaultDialer = net.Dialer{
Timeout: 60 * time.Second,
KeepAlive: 60 * time.Second,
}
DefaultTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: defaultDialer.DialContext,
MaxIdleConnsPerHost: 1000,
}
DefaultTransportNoTLS = func() *http.Transport {
var clone = DefaultTransport.Clone()
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
return clone
}()
)

75
internal/common/ports.go Normal file
View File

@@ -0,0 +1,75 @@
package common
var (
WellKnownHTTPPorts = map[string]bool{
"80": true,
"8000": true,
"8008": true,
"8080": true,
"3000": true,
}
ServiceNamePortMapTCP = map[string]int{
"mssql": 1433,
"mysql": 3306,
"mariadb": 3306,
"postgres": 5432,
"rabbitmq": 5672,
"redis": 6379,
"memcached": 11211,
"mongo": 27017,
"minecraft-server": 25565,
"ssh": 22,
"ftp": 21,
"smtp": 25,
"dns": 53,
"pop3": 110,
"imap": 143,
}
ImageNamePortMap = func() (m map[string]int) {
m = make(map[string]int, len(ServiceNamePortMapTCP)+len(imageNamePortMap))
for k, v := range ServiceNamePortMapTCP {
m[k] = v
}
for k, v := range imageNamePortMap {
m[k] = v
}
return
}()
imageNamePortMap = map[string]int{
"adguardhome": 3000,
"bazarr": 6767,
"calibre-web": 8083,
"changedetection.io": 3000,
"dockge": 5001,
"gitea": 3000,
"gogs": 3000,
"grafana": 3000,
"home-assistant": 8123,
"homebridge": 8581,
"httpd": 80,
"immich": 3001,
"jellyfin": 8096,
"lidarr": 8686,
"microbin": 8080,
"nginx": 80,
"nginx-proxy-manager": 81,
"open-webui": 8080,
"plex": 32400,
"portainer-be": 9443,
"portainer-ce": 9443,
"prometheus": 9090,
"prowlarr": 9696,
"radarr": 7878,
"radarr-sma": 7878,
"rsshub": 1200,
"rss-bridge": 80,
"sonarr": 8989,
"sonarr-sma": 8989,
"uptime-kuma": 3001,
"whisparr": 6969,
}
)

235
internal/config/config.go Normal file
View File

@@ -0,0 +1,235 @@
package config
import (
"context"
"os"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
M "github.com/yusing/go-proxy/internal/models"
PR "github.com/yusing/go-proxy/internal/proxy/provider"
R "github.com/yusing/go-proxy/internal/route"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
W "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events"
"gopkg.in/yaml.v3"
)
type Config struct {
value *M.Config
proxyProviders F.Map[string, *PR.Provider]
autocertProvider *autocert.Provider
l logrus.FieldLogger
watcher W.Watcher
watcherCtx context.Context
watcherCancel context.CancelFunc
reloadReq chan struct{}
}
var instance *Config
func GetConfig() *Config {
return instance
}
func Load() E.NestedError {
if instance != nil {
return nil
}
instance = &Config{
value: M.DefaultConfig(),
proxyProviders: F.NewMapOf[string, *PR.Provider](),
l: logrus.WithField("module", "config"),
watcher: W.NewConfigFileWatcher(common.ConfigFileName),
reloadReq: make(chan struct{}, 1),
}
return instance.load()
}
func Validate(data []byte) E.NestedError {
return U.ValidateYaml(U.GetSchema(common.ConfigSchemaPath), data)
}
func MatchDomains() []string {
if instance == nil {
logrus.Panic("config has not been loaded, please check if there is any errors")
}
return instance.value.MatchDomains
}
func (cfg *Config) Value() M.Config {
if cfg == nil {
logrus.Panic("config has not been loaded, please check if there is any errors")
}
return *cfg.value
}
func (cfg *Config) GetAutoCertProvider() *autocert.Provider {
if instance == nil {
logrus.Panic("config has not been loaded, please check if there is any errors")
}
return cfg.autocertProvider
}
func (cfg *Config) Dispose() {
if cfg.watcherCancel != nil {
cfg.watcherCancel()
cfg.l.Debug("stopped watcher")
}
cfg.stopProviders()
}
func (cfg *Config) Reload() (err E.NestedError) {
cfg.stopProviders()
err = cfg.load()
cfg.StartProxyProviders()
return
}
func (cfg *Config) StartProxyProviders() {
cfg.controlProviders("start", (*PR.Provider).StartAllRoutes)
}
func (cfg *Config) WatchChanges() {
cfg.watcherCtx, cfg.watcherCancel = context.WithCancel(context.Background())
go func() {
for {
select {
case <-cfg.watcherCtx.Done():
return
case <-cfg.reloadReq:
if err := cfg.Reload(); err.HasError() {
cfg.l.Error(err)
}
}
}
}()
go func() {
eventCh, errCh := cfg.watcher.Events(cfg.watcherCtx)
for {
select {
case <-cfg.watcherCtx.Done():
return
case event := <-eventCh:
if event.Action == events.ActionFileDeleted || event.Action == events.ActionFileRenamed {
cfg.l.Error("config file deleted or renamed, ignoring...")
continue
} else {
cfg.reloadReq <- struct{}{}
}
case err := <-errCh:
cfg.l.Error(err)
continue
}
}
}()
}
func (cfg *Config) forEachRoute(do func(alias string, r R.Route, p *PR.Provider)) {
cfg.proxyProviders.RangeAll(func(_ string, p *PR.Provider) {
p.RangeRoutes(func(a string, r R.Route) {
do(a, r, p)
})
})
}
func (cfg *Config) load() (res E.NestedError) {
b := E.NewBuilder("errors loading config")
defer b.To(&res)
cfg.l.Debug("loading config")
defer cfg.l.Debug("loaded config")
data, err := E.Check(os.ReadFile(common.ConfigPath))
if err.HasError() {
b.Add(E.FailWith("read config", err))
logrus.Fatal(b.Build())
}
if !common.NoSchemaValidation {
if err = Validate(data); err.HasError() {
b.Add(E.FailWith("schema validation", err))
logrus.Fatal(b.Build())
}
}
model := M.DefaultConfig()
if err := E.From(yaml.Unmarshal(data, model)); err.HasError() {
b.Add(E.FailWith("parse config", err))
logrus.Fatal(b.Build())
}
// errors are non fatal below
b.Add(cfg.initAutoCert(&model.AutoCert))
b.Add(cfg.loadProviders(&model.Providers))
cfg.value = model
R.SetFindMuxDomains(model.MatchDomains)
return
}
func (cfg *Config) initAutoCert(autocertCfg *M.AutoCertConfig) (err E.NestedError) {
if cfg.autocertProvider != nil {
return
}
cfg.l.Debug("initializing autocert")
defer cfg.l.Debug("initialized autocert")
cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider()
if err.HasError() {
err = E.FailWith("autocert provider", err)
}
return
}
func (cfg *Config) loadProviders(providers *M.ProxyProviders) (res E.NestedError) {
cfg.l.Debug("loading providers")
defer cfg.l.Debug("loaded providers")
b := E.NewBuilder("errors loading providers")
defer b.To(&res)
for _, filename := range providers.Files {
p, err := PR.NewFileProvider(filename)
if err != nil {
b.Add(err.Subject(filename))
continue
}
cfg.proxyProviders.Store(p.GetName(), p)
b.Add(p.LoadRoutes().Subject(filename))
}
for name, dockerHost := range providers.Docker {
p, err := PR.NewDockerProvider(name, dockerHost)
if err != nil {
b.Add(err.Subjectf("%s (%s)", name, dockerHost))
continue
}
cfg.proxyProviders.Store(p.GetName(), p)
b.Add(p.LoadRoutes().Subject(dockerHost))
}
return
}
func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) {
errors := E.NewBuilder("errors in %s these providers", action)
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) {
if err := do(p); err.HasError() {
errors.Add(err.Subject(p))
}
})
if err := errors.Build(); err.HasError() {
cfg.l.Error(err)
}
}
func (cfg *Config) stopProviders() {
cfg.controlProviders("stop routes", (*PR.Provider).StopAllRoutes)
}

82
internal/config/query.go Normal file
View File

@@ -0,0 +1,82 @@
package config
import (
M "github.com/yusing/go-proxy/internal/models"
PR "github.com/yusing/go-proxy/internal/proxy/provider"
R "github.com/yusing/go-proxy/internal/route"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
func (cfg *Config) DumpEntries() map[string]*M.RawEntry {
entries := make(map[string]*M.RawEntry)
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
entries[alias] = r.Entry()
})
return entries
}
func (cfg *Config) DumpProviders() map[string]*PR.Provider {
entries := make(map[string]*PR.Provider)
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) {
entries[name] = p
})
return entries
}
func (cfg *Config) RoutesByAlias() map[string]U.SerializedObject {
routes := make(map[string]U.SerializedObject)
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
obj, err := U.Serialize(r)
if err.HasError() {
cfg.l.Error(err)
return
}
obj["provider"] = p.GetName()
obj["type"] = string(r.Type())
routes[alias] = obj
})
return routes
}
func (cfg *Config) Statistics() map[string]any {
nTotalStreams := 0
nTotalRPs := 0
providerStats := make(map[string]any)
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
s, ok := providerStats[p.GetName()]
if !ok {
s = make(map[string]int)
}
stats := s.(map[string]int)
switch r.Type() {
case R.RouteTypeStream:
stats["num_streams"]++
nTotalStreams++
case R.RouteTypeReverseProxy:
stats["num_reverse_proxies"]++
nTotalRPs++
default:
panic("bug: should not reach here")
}
})
return map[string]any{
"num_total_streams": nTotalStreams,
"num_total_reverse_proxies": nTotalRPs,
"providers": providerStats,
}
}
func (cfg *Config) FindRoute(alias string) R.Route {
return F.MapFind(cfg.proxyProviders,
func(p *PR.Provider) (R.Route, bool) {
if route, ok := p.GetRoute(alias); ok {
return route, true
}
return nil, false
},
)
}

153
internal/docker/client.go Normal file
View File

@@ -0,0 +1,153 @@
package docker
import (
"net/http"
"sync"
"sync/atomic"
"github.com/docker/cli/cli/connhelper"
"github.com/docker/docker/client"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type Client struct {
key string
refCount *atomic.Int32
*client.Client
l logrus.FieldLogger
}
func ParseDockerHostname(host string) (string, E.NestedError) {
switch host {
case common.DockerHostFromEnv, "":
return "localhost", nil
}
url, err := E.Check(client.ParseHostURL(host))
if err != nil {
return "", E.Invalid("host", host).With(err)
}
return url.Hostname(), nil
}
func (c Client) DaemonHostname() string {
// DaemonHost should always return a valid host
hostname, _ := ParseDockerHostname(c.DaemonHost())
return hostname
}
func (c Client) Connected() bool {
return c.Client != nil
}
// if the client is still referenced, this is no-op
func (c *Client) Close() error {
if c.refCount.Add(-1) > 0 {
return nil
}
clientMap.Delete(c.key)
client := c.Client
c.Client = nil
c.l.Debugf("client closed")
if client != nil {
return client.Close()
}
return nil
}
// ConnectClient creates a new Docker client connection to the specified host.
//
// Returns existing client if available.
//
// Parameters:
// - host: the host to connect to (either a URL or common.DockerHostFromEnv).
//
// Returns:
// - Client: the Docker client connection.
// - error: an error if the connection failed.
func ConnectClient(host string) (Client, E.NestedError) {
clientMapMu.Lock()
defer clientMapMu.Unlock()
// check if client exists
if client, ok := clientMap.Load(host); ok {
client.refCount.Add(1)
return client, nil
}
// create client
var opt []client.Opt
switch host {
case common.DockerHostFromEnv:
opt = clientOptEnvHost
default:
helper, err := E.Check(connhelper.GetConnectionHelper(host))
if err.HasError() {
return Client{}, E.UnexpectedError(err.Error())
}
if helper != nil {
httpClient := &http.Client{
Transport: &http.Transport{
DialContext: helper.Dialer,
},
}
opt = []client.Opt{
client.WithHTTPClient(httpClient),
client.WithHost(helper.Host),
client.WithAPIVersionNegotiation(),
client.WithDialContext(helper.Dialer),
}
} else {
opt = []client.Opt{
client.WithHost(host),
client.WithAPIVersionNegotiation(),
}
}
}
client, err := E.Check(client.NewClientWithOpts(opt...))
if err.HasError() {
return Client{}, err
}
c := Client{
Client: client,
key: host,
refCount: &atomic.Int32{},
l: logger.WithField("docker_client", client.DaemonHost()),
}
c.refCount.Add(1)
c.l.Debugf("client connected")
clientMap.Store(host, c)
return c, nil
}
func CloseAllClients() {
clientMap.RangeAll(func(_ string, c Client) {
c.Client.Close()
})
clientMap.Clear()
logger.Debug("closed all clients")
}
var (
clientMap F.Map[string, Client] = F.NewMapOf[string, Client]()
clientMapMu sync.Mutex
clientOptEnvHost = []client.Opt{
client.WithHostFromEnv(),
client.WithAPIVersionNegotiation(),
}
logger = logrus.WithField("module", "docker")
)

View File

@@ -0,0 +1,54 @@
package docker
import (
"context"
"time"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
E "github.com/yusing/go-proxy/internal/error"
)
type ClientInfo struct {
Client Client
Containers []types.Container
}
var listOptions = container.ListOptions{
// Filters: filters.NewArgs(
// filters.Arg("health", "healthy"),
// filters.Arg("health", "none"),
// filters.Arg("health", "starting"),
// ),
All: true,
}
func GetClientInfo(clientHost string, getContainer bool) (*ClientInfo, E.NestedError) {
dockerClient, err := ConnectClient(clientHost)
if err.HasError() {
return nil, E.FailWith("connect to docker", err)
}
defer dockerClient.Close()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
var containers []types.Container
if getContainer {
containers, err = E.Check(dockerClient.ContainerList(ctx, listOptions))
if err.HasError() {
return nil, E.FailWith("list containers", err)
}
}
return &ClientInfo{
Client: dockerClient,
Containers: containers,
}, nil
}
func IsErrConnectionFailed(err error) bool {
return client.IsErrConnectionFailed(err)
}

View File

@@ -0,0 +1,105 @@
package docker
import (
"fmt"
"strconv"
"strings"
"github.com/docker/docker/api/types"
U "github.com/yusing/go-proxy/internal/utils"
)
type Container struct {
*types.Container
*ProxyProperties
}
func FromDocker(c *types.Container, dockerHost string) (res Container) {
res.Container = c
res.ProxyProperties = &ProxyProperties{
DockerHost: dockerHost,
ContainerName: res.getName(),
ImageName: res.getImageName(),
PublicPortMapping: res.getPublicPortMapping(),
PrivatePortMapping: res.getPrivatePortMapping(),
NetworkMode: c.HostConfig.NetworkMode,
Aliases: res.getAliases(),
IsExcluded: U.ParseBool(res.getDeleteLabel(LabelExclude)),
IdleTimeout: res.getDeleteLabel(LabelIdleTimeout),
WakeTimeout: res.getDeleteLabel(LabelWakeTimeout),
StopMethod: res.getDeleteLabel(LabelStopMethod),
StopTimeout: res.getDeleteLabel(LabelStopTimeout),
StopSignal: res.getDeleteLabel(LabelStopSignal),
Running: c.Status == "running" || c.State == "running",
}
return
}
func FromJson(json types.ContainerJSON, dockerHost string) Container {
ports := make([]types.Port, 0)
for k, bindings := range json.NetworkSettings.Ports {
for _, v := range bindings {
pubPort, _ := strconv.ParseUint(v.HostPort, 10, 16)
privPort, _ := strconv.ParseUint(k.Port(), 10, 16)
ports = append(ports, types.Port{
IP: v.HostIP,
PublicPort: uint16(pubPort),
PrivatePort: uint16(privPort),
})
}
}
return FromDocker(&types.Container{
ID: json.ID,
Names: []string{json.Name},
Image: json.Image,
Ports: ports,
Labels: json.Config.Labels,
State: json.State.Status,
Status: json.State.Status,
}, dockerHost)
}
func (c Container) getDeleteLabel(label string) string {
if l, ok := c.Labels[label]; ok {
delete(c.Labels, label)
return l
}
return ""
}
func (c Container) getAliases() []string {
if l := c.getDeleteLabel(LabelAliases); l != "" {
return U.CommaSeperatedList(l)
} else {
return []string{c.getName()}
}
}
func (c Container) getName() string {
return strings.TrimPrefix(c.Names[0], "/")
}
func (c Container) getImageName() string {
colonSep := strings.Split(c.Image, ":")
slashSep := strings.Split(colonSep[0], "/")
return slashSep[len(slashSep)-1]
}
func (c Container) getPublicPortMapping() PortMapping {
res := make(PortMapping)
for _, v := range c.Ports {
if v.PublicPort == 0 {
continue
}
res[fmt.Sprint(v.PublicPort)] = v
}
return res
}
func (c Container) getPrivatePortMapping() PortMapping {
res := make(PortMapping)
for _, v := range c.Ports {
res[fmt.Sprint(v.PrivatePort)] = v
}
return res
}

View File

@@ -0,0 +1,32 @@
package docker
type (
HomePageConfig struct{ m map[string]HomePageCategory }
HomePageCategory []HomePageItem
HomePageItem struct {
Name string
Icon string
Category string
Description string
WidgetConfig map[string]any
}
)
func NewHomePageConfig() *HomePageConfig {
return &HomePageConfig{m: make(map[string]HomePageCategory)}
}
func NewHomePageItem() *HomePageItem {
return &HomePageItem{}
}
func (c *HomePageConfig) Clear() {
c.m = make(map[string]HomePageCategory)
}
func (c *HomePageConfig) Add(item HomePageItem) {
c.m[item.Category] = HomePageCategory{item}
}
const NSHomePage = "homepage"

View File

@@ -0,0 +1,87 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>{{.Title}}</title>
<style>
/* Global Styles */
* {
box-sizing: border-box;
margin: 0;
padding: 0;
}
body {
font-family: Inter, Arial, sans-serif;
font-size: 16px;
line-height: 1.5;
color: #fff;
background-color: #212121;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
}
/* Spinner Styles */
.spinner {
width: 120px;
height: 120px;
border: 16px solid #333;
border-radius: 50%;
border-top: 16px solid #66d9ef;
animation: spin 2s linear infinite;
}
@keyframes spin {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}
/* Error Styles */
.error {
display: inline-block;
text-align: center;
justify-content: center;
}
.error::before {
content: "\26A0"; /* Unicode for warning symbol */
font-size: 40px;
color: #ff9900;
}
/* Message Styles */
.message {
font-size: 24px;
font-weight: bold;
padding-left: 32px;
text-align: center;
}
</style>
</head>
<body>
<script>
window.onload = async function () {
let result = await fetch(window.location.href, {
headers: {
{{ range $key, $value := .RequestHeaders }}
'{{ $key }}' : {{ $value }}
{{ end }}
},
}).then((resp) => resp.text())
.catch((err) => {
document.getElementById("message").innerText = err;
});
if (result) {
document.documentElement.innerHTML = result
}
};
</script>
<div class="{{.SpinnerClass}}"></div>
<div class="message">{{.Message}}</div>
</body>
</html>

View File

@@ -0,0 +1,87 @@
package idlewatcher
import (
"bytes"
_ "embed"
"fmt"
"io"
"net/http"
"strings"
"text/template"
)
type templateData struct {
Title string
Message string
RequestHeaders http.Header
SpinnerClass string
}
//go:embed html/loading_page.html
var loadingPage []byte
var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(loadingPage)))
const (
htmlContentType = "text/html; charset=utf-8"
errPrefix = "\u1000"
headerGoProxyTargetURL = "X-GoProxy-Target"
headerContentType = "Content-Type"
spinnerClassSpinner = "spinner"
spinnerClassErrorSign = "error"
)
func (w *watcher) makeSuccResp(redirectURL string, resp *http.Response) (*http.Response, error) {
h := make(http.Header)
h.Set("Location", redirectURL)
h.Set("Content-Length", "0")
h.Set(headerContentType, htmlContentType)
return &http.Response{
StatusCode: http.StatusTemporaryRedirect,
Header: h,
Body: http.NoBody,
TLS: resp.TLS,
}, nil
}
func (w *watcher) makeErrResp(errFmt string, args ...any) (*http.Response, error) {
return w.makeResp(errPrefix+errFmt, args...)
}
func (w *watcher) makeResp(format string, args ...any) (*http.Response, error) {
msg := fmt.Sprintf(format, args...)
data := new(templateData)
data.Title = w.ContainerName
data.Message = strings.ReplaceAll(msg, "\n", "<br>")
data.Message = strings.ReplaceAll(data.Message, " ", "&ensp;")
data.RequestHeaders = make(http.Header)
data.RequestHeaders.Add(headerGoProxyTargetURL, "window.location.href")
if strings.HasPrefix(data.Message, errPrefix) {
data.Message = strings.TrimLeft(data.Message, errPrefix)
data.SpinnerClass = spinnerClassErrorSign
} else {
data.SpinnerClass = spinnerClassSpinner
}
buf := bytes.NewBuffer(make([]byte, 128)) // more than enough
err := loadingPageTmpl.Execute(buf, data)
if err != nil { // should never happen
panic(err)
}
return &http.Response{
StatusCode: http.StatusAccepted,
Header: http.Header{
headerContentType: {htmlContentType},
"Cache-Control": {
"no-cache",
"no-store",
"must-revalidate",
},
},
Body: io.NopCloser(buf),
ContentLength: int64(buf.Len()),
}, nil
}

View File

@@ -0,0 +1,83 @@
package idlewatcher
import (
"context"
"net/http"
)
type (
roundTripper struct {
patched roundTripFunc
}
roundTripFunc func(*http.Request) (*http.Response, error)
)
func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return rt.patched(req)
}
func (w *watcher) roundTrip(origRoundTrip roundTripFunc, req *http.Request) (*http.Response, error) {
// wake the container
select {
case w.wakeCh <- struct{}{}:
default:
}
// target site is ready, passthrough
if w.ready.Load() {
return origRoundTrip(req)
}
// initial request
targetUrl := req.Header.Get(headerGoProxyTargetURL)
if targetUrl == "" {
return w.makeResp(
"%s is starting... Please wait",
w.ContainerName,
)
}
w.l.Debug("serving event")
// stream request
rtDone := make(chan *http.Response, 1)
ctx, cancel := context.WithTimeout(req.Context(), w.WakeTimeout)
defer cancel()
// loop original round trip until success in a goroutine
go func() {
for {
select {
case <-ctx.Done():
return
case <-w.ctx.Done():
return
default:
resp, err := origRoundTrip(req)
if err == nil {
w.ready.Store(true)
rtDone <- resp
return
}
}
}
}()
for {
select {
case resp := <-rtDone:
return w.makeSuccResp(targetUrl, resp)
case err := <-w.wakeDone:
if err != nil {
return w.makeErrResp("error waking up %s\n%s", w.ContainerName, err.Error())
}
case <-ctx.Done():
if ctx.Err() == context.DeadlineExceeded {
return w.makeErrResp("Timed out waiting for %s to fully wake", w.ContainerName)
}
return w.makeErrResp("idlewatcher has stopped\n%s", w.ctx.Err().Error())
case <-w.ctx.Done():
return w.makeErrResp("idlewatcher has stopped\n%s", w.ctx.Err().Error())
}
}
}

View File

@@ -0,0 +1,276 @@
package idlewatcher
import (
"context"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/docker/docker/api/types/container"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
P "github.com/yusing/go-proxy/internal/proxy"
PT "github.com/yusing/go-proxy/internal/proxy/fields"
W "github.com/yusing/go-proxy/internal/watcher"
)
type (
watcher struct {
*P.ReverseProxyEntry
client D.Client
ready atomic.Bool // whether the site is ready to accept connection
stopByMethod StopCallback // send a docker command w.r.t. `stop_method`
wakeCh chan struct{}
wakeDone chan E.NestedError
ctx context.Context
cancel context.CancelFunc
refCount *sync.WaitGroup
l logrus.FieldLogger
}
WakeDone <-chan error
WakeFunc func() WakeDone
StopCallback func() E.NestedError
)
var (
mainLoopCtx context.Context
mainLoopCancel context.CancelFunc
mainLoopWg sync.WaitGroup
watcherMap = make(map[string]*watcher)
watcherMapMu sync.Mutex
newWatcherCh = make(chan *watcher)
logger = logrus.WithField("module", "idle_watcher")
)
func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) {
failure := E.Failure("idle_watcher register")
if entry.IdleTimeout == 0 {
return nil, failure.With(E.Invalid("idle_timeout", 0))
}
watcherMapMu.Lock()
defer watcherMapMu.Unlock()
if w, ok := watcherMap[entry.ContainerName]; ok {
w.refCount.Add(1)
w.ReverseProxyEntry = entry
return w, nil
}
client, err := D.ConnectClient(entry.DockerHost)
if err.HasError() {
return nil, failure.With(err)
}
w := &watcher{
ReverseProxyEntry: entry,
client: client,
refCount: &sync.WaitGroup{},
wakeCh: make(chan struct{}),
wakeDone: make(chan E.NestedError),
l: logger.WithField("container", entry.ContainerName),
}
w.refCount.Add(1)
w.stopByMethod = w.getStopCallback()
watcherMap[w.ContainerName] = w
go func() {
newWatcherCh <- w
}()
return w, nil
}
func Unregister(containerName string) {
if w, ok := watcherMap[containerName]; ok {
w.refCount.Add(-1)
}
}
func Start() {
logger.Debug("started")
defer logger.Debug("stopped")
mainLoopCtx, mainLoopCancel = context.WithCancel(context.Background())
for {
select {
case <-mainLoopCtx.Done():
return
case w := <-newWatcherCh:
w.l.Debug("registered")
mainLoopWg.Add(1)
go func() {
w.watchUntilCancel()
w.refCount.Wait() // wait for 0 ref count
w.client.Close()
delete(watcherMap, w.ContainerName)
w.l.Debug("unregistered")
mainLoopWg.Done()
}()
}
}
}
func Stop() {
mainLoopCancel()
mainLoopWg.Wait()
}
func (w *watcher) PatchRoundTripper(rtp http.RoundTripper) roundTripper {
return roundTripper{patched: func(r *http.Request) (*http.Response, error) {
return w.roundTrip(rtp.RoundTrip, r)
}}
}
func (w *watcher) containerStop() error {
return w.client.ContainerStop(w.ctx, w.ContainerName, container.StopOptions{
Signal: string(w.StopSignal),
Timeout: &w.StopTimeout})
}
func (w *watcher) containerPause() error {
return w.client.ContainerPause(w.ctx, w.ContainerName)
}
func (w *watcher) containerKill() error {
return w.client.ContainerKill(w.ctx, w.ContainerName, string(w.StopSignal))
}
func (w *watcher) containerUnpause() error {
return w.client.ContainerUnpause(w.ctx, w.ContainerName)
}
func (w *watcher) containerStart() error {
return w.client.ContainerStart(w.ctx, w.ContainerName, container.StartOptions{})
}
func (w *watcher) containerStatus() (string, E.NestedError) {
json, err := w.client.ContainerInspect(w.ctx, w.ContainerName)
if err != nil {
return "", E.FailWith("inspect container", err)
}
return json.State.Status, nil
}
func (w *watcher) wakeIfStopped() E.NestedError {
status, err := w.containerStatus()
if err.HasError() {
return err
}
// "created", "running", "paused", "restarting", "removing", "exited", or "dead"
switch status {
case "exited", "dead":
return E.From(w.containerStart())
case "paused":
return E.From(w.containerUnpause())
case "running":
return nil
default:
return E.Unexpected("container state", status)
}
}
func (w *watcher) getStopCallback() StopCallback {
var cb func() error
switch w.StopMethod {
case PT.StopMethodPause:
cb = w.containerPause
case PT.StopMethodStop:
cb = w.containerStop
case PT.StopMethodKill:
cb = w.containerKill
default:
panic("should not reach here")
}
return func() E.NestedError {
status, err := w.containerStatus()
if err.HasError() {
return err
}
if status != "running" {
return nil
}
return E.From(cb())
}
}
func (w *watcher) watchUntilCancel() {
defer close(w.wakeCh)
w.ctx, w.cancel = context.WithCancel(context.Background())
dockerWatcher := W.NewDockerWatcherWithClient(w.client)
dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.ctx, W.DockerListOptions{
Filters: W.NewDockerFilter(
W.DockerFilterContainer,
W.DockerrFilterContainerName(w.ContainerName),
W.DockerFilterStart,
W.DockerFilterStop,
W.DockerFilterDie,
W.DockerFilterKill,
W.DockerFilterPause,
W.DockerFilterUnpause,
),
})
ticker := time.NewTicker(w.IdleTimeout)
defer ticker.Stop()
for {
select {
case <-mainLoopCtx.Done():
w.cancel()
case <-w.ctx.Done():
w.l.Debug("stopped")
return
case err := <-dockerEventErrCh:
if err != nil && err.IsNot(context.Canceled) {
w.l.Error(E.FailWith("docker watcher", err))
}
case e := <-dockerEventCh:
switch {
// create / start / unpause
case e.Action.IsContainerWake():
ticker.Reset(w.IdleTimeout)
w.l.Info(e)
default: // stop / pause / kill
ticker.Stop()
w.ready.Store(false)
w.l.Info(e)
}
case <-ticker.C:
w.l.Debug("idle timeout")
ticker.Stop()
if err := w.stopByMethod(); err != nil && err.IsNot(context.Canceled) {
w.l.Error(E.FailWith("stop", err).Extraf("stop method: %s", w.StopMethod))
}
case <-w.wakeCh:
w.l.Debug("wake signal received")
ticker.Reset(w.IdleTimeout)
err := w.wakeIfStopped()
if err != nil && err.IsNot(context.Canceled) {
w.l.Error(E.FailWith("wake", err))
}
select {
case w.wakeDone <- err: // this is passed to roundtrip
default:
}
}
}
}

View File

@@ -0,0 +1,19 @@
package docker
import (
"context"
"time"
E "github.com/yusing/go-proxy/internal/error"
)
func (c Client) Inspect(containerID string) (Container, E.NestedError) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
json, err := c.ContainerInspect(ctx, containerID)
if err != nil {
return Container{}, E.From(err)
}
return FromJson(json, c.key), nil
}

151
internal/docker/label.go Normal file
View File

@@ -0,0 +1,151 @@
package docker
import (
"reflect"
"strings"
E "github.com/yusing/go-proxy/internal/error"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
/*
Formats:
- namespace.attribute
- namespace.target.attribute
- namespace.target.attribute.namespace2.attribute
*/
type (
Label struct {
Namespace string
Target string
Attribute string
Value any
}
NestedLabelMap map[string]U.SerializedObject
ValueParser func(string) (any, E.NestedError)
ValueParserMap map[string]ValueParser
)
func (l *Label) String() string {
if l.Attribute == "" {
return l.Namespace + "." + l.Target
}
return l.Namespace + "." + l.Target + "." + l.Attribute
}
// Apply applies the value of a Label to the corresponding field in the given object.
//
// Parameters:
// - obj: a pointer to the object to which the Label value will be applied.
// - l: a pointer to the Label containing the attribute and value to be applied.
//
// Returns:
// - error: an error if the field does not exist.
func ApplyLabel[T any](obj *T, l *Label) E.NestedError {
if obj == nil {
return E.Invalid("nil object", l)
}
switch nestedLabel := l.Value.(type) {
case *Label:
var field reflect.Value
objType := reflect.TypeFor[T]()
for i := 0; i < reflect.TypeFor[T]().NumField(); i++ {
if objType.Field(i).Tag.Get("yaml") == l.Attribute {
field = reflect.ValueOf(obj).Elem().Field(i)
break
}
}
if !field.IsValid() {
return E.NotExist("field", l.Attribute)
}
dst, ok := field.Interface().(NestedLabelMap)
if !ok {
return E.Invalid("type", field.Type())
}
if dst == nil {
field.Set(reflect.MakeMap(reflect.TypeFor[NestedLabelMap]()))
dst = field.Interface().(NestedLabelMap)
}
if dst[nestedLabel.Namespace] == nil {
dst[nestedLabel.Namespace] = make(U.SerializedObject)
}
dst[nestedLabel.Namespace][nestedLabel.Attribute] = nestedLabel.Value
return nil
default:
return U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj)
}
}
func ParseLabel(label string, value string) (*Label, E.NestedError) {
parts := strings.Split(label, ".")
if len(parts) < 2 {
return &Label{
Namespace: label,
Value: value,
}, nil
}
l := &Label{
Namespace: parts[0],
Target: parts[1],
Value: value,
}
switch len(parts) {
case 2:
l.Attribute = l.Target
case 3:
l.Attribute = parts[2]
default:
l.Attribute = parts[2]
nestedLabel, err := ParseLabel(strings.Join(parts[3:], "."), value)
if err.HasError() {
return nil, err
}
l.Value = nestedLabel
}
// find if namespace has value parser
pm, ok := valueParserMap.Load(U.ToLowerNoSnake(l.Namespace))
if !ok {
return l, nil
}
// find if attribute has value parser
p, ok := pm[U.ToLowerNoSnake(l.Attribute)]
if !ok {
return l, nil
}
// try to parse value
v, err := p(value)
if err.HasError() {
return nil, err.Subject(label)
}
l.Value = v
return l, nil
}
func RegisterNamespace(namespace string, pm ValueParserMap) {
pmCleaned := make(ValueParserMap, len(pm))
for k, v := range pm {
pmCleaned[U.ToLowerNoSnake(k)] = v
}
valueParserMap.Store(U.ToLowerNoSnake(namespace), pmCleaned)
}
func GetRegisteredNamespaces() map[string][]string {
r := make(map[string][]string)
valueParserMap.RangeAll(func(ns string, vpm ValueParserMap) {
r[ns] = make([]string, 0, len(vpm))
for attr := range vpm {
r[ns] = append(r[ns], attr)
}
})
return r
}
// namespace:target.attribute -> func(string) (any, error)
var valueParserMap = F.NewMapOf[string, ValueParserMap]()

View File

@@ -0,0 +1,78 @@
package docker
import (
"strings"
E "github.com/yusing/go-proxy/internal/error"
"gopkg.in/yaml.v3"
)
const (
NSProxy = "proxy"
ProxyAttributePathPatterns = "path_patterns"
ProxyAttributeNoTLSVerify = "no_tls_verify"
ProxyAttributeMiddlewares = "middlewares"
)
var _ = func() int {
RegisterNamespace(NSProxy, ValueParserMap{
ProxyAttributePathPatterns: YamlStringListParser,
ProxyAttributeNoTLSVerify: BoolParser,
})
return 0
}()
func YamlStringListParser(value string) (any, E.NestedError) {
/*
- foo
- bar
- baz
*/
value = strings.TrimSpace(value)
if value == "" {
return []string{}, nil
}
var data []string
err := E.From(yaml.Unmarshal([]byte(value), &data))
return data, err
}
func YamlLikeMappingParser(allowDuplicate bool) func(string) (any, E.NestedError) {
return func(value string) (any, E.NestedError) {
/*
foo: bar
boo: baz
*/
value = strings.TrimSpace(value)
lines := strings.Split(value, "\n")
h := make(map[string]string)
for _, line := range lines {
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
return nil, E.Invalid("syntax", line).With("too many colons")
}
key := strings.TrimSpace(parts[0])
val := strings.TrimSpace(parts[1])
if existing, ok := h[key]; ok {
if !allowDuplicate {
return nil, E.Duplicated("key", key)
}
h[key] = existing + ", " + val
} else {
h[key] = val
}
}
return h, nil
}
}
func BoolParser(value string) (any, E.NestedError) {
switch strings.ToLower(value) {
case "true", "yes", "1":
return true, nil
case "false", "no", "0":
return false, nil
default:
return nil, E.Invalid("boolean value", value)
}
}

View File

@@ -0,0 +1,106 @@
package docker
import (
"fmt"
"testing"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func makeLabel(namespace string, alias string, field string) string {
return fmt.Sprintf("%s.%s.%s", namespace, alias, field)
}
func TestParseLabel(t *testing.T) {
alias := "foo"
field := "ip"
v := "bar"
pl, err := ParseLabel(makeLabel(NSHomePage, alias, field), v)
ExpectNoError(t, err.Error())
ExpectEqual(t, pl.Namespace, NSHomePage)
ExpectEqual(t, pl.Target, alias)
ExpectEqual(t, pl.Attribute, field)
ExpectEqual(t, pl.Value.(string), v)
}
func TestStringProxyLabel(t *testing.T) {
v := "bar"
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "ip"), v)
ExpectNoError(t, err.Error())
ExpectEqual(t, pl.Value.(string), v)
}
func TestBoolProxyLabelValid(t *testing.T) {
tests := map[string]bool{
"true": true,
"TRUE": true,
"yes": true,
"1": true,
"false": false,
"FALSE": false,
"no": false,
"0": false,
}
for k, v := range tests {
pl, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeNoTLSVerify), k)
ExpectNoError(t, err.Error())
ExpectEqual(t, pl.Value.(bool), v)
}
}
func TestBoolProxyLabelInvalid(t *testing.T) {
_, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeNoTLSVerify), "invalid")
if !err.Is(E.ErrInvalid) {
t.Errorf("Expected err InvalidProxyLabel, got %s", err.Error())
}
}
// func TestSetHeaderProxyLabelValid(t *testing.T) {
// v := `
// X-Custom-Header1: foo, bar
// X-Custom-Header1: baz
// X-Custom-Header2: boo`
// v = strings.TrimPrefix(v, "\n")
// h := map[string]string{
// "X-Custom-Header1": "foo, bar, baz",
// "X-Custom-Header2": "boo",
// }
// pl, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeSetHeaders), v)
// ExpectNoError(t, err.Error())
// hGot := ExpectType[map[string]string](t, pl.Value)
// ExpectFalse(t, hGot == nil)
// ExpectDeepEqual(t, h, hGot)
// }
// func TestSetHeaderProxyLabelInvalid(t *testing.T) {
// tests := []string{
// "X-Custom-Header1 = bar",
// "X-Custom-Header1",
// "- X-Custom-Header1",
// }
// for _, v := range tests {
// _, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeSetHeaders), v)
// if !err.Is(E.ErrInvalid) {
// t.Errorf("Expected invalid err for %q, got %s", v, err.Error())
// }
// }
// }
// func TestHideHeadersProxyLabel(t *testing.T) {
// v := `
// - X-Custom-Header1
// - X-Custom-Header2
// - X-Custom-Header3
// `
// v = strings.TrimPrefix(v, "\n")
// pl, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeHideHeaders), v)
// ExpectNoError(t, err.Error())
// sGot := ExpectType[[]string](t, pl.Value)
// sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"}
// ExpectFalse(t, sGot == nil)
// ExpectDeepEqual(t, sGot, sWant)
// }

View File

@@ -0,0 +1,85 @@
package docker
import (
"fmt"
"testing"
U "github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestNestedLabel(t *testing.T) {
mName := "middleware1"
mAttr := "prop1"
v := "value1"
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s.%s", ProxyAttributeMiddlewares, mName, mAttr)), v)
ExpectNoError(t, err.Error())
sGot := ExpectType[*Label](t, pl.Value)
ExpectFalse(t, sGot == nil)
ExpectEqual(t, sGot.Namespace, mName)
ExpectEqual(t, sGot.Attribute, mAttr)
}
func TestApplyNestedLabel(t *testing.T) {
entry := new(struct {
Middlewares NestedLabelMap `yaml:"middlewares"`
})
mName := "middleware1"
mAttr := "prop1"
v := "value1"
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s.%s", ProxyAttributeMiddlewares, mName, mAttr)), v)
ExpectNoError(t, err.Error())
err = ApplyLabel(entry, pl)
ExpectNoError(t, err.Error())
middleware1, ok := entry.Middlewares[mName]
ExpectTrue(t, ok)
got := ExpectType[string](t, middleware1[mAttr])
ExpectEqual(t, got, v)
}
func TestApplyNestedLabelExisting(t *testing.T) {
mName := "middleware1"
mAttr := "prop1"
v := "value1"
checkAttr := "prop2"
checkV := "value2"
entry := new(struct {
Middlewares NestedLabelMap `yaml:"middlewares"`
})
entry.Middlewares = make(NestedLabelMap)
entry.Middlewares[mName] = make(U.SerializedObject)
entry.Middlewares[mName][checkAttr] = checkV
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s.%s", ProxyAttributeMiddlewares, mName, mAttr)), v)
ExpectNoError(t, err.Error())
err = ApplyLabel(entry, pl)
ExpectNoError(t, err.Error())
middleware1, ok := entry.Middlewares[mName]
ExpectTrue(t, ok)
got := ExpectType[string](t, middleware1[mAttr])
ExpectEqual(t, got, v)
// check if prop2 is affected
ExpectFalse(t, middleware1[checkAttr] == nil)
got = ExpectType[string](t, middleware1[checkAttr])
ExpectEqual(t, got, checkV)
}
func TestApplyNestedLabelNoAttr(t *testing.T) {
mName := "middleware1"
v := "value1"
entry := new(struct {
Middlewares NestedLabelMap `yaml:"middlewares"`
})
entry.Middlewares = make(NestedLabelMap)
entry.Middlewares[mName] = make(U.SerializedObject)
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", ProxyAttributeMiddlewares, mName)), v)
ExpectNoError(t, err.Error())
err = ApplyLabel(entry, pl)
ExpectNoError(t, err.Error())
_, ok := entry.Middlewares[mName]
ExpectTrue(t, ok)
}

13
internal/docker/labels.go Normal file
View File

@@ -0,0 +1,13 @@
package docker
const (
WildcardAlias = "*"
LabelAliases = NSProxy + ".aliases"
LabelExclude = NSProxy + ".exclude"
LabelIdleTimeout = NSProxy + ".idle_timeout"
LabelWakeTimeout = NSProxy + ".wake_timeout"
LabelStopMethod = NSProxy + ".stop_method"
LabelStopTimeout = NSProxy + ".stop_timeout"
LabelStopSignal = NSProxy + ".stop_signal"
)

View File

@@ -0,0 +1,22 @@
package docker
import "github.com/docker/docker/api/types"
type PortMapping = map[string]types.Port
type ProxyProperties struct {
DockerHost string `yaml:"-" json:"docker_host"`
ContainerName string `yaml:"-" json:"container_name"`
ImageName string `yaml:"-" json:"image_name"`
PublicPortMapping PortMapping `yaml:"-" json:"public_port_mapping"` // non-zero publicPort:types.Port
PrivatePortMapping PortMapping `yaml:"-" json:"private_port_mapping"` // privatePort:types.Port
NetworkMode string `yaml:"-" json:"network_mode"`
Aliases []string `yaml:"-" json:"aliases"`
IsExcluded bool `yaml:"-" json:"is_excluded"`
IdleTimeout string `yaml:"-" json:"idle_timeout"`
WakeTimeout string `yaml:"-" json:"wake_timeout"`
StopMethod string `yaml:"-" json:"stop_method"`
StopTimeout string `yaml:"-" json:"stop_timeout"` // stop_method = "stop" only
StopSignal string `yaml:"-" json:"stop_signal"` // stop_method = "stop" | "kill" only
Running bool `yaml:"-" json:"running"`
}

70
internal/error/builder.go Normal file
View File

@@ -0,0 +1,70 @@
package error
import (
"fmt"
"sync"
)
type Builder struct {
*builder
}
type builder struct {
message string
errors []NestedError
sync.Mutex
}
func NewBuilder(format string, args ...any) Builder {
return Builder{&builder{message: fmt.Sprintf(format, args...)}}
}
// adding nil / nil is no-op,
// you may safely pass expressions returning error to it
func (b Builder) Add(err NestedError) Builder {
if err != nil {
b.Lock()
// TODO: if err severity is higher than b.severity, update b.severity
b.errors = append(b.errors, err)
b.Unlock()
}
return b
}
func (b Builder) AddE(err error) Builder {
return b.Add(From(err))
}
func (b Builder) Addf(format string, args ...any) Builder {
return b.Add(errorf(format, args...))
}
// Build builds a NestedError based on the errors collected in the Builder.
//
// If there are no errors in the Builder, it returns a Nil() NestedError.
// Otherwise, it returns a NestedError with the message and the errors collected.
//
// Returns:
// - NestedError: the built NestedError.
func (b Builder) Build() NestedError {
if len(b.errors) == 0 {
return nil
} else if len(b.errors) == 1 {
return b.errors[0]
}
return Join(b.message, b.errors...)
}
func (b Builder) To(ptr *NestedError) {
if ptr == nil {
return
} else if *ptr == nil {
*ptr = b.Build()
} else {
(*ptr).With(b.Build())
}
}
func (b Builder) HasError() bool {
return len(b.errors) > 0
}

View File

@@ -0,0 +1,53 @@
package error
import (
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestBuilderEmpty(t *testing.T) {
eb := NewBuilder("qwer")
ExpectTrue(t, eb.Build() == nil)
ExpectTrue(t, eb.Build().NoError())
ExpectFalse(t, eb.HasError())
}
func TestBuilderAddNil(t *testing.T) {
eb := NewBuilder("asdf")
var err NestedError
for range 3 {
eb.Add(nil)
}
for range 3 {
eb.Add(err)
}
ExpectTrue(t, eb.Build() == nil)
ExpectTrue(t, eb.Build().NoError())
ExpectFalse(t, eb.HasError())
}
func TestBuilderNested(t *testing.T) {
eb := NewBuilder("error occurred")
eb.Add(Failure("Action 1").With(Invalid("Inner", "1")).With(Invalid("Inner", "2")))
eb.Add(Failure("Action 2").With(Invalid("Inner", "3")))
got := eb.Build().String()
expected1 :=
(`error occurred:
- Action 1 failed:
- invalid Inner: 1
- invalid Inner: 2
- Action 2 failed:
- invalid Inner: 3`)
expected2 :=
(`error occurred:
- Action 1 failed:
- invalid Inner: "1"
- invalid Inner: "2"
- Action 2 failed:
- invalid Inner: "3"`)
if got != expected1 && got != expected2 {
t.Errorf("expected \n%s, got \n%s", expected1, got)
}
}

295
internal/error/error.go Normal file
View File

@@ -0,0 +1,295 @@
package error
import (
"encoding/json"
"errors"
"fmt"
"strings"
)
type (
NestedError = *nestedError
nestedError struct {
subject string
err error
extras []nestedError
}
jsonNestedError struct {
Subject string
Err string
Extras []jsonNestedError
}
)
func From(err error) NestedError {
if IsNil(err) {
return nil
}
return &nestedError{err: err}
}
func FromJSON(data []byte) (NestedError, bool) {
var j jsonNestedError
if err := json.Unmarshal(data, &j); err != nil {
return nil, false
}
if j.Err == "" {
return nil, false
}
extras := make([]nestedError, len(j.Extras))
for i, e := range j.Extras {
extra, ok := fromJSONObject(e)
if !ok {
return nil, false
}
extras[i] = *extra
}
return &nestedError{
subject: j.Subject,
err: errors.New(j.Err),
extras: extras,
}, true
}
// Check is a helper function that
// convert (T, error) to (T, NestedError).
func Check[T any](obj T, err error) (T, NestedError) {
return obj, From(err)
}
func Join(message string, err ...NestedError) NestedError {
extras := make([]nestedError, len(err))
nErr := 0
for i, e := range err {
if e == nil {
continue
}
extras[i] = *e
nErr += 1
}
if nErr == 0 {
return nil
}
return &nestedError{
err: errors.New(message),
extras: extras,
}
}
func JoinE(message string, err ...error) NestedError {
b := NewBuilder(message)
for _, e := range err {
b.AddE(e)
}
return b.Build()
}
func IsNil(err error) bool {
return err == nil
}
func IsNotNil(err error) bool {
return err != nil
}
func (ne NestedError) String() string {
var buf strings.Builder
ne.writeToSB(&buf, 0, "")
return buf.String()
}
func (ne NestedError) Is(err error) bool {
if ne == nil {
return err == nil
}
// return errors.Is(ne.err, err)
if errors.Is(ne.err, err) {
return true
}
for _, e := range ne.extras {
if e.Is(err) {
return true
}
}
return false
}
func (ne NestedError) IsNot(err error) bool {
return !ne.Is(err)
}
func (ne NestedError) Error() error {
if ne == nil {
return nil
}
return ne.buildError(0, "")
}
func (ne NestedError) With(s any) NestedError {
if ne == nil {
return ne
}
var msg string
switch ss := s.(type) {
case nil:
return ne
case NestedError:
return ne.withError(ss)
case error:
return ne.withError(From(ss))
case string:
msg = ss
case fmt.Stringer:
return ne.appendMsg(ss.String())
default:
return ne.appendMsg(fmt.Sprint(s))
}
return ne.withError(From(errors.New(msg)))
}
func (ne NestedError) Extraf(format string, args ...any) NestedError {
return ne.With(errorf(format, args...))
}
func (ne NestedError) Subject(s any) NestedError {
if ne == nil {
return ne
}
var subject string
switch ss := s.(type) {
case string:
subject = ss
case fmt.Stringer:
subject = ss.String()
default:
subject = fmt.Sprint(s)
}
if ne.subject == "" {
ne.subject = subject
} else {
ne.subject = fmt.Sprintf("%s > %s", subject, ne.subject)
}
return ne
}
func (ne NestedError) Subjectf(format string, args ...any) NestedError {
if ne == nil {
return ne
}
if strings.Contains(format, "%q") {
panic("Subjectf format should not contain %q")
}
if strings.Contains(format, "%w") {
panic("Subjectf format should not contain %w")
}
ne.subject = fmt.Sprintf(format, args...)
return ne
}
func (ne NestedError) JSONObject() jsonNestedError {
extras := make([]jsonNestedError, len(ne.extras))
for i, e := range ne.extras {
extras[i] = e.JSONObject()
}
return jsonNestedError{
Subject: ne.subject,
Err: ne.err.Error(),
Extras: extras,
}
}
func (ne NestedError) JSON() []byte {
b, _ := json.MarshalIndent(ne.JSONObject(), "", " ")
return b
}
func (ne NestedError) NoError() bool {
return ne == nil
}
func (ne NestedError) HasError() bool {
return ne != nil
}
func errorf(format string, args ...any) NestedError {
return From(fmt.Errorf(format, args...))
}
func fromJSONObject(obj jsonNestedError) (NestedError, bool) {
data, err := json.Marshal(obj)
if err != nil {
return nil, false
}
return FromJSON(data)
}
func (ne NestedError) withError(err NestedError) NestedError {
if ne != nil && err != nil {
ne.extras = append(ne.extras, *err)
}
return ne
}
func (ne NestedError) appendMsg(msg string) NestedError {
if ne == nil {
return nil
}
ne.err = fmt.Errorf("%w %s", ne.err, msg)
return ne
}
func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) {
for i := 0; i < level; i++ {
sb.WriteString(" ")
}
sb.WriteString(prefix)
if ne.NoError() {
sb.WriteString("nil")
return
}
sb.WriteString(ne.err.Error())
if ne.subject != "" {
sb.WriteString(fmt.Sprintf(" for %q", ne.subject))
}
if len(ne.extras) > 0 {
sb.WriteRune(':')
for _, extra := range ne.extras {
sb.WriteRune('\n')
extra.writeToSB(sb, level+1, "- ")
}
}
}
func (ne NestedError) buildError(level int, prefix string) error {
var res error
var sb strings.Builder
for i := 0; i < level; i++ {
sb.WriteString(" ")
}
sb.WriteString(prefix)
if ne.NoError() {
sb.WriteString("nil")
return errors.New(sb.String())
}
res = fmt.Errorf("%s%w", sb.String(), ne.err)
sb.Reset()
if ne.subject != "" {
sb.WriteString(fmt.Sprintf(" for %q", ne.subject))
}
if len(ne.extras) > 0 {
sb.WriteRune(':')
res = fmt.Errorf("%w%s", res, sb.String())
for _, extra := range ne.extras {
res = errors.Join(res, extra.buildError(level+1, "- "))
}
} else {
res = fmt.Errorf("%w%s", res, sb.String())
}
return res
}

View File

@@ -0,0 +1,109 @@
package error_test
import (
"errors"
"testing"
. "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestErrorIs(t *testing.T) {
ExpectTrue(t, Failure("foo").Is(ErrFailure))
ExpectTrue(t, Failure("foo").With("bar").Is(ErrFailure))
ExpectFalse(t, Failure("foo").With("bar").Is(ErrInvalid))
ExpectFalse(t, Failure("foo").With("bar").With("baz").Is(ErrInvalid))
ExpectTrue(t, Invalid("foo", "bar").Is(ErrInvalid))
ExpectFalse(t, Invalid("foo", "bar").Is(ErrFailure))
ExpectFalse(t, Invalid("foo", "bar").Is(nil))
ExpectTrue(t, errors.Is(Failure("foo").Error(), ErrFailure))
ExpectTrue(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrInvalid))
ExpectTrue(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrFailure))
ExpectFalse(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrNotExists))
}
func TestErrorNestedIs(t *testing.T) {
var err NestedError
ExpectTrue(t, err.Is(nil))
err = Failure("some reason")
ExpectTrue(t, err.Is(ErrFailure))
ExpectFalse(t, err.Is(ErrDuplicated))
err.With(Duplicated("something", ""))
ExpectTrue(t, err.Is(ErrFailure))
ExpectTrue(t, err.Is(ErrDuplicated))
ExpectFalse(t, err.Is(ErrInvalid))
}
func TestIsNil(t *testing.T) {
var err NestedError
ExpectTrue(t, err.Is(nil))
ExpectFalse(t, err.HasError())
ExpectTrue(t, err == nil)
ExpectTrue(t, err.NoError())
eb := NewBuilder("")
returnNil := func() error {
return eb.Build().Error()
}
ExpectTrue(t, IsNil(returnNil()))
ExpectTrue(t, returnNil() == nil)
ExpectTrue(t, (err.
Subject("any").
With("something").
Extraf("foo %s", "bar")) == nil)
}
func TestErrorSimple(t *testing.T) {
ne := Failure("foo bar")
ExpectEqual(t, ne.String(), "foo bar failed")
ne = ne.Subject("baz")
ExpectEqual(t, ne.String(), "foo bar failed for \"baz\"")
}
func TestErrorWith(t *testing.T) {
ne := Failure("foo").With("bar").With("baz")
ExpectEqual(t, ne.String(), "foo failed:\n - bar\n - baz")
}
func TestErrorNested(t *testing.T) {
inner := Failure("inner").
With("1").
With("1")
inner2 := Failure("inner2").
Subject("action 2").
With("2").
With("2")
inner3 := Failure("inner3").
Subject("action 3").
With("3").
With("3")
ne := Failure("foo").
With("bar").
With("baz").
With(inner).
With(inner.With(inner2.With(inner3)))
want :=
`foo failed:
- bar
- baz
- inner failed:
- 1
- 1
- inner failed:
- 1
- 1
- inner2 failed for "action 2":
- 2
- 2
- inner3 failed for "action 3":
- 3
- 3`
ExpectEqual(t, ne.String(), want)
ExpectEqual(t, ne.Error().Error(), want)
}

62
internal/error/errors.go Normal file
View File

@@ -0,0 +1,62 @@
package error
import (
stderrors "errors"
)
var (
ErrFailure = stderrors.New("failed")
ErrInvalid = stderrors.New("invalid")
ErrUnsupported = stderrors.New("unsupported")
ErrUnexpected = stderrors.New("unexpected")
ErrNotExists = stderrors.New("does not exist")
ErrMissing = stderrors.New("missing")
ErrDuplicated = stderrors.New("duplicated")
ErrOutOfRange = stderrors.New("out of range")
)
const fmtSubjectWhat = "%w %v: %q"
func Failure(what string) NestedError {
return errorf("%s %w", what, ErrFailure)
}
func FailedWhy(what string, why string) NestedError {
return Failure(what).With(why)
}
func FailWith(what string, err any) NestedError {
return Failure(what).With(err)
}
func Invalid(subject, what any) NestedError {
return errorf(fmtSubjectWhat, ErrInvalid, subject, what)
}
func Unsupported(subject, what any) NestedError {
return errorf(fmtSubjectWhat, ErrUnsupported, subject, what)
}
func Unexpected(subject, what any) NestedError {
return errorf(fmtSubjectWhat, ErrUnexpected, subject, what)
}
func UnexpectedError(err error) NestedError {
return errorf("%w error: %w", ErrUnexpected, err)
}
func NotExist(subject, what any) NestedError {
return errorf("%v %w: %v", subject, ErrNotExists, what)
}
func Missing(subject any) NestedError {
return errorf("%w %v", ErrMissing, subject)
}
func Duplicated(subject, what any) NestedError {
return errorf("%w %v: %v", ErrDuplicated, subject, what)
}
func OutOfRange(subject string, value any) NestedError {
return errorf("%v %w: %v", subject, ErrOutOfRange, value)
}

View File

@@ -0,0 +1,32 @@
package http
import (
"mime"
"net/http"
)
type ContentType string
func GetContentType(h http.Header) ContentType {
ct := h.Get("Content-Type")
if ct == "" {
return ""
}
ct, _, err := mime.ParseMediaType(ct)
if err != nil {
return ""
}
return ContentType(ct)
}
func (ct ContentType) IsHTML() bool {
return ct == "text/html" || ct == "application/xhtml+xml"
}
func (ct ContentType) IsJSON() bool {
return ct == "application/json"
}
func (ct ContentType) IsPlainText() bool {
return ct == "text/plain"
}

View File

@@ -0,0 +1,42 @@
package http
import (
"net/http"
"slices"
)
func RemoveHop(h http.Header) {
reqUpType := UpgradeType(h)
RemoveHopByHopHeaders(h)
if reqUpType != "" {
h.Set("Connection", "Upgrade")
h.Set("Upgrade", reqUpType)
} else {
h.Del("Connection")
}
}
func CopyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func FilterHeaders(h http.Header, allowed []string) {
if allowed == nil {
return
}
for i := range allowed {
allowed[i] = http.CanonicalHeaderKey(allowed[i])
}
for key := range h {
if !slices.Contains(allowed, key) {
h.Del(key)
}
}
}

View File

@@ -0,0 +1,96 @@
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/response_modifier.go)
// Copyright (c) 2020-2024 Traefik Labs
package http
import (
"bufio"
"fmt"
"net"
"net/http"
)
type ModifyResponseFunc func(*http.Response) error
type ModifyResponseWriter struct {
w http.ResponseWriter
r *http.Request
headerSent bool
code int
modifier ModifyResponseFunc
modified bool
modifierErr error
}
func NewModifyResponseWriter(w http.ResponseWriter, r *http.Request, f ModifyResponseFunc) *ModifyResponseWriter {
return &ModifyResponseWriter{
w: w,
r: r,
modifier: f,
code: http.StatusOK,
}
}
func (w *ModifyResponseWriter) WriteHeader(code int) {
if w.headerSent {
return
}
if code >= http.StatusContinue && code < http.StatusOK {
w.w.WriteHeader(code)
}
defer func() {
w.headerSent = true
w.code = code
}()
if w.modifier == nil || w.modified {
w.w.WriteHeader(code)
return
}
resp := http.Response{
Header: w.w.Header(),
Request: w.r,
}
if err := w.modifier(&resp); err != nil {
w.modifierErr = err
logger.Errorf("error modifying response: %s", err)
w.w.WriteHeader(http.StatusInternalServerError)
return
}
w.modified = true
w.w.WriteHeader(code)
}
func (w *ModifyResponseWriter) Header() http.Header {
return w.w.Header()
}
func (w *ModifyResponseWriter) Write(b []byte) (int, error) {
w.WriteHeader(w.code)
if w.modifierErr != nil {
return 0, w.modifierErr
}
return w.w.Write(b)
}
// Hijack hijacks the connection.
func (w *ModifyResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if h, ok := w.w.(http.Hijacker); ok {
return h.Hijack()
}
return nil, nil, fmt.Errorf("not a hijacker: %T", w.w)
}
// Flush sends any buffered data to the client.
func (w *ModifyResponseWriter) Flush() {
if flusher, ok := w.w.(http.Flusher); ok {
flusher.Flush()
}
}

View File

@@ -0,0 +1,556 @@
// Copyright 2011 The Go Authors.
// Modified from the Go project under the a BSD-style License (https://cs.opensource.google/go/go/+/refs/tags/go1.23.1:src/net/http/httputil/reverseproxy.go)
// https://cs.opensource.google/go/go/+/master:LICENSE
package http
// This is a small mod on net/http/httputil/reverseproxy.go
// that boosts performance in some cases
// and compatible to other modules of this project
// Copyright (c) 2024 yusing
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptrace"
"net/textproto"
"net/url"
"strings"
"github.com/sirupsen/logrus"
"golang.org/x/net/http/httpguts"
)
// A ProxyRequest contains a request to be rewritten by a [ReverseProxy].
type ProxyRequest struct {
// In is the request received by the proxy.
// The Rewrite function must not modify In.
In *http.Request
// Out is the request which will be sent by the proxy.
// The Rewrite function may modify or replace this request.
// Hop-by-hop headers are removed from this request
// before Rewrite is called.
Out *http.Request
}
// SetXForwarded sets the X-Forwarded-For, X-Forwarded-Host, and
// X-Forwarded-Proto headers of the outbound request.
//
// - The X-Forwarded-For header is set to the client IP address.
// - The X-Forwarded-Host header is set to the host name requested
// by the client.
// - The X-Forwarded-Proto header is set to "http" or "https", depending
// on whether the inbound request was made on a TLS-enabled connection.
//
// If the outbound request contains an existing X-Forwarded-For header,
// SetXForwarded appends the client IP address to it. To append to the
// inbound request's X-Forwarded-For header (the default behavior of
// [ReverseProxy] when using a Director function), copy the header
// from the inbound request before calling SetXForwarded:
//
// rewriteFunc := func(r *httputil.ProxyRequest) {
// r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
// r.SetXForwarded()
// }
func (r *ProxyRequest) SetXForwarded() {
clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
if err == nil {
r.Out.Header.Set("X-Forwarded-For", clientIP)
} else {
r.Out.Header.Del("X-Forwarded-For")
}
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
if r.In.TLS == nil {
r.Out.Header.Set("X-Forwarded-Proto", "http")
} else {
r.Out.Header.Set("X-Forwarded-Proto", "https")
}
}
func (r *ProxyRequest) AddXForwarded() {
clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
if err == nil {
prior := r.Out.Header["X-Forwarded-For"]
if len(prior) > 0 {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
r.Out.Header.Set("X-Forwarded-For", clientIP)
} else {
r.Out.Header.Del("X-Forwarded-For")
}
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
if r.In.TLS == nil {
r.Out.Header.Set("X-Forwarded-Proto", "http")
} else {
r.Out.Header.Set("X-Forwarded-Proto", "https")
}
}
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
//
// 1xx responses are forwarded to the client if the underlying
// transport supports ClientTrace.Got1xxResponse.
type ReverseProxy struct {
// Rewrite must be a function which modifies
// the request into a new request to be sent
// using Transport. Its response is then copied
// back to the original client unmodified.
// Rewrite must not access the provided ProxyRequest
// or its contents after returning.
//
// The Forwarded, X-Forwarded, X-Forwarded-Host,
// and X-Forwarded-Proto headers are removed from the
// outbound request before Rewrite is called. See also
// the ProxyRequest.SetXForwarded method.
//
// Unparsable query parameters are removed from the
// outbound request before Rewrite is called.
// The Rewrite function may copy the inbound URL's
// RawQuery to the outbound URL to preserve the original
// parameter string. Note that this can lead to security
// issues if the proxy's interpretation of query parameters
// does not match that of the downstream server.
//
// At most one of Rewrite or Director may be set.
Rewrite func(*ProxyRequest)
// The transport used to perform proxy requests.
// If nil, http.DefaultTransport is used.
Transport http.RoundTripper
// ModifyResponse is an optional function that modifies the
// Response from the backend. It is called if the backend
// returns a response at all, with any HTTP status code.
// If the backend is unreachable, the optional ErrorHandler is
// called before ModifyResponse.
//
// If ModifyResponse returns an error, ErrorHandler is called
// with its error value. If ErrorHandler is nil, its default
// implementation is used.
ModifyResponse func(*http.Response) error
ServeHTTP http.HandlerFunc
}
// A BufferPool is an interface for getting and returning temporary
// byte slices for use by [io.CopyBuffer].
type BufferPool interface {
Get() []byte
Put([]byte)
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
func joinURLPath(a, b *url.URL) (path, rawpath string) {
if a.RawPath == "" && b.RawPath == "" {
return singleJoiningSlash(a.Path, b.Path), ""
}
// Same as singleJoiningSlash, but uses EscapedPath to determine
// whether a slash should be added
apath := a.EscapedPath()
bpath := b.EscapedPath()
aslash := strings.HasSuffix(apath, "/")
bslash := strings.HasPrefix(bpath, "/")
switch {
case aslash && bslash:
return a.Path + b.Path[1:], apath + bpath[1:]
case !aslash && !bslash:
return a.Path + "/" + b.Path, apath + "/" + bpath
}
return a.Path + b.Path, apath + bpath
}
// NewReverseProxy returns a new [ReverseProxy] that routes
// URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir.
//
// NewReverseProxy does not rewrite the Host header.
//
// To customize the ReverseProxy behavior beyond what
// NewReverseProxy provides, use ReverseProxy directly
// with a Rewrite function. The ProxyRequest SetURL method
// may be used to route the outbound request. (Note that SetURL,
// unlike NewReverseProxy, rewrites the Host header
// of the outbound request by default.)
//
// proxy := &ReverseProxy{
// Rewrite: func(r *ProxyRequest) {
// r.SetURL(target)
// r.Out.Host = r.In.Host // if desired
// },
// }
//
func NewReverseProxy(target *url.URL, transport http.RoundTripper) *ReverseProxy {
rp := &ReverseProxy{
Rewrite: func(pr *ProxyRequest) {
rewriteRequestURL(pr.Out, target)
}, Transport: transport,
}
rp.ServeHTTP = rp.serveHTTP
return rp
}
func rewriteRequestURL(req *http.Request, target *url.URL) {
targetQuery := target.RawQuery
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
}
// Hop-by-hop headers. These are removed when sent to the backend.
// As of RFC 7230, hop-by-hop headers are required to appear in the
// Connection header field. These are the headers defined by the
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
// compatibility.
var hopHeaders = []string{
"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
"Upgrade",
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err error, writeHeader bool) {
logger.Errorf("http proxy to %s failed: %s", r.URL.String(), err)
if writeHeader {
rw.WriteHeader(http.StatusBadGateway)
}
}
// modifyResponse conditionally runs the optional ModifyResponse hook
// and reports whether the request should proceed.
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
if p.ModifyResponse == nil {
return true
}
if err := p.ModifyResponse(res); err != nil {
res.Body.Close()
p.errorHandler(rw, req, err, true)
return false
}
return true
}
func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport
ctx := req.Context()
if ctx.Done() != nil {
// CloseNotifier predates context.Context, and has been
// entirely superseded by it. If the request contains
// a Context that carries a cancellation signal, don't
// bother spinning up a goroutine to watch the CloseNotify
// channel (if any).
//
// If the request Context has a nil Done channel (which
// means it is either context.Background, or a custom
// Context implementation with no cancellation signal),
// then consult the CloseNotifier if available.
} else if cn, ok := rw.(http.CloseNotifier); ok {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
defer cancel()
notifyChan := cn.CloseNotify()
go func() {
select {
case <-notifyChan:
cancel()
case <-ctx.Done():
}
}()
}
outreq := req.Clone(ctx)
if req.ContentLength == 0 {
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
}
if outreq.Body != nil {
// Reading from the request body after returning from a handler is not
// allowed, and the RoundTrip goroutine that reads the Body can outlive
// this handler. This can lead to a crash if the handler panics (see
// Issue 46866). Although calling Close doesn't guarantee there isn't
// any Read in flight after the handle returns, in practice it's safe to
// read after closing it.
defer outreq.Body.Close()
}
if outreq.Header == nil {
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
}
outreq.Close = false
reqUpType := UpgradeType(outreq.Header)
if !IsPrint(reqUpType) {
p.errorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), true)
return
}
RemoveHopByHopHeaders(outreq.Header)
// Issue 21096: tell backend applications that care about trailer support
// that we support trailers. (We do, but we don't go out of our way to
// advertise that unless the incoming client request thought it was worth
// mentioning.) Note that we look at req.Header, not outreq.Header, since
// the latter has passed through removeHopByHopHeaders.
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
outreq.Header.Set("Te", "trailers")
}
// After stripping all the hop-by-hop connection headers above, add back any
// necessary for protocol upgrades, such as for websockets.
if reqUpType != "" {
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
}
outreq.Header.Del("Forwarded")
// outreq.Header.Del("X-Forwarded-For")
// outreq.Header.Del("X-Forwarded-Host")
// outreq.Header.Del("X-Forwarded-Proto")
pr := &ProxyRequest{
In: req,
Out: outreq,
}
p.Rewrite(pr)
outreq = pr.Out
if _, ok := outreq.Header["User-Agent"]; !ok {
// If the outbound request doesn't have a User-Agent header set,
// don't send the default Go HTTP client User-Agent.
outreq.Header.Set("User-Agent", "")
}
trace := &httptrace.ClientTrace{
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
h := rw.Header()
// copyHeader(h, http.Header(header))
for k, vv := range header {
for _, v := range vv {
h.Add(k, v)
}
}
rw.WriteHeader(code)
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
clear(h)
return nil
},
}
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
res, err := transport.RoundTrip(outreq)
if err != nil {
p.errorHandler(rw, outreq, err, false)
errMsg := err.Error()
res = &http.Response{
Status: http.StatusText(http.StatusBadGateway),
StatusCode: http.StatusBadGateway,
Proto: outreq.Proto,
ProtoMajor: outreq.ProtoMajor,
ProtoMinor: outreq.ProtoMinor,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader([]byte(errMsg))),
Request: outreq,
ContentLength: int64(len(errMsg)),
TLS: outreq.TLS,
}
}
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode == http.StatusSwitchingProtocols {
if !p.modifyResponse(rw, res, outreq) {
return
}
p.handleUpgradeResponse(rw, outreq, res)
return
}
if !p.modifyResponse(rw, res, outreq) {
return
}
copyHeader(rw.Header(), res.Header)
// The "Trailer" header isn't included in the Transport's response,
// at least for *http.Transport. Build it up from Trailer.
announcedTrailers := len(res.Trailer)
if announcedTrailers > 0 {
trailerKeys := make([]string, 0, len(res.Trailer))
for k := range res.Trailer {
trailerKeys = append(trailerKeys, k)
}
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
}
rw.WriteHeader(res.StatusCode)
_, err = io.Copy(rw, res.Body)
if err != nil {
p.errorHandler(rw, req, err, true)
res.Body.Close()
return
}
res.Body.Close() // close now, instead of defer, to populate res.Trailer
if len(res.Trailer) > 0 {
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
// bodies and adding a Content-Length.
http.NewResponseController(rw).Flush()
}
if len(res.Trailer) == announcedTrailers {
copyHeader(rw.Header(), res.Trailer)
return
}
for k, vv := range res.Trailer {
k = http.TrailerPrefix + k
for _, v := range vv {
rw.Header().Add(k, v)
}
}
}
func UpgradeType(h http.Header) string {
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
return ""
}
return h.Get("Upgrade")
}
// RemoveHopByHopHeaders removes hop-by-hop headers.
func RemoveHopByHopHeaders(h http.Header) {
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
for _, f := range h["Connection"] {
for _, sf := range strings.Split(f, ",") {
if sf = textproto.TrimString(sf); sf != "" {
h.Del(sf)
}
}
}
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
// This behavior is superseded by the RFC 7230 Connection header, but
// preserve it for backwards compatibility.
for _, f := range hopHeaders {
h.Del(f)
}
}
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
reqUpType := UpgradeType(req.Header)
resUpType := UpgradeType(res.Header)
if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), true)
}
if !strings.EqualFold(reqUpType, resUpType) {
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType), true)
return
}
backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok {
p.errorHandler(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"), true)
return
}
rc := http.NewResponseController(rw)
conn, brw, hijackErr := rc.Hijack()
if errors.Is(hijackErr, http.ErrNotSupported) {
p.errorHandler(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw), true)
return
}
backConnCloseCh := make(chan bool)
go func() {
// Ensure that the cancellation of a request closes the backend.
// See issue https://golang.org/issue/35559.
select {
case <-req.Context().Done():
case <-backConnCloseCh:
}
backConn.Close()
}()
defer close(backConnCloseCh)
if hijackErr != nil {
p.errorHandler(rw, req, fmt.Errorf("hijack failed on protocol switch: %w", hijackErr), true)
return
}
defer conn.Close()
copyHeader(rw.Header(), res.Header)
res.Header = rw.Header()
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
if err := res.Write(brw); err != nil {
p.errorHandler(rw, req, fmt.Errorf("response write: %s", err), true)
return
}
if err := brw.Flush(); err != nil {
p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true)
return
}
errc := make(chan error, 1)
go func() {
_, err := io.Copy(conn, backConn)
errc <- err
}()
go func() {
_, err := io.Copy(backConn, conn)
errc <- err
}()
<-errc
}
func IsPrint(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] < ' ' || s[i] > '~' {
return false
}
}
return true
}
var logger = logrus.WithField("module", "http")

View File

@@ -0,0 +1,7 @@
package http
import "net/http"
func IsSuccess(status int) bool {
return status >= http.StatusOK && status < http.StatusMultipleChoices
}

View File

@@ -0,0 +1,13 @@
package model
type (
AutoCertConfig struct {
Email string `json:"email"`
Domains []string `yaml:",flow" json:"domains"`
CertPath string `yaml:"cert_path" json:"cert_path"`
KeyPath string `yaml:"key_path" json:"key_path"`
Provider string `json:"provider"`
Options AutocertProviderOpt `yaml:",flow" json:"options"`
}
AutocertProviderOpt map[string]any
)

17
internal/models/config.go Normal file
View File

@@ -0,0 +1,17 @@
package model
type Config struct {
Providers ProxyProviders `yaml:",flow" json:"providers"`
AutoCert AutoCertConfig `yaml:",flow" json:"autocert"`
MatchDomains []string `yaml:"match_domains" json:"match_domains"`
TimeoutShutdown int `yaml:"timeout_shutdown" json:"timeout_shutdown"`
RedirectToHTTPS bool `yaml:"redirect_to_https" json:"redirect_to_https"`
}
func DefaultConfig() *Config {
return &Config{
Providers: ProxyProviders{},
TimeoutShutdown: 3,
RedirectToHTTPS: false,
}
}

View File

@@ -0,0 +1,6 @@
package model
type ProxyProviders struct {
Files []string `yaml:"include" json:"include"` // docker, file
Docker map[string]string `yaml:"docker" json:"docker"`
}

View File

@@ -0,0 +1,151 @@
package model
import (
"fmt"
"strconv"
"strings"
. "github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
RawEntry struct {
// raw entry object before validation
// loaded from docker labels or yaml file
Alias string `yaml:"-" json:"-"`
Scheme string `yaml:"scheme" json:"scheme"`
Host string `yaml:"host" json:"host"`
Port string `yaml:"port" json:"port"`
NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify"` // https proxy only
PathPatterns []string `yaml:"path_patterns" json:"path_patterns"` // http(s) proxy only
Middlewares D.NestedLabelMap `yaml:"middlewares" json:"middlewares"`
/* Docker only */
*D.ProxyProperties `yaml:"-" json:"proxy_properties"`
}
RawEntries = F.Map[string, *RawEntry]
)
var NewProxyEntries = F.NewMapOf[string, *RawEntry]
func (e *RawEntry) FillMissingFields() bool {
isDocker := e.ProxyProperties != nil
if !isDocker {
e.ProxyProperties = &D.ProxyProperties{}
}
lp, pp, extra := e.splitPorts()
if port, ok := ServiceNamePortMapTCP[e.ImageName]; ok {
if pp == "" {
pp = strconv.Itoa(port)
}
if e.Scheme == "" {
e.Scheme = "tcp"
}
} else if port, ok := ImageNamePortMap[e.ImageName]; ok {
if pp == "" {
pp = strconv.Itoa(port)
}
if e.Scheme == "" {
e.Scheme = "http"
}
} else if pp == "" && e.Scheme == "https" {
pp = "443"
} else if pp == "" {
if p, ok := F.FirstValueOf(e.PrivatePortMapping); ok {
pp = fmt.Sprint(p.PrivatePort)
} else {
pp = "80"
}
}
// replace private port with public port (if any)
if isDocker && e.NetworkMode != "host" {
if p, ok := e.PrivatePortMapping[pp]; ok {
pp = fmt.Sprint(p.PublicPort)
}
if _, ok := e.PublicPortMapping[pp]; !ok { // port is not exposed, but specified
// try to fallback to first public port
if p, ok := F.FirstValueOf(e.PublicPortMapping); ok {
pp = fmt.Sprint(p.PublicPort)
}
// ignore only if it is NOT RUNNING
// because stopped containers
// will have empty port mapping got from docker
if e.Running {
return false
}
}
}
if e.Scheme == "" && isDocker {
if p, ok := e.PublicPortMapping[pp]; ok && p.Type == "udp" {
e.Scheme = "udp"
}
}
if e.Scheme == "" {
if lp != "" {
e.Scheme = "tcp"
} else if strings.HasSuffix(pp, "443") {
e.Scheme = "https"
} else if _, ok := WellKnownHTTPPorts[pp]; ok {
e.Scheme = "http"
} else {
// assume its http
e.Scheme = "http"
}
}
if e.Host == "" {
e.Host = "localhost"
}
if e.IdleTimeout == "" {
e.IdleTimeout = IdleTimeoutDefault
}
if e.WakeTimeout == "" {
e.WakeTimeout = WakeTimeoutDefault
}
if e.StopTimeout == "" {
e.StopTimeout = StopTimeoutDefault
}
if e.StopMethod == "" {
e.StopMethod = StopMethodDefault
}
e.Port = joinPorts(lp, pp, extra)
return true
}
func (e *RawEntry) splitPorts() (lp string, pp string, extra string) {
portSplit := strings.Split(e.Port, ":")
if len(portSplit) == 1 {
pp = portSplit[0]
} else {
lp = portSplit[0]
pp = portSplit[1]
}
if len(portSplit) > 2 {
extra = strings.Join(portSplit[2:], ":")
}
return
}
func joinPorts(lp string, pp string, extra string) string {
s := make([]string, 0, 3)
if lp != "" {
s = append(s, lp)
}
if pp != "" {
s = append(s, pp)
}
if extra != "" {
s = append(s, extra)
}
return strings.Join(s, ":")
}

142
internal/proxy/entry.go Normal file
View File

@@ -0,0 +1,142 @@
package proxy
import (
"fmt"
"net/url"
"time"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
M "github.com/yusing/go-proxy/internal/models"
T "github.com/yusing/go-proxy/internal/proxy/fields"
)
type (
ReverseProxyEntry struct { // real model after validation
Alias T.Alias
Scheme T.Scheme
URL *url.URL
NoTLSVerify bool
PathPatterns T.PathPatterns
Middlewares D.NestedLabelMap
/* Docker only */
IdleTimeout time.Duration
WakeTimeout time.Duration
StopMethod T.StopMethod
StopTimeout int
StopSignal T.Signal
DockerHost string
ContainerName string
ContainerRunning bool
}
StreamEntry struct {
Alias T.Alias `json:"alias"`
Scheme T.StreamScheme `json:"scheme"`
Host T.Host `json:"host"`
Port T.StreamPort `json:"port"`
}
)
func (rp *ReverseProxyEntry) UseIdleWatcher() bool {
return rp.IdleTimeout > 0 && rp.DockerHost != ""
}
func ValidateEntry(m *M.RawEntry) (any, E.NestedError) {
if !m.FillMissingFields() {
return nil, E.Missing("fields")
}
scheme, err := T.NewScheme(m.Scheme)
if err.HasError() {
return nil, err
}
var entry any
e := E.NewBuilder("error validating entry")
if scheme.IsStream() {
entry = validateStreamEntry(m, e)
} else {
entry = validateRPEntry(m, scheme, e)
}
if err := e.Build(); err.HasError() {
return nil, err
}
return entry, nil
}
func validateRPEntry(m *M.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry {
var stopTimeOut time.Duration
host, err := T.ValidateHost(m.Host)
b.Add(err)
port, err := T.ValidatePort(m.Port)
b.Add(err)
pathPatterns, err := T.ValidatePathPatterns(m.PathPatterns)
b.Add(err)
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
b.Add(err)
idleTimeout, err := T.ValidateDurationPostitive(m.IdleTimeout)
b.Add(err)
wakeTimeout, err := T.ValidateDurationPostitive(m.WakeTimeout)
b.Add(err)
stopMethod, err := T.ValidateStopMethod(m.StopMethod)
b.Add(err)
if stopMethod == T.StopMethodStop {
stopTimeOut, err = T.ValidateDurationPostitive(m.StopTimeout)
b.Add(err)
}
stopSignal, err := T.ValidateSignal(m.StopSignal)
b.Add(err)
if err.HasError() {
return nil
}
return &ReverseProxyEntry{
Alias: T.NewAlias(m.Alias),
Scheme: s,
URL: url,
NoTLSVerify: m.NoTLSVerify,
PathPatterns: pathPatterns,
Middlewares: m.Middlewares,
IdleTimeout: idleTimeout,
WakeTimeout: wakeTimeout,
StopMethod: stopMethod,
StopTimeout: int(stopTimeOut.Seconds()), // docker api takes integer seconds for timeout argument
StopSignal: stopSignal,
DockerHost: m.DockerHost,
ContainerName: m.ContainerName,
ContainerRunning: m.Running,
}
}
func validateStreamEntry(m *M.RawEntry, b E.Builder) *StreamEntry {
host, err := T.ValidateHost(m.Host)
b.Add(err)
port, err := T.ValidateStreamPort(m.Port)
b.Add(err)
scheme, err := T.ValidateStreamScheme(m.Scheme)
b.Add(err)
if b.HasError() {
return nil
}
return &StreamEntry{
Alias: T.NewAlias(m.Alias),
Scheme: *scheme,
Host: host,
Port: port,
}
}

View File

@@ -0,0 +1,6 @@
package fields
type (
Alias string
NewAlias = Alias
)

View File

@@ -0,0 +1,19 @@
package fields
import (
"net/http"
"strings"
E "github.com/yusing/go-proxy/internal/error"
)
func ValidateHTTPHeaders(headers map[string]string) (http.Header, E.NestedError) {
h := make(http.Header)
for k, v := range headers {
vSplit := strings.Split(v, ",")
for _, header := range vSplit {
h.Add(k, strings.TrimSpace(header))
}
}
return h, nil
}

View File

@@ -0,0 +1,12 @@
package fields
import (
E "github.com/yusing/go-proxy/internal/error"
)
type Host string
type Subdomain = Alias
func ValidateHost[String ~string](s String) (Host, E.NestedError) {
return Host(s), nil
}

View File

@@ -0,0 +1,24 @@
package fields
import (
E "github.com/yusing/go-proxy/internal/error"
)
type PathMode string
func NewPathMode(pm string) (PathMode, E.NestedError) {
switch pm {
case "", "forward":
return PathMode(pm), nil
default:
return "", E.Invalid("path mode", pm)
}
}
func (p PathMode) IsRemove() bool {
return p == ""
}
func (p PathMode) IsForward() bool {
return p == "forward"
}

View File

@@ -0,0 +1,37 @@
package fields
import (
"regexp"
E "github.com/yusing/go-proxy/internal/error"
)
type PathPattern string
type PathPatterns = []PathPattern
func NewPathPattern(s string) (PathPattern, E.NestedError) {
if len(s) == 0 {
return "", E.Invalid("path", "must not be empty")
}
if !pathPattern.MatchString(s) {
return "", E.Invalid("path pattern", s)
}
return PathPattern(s), nil
}
func ValidatePathPatterns(s []string) (PathPatterns, E.NestedError) {
if len(s) == 0 {
return []PathPattern{"/"}, nil
}
pp := make(PathPatterns, len(s))
for i, v := range s {
if pattern, err := NewPathPattern(v); err.HasError() {
return nil, err
} else {
pp[i] = pattern
}
}
return pp, nil
}
var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`)

View File

@@ -0,0 +1,47 @@
package fields
import (
"testing"
E "github.com/yusing/go-proxy/internal/error"
U "github.com/yusing/go-proxy/internal/utils/testing"
)
var validPatterns = []string{
"/",
"/index.html",
"/somepage/",
"/drive/abc.mp4",
"/{$}",
"/some-page/{$}",
"GET /",
"GET /static/{$}",
"GET /drive/abc.mp4",
"GET /drive/abc.mp4/{$}",
"POST /auth",
"DELETE /user/",
"PUT /storage/id/",
}
var invalidPatterns = []string{
"/$",
"/{$}{$}",
"/{$}/{$}",
"/index.html$",
"get /",
"GET/",
"GET /$",
"GET /drive/{$}/abc.mp4/",
"OPTION /config/{$}/abc.conf/{$}",
}
func TestPathPatternRegex(t *testing.T) {
for _, pattern := range validPatterns {
_, err := NewPathPattern(pattern)
U.ExpectNoError(t, err.Error())
}
for _, pattern := range invalidPatterns {
_, err := NewPathPattern(pattern)
U.ExpectError2(t, pattern, E.ErrInvalid, err.Error())
}
}

View File

@@ -0,0 +1,41 @@
package fields
import (
"strconv"
E "github.com/yusing/go-proxy/internal/error"
)
type Port int
func ValidatePort[String ~string](v String) (Port, E.NestedError) {
p, err := strconv.Atoi(string(v))
if err != nil {
return ErrPort, E.Invalid("port number", v).With(err)
}
return ValidatePortInt(p)
}
func ValidatePortInt[Int int | uint16](v Int) (Port, E.NestedError) {
p := Port(v)
if !p.inBound() {
return ErrPort, E.OutOfRange("port", p)
}
return p, nil
}
func (p Port) inBound() bool {
return p >= MinPort && p <= MaxPort
}
func (p Port) String() string {
return strconv.Itoa(int(p))
}
const (
MinPort = 0
MaxPort = 65535
ErrPort = Port(-1)
NoPort = Port(-1)
ZeroPort = Port(0)
)

View File

@@ -0,0 +1,21 @@
package fields
import (
E "github.com/yusing/go-proxy/internal/error"
)
type Scheme string
func NewScheme[String ~string](s String) (Scheme, E.NestedError) {
switch s {
case "http", "https", "tcp", "udp":
return Scheme(s), nil
}
return "", E.Invalid("scheme", s)
}
func (s Scheme) IsHTTP() bool { return s == "http" }
func (s Scheme) IsHTTPS() bool { return s == "https" }
func (s Scheme) IsTCP() bool { return s == "tcp" }
func (s Scheme) IsUDP() bool { return s == "udp" }
func (s Scheme) IsStream() bool { return s.IsTCP() || s.IsUDP() }

View File

@@ -0,0 +1,17 @@
package fields
import (
E "github.com/yusing/go-proxy/internal/error"
)
type Signal string
func ValidateSignal(s string) (Signal, E.NestedError) {
switch s {
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
"INT", "TERM", "HUP", "QUIT":
return Signal(s), nil
}
return "", E.Invalid("signal", s)
}

View File

@@ -0,0 +1,23 @@
package fields
import (
E "github.com/yusing/go-proxy/internal/error"
)
type StopMethod string
const (
StopMethodPause StopMethod = "pause"
StopMethodStop StopMethod = "stop"
StopMethodKill StopMethod = "kill"
)
func ValidateStopMethod(s string) (StopMethod, E.NestedError) {
sm := StopMethod(s)
switch sm {
case StopMethodPause, StopMethodStop, StopMethodKill:
return sm, nil
default:
return "", E.Invalid("stop_method", sm)
}
}

View File

@@ -0,0 +1,56 @@
package fields
import (
"strings"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
)
type StreamPort struct {
ListeningPort Port `json:"listening"`
ProxyPort Port `json:"proxy"`
}
func ValidateStreamPort(p string) (StreamPort, E.NestedError) {
split := strings.Split(p, ":")
switch len(split) {
case 1:
split = []string{"0", split[0]}
case 2:
break
default:
return ErrStreamPort, E.Invalid("stream port", p).With("too many colons")
}
listeningPort, err := ValidatePort(split[0])
if err != nil {
return ErrStreamPort, err.Subject("listening port")
}
proxyPort, err := ValidatePort(split[1])
if err.Is(E.ErrOutOfRange) {
return ErrStreamPort, err.Subject("proxy port")
} else if proxyPort == 0 {
return ErrStreamPort, E.Invalid("proxy port", p)
} else if err != nil {
proxyPort, err = parseNameToPort(split[1])
if err != nil {
return ErrStreamPort, E.Invalid("proxy port", proxyPort)
}
}
return StreamPort{listeningPort, proxyPort}, nil
}
func parseNameToPort(name string) (Port, E.NestedError) {
port, ok := common.ServiceNamePortMapTCP[name]
if !ok {
return ErrPort, E.Invalid("service", name)
}
return Port(port), nil
}
var ErrStreamPort = StreamPort{ErrPort, ErrPort}

View File

@@ -0,0 +1,48 @@
package fields
import (
"testing"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
var validPorts = []string{
"1234:5678",
"0:2345",
"2345",
"1234:postgres",
}
var invalidPorts = []string{
"",
"123:",
"0:",
":1234",
"1234:1234:1234",
"qwerty",
"asdfgh:asdfgh",
"1234:asdfgh",
}
var outOfRangePorts = []string{
"-1:1234",
"1234:-1",
"65536",
"0:65536",
}
func TestStreamPort(t *testing.T) {
for _, port := range validPorts {
_, err := ValidateStreamPort(port)
ExpectNoError(t, err.Error())
}
for _, port := range invalidPorts {
_, err := ValidateStreamPort(port)
ExpectError2(t, port, E.ErrInvalid, err.Error())
}
for _, port := range outOfRangePorts {
_, err := ValidateStreamPort(port)
ExpectError2(t, port, E.ErrOutOfRange, err.Error())
}
}

View File

@@ -0,0 +1,43 @@
package fields
import (
"fmt"
"strings"
E "github.com/yusing/go-proxy/internal/error"
)
type StreamScheme struct {
ListeningScheme Scheme `json:"listening"`
ProxyScheme Scheme `json:"proxy"`
}
func ValidateStreamScheme(s string) (ss *StreamScheme, err E.NestedError) {
ss = &StreamScheme{}
parts := strings.Split(s, ":")
if len(parts) == 1 {
parts = []string{s, s}
} else if len(parts) != 2 {
return nil, E.Invalid("stream scheme", s)
}
ss.ListeningScheme, err = NewScheme(parts[0])
if err.HasError() {
return nil, err
}
ss.ProxyScheme, err = NewScheme(parts[1])
if err.HasError() {
return nil, err
}
return ss, nil
}
func (s StreamScheme) String() string {
return fmt.Sprintf("%s:%s", s.ListeningScheme, s.ProxyScheme)
}
// IsCoherent checks if the ListeningScheme and ProxyScheme of the StreamScheme are equal.
//
// It returns a boolean value indicating whether the ListeningScheme and ProxyScheme are equal.
func (s StreamScheme) IsCoherent() bool {
return s.ListeningScheme == s.ProxyScheme
}

View File

@@ -0,0 +1,18 @@
package fields
import (
"time"
E "github.com/yusing/go-proxy/internal/error"
)
func ValidateDurationPostitive(value string) (time.Duration, E.NestedError) {
d, err := time.ParseDuration(value)
if err != nil {
return 0, E.Invalid("duration", value)
}
if d < 0 {
return 0, E.Invalid("duration", "negative value")
}
return d, nil
}

206
internal/proxy/provider/docker.go Executable file
View File

@@ -0,0 +1,206 @@
package provider
import (
"fmt"
"regexp"
"strconv"
"strings"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
M "github.com/yusing/go-proxy/internal/models"
R "github.com/yusing/go-proxy/internal/route"
W "github.com/yusing/go-proxy/internal/watcher"
)
type DockerProvider struct {
dockerHost, hostname string
}
var AliasRefRegex = regexp.MustCompile(`#\d+`)
var AliasRefRegexOld = regexp.MustCompile(`\$\d+`)
func DockerProviderImpl(dockerHost string) (ProviderImpl, E.NestedError) {
hostname, err := D.ParseDockerHostname(dockerHost)
if err.HasError() {
return nil, err
}
return &DockerProvider{dockerHost: dockerHost, hostname: hostname}, nil
}
func (p *DockerProvider) String() string {
return fmt.Sprintf("docker:%s", p.dockerHost)
}
func (p *DockerProvider) NewWatcher() W.Watcher {
return W.NewDockerWatcher(p.dockerHost)
}
func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
routes = R.NewRoutes()
entries := M.NewProxyEntries()
info, err := D.GetClientInfo(p.dockerHost, true)
if err.HasError() {
return routes, E.FailWith("connect to docker", err)
}
errors := E.NewBuilder("errors when parse docker labels")
for _, c := range info.Containers {
container := D.FromDocker(&c, p.dockerHost)
if container.IsExcluded {
continue
}
newEntries, err := p.entriesFromContainerLabels(container)
if err.HasError() {
errors.Add(err)
}
// although err is not nil
// there may be some valid entries in `en`
dups := entries.MergeFrom(newEntries)
// add the duplicate proxy entries to the error
dups.RangeAll(func(k string, v *M.RawEntry) {
errors.Addf("duplicate alias %s", k)
})
}
entries.RangeAll(func(_ string, e *M.RawEntry) {
e.DockerHost = p.dockerHost
})
routes, err = R.FromEntries(entries)
errors.Add(err)
return routes, errors.Build()
}
func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) {
b := E.NewBuilder("event %s error", event)
defer b.To(&res.err)
routes.RangeAll(func(k string, v R.Route) {
if v.Entry().ContainerName == event.ActorName {
b.Add(v.Stop())
routes.Delete(k)
res.nRemoved++
}
})
client, err := D.ConnectClient(p.dockerHost)
if err.HasError() {
b.Add(E.FailWith("connect to docker", err))
return
}
defer client.Close()
cont, err := client.Inspect(event.ActorID)
if err.HasError() {
b.Add(E.FailWith("inspect container", err))
return
}
entries, err := p.entriesFromContainerLabels(cont)
b.Add(err)
entries.RangeAll(func(alias string, entry *M.RawEntry) {
if routes.Has(alias) {
b.Add(E.Duplicated("alias", alias))
} else {
if route, err := R.NewRoute(entry); err.HasError() {
b.Add(err)
} else {
routes.Store(alias, route)
b.Add(route.Start())
res.nAdded++
}
}
})
return
}
// Returns a list of proxy entries for a container.
// Always non-nil
func (p *DockerProvider) entriesFromContainerLabels(container D.Container) (M.RawEntries, E.NestedError) {
entries := M.NewProxyEntries()
if container.IsExcluded {
return entries, nil
}
// init entries map for all aliases
for _, a := range container.Aliases {
entries.Store(a, &M.RawEntry{
Alias: a,
Host: p.hostname,
ProxyProperties: container.ProxyProperties,
})
}
errors := E.NewBuilder("failed to apply label")
for key, val := range container.Labels {
errors.Add(p.applyLabel(container, entries, key, val))
}
// remove all entries that failed to fill in missing fields
entries.RemoveAll(func(re *M.RawEntry) bool {
return !re.FillMissingFields()
})
return entries, errors.Build().Subject(container.ContainerName)
}
func (p *DockerProvider) applyLabel(container D.Container, entries M.RawEntries, key, val string) (res E.NestedError) {
b := E.NewBuilder("errors in label %s", key)
defer b.To(&res)
refErr := E.NewBuilder("errors parsing alias references")
replaceIndexRef := func(ref string) string {
index, err := strconv.Atoi(ref[1:])
if err != nil {
refErr.Add(E.Invalid("integer", ref))
return ref
}
if index < 1 || index > len(container.Aliases) {
refErr.Add(E.OutOfRange("index", ref))
return ref
}
return container.Aliases[index-1]
}
lbl, err := D.ParseLabel(key, val)
if err.HasError() {
b.Add(err.Subject(key))
}
if lbl.Namespace != D.NSProxy {
return
}
if lbl.Target == D.WildcardAlias {
// apply label for all aliases
entries.RangeAll(func(a string, e *M.RawEntry) {
if err = D.ApplyLabel(e, lbl); err.HasError() {
b.Add(err.Subjectf("alias %s", lbl.Target))
}
})
} else {
lbl.Target = AliasRefRegex.ReplaceAllStringFunc(lbl.Target, replaceIndexRef)
lbl.Target = AliasRefRegexOld.ReplaceAllStringFunc(lbl.Target, func(s string) string {
logrus.Warnf("%q should now be %q, old syntax will be removed in a future version", lbl, strings.ReplaceAll(lbl.String(), "$", "#"))
return replaceIndexRef(s)
})
if refErr.HasError() {
b.Add(refErr.Build())
return
}
config, ok := entries.Load(lbl.Target)
if !ok {
b.Add(E.NotExist("alias", lbl.Target))
return
}
if err = D.ApplyLabel(config, lbl); err.HasError() {
b.Add(err.Subjectf("alias %s", lbl.Target))
}
}
return
}

View File

@@ -0,0 +1,307 @@
package provider
import (
"strings"
"testing"
"github.com/docker/docker/api/types"
"github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
P "github.com/yusing/go-proxy/internal/proxy"
T "github.com/yusing/go-proxy/internal/proxy/fields"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
var dummyNames = []string{"/a"}
func TestApplyLabelFieldValidity(t *testing.T) {
pathPatterns := `
- /
- POST /upload/{$}
- GET /static
`[1:]
pathPatternsExpect := []string{
"/",
"POST /upload/{$}",
"GET /static",
}
middlewaresExpect := D.NestedLabelMap{
"middleware1": {
"prop1": "value1",
"prop2": "value2",
},
"middleware2": {
"prop3": "value3",
"prop4": "value4",
},
}
var p DockerProvider
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a,b",
D.LabelIdleTimeout: common.IdleTimeoutDefault,
D.LabelStopMethod: common.StopMethodDefault,
D.LabelStopSignal: "SIGTERM",
D.LabelStopTimeout: common.StopTimeoutDefault,
D.LabelWakeTimeout: common.WakeTimeoutDefault,
"proxy.*.no_tls_verify": "true",
"proxy.*.scheme": "https",
"proxy.*.host": "app",
"proxy.*.port": "4567",
"proxy.a.no_tls_verify": "true",
"proxy.a.path_patterns": pathPatterns,
"proxy.a.middlewares.middleware1.prop1": "value1",
"proxy.a.middlewares.middleware1.prop2": "value2",
"proxy.a.middlewares.middleware2.prop3": "value3",
"proxy.a.middlewares.middleware2.prop4": "value4",
},
Ports: []types.Port{
{Type: "tcp", PrivatePort: 4567, PublicPort: 8888},
}}, ""))
ExpectNoError(t, err.Error())
a, ok := entries.Load("a")
ExpectTrue(t, ok)
b, ok := entries.Load("b")
ExpectTrue(t, ok)
ExpectEqual(t, a.Scheme, "https")
ExpectEqual(t, b.Scheme, "https")
ExpectEqual(t, a.Host, "app")
ExpectEqual(t, b.Host, "app")
ExpectEqual(t, a.Port, "8888")
ExpectEqual(t, b.Port, "8888")
ExpectTrue(t, a.NoTLSVerify)
ExpectTrue(t, b.NoTLSVerify)
ExpectDeepEqual(t, a.PathPatterns, pathPatternsExpect)
ExpectEqual(t, len(b.PathPatterns), 0)
ExpectDeepEqual(t, a.Middlewares, middlewaresExpect)
ExpectEqual(t, len(b.Middlewares), 0)
ExpectEqual(t, a.IdleTimeout, common.IdleTimeoutDefault)
ExpectEqual(t, b.IdleTimeout, common.IdleTimeoutDefault)
ExpectEqual(t, a.StopTimeout, common.StopTimeoutDefault)
ExpectEqual(t, b.StopTimeout, common.StopTimeoutDefault)
ExpectEqual(t, a.StopMethod, common.StopMethodDefault)
ExpectEqual(t, b.StopMethod, common.StopMethodDefault)
ExpectEqual(t, a.WakeTimeout, common.WakeTimeoutDefault)
ExpectEqual(t, b.WakeTimeout, common.WakeTimeoutDefault)
ExpectEqual(t, a.StopSignal, "SIGTERM")
ExpectEqual(t, b.StopSignal, "SIGTERM")
}
func TestApplyLabel(t *testing.T) {
var p DockerProvider
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a,b,c",
"proxy.a.no_tls_verify": "true",
"proxy.a.port": "3333",
"proxy.b.port": "1234",
"proxy.c.scheme": "https",
},
Ports: []types.Port{
{Type: "tcp", PrivatePort: 3333, PublicPort: 1111},
{Type: "tcp", PrivatePort: 4444, PublicPort: 1234},
}}, "",
))
a, ok := entries.Load("a")
ExpectTrue(t, ok)
b, ok := entries.Load("b")
ExpectTrue(t, ok)
c, ok := entries.Load("c")
ExpectTrue(t, ok)
ExpectNoError(t, err.Error())
ExpectEqual(t, a.Scheme, "http")
ExpectEqual(t, a.Port, "1111")
ExpectEqual(t, a.NoTLSVerify, true)
ExpectEqual(t, b.Scheme, "http")
ExpectEqual(t, b.Port, "1234")
ExpectEqual(t, c.Scheme, "https")
// map does not necessary follow the order above
ExpectEqualAny(t, c.Port, []string{"1111", "1234"})
}
func TestApplyLabelWithRef(t *testing.T) {
var p DockerProvider
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a,b,c",
"proxy.#1.host": "localhost",
"proxy.#1.port": "4444",
"proxy.#2.port": "9999",
"proxy.#3.port": "1111",
"proxy.#3.scheme": "https",
},
Ports: []types.Port{
{Type: "tcp", PrivatePort: 3333, PublicPort: 9999},
{Type: "tcp", PrivatePort: 4444, PublicPort: 5555},
{Type: "tcp", PrivatePort: 1111, PublicPort: 2222},
}}, ""))
a, ok := entries.Load("a")
ExpectTrue(t, ok)
b, ok := entries.Load("b")
ExpectTrue(t, ok)
c, ok := entries.Load("c")
ExpectTrue(t, ok)
ExpectNoError(t, err.Error())
ExpectEqual(t, a.Scheme, "http")
ExpectEqual(t, a.Host, "localhost")
ExpectEqual(t, a.Port, "5555")
ExpectEqual(t, b.Port, "9999")
ExpectEqual(t, c.Scheme, "https")
ExpectEqual(t, c.Port, "2222")
}
func TestApplyLabelWithRefIndexError(t *testing.T) {
var p DockerProvider
var c = D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a,b",
"proxy.#1.host": "localhost",
"proxy.#4.scheme": "https",
}}, "")
_, err := p.entriesFromContainerLabels(c)
ExpectError(t, E.ErrOutOfRange, err.Error())
ExpectTrue(t, strings.Contains(err.String(), "index out of range"))
_, err = p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a,b",
"proxy.#0.host": "localhost",
}}, ""))
ExpectError(t, E.ErrOutOfRange, err.Error())
ExpectTrue(t, strings.Contains(err.String(), "index out of range"))
}
func TestStreamDefaultValues(t *testing.T) {
var p DockerProvider
var c = D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a",
"proxy.*.no_tls_verify": "true",
},
Ports: []types.Port{
{Type: "udp", PrivatePort: 1234, PublicPort: 5678},
}}, "",
)
entries, err := p.entriesFromContainerLabels(c)
ExpectNoError(t, err.Error())
raw, ok := entries.Load("a")
ExpectTrue(t, ok)
entry, err := P.ValidateEntry(raw)
ExpectNoError(t, err.Error())
a := ExpectType[*P.StreamEntry](t, entry)
ExpectEqual(t, a.Scheme.ListeningScheme, T.Scheme("udp"))
ExpectEqual(t, a.Scheme.ProxyScheme, T.Scheme("udp"))
ExpectEqual(t, a.Port.ListeningPort, 0)
ExpectEqual(t, a.Port.ProxyPort, 5678)
}
func TestExplicitExclude(t *testing.T) {
var p DockerProvider
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a",
D.LabelExclude: "true",
"proxy.a.no_tls_verify": "true",
}}, ""))
ExpectNoError(t, err.Error())
_, ok := entries.Load("a")
ExpectFalse(t, ok)
}
func TestImplicitExclude(t *testing.T) {
var p DockerProvider
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a",
"proxy.a.no_tls_verify": "true",
},
State: "running",
}, ""))
ExpectNoError(t, err.Error())
_, ok := entries.Load("a")
ExpectFalse(t, ok)
}
func TestImplicitExcludeNoExposedPort(t *testing.T) {
var p DockerProvider
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Image: "redis",
Names: []string{"redis"},
Ports: []types.Port{
{Type: "tcp", PrivatePort: 6379, PublicPort: 0}, // not exposed
},
State: "running",
}, ""))
ExpectNoError(t, err.Error())
_, ok := entries.Load("redis")
ExpectFalse(t, ok)
}
func TestNotExcludeSpecifiedPort(t *testing.T) {
var p DockerProvider
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Image: "redis",
Names: []string{"redis"},
Ports: []types.Port{
{Type: "tcp", PrivatePort: 6379, PublicPort: 0}, // not exposed
},
Labels: map[string]string{
"proxy.redis.port": "6379:6379", // but specified in label
},
}, ""))
ExpectNoError(t, err.Error())
_, ok := entries.Load("redis")
ExpectTrue(t, ok)
}
func TestNotExcludeNonExposedPortHostNetwork(t *testing.T) {
var p DockerProvider
cont := &types.Container{
Image: "redis",
Names: []string{"redis"},
Ports: []types.Port{
{Type: "tcp", PrivatePort: 6379, PublicPort: 0}, // not exposed
},
Labels: map[string]string{
"proxy.redis.port": "6379:6379",
},
}
cont.HostConfig.NetworkMode = "host"
entries, err := p.entriesFromContainerLabels(D.FromDocker(cont, ""))
ExpectNoError(t, err.Error())
_, ok := entries.Load("redis")
ExpectTrue(t, ok)
}

View File

@@ -0,0 +1,98 @@
package provider
import (
"errors"
"os"
"path"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
M "github.com/yusing/go-proxy/internal/models"
R "github.com/yusing/go-proxy/internal/route"
U "github.com/yusing/go-proxy/internal/utils"
W "github.com/yusing/go-proxy/internal/watcher"
)
type FileProvider struct {
fileName string
path string
}
func FileProviderImpl(filename string) (ProviderImpl, E.NestedError) {
impl := &FileProvider{
fileName: filename,
path: path.Join(common.ConfigBasePath, filename),
}
_, err := os.Stat(impl.path)
switch {
case err == nil:
return impl, nil
case errors.Is(err, os.ErrNotExist):
return nil, E.NotExist("file", impl.path)
default:
return nil, E.UnexpectedError(err)
}
}
func Validate(data []byte) E.NestedError {
return U.ValidateYaml(U.GetSchema(common.FileProviderSchemaPath), data)
}
func (p FileProvider) String() string {
return p.fileName
}
func (p FileProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) {
b := E.NewBuilder("event %s error", event)
defer b.To(&res.err)
newRoutes, err := p.LoadRoutesImpl()
if err.HasError() {
b.Add(err)
return
}
routes.RangeAll(func(_ string, v R.Route) {
b.Add(v.Stop())
})
routes.Clear()
newRoutes.RangeAll(func(_ string, v R.Route) {
b.Add(v.Start())
})
routes.MergeFrom(newRoutes)
return
}
func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) {
routes = R.NewRoutes()
b := E.NewBuilder("file %q validation failure", p.fileName)
defer b.To(&res)
entries := M.NewProxyEntries()
data, err := E.Check(os.ReadFile(p.path))
if err.HasError() {
b.Add(E.FailWith("read file", err))
return
}
if !common.NoSchemaValidation {
if err = Validate(data); err.HasError() {
b.Add(err)
return
}
}
if err = entries.UnmarshalFromYAML(data); err.HasError() {
b.Add(err)
return
}
return R.FromEntries(entries)
}
func (p *FileProvider) NewWatcher() W.Watcher {
return W.NewConfigFileWatcher(p.fileName)
}

View File

@@ -0,0 +1,178 @@
package provider
import (
"context"
"path"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error"
R "github.com/yusing/go-proxy/internal/route"
W "github.com/yusing/go-proxy/internal/watcher"
)
type (
Provider struct {
ProviderImpl `json:"-"`
name string
t ProviderType
routes R.Routes
watcher W.Watcher
watcherCtx context.Context
watcherCancel context.CancelFunc
l *logrus.Entry
}
ProviderImpl interface {
NewWatcher() W.Watcher
// even returns error, routes must be non-nil
LoadRoutesImpl() (R.Routes, E.NestedError)
OnEvent(event W.Event, routes R.Routes) EventResult
String() string
}
ProviderType string
EventResult struct {
nRemoved int
nAdded int
err E.NestedError
}
)
const (
ProviderTypeDocker ProviderType = "docker"
ProviderTypeFile ProviderType = "file"
)
func newProvider(name string, t ProviderType) *Provider {
p := &Provider{
name: name,
t: t,
routes: R.NewRoutes(),
}
p.l = logrus.WithField("provider", p)
return p
}
func NewFileProvider(filename string) (p *Provider, err E.NestedError) {
name := path.Base(filename)
p = newProvider(name, ProviderTypeFile)
p.ProviderImpl, err = FileProviderImpl(filename)
if err != nil {
return nil, err
}
p.watcher = p.NewWatcher()
return
}
func NewDockerProvider(name string, dockerHost string) (p *Provider, err E.NestedError) {
p = newProvider(name, ProviderTypeDocker)
p.ProviderImpl, err = DockerProviderImpl(dockerHost)
if err != nil {
return nil, err
}
p.watcher = p.NewWatcher()
return
}
func (p *Provider) GetName() string {
return p.name
}
func (p *Provider) GetType() ProviderType {
return p.t
}
// to work with json marshaller
func (p *Provider) MarshalText() ([]byte, error) {
return []byte(p.String()), nil
}
func (p *Provider) StartAllRoutes() (res E.NestedError) {
errors := E.NewBuilder("errors in routes")
defer errors.To(&res)
// start watcher no matter load success or not
go p.watchEvents()
nStarted := 0
nFailed := 0
p.routes.RangeAll(func(alias string, r R.Route) {
if err := r.Start(); err.HasError() {
errors.Add(err.Subject(r))
nFailed++
} else {
nStarted++
}
})
p.l.Debugf("%d routes started, %d failed", nStarted, nFailed)
return
}
func (p *Provider) StopAllRoutes() (res E.NestedError) {
if p.watcherCancel != nil {
p.watcherCancel()
p.watcherCancel = nil
}
errors := E.NewBuilder("errors stopping routes for provider %q", p.name)
defer errors.To(&res)
nStopped := 0
nFailed := 0
p.routes.RangeAll(func(alias string, r R.Route) {
if err := r.Stop(); err.HasError() {
errors.Add(err.Subject(r))
nFailed++
} else {
nStopped++
}
})
p.l.Debugf("%d routes stopped, %d failed", nStopped, nFailed)
return
}
func (p *Provider) RangeRoutes(do func(string, R.Route)) {
p.routes.RangeAll(do)
}
func (p *Provider) GetRoute(alias string) (R.Route, bool) {
return p.routes.Load(alias)
}
func (p *Provider) LoadRoutes() E.NestedError {
var err E.NestedError
p.routes, err = p.LoadRoutesImpl()
if p.routes.Size() > 0 {
p.l.Infof("loaded %d routes", p.routes.Size())
return err
}
return E.FailWith("loading routes", err)
}
func (p *Provider) watchEvents() {
p.watcherCtx, p.watcherCancel = context.WithCancel(context.Background())
events, errs := p.watcher.Events(p.watcherCtx)
l := p.l.WithField("module", "watcher")
for {
select {
case <-p.watcherCtx.Done():
return
case event := <-events:
res := p.OnEvent(event, p.routes)
l.Infof("%s event %q", event.Type, event)
l.Infof("%d route added, %d routes removed", res.nAdded, res.nRemoved)
if res.err.HasError() {
l.Error(res.err)
}
case err := <-errs:
if err == nil || err.Is(context.Canceled) {
continue
}
l.Errorf("watcher error: %s", err)
}
}
}

View File

@@ -0,0 +1,8 @@
package route
import (
"time"
)
const udpBufferSize = 8192
const streamStopListenTimeout = 1 * time.Second

229
internal/route/http.go Executable file
View File

@@ -0,0 +1,229 @@
package route
import (
"encoding/json"
"fmt"
"slices"
"sync"
"net/http"
"net/url"
"strings"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/docker/idlewatcher"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/http"
P "github.com/yusing/go-proxy/internal/proxy"
PT "github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/route/middleware"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
HTTPRoute struct {
Alias PT.Alias `json:"alias"`
TargetURL *URL `json:"target_url"`
PathPatterns PT.PathPatterns `json:"path_patterns"`
entry *P.ReverseProxyEntry
mux *http.ServeMux
handler *ReverseProxy
regIdleWatcher func() E.NestedError
unregIdleWatcher func()
}
URL url.URL
SubdomainKey = PT.Alias
)
var (
findMuxFunc = findMuxAnyDomain
httpRoutes = F.NewMapOf[SubdomainKey, *HTTPRoute]()
httpRoutesMu sync.Mutex
globalMux = http.NewServeMux() // TODO: support regex subdomain matching
)
func SetFindMuxDomains(domains []string) {
if len(domains) == 0 {
findMuxFunc = findMuxAnyDomain
} else {
findMuxFunc = findMuxByDomains(domains)
}
}
func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
var trans *http.Transport
var regIdleWatcher func() E.NestedError
var unregIdleWatcher func()
if entry.NoTLSVerify {
trans = common.DefaultTransportNoTLS.Clone()
} else {
trans = common.DefaultTransport.Clone()
}
rp := NewReverseProxy(entry.URL, trans)
if len(entry.Middlewares) > 0 {
err := middleware.PatchReverseProxy(rp, entry.Middlewares)
if err != nil {
return nil, err
}
}
if entry.UseIdleWatcher() {
// allow time for response header up to `WakeTimeout`
if entry.WakeTimeout > trans.ResponseHeaderTimeout {
trans.ResponseHeaderTimeout = entry.WakeTimeout
}
regIdleWatcher = func() E.NestedError {
watcher, err := idlewatcher.Register(entry)
if err.HasError() {
return err
}
// patch round-tripper
rp.Transport = watcher.PatchRoundTripper(trans)
return nil
}
unregIdleWatcher = func() {
idlewatcher.Unregister(entry.ContainerName)
rp.Transport = trans
}
}
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
_, exists := httpRoutes.Load(entry.Alias)
if exists {
return nil, E.Duplicated("HTTPRoute alias", entry.Alias)
}
r := &HTTPRoute{
Alias: entry.Alias,
TargetURL: (*URL)(entry.URL),
PathPatterns: entry.PathPatterns,
entry: entry,
handler: rp,
regIdleWatcher: regIdleWatcher,
unregIdleWatcher: unregIdleWatcher,
}
return r, nil
}
func (r *HTTPRoute) String() string {
return string(r.Alias)
}
func (r *HTTPRoute) Start() E.NestedError {
if r.mux != nil {
return nil
}
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
if r.regIdleWatcher != nil {
if err := r.regIdleWatcher(); err.HasError() {
r.unregIdleWatcher = nil
return err
}
}
r.mux = http.NewServeMux()
for _, p := range r.PathPatterns {
r.mux.HandleFunc(string(p), r.handler.ServeHTTP)
}
httpRoutes.Store(r.Alias, r)
return nil
}
func (r *HTTPRoute) Stop() E.NestedError {
if r.mux == nil {
return nil
}
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
if r.unregIdleWatcher != nil {
r.unregIdleWatcher()
r.unregIdleWatcher = nil
}
r.mux = nil
httpRoutes.Delete(r.Alias)
return nil
}
func (u *URL) String() string {
return (*url.URL)(u).String()
}
func (u *URL) MarshalText() (text []byte, err error) {
return []byte(u.String()), nil
}
func ProxyHandler(w http.ResponseWriter, r *http.Request) {
mux, err := findMuxFunc(r.Host)
if err != nil {
logrus.Error(E.Failure("request").
Subjectf("%s %s", r.Method, r.URL.String()).
With(err))
acceptTypes := r.Header.Values("Accept")
switch {
case slices.Contains(acceptTypes, "text/html"):
case slices.Contains(acceptTypes, "application/json"):
w.WriteHeader(http.StatusNotFound)
json.NewEncoder(w).Encode(map[string]string{
"error": err.Error(),
})
default:
http.Error(w, err.Error(), http.StatusNotFound)
}
return
}
mux.ServeHTTP(w, r)
}
func findMuxAnyDomain(host string) (*http.ServeMux, error) {
hostSplit := strings.Split(host, ".")
n := len(hostSplit)
if n <= 2 {
return nil, fmt.Errorf("missing subdomain in url")
}
sd := strings.Join(hostSplit[:n-2], ".")
if r, ok := httpRoutes.Load(PT.Alias(sd)); ok {
return r.mux, nil
}
return nil, fmt.Errorf("no such route: %s", sd)
}
func findMuxByDomains(domains []string) func(host string) (*http.ServeMux, error) {
return func(host string) (*http.ServeMux, error) {
var subdomain string
for _, domain := range domains {
if !strings.HasPrefix(domain, ".") {
domain = "." + domain
}
subdomain = strings.TrimSuffix(host, domain)
if len(subdomain) < len(host) {
break
}
}
if len(subdomain) == len(host) { // not matched
return nil, fmt.Errorf("%s does not match any base domain", host)
}
if r, ok := httpRoutes.Load(PT.Alias(subdomain)); ok {
return r.mux, nil
}
return nil, fmt.Errorf("no such route: %s", subdomain)
}
}

View File

@@ -0,0 +1,70 @@
package middleware
import (
"bytes"
"fmt"
"io"
"net/http"
"path/filepath"
"strings"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/api/v1/error_page"
gpHTTP "github.com/yusing/go-proxy/internal/http"
)
const staticFilePathPrefix = "/$gperrorpage/"
var CustomErrorPage = &Middleware{
before: func(next http.Handler, w ResponseWriter, r *Request) {
path := r.URL.Path
if path != "" && path[0] != '/' {
path = "/" + path
}
if strings.HasPrefix(path, staticFilePathPrefix) {
filename := path[len(staticFilePathPrefix):]
file, ok := error_page.GetStaticFile(filename)
if !ok {
http.NotFound(w, r)
errPageLogger.Errorf("unable to load resource %s", filename)
} else {
ext := filepath.Ext(filename)
switch ext {
case ".html":
w.Header().Set("Content-Type", "text/html; charset=utf-8")
case ".js":
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
case ".css":
w.Header().Set("Content-Type", "text/css; charset=utf-8")
default:
errPageLogger.Errorf("unexpected file type %q for %s", ext, filename)
}
w.Write(file)
}
return
}
next.ServeHTTP(w, r)
},
modifyResponse: func(resp *Response) error {
// only handles non-success status code and html/plain content type
contentType := gpHTTP.GetContentType(resp.Header)
if !gpHTTP.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
errorPage, ok := error_page.GetErrorPageByStatus(resp.StatusCode)
if ok {
errPageLogger.Debugf("error page for status %d loaded", resp.StatusCode)
io.Copy(io.Discard, resp.Body) // drain the original body
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
resp.ContentLength = int64(len(errorPage))
resp.Header.Set("Content-Length", fmt.Sprint(len(errorPage)))
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
} else {
errPageLogger.Errorf("unable to load error page for status %d", resp.StatusCode)
}
return nil
}
return nil
},
}
var errPageLogger = logrus.WithField("middleware", "error_page")

View File

@@ -0,0 +1,249 @@
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/auth/forward.go)
// Copyright (c) 2020-2024 Traefik Labs
// Copyright (c) 2024 yusing
package middleware
import (
"io"
"net"
"net/http"
"net/url"
"slices"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
gpHTTP "github.com/yusing/go-proxy/internal/http"
U "github.com/yusing/go-proxy/internal/utils"
)
type (
forwardAuth struct {
*forwardAuthOpts
m *Middleware
client http.Client
}
forwardAuthOpts struct {
Address string
TrustForwardHeader bool
AuthResponseHeaders []string
AddAuthCookiesToResponse []string
}
)
const (
xForwardedFor = "X-Forwarded-For"
xForwardedMethod = "X-Forwarded-Method"
xForwardedHost = "X-Forwarded-Host"
xForwardedProto = "X-Forwarded-Proto"
xForwardedURI = "X-Forwarded-Uri"
xForwardedPort = "X-Forwarded-Port"
)
var ForwardAuth = newForwardAuth()
var faLogger = logrus.WithField("middleware", "ForwardAuth")
func newForwardAuth() (fa *forwardAuth) {
fa = new(forwardAuth)
fa.m = new(Middleware)
fa.m.labelParserMap = D.ValueParserMap{
"trust_forward_header": D.BoolParser,
"auth_response_headers": D.YamlStringListParser,
"add_auth_cookies_to_response": D.YamlStringListParser,
}
fa.m.withOptions = func(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
tr, ok := rp.Transport.(*http.Transport)
if ok {
tr = tr.Clone()
} else {
tr = common.DefaultTransport.Clone()
}
faWithOpts := new(forwardAuth)
faWithOpts.forwardAuthOpts = new(forwardAuthOpts)
faWithOpts.client = http.Client{
CheckRedirect: func(r *Request, via []*Request) error {
return http.ErrUseLastResponse
},
Timeout: 30 * time.Second,
Transport: tr,
}
faWithOpts.m = &Middleware{
impl: faWithOpts,
before: faWithOpts.forward,
}
err := U.Deserialize(optsRaw, faWithOpts.forwardAuthOpts)
if err != nil {
return nil, E.FailWith("set options", err)
}
_, err = E.Check(url.Parse(faWithOpts.Address))
if err != nil {
return nil, E.Invalid("address", faWithOpts.Address)
}
return faWithOpts.m, nil
}
return
}
func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request) {
gpHTTP.RemoveHop(req.Header)
faReq, err := http.NewRequestWithContext(
req.Context(),
http.MethodGet,
fa.Address,
nil,
)
if err != nil {
faLogger.Debugf("new request err to %s: %s", fa.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
gpHTTP.CopyHeader(faReq.Header, req.Header)
gpHTTP.RemoveHop(faReq.Header)
gpHTTP.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
fa.setAuthHeaders(req, faReq)
faResp, err := fa.client.Do(faReq)
if err != nil {
faLogger.Debugf("failed to call %s: %s", fa.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
defer faResp.Body.Close()
body, err := io.ReadAll(faResp.Body)
if err != nil {
faLogger.Debugf("failed to read response body from %s: %s", fa.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
gpHTTP.CopyHeader(w.Header(), faResp.Header)
gpHTTP.RemoveHop(w.Header())
redirectURL, err := faResp.Location()
if err != nil {
faLogger.Debugf("failed to get location from %s: %s", fa.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
} else if redirectURL.String() != "" {
w.Header().Set("Location", redirectURL.String())
}
w.WriteHeader(faResp.StatusCode)
if _, err = w.Write(body); err != nil {
faLogger.Debugf("failed to write response body from %s: %s", fa.Address, err)
}
return
}
for _, key := range fa.AuthResponseHeaders {
key := http.CanonicalHeaderKey(key)
req.Header.Del(key)
if len(faResp.Header[key]) > 0 {
req.Header[key] = append([]string(nil), faResp.Header[key]...)
}
}
req.RequestURI = req.URL.RequestURI()
authCookies := faResp.Cookies()
if len(authCookies) == 0 {
next.ServeHTTP(w, req)
return
}
next.ServeHTTP(gpHTTP.NewModifyResponseWriter(w, req, func(resp *Response) error {
fa.setAuthCookies(resp, authCookies)
return nil
}), req)
}
func (fa *forwardAuth) setAuthCookies(resp *Response, authCookies []*Cookie) {
if len(fa.AddAuthCookiesToResponse) == 0 {
return
}
cookies := resp.Cookies()
resp.Header.Del("Set-Cookie")
for _, cookie := range cookies {
if !slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
// this cookie is not an auth cookie, so add it back
resp.Header.Add("Set-Cookie", cookie.String())
}
}
for _, cookie := range authCookies {
if slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
// this cookie is an auth cookie, so add to resp
resp.Header.Add("Set-Cookie", cookie.String())
}
}
}
func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if fa.TrustForwardHeader {
if prior, ok := req.Header[xForwardedFor]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
}
faReq.Header.Set(xForwardedFor, clientIP)
}
xMethod := req.Header.Get(xForwardedMethod)
switch {
case xMethod != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedMethod, xMethod)
case req.Method != "":
faReq.Header.Set(xForwardedMethod, req.Method)
default:
faReq.Header.Del(xForwardedMethod)
}
xfp := req.Header.Get(xForwardedProto)
switch {
case xfp != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedProto, xfp)
case req.TLS != nil:
faReq.Header.Set(xForwardedProto, "https")
default:
faReq.Header.Set(xForwardedProto, "http")
}
if xfp := req.Header.Get(xForwardedPort); xfp != "" && fa.TrustForwardHeader {
faReq.Header.Set(xForwardedPort, xfp)
}
xfh := req.Header.Get(xForwardedHost)
switch {
case xfh != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedHost, xfh)
case req.Host != "":
faReq.Header.Set(xForwardedHost, req.Host)
default:
faReq.Header.Del(xForwardedHost)
}
xfURI := req.Header.Get(xForwardedURI)
switch {
case xfURI != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedURI, xfURI)
case req.URL.RequestURI() != "":
faReq.Header.Set(xForwardedURI, req.URL.RequestURI())
default:
faReq.Header.Del(xForwardedURI)
}
}

View File

@@ -0,0 +1,145 @@
package middleware
import (
"net/http"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
gpHTTP "github.com/yusing/go-proxy/internal/http"
)
type (
Error = E.NestedError
ReverseProxy = gpHTTP.ReverseProxy
ProxyRequest = gpHTTP.ProxyRequest
Request = http.Request
Response = http.Response
ResponseWriter = http.ResponseWriter
Header = http.Header
Cookie = http.Cookie
BeforeFunc func(next http.Handler, w ResponseWriter, r *Request)
RewriteFunc func(req *ProxyRequest)
ModifyResponseFunc func(resp *Response) error
CloneWithOptFunc func(opts OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError)
OptionsRaw = map[string]any
Options any
Middleware struct {
name string
before BeforeFunc // runs before ReverseProxy.ServeHTTP
rewrite RewriteFunc // runs after ReverseProxy.Rewrite
modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse
transport http.RoundTripper
withOptions CloneWithOptFunc
labelParserMap D.ValueParserMap
impl any
}
)
func (m *Middleware) Name() string {
return m.name
}
func (m *Middleware) String() string {
return m.name
}
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
if len(optsRaw) != 0 && m.withOptions != nil {
if mWithOpt, err := m.withOptions(optsRaw, rp); err != nil {
return nil, err
} else {
return mWithOpt, nil
}
}
// WithOptionsClone is called only once
// set withOptions and labelParser will not be used after that
return &Middleware{m.name, m.before, m.rewrite, m.modifyResponse, m.transport, nil, nil, m.impl}, nil
}
// TODO: check conflict or duplicates
func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res E.NestedError) {
befores := make([]BeforeFunc, 0, len(middlewares))
rewrites := make([]RewriteFunc, 0, len(middlewares))
modResps := make([]ModifyResponseFunc, 0, len(middlewares))
invalidM := E.NewBuilder("invalid middlewares")
invalidOpts := E.NewBuilder("invalid options")
defer func() {
invalidM.Add(invalidOpts.Build())
invalidM.To(&res)
}()
for name, opts := range middlewares {
m, ok := Get(name)
if !ok {
invalidM.Addf("%s", name)
continue
}
m, err := m.WithOptionsClone(opts, rp)
if err != nil {
invalidOpts.Add(err.Subject(name))
continue
}
if m.before != nil {
befores = append(befores, m.before)
}
if m.rewrite != nil {
rewrites = append(rewrites, m.rewrite)
}
if m.modifyResponse != nil {
modResps = append(modResps, m.modifyResponse)
}
}
if invalidM.HasError() {
return
}
origServeHTTP := rp.ServeHTTP
for i, before := range befores {
if i < len(befores)-1 {
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
before(rp.ServeHTTP, w, r)
}
} else {
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
before(origServeHTTP, w, r)
}
}
}
if len(rewrites) > 0 {
if rp.Rewrite != nil {
rewrites = append([]RewriteFunc{rp.Rewrite}, rewrites...)
}
rp.Rewrite = func(req *ProxyRequest) {
for _, rewrite := range rewrites {
rewrite(req)
}
}
}
if len(modResps) > 0 {
if rp.ModifyResponse != nil {
modResps = append([]ModifyResponseFunc{rp.ModifyResponse}, modResps...)
}
rp.ModifyResponse = func(res *Response) error {
b := E.NewBuilder("errors in middleware ModifyResponse")
for _, mr := range modResps {
b.AddE(mr(res))
}
return b.Build().Error()
}
}
return
}

View File

@@ -0,0 +1,45 @@
package middleware
import (
"fmt"
"strings"
D "github.com/yusing/go-proxy/internal/docker"
)
var middlewares map[string]*Middleware
func Get(name string) (middleware *Middleware, ok bool) {
middleware, ok = middlewares[name]
return
}
// initialize middleware names and label parsers
func init() {
middlewares = map[string]*Middleware{
"set_x_forwarded": SetXForwarded,
"add_x_forwarded": AddXForwarded,
"redirect_http": RedirectHTTP,
"forward_auth": ForwardAuth.m,
"modify_response": ModifyResponse.m,
"modify_request": ModifyRequest.m,
"error_page": CustomErrorPage,
"custom_error_page": CustomErrorPage,
}
names := make(map[*Middleware][]string)
for name, m := range middlewares {
names[m] = append(names[m], name)
// register middleware name to docker label parsr
// in order to parse middleware_name.option=value into correct type
if m.labelParserMap != nil {
D.RegisterNamespace(name, m.labelParserMap)
}
}
for m, names := range names {
if len(names) > 1 {
m.name = fmt.Sprintf("%s (a.k.a. %s)", names[0], strings.Join(names[1:], ", "))
} else {
m.name = names[0]
}
}
}

View File

@@ -0,0 +1,58 @@
package middleware
import (
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
U "github.com/yusing/go-proxy/internal/utils"
)
type (
modifyRequest struct {
*modifyRequestOpts
m *Middleware
}
// order: set_headers -> add_headers -> hide_headers
modifyRequestOpts struct {
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
}
)
var ModifyRequest = newModifyRequest()
func newModifyRequest() (mr *modifyRequest) {
mr = new(modifyRequest)
mr.m = new(Middleware)
mr.m.labelParserMap = D.ValueParserMap{
"set_headers": D.YamlLikeMappingParser(true),
"add_headers": D.YamlLikeMappingParser(true),
"hide_headers": D.YamlStringListParser,
}
mr.m.withOptions = func(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
mrWithOpts := new(modifyRequest)
mrWithOpts.m = &Middleware{
impl: mrWithOpts,
rewrite: mrWithOpts.modifyRequest,
}
mrWithOpts.modifyRequestOpts = new(modifyRequestOpts)
err := U.Deserialize(optsRaw, mrWithOpts.modifyRequestOpts)
if err != nil {
return nil, E.FailWith("set options", err)
}
return mrWithOpts.m, nil
}
return
}
func (mr *modifyRequest) modifyRequest(req *ProxyRequest) {
for k, v := range mr.SetHeaders {
req.Out.Header.Set(k, v)
}
for k, v := range mr.AddHeaders {
req.Out.Header.Add(k, v)
}
for _, k := range mr.HideHeaders {
req.Out.Header.Del(k)
}
}

View File

@@ -0,0 +1,34 @@
package middleware
import (
"slices"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetModifyRequest(t *testing.T) {
opts := OptionsRaw{
"set_headers": map[string]string{"User-Agent": "go-proxy/v0.5.0"},
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
"hide_headers": []string{"Accept"},
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.m.WithOptionsClone(opts, nil)
ExpectNoError(t, err.Error())
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string))
})
t.Run("request_headers", func(t *testing.T) {
result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{
middlewareOpt: opts,
})
ExpectNoError(t, err.Error())
ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
})
}

View File

@@ -0,0 +1,61 @@
package middleware
import (
"net/http"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
U "github.com/yusing/go-proxy/internal/utils"
)
type (
modifyResponse struct {
*modifyResponseOpts
m *Middleware
}
// order: set_headers -> add_headers -> hide_headers
modifyResponseOpts struct {
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
}
)
var ModifyResponse = newModifyResponse()
func newModifyResponse() (mr *modifyResponse) {
mr = new(modifyResponse)
mr.m = new(Middleware)
mr.m.labelParserMap = D.ValueParserMap{
"set_headers": D.YamlLikeMappingParser(true),
"add_headers": D.YamlLikeMappingParser(true),
"hide_headers": D.YamlStringListParser,
}
mr.m.withOptions = func(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
mrWithOpts := new(modifyResponse)
mrWithOpts.m = &Middleware{
impl: mrWithOpts,
modifyResponse: mrWithOpts.modifyResponse,
}
mrWithOpts.modifyResponseOpts = new(modifyResponseOpts)
err := U.Deserialize(optsRaw, mrWithOpts.modifyResponseOpts)
if err != nil {
return nil, E.FailWith("set options", err)
}
return mrWithOpts.m, nil
}
return
}
func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
for k, v := range mr.SetHeaders {
resp.Header.Set(k, v)
}
for k, v := range mr.AddHeaders {
resp.Header.Add(k, v)
}
for _, k := range mr.HideHeaders {
resp.Header.Del(k)
}
return nil
}

View File

@@ -0,0 +1,35 @@
package middleware
import (
"slices"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetModifyResponse(t *testing.T) {
opts := OptionsRaw{
"set_headers": map[string]string{"User-Agent": "go-proxy/v0.5.0"},
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
"hide_headers": []string{"Accept"},
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.m.WithOptionsClone(opts, nil)
ExpectNoError(t, err.Error())
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
})
t.Run("request_headers", func(t *testing.T) {
result, err := newMiddlewareTest(ModifyResponse.m, &testArgs{
middlewareOpt: opts,
})
ExpectNoError(t, err.Error())
ExpectEqual(t, result.ResponseHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
t.Log(result.ResponseHeaders.Get("Accept-Encoding"))
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "")
})
}

View File

@@ -0,0 +1,19 @@
package middleware
import (
"net/http"
"github.com/yusing/go-proxy/internal/common"
)
var RedirectHTTP = &Middleware{
before: func(next http.Handler, w ResponseWriter, r *Request) {
if r.TLS == nil {
r.URL.Scheme = "https"
r.URL.Host = r.URL.Hostname() + ":" + common.ProxyHTTPSPort
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
return
}
next.ServeHTTP(w, r)
},
}

View File

@@ -0,0 +1,26 @@
package middleware
import (
"net/http"
"testing"
"github.com/yusing/go-proxy/internal/common"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRedirectToHTTPs(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
scheme: "http",
})
ExpectNoError(t, err.Error())
ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect)
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://"+testHost+":"+common.ProxyHTTPSPort)
}
func TestNoRedirect(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
scheme: "https",
})
ExpectNoError(t, err.Error())
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
}

View File

@@ -0,0 +1,17 @@
{
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"Accept-Encoding": "gzip, deflate, br, zstd",
"Accept-Language": "en,zh-HK;q=0.9,zh-TW;q=0.8,zh-CN;q=0.7,zh;q=0.6",
"Dnt": "1",
"Host": "localhost",
"Priority": "u=0, i",
"Sec-Ch-Ua": "\"Chromium\";v=\"129\", \"Not=A?Brand\";v=\"8\"",
"Sec-Ch-Ua-Mobile": "?0",
"Sec-Ch-Ua-Platform": "\"Windows\"",
"Sec-Fetch-Dest": "document",
"Sec-Fetch-Mode": "navigate",
"Sec-Fetch-Site": "none",
"Sec-Fetch-User": "?1",
"Upgrade-Insecure-Requests": "1",
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36"
}

View File

@@ -0,0 +1,125 @@
package middleware
import (
"bytes"
_ "embed"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
E "github.com/yusing/go-proxy/internal/error"
gpHTTP "github.com/yusing/go-proxy/internal/http"
)
//go:embed test_data/sample_headers.json
var testHeadersRaw []byte
var testHeaders http.Header
const testHost = "example.com"
func init() {
tmp := map[string]string{}
err := json.Unmarshal(testHeadersRaw, &tmp)
if err != nil {
panic(err)
}
testHeaders = http.Header{}
for k, v := range tmp {
testHeaders.Set(k, v)
}
}
type requestHeaderRecorder struct {
parent http.RoundTripper
reqHeaders http.Header
}
func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
rt.reqHeaders = req.Header
if rt.parent != nil {
return rt.parent.RoundTrip(req)
}
return &http.Response{
StatusCode: http.StatusOK,
Header: testHeaders,
Body: io.NopCloser(bytes.NewBufferString("OK")),
Request: req,
}, nil
}
type TestResult struct {
RequestHeaders http.Header
ResponseHeaders http.Header
ResponseStatus int
Data []byte
}
type testArgs struct {
middlewareOpt OptionsRaw
proxyURL string
body []byte
scheme string
}
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) {
var body io.Reader
var rt = new(requestHeaderRecorder)
var proxyURL *url.URL
var requestTarget string
var err error
if args == nil {
args = new(testArgs)
}
if args.body != nil {
body = bytes.NewReader(args.body)
}
if args.scheme == "" || args.scheme == "http" {
requestTarget = "http://" + testHost
} else if args.scheme == "https" {
requestTarget = "https://" + testHost
} else {
panic("typo?")
}
req := httptest.NewRequest(http.MethodGet, requestTarget, body)
w := httptest.NewRecorder()
if args.scheme == "https" && req.TLS == nil {
panic("bug occurred")
}
if args.proxyURL != "" {
proxyURL, err = url.Parse(args.proxyURL)
if err != nil {
return nil, E.From(err)
}
rt.parent = http.DefaultTransport
} else {
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
}
rp := gpHTTP.NewReverseProxy(proxyURL, rt)
setOptErr := PatchReverseProxy(rp, map[string]OptionsRaw{
middleware.name: args.middlewareOpt,
})
if setOptErr != nil {
return nil, setOptErr
}
rp.ServeHTTP(w, req)
resp := w.Result()
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, E.From(err)
}
return &TestResult{
RequestHeaders: rt.reqHeaders,
ResponseHeaders: resp.Header,
ResponseStatus: resp.StatusCode,
Data: data,
}, nil
}

View File

@@ -0,0 +1,9 @@
package middleware
var AddXForwarded = &Middleware{
rewrite: (*ProxyRequest).AddXForwarded,
}
var SetXForwarded = &Middleware{
rewrite: (*ProxyRequest).SetXForwarded,
}

95
internal/route/route.go Executable file
View File

@@ -0,0 +1,95 @@
package route
import (
"fmt"
"net/url"
E "github.com/yusing/go-proxy/internal/error"
M "github.com/yusing/go-proxy/internal/models"
P "github.com/yusing/go-proxy/internal/proxy"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
Route interface {
RouteImpl
Entry() *M.RawEntry
Type() RouteType
URL() *url.URL
}
Routes = F.Map[string, Route]
RouteImpl interface {
Start() E.NestedError
Stop() E.NestedError
String() string
}
RouteType string
route struct {
RouteImpl
type_ RouteType
entry *M.RawEntry
}
)
const (
RouteTypeStream RouteType = "stream"
RouteTypeReverseProxy RouteType = "reverse_proxy"
)
// function alias
var NewRoutes = F.NewMapOf[string, Route]
func NewRoute(en *M.RawEntry) (Route, E.NestedError) {
rt, err := P.ValidateEntry(en)
if err != nil {
return nil, err
}
var t RouteType
switch e := rt.(type) {
case *P.StreamEntry:
rt, err = NewStreamRoute(e)
t = RouteTypeStream
case *P.ReverseProxyEntry:
rt, err = NewHTTPRoute(e)
t = RouteTypeReverseProxy
default:
panic("bug: should not reach here")
}
if err != nil {
return nil, err
}
return &route{RouteImpl: rt.(RouteImpl), entry: en, type_: t}, nil
}
func (rt *route) Entry() *M.RawEntry {
return rt.entry
}
func (rt *route) Type() RouteType {
return rt.type_
}
func (rt *route) URL() *url.URL {
url, _ := url.Parse(fmt.Sprintf("%s://%s", rt.entry.Scheme, rt.entry.Host))
return url
}
func FromEntries(entries M.RawEntries) (Routes, E.NestedError) {
b := E.NewBuilder("errors in routes")
routes := NewRoutes()
entries.RangeAll(func(alias string, entry *M.RawEntry) {
entry.Alias = alias
r, err := NewRoute(entry)
if err.HasError() {
b.Add(err.Subject(alias))
} else {
routes.Store(alias, r)
}
})
return routes, b.Build()
}

141
internal/route/stream.go Executable file
View File

@@ -0,0 +1,141 @@
package route
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error"
P "github.com/yusing/go-proxy/internal/proxy"
)
type StreamRoute struct {
*P.StreamEntry
StreamImpl `json:"-"`
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
connCh chan any
started atomic.Bool
l logrus.FieldLogger
}
type StreamImpl interface {
Setup() error
Accept() (any, error)
Handle(any) error
CloseListeners()
String() string
}
func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) {
// TODO: support non-coherent scheme
if !entry.Scheme.IsCoherent() {
return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme))
}
base := &StreamRoute{
StreamEntry: entry,
connCh: make(chan any, 100),
}
if entry.Scheme.ListeningScheme.IsTCP() {
base.StreamImpl = NewTCPRoute(base)
} else {
base.StreamImpl = NewUDPRoute(base)
}
base.l = logrus.WithField("route", base.StreamImpl)
return base, nil
}
func (r *StreamRoute) String() string {
return fmt.Sprintf("%s stream: %s", r.Scheme, r.Alias)
}
func (r *StreamRoute) Start() E.NestedError {
if r.started.Load() {
return nil
}
r.ctx, r.cancel = context.WithCancel(context.Background())
r.wg.Wait()
if err := r.Setup(); err != nil {
return E.FailWith("setup", err)
}
r.l.Infof("listening on port %d", r.Port.ListeningPort)
r.started.Store(true)
r.wg.Add(2)
go r.grAcceptConnections()
go r.grHandleConnections()
return nil
}
func (r *StreamRoute) Stop() E.NestedError {
if !r.started.Load() {
return nil
}
l := r.l
r.cancel()
r.CloseListeners()
done := make(chan struct{}, 1)
go func() {
r.wg.Wait()
close(done)
}()
timeout := time.After(streamStopListenTimeout)
for {
select {
case <-done:
l.Debug("stopped listening")
return nil
case <-timeout:
return E.FailedWhy("stop", "timed out")
}
}
}
func (r *StreamRoute) grAcceptConnections() {
defer r.wg.Done()
for {
select {
case <-r.ctx.Done():
return
default:
conn, err := r.Accept()
if err != nil {
select {
case <-r.ctx.Done():
return
default:
r.l.Error(err)
continue
}
}
r.connCh <- conn
}
}
}
func (r *StreamRoute) grHandleConnections() {
defer r.wg.Done()
for {
select {
case <-r.ctx.Done():
return
case conn := <-r.connCh:
go func() {
err := r.Handle(conn)
if err != nil && !errors.Is(err, context.Canceled) {
r.l.Error(err)
}
}()
}
}
}

80
internal/route/tcp.go Executable file
View File

@@ -0,0 +1,80 @@
package route
import (
"context"
"fmt"
"net"
"sync"
"time"
T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils"
)
const tcpDialTimeout = 5 * time.Second
type (
Pipes []U.BidirectionalPipe
TCPRoute struct {
*StreamRoute
listener net.Listener
pipe Pipes
mu sync.Mutex
}
)
func NewTCPRoute(base *StreamRoute) StreamImpl {
return &TCPRoute{
StreamRoute: base,
pipe: make(Pipes, 0),
}
}
func (route *TCPRoute) Setup() error {
in, err := net.Listen("tcp", fmt.Sprintf(":%v", route.Port.ListeningPort))
if err != nil {
return err
}
//! this read the allocated port from orginal ':0'
route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port)
route.listener = in
return nil
}
func (route *TCPRoute) Accept() (any, error) {
return route.listener.Accept()
}
func (route *TCPRoute) Handle(c any) error {
clientConn := c.(net.Conn)
defer clientConn.Close()
ctx, cancel := context.WithTimeout(route.ctx, tcpDialTimeout)
defer cancel()
serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
dialer := &net.Dialer{}
serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr)
if err != nil {
return err
}
route.mu.Lock()
pipe := U.NewBidirectionalPipe(route.ctx, clientConn, serverConn)
route.pipe = append(route.pipe, pipe)
route.mu.Unlock()
return pipe.Start()
}
func (route *TCPRoute) CloseListeners() {
if route.listener == nil {
return
}
route.listener.Close()
route.listener = nil
}

129
internal/route/udp.go Executable file
View File

@@ -0,0 +1,129 @@
package route
import (
"fmt"
"io"
"net"
T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
UDPRoute struct {
*StreamRoute
connMap UDPConnMap
listeningConn *net.UDPConn
targetAddr *net.UDPAddr
}
UDPConn struct {
src *net.UDPConn
dst *net.UDPConn
U.BidirectionalPipe
}
UDPConnMap = F.Map[string, *UDPConn]
)
var NewUDPConnMap = F.NewMapOf[string, *UDPConn]
func NewUDPRoute(base *StreamRoute) StreamImpl {
return &UDPRoute{
StreamRoute: base,
connMap: NewUDPConnMap(),
}
}
func (route *UDPRoute) Setup() error {
laddr, err := net.ResolveUDPAddr(string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ListeningPort))
if err != nil {
return err
}
source, err := net.ListenUDP(string(route.Scheme.ListeningScheme), laddr)
if err != nil {
return err
}
raddr, err := net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
if err != nil {
source.Close()
return err
}
//! this read the allocated listeningPort from orginal ':0'
route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port)
route.listeningConn = source
route.targetAddr = raddr
return nil
}
func (route *UDPRoute) Accept() (any, error) {
in := route.listeningConn
buffer := make([]byte, udpBufferSize)
nRead, srcAddr, err := in.ReadFromUDP(buffer)
if err != nil {
return nil, err
}
if nRead == 0 {
return nil, io.ErrShortBuffer
}
key := srcAddr.String()
conn, ok := route.connMap.Load(key)
if !ok {
srcConn, err := net.DialUDP("udp", nil, srcAddr)
if err != nil {
return nil, err
}
dstConn, err := net.DialUDP("udp", nil, route.targetAddr)
if err != nil {
srcConn.Close()
return nil, err
}
conn = &UDPConn{
srcConn,
dstConn,
U.NewBidirectionalPipe(route.ctx, sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}),
}
route.connMap.Store(key, conn)
}
_, err = conn.dst.Write(buffer[:nRead])
return conn, err
}
func (route *UDPRoute) Handle(c any) error {
return c.(*UDPConn).Start()
}
func (route *UDPRoute) CloseListeners() {
if route.listeningConn != nil {
route.listeningConn.Close()
route.listeningConn = nil
}
route.connMap.RangeAll(func(_ string, conn *UDPConn) {
if err := conn.src.Close(); err != nil {
route.l.Errorf("error closing src conn: %s", err)
}
if err := conn.dst.Close(); err != nil {
route.l.Error("error closing dst conn: %s", err)
}
})
route.connMap.Clear()
}
type sourceRWCloser struct {
server *net.UDPConn
*net.UDPConn
}
func (w sourceRWCloser) Write(p []byte) (int, error) {
return w.server.WriteToUDP(p, w.RemoteAddr().(*net.UDPAddr)) // TODO: support non udp
}

View File

@@ -0,0 +1,25 @@
package server
var proxyServer, apiServer *server
func InitProxyServer(opt Options) *server {
if proxyServer == nil {
proxyServer = NewServer(opt)
}
return proxyServer
}
func InitAPIServer(opt Options) *server {
if apiServer == nil {
apiServer = NewServer(opt)
}
return apiServer
}
func GetProxyServer() *server {
return proxyServer
}
func GetAPIServer() *server {
return apiServer
}

164
internal/server/server.go Normal file
View File

@@ -0,0 +1,164 @@
package server
import (
"crypto/tls"
"log"
"net/http"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/autocert"
"golang.org/x/net/context"
)
type server struct {
Name string
CertProvider *autocert.Provider
http *http.Server
https *http.Server
httpStarted bool
httpsStarted bool
startTime time.Time
}
type Options struct {
Name string
HTTPAddr string
HTTPSAddr string
CertProvider *autocert.Provider
RedirectToHTTPS bool
Handler http.Handler
}
type LogrusWrapper struct {
*logrus.Entry
}
func (l LogrusWrapper) Write(b []byte) (int, error) {
return l.Logger.WriterLevel(logrus.ErrorLevel).Write(b)
}
func NewServer(opt Options) (s *server) {
var httpSer, httpsSer *http.Server
var httpHandler http.Handler
logger := log.Default()
logger.SetOutput(LogrusWrapper{
logrus.WithFields(logrus.Fields{"?": "server", "name": opt.Name}),
})
certAvailable := false
if opt.CertProvider != nil {
_, err := opt.CertProvider.GetCert(nil)
certAvailable = err == nil
}
if certAvailable && opt.RedirectToHTTPS && opt.HTTPSAddr != "" {
httpHandler = redirectToTLSHandler(opt.HTTPSAddr)
} else {
httpHandler = opt.Handler
}
if opt.HTTPAddr != "" {
httpSer = &http.Server{
Addr: opt.HTTPAddr,
Handler: httpHandler,
ErrorLog: logger,
}
}
if certAvailable && opt.HTTPSAddr != "" {
httpsSer = &http.Server{
Addr: opt.HTTPSAddr,
Handler: opt.Handler,
ErrorLog: logger,
TLSConfig: &tls.Config{
GetCertificate: opt.CertProvider.GetCert,
},
}
}
return &server{
Name: opt.Name,
CertProvider: opt.CertProvider,
http: httpSer,
https: httpsSer,
}
}
// Start will start the http and https servers.
//
// If both are not set, this does nothing.
//
// Start() is non-blocking
func (s *server) Start() {
if s.http == nil && s.https == nil {
return
}
s.startTime = time.Now()
if s.http != nil {
s.httpStarted = true
logrus.Printf("starting http %s server on %s", s.Name, s.http.Addr)
go func() {
s.handleErr("http", s.http.ListenAndServe())
}()
}
if s.https != nil {
s.httpsStarted = true
logrus.Printf("starting https %s server on %s", s.Name, s.https.Addr)
go func() {
s.handleErr("https", s.https.ListenAndServeTLS(s.CertProvider.GetCertPath(), s.CertProvider.GetKeyPath()))
}()
}
}
func (s *server) Stop() {
if s.http == nil && s.https == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if s.http != nil && s.httpStarted {
s.handleErr("http", s.http.Shutdown(ctx))
s.httpStarted = false
logger.Debugf("HTTP server %q stopped", s.Name)
}
if s.https != nil && s.httpsStarted {
s.handleErr("https", s.https.Shutdown(ctx))
s.httpsStarted = false
logger.Debugf("HTTPS server %q stopped", s.Name)
}
}
func (s *server) Uptime() time.Duration {
return time.Since(s.startTime)
}
func (s *server) handleErr(scheme string, err error) {
switch err {
case nil, http.ErrServerClosed:
return
default:
logrus.Fatalf("failed to start %s %s server: %s", scheme, s.Name, err)
}
}
func redirectToTLSHandler(port string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
r.URL.Scheme = "https"
r.URL.Host = r.URL.Hostname() + port
var redirectCode int
if r.Method == http.MethodGet {
redirectCode = http.StatusMovedPermanently
} else {
redirectCode = http.StatusPermanentRedirect
}
http.Redirect(w, r, r.URL.String(), redirectCode)
}
}
var logger = logrus.WithField("module", "server")

115
internal/setup.go Normal file
View File

@@ -0,0 +1,115 @@
package internal
import (
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"path"
. "github.com/yusing/go-proxy/internal/common"
)
var branch = GetEnv("GOPROXY_BRANCH", "v0.5")
var baseUrl = fmt.Sprintf("https://github.com/yusing/go-proxy/raw/%s", branch)
var requiredConfigs = []Config{
{ConfigBasePath, true, false, ""},
{ComposeFileName, false, true, ComposeExampleFileName},
{path.Join(ConfigBasePath, ConfigFileName), false, true, ConfigExampleFileName},
}
type Config struct {
Pathname string
IsDir bool
NeedDownload bool
DownloadFileName string
}
func Setup() {
log.Println("setting up go-proxy")
log.Println("branch:", branch)
os.Chdir("/setup")
for _, config := range requiredConfigs {
config.setup()
}
log.Println("done")
}
func (c *Config) setup() {
if c.IsDir {
mkdir(c.Pathname)
return
}
if !c.NeedDownload {
touch(c.Pathname)
return
}
fetch(c.DownloadFileName, c.Pathname)
}
func hasFileOrDir(path string) bool {
_, err := os.Stat(path)
return err == nil
}
func mkdir(pathname string) {
_, err := os.Stat(pathname)
if err != nil && os.IsNotExist(err) {
log.Printf("creating directory %q\n", pathname)
err := os.MkdirAll(pathname, 0o755)
if err != nil {
log.Fatalf("failed: %s\n", err)
}
return
}
if err != nil {
log.Fatalf("failed: %s\n", err)
}
}
func touch(pathname string) {
if hasFileOrDir(pathname) {
return
}
log.Printf("creating file %q\n", pathname)
_, err := os.Create(pathname)
if err != nil {
log.Fatalf("failed: %s\n", err)
}
}
func fetch(remoteFilename string, outFileName string) {
if hasFileOrDir(outFileName) {
return
}
log.Printf("downloading %q\n", remoteFilename)
url, err := url.JoinPath(baseUrl, remoteFilename)
if err != nil {
log.Fatalf("unexpected error: %s\n", err)
}
resp, err := http.Get(url)
if err != nil {
log.Fatalf("http request failed: %s\n", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Fatalf("error reading response body: %s\n", err)
}
err = os.WriteFile(outFileName, body, 0o644)
if err != nil {
log.Fatalf("failed to write to file: %s\n", err)
}
log.Printf("downloaded %q\n", outFileName)
}

59
internal/utils/format.go Normal file
View File

@@ -0,0 +1,59 @@
package utils
import (
"fmt"
"strings"
"time"
)
func FormatDuration(d time.Duration) string {
// Get total seconds from duration
totalSeconds := int64(d.Seconds())
// Calculate days, hours, minutes, and seconds
days := totalSeconds / (24 * 3600)
hours := (totalSeconds % (24 * 3600)) / 3600
minutes := (totalSeconds % 3600) / 60
seconds := totalSeconds % 60
// Create a slice to hold parts of the duration
var parts []string
if days > 0 {
parts = append(parts, fmt.Sprintf("%d day%s", days, pluralize(days)))
}
if hours > 0 {
parts = append(parts, fmt.Sprintf("%d hour%s", hours, pluralize(hours)))
}
if minutes > 0 {
parts = append(parts, fmt.Sprintf("%d minute%s", minutes, pluralize(minutes)))
}
if seconds > 0 {
parts = append(parts, fmt.Sprintf("%d second%s", seconds, pluralize(seconds)))
}
// Join the parts with appropriate connectors
if len(parts) == 0 {
return "0 Seconds"
}
if len(parts) == 1 {
return parts[0]
}
return strings.Join(parts[:len(parts)-1], ", ") + " and " + parts[len(parts)-1]
}
func ParseBool(s string) bool {
switch strings.ToLower(s) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
func pluralize(n int64) string {
if n > 1 {
return "s"
}
return ""
}

Some files were not shown because too many files have changed in this diff Show More