mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-18 06:29:42 +02:00
Dev
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"crypto/x509"
|
||||
"os"
|
||||
"path"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
"github.com/go-acme/lego/v4/lego"
|
||||
"github.com/go-acme/lego/v4/providers/dns/clouddns"
|
||||
"github.com/go-acme/lego/v4/providers/dns/cloudflare"
|
||||
"github.com/go-acme/lego/v4/providers/dns/duckdns"
|
||||
"github.com/go-acme/lego/v4/registration"
|
||||
)
|
||||
|
||||
@@ -54,7 +56,7 @@ type AutoCertProvider interface {
|
||||
GetExpiries() CertExpiries
|
||||
LoadCert() bool
|
||||
ObtainCert() NestedErrorLike
|
||||
RenewalOn() time.Time
|
||||
ShouldRenewOn() time.Time
|
||||
ScheduleRenewal()
|
||||
}
|
||||
|
||||
@@ -72,7 +74,7 @@ func (cfg AutoCertConfig) GetProvider() (AutoCertProvider, error) {
|
||||
}
|
||||
gen, ok := providersGenMap[cfg.Provider]
|
||||
if !ok {
|
||||
ne.Extraf("unknown provider: %s", cfg.Provider)
|
||||
ne.Extraf("unknown provider: %q", cfg.Provider)
|
||||
}
|
||||
if ne.HasExtras() {
|
||||
return nil, ne
|
||||
@@ -189,13 +191,9 @@ func (p *autoCertProvider) LoadCert() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *autoCertProvider) RenewalOn() time.Time {
|
||||
t := time.Now().AddDate(0, 0, 3)
|
||||
func (p *autoCertProvider) ShouldRenewOn() time.Time {
|
||||
for _, expiry := range p.certExpiries {
|
||||
if expiry.Before(t) {
|
||||
return time.Now()
|
||||
}
|
||||
return t
|
||||
return expiry.AddDate(0, -1, 0)
|
||||
}
|
||||
// this line should never be reached
|
||||
panic("no certificate available")
|
||||
@@ -203,8 +201,8 @@ func (p *autoCertProvider) RenewalOn() time.Time {
|
||||
|
||||
func (p *autoCertProvider) ScheduleRenewal() {
|
||||
for {
|
||||
t := time.Until(p.RenewalOn())
|
||||
aclog.Infof("next renewal in %v", t)
|
||||
t := time.Until(p.ShouldRenewOn())
|
||||
aclog.Infof("next renewal in %v", t.Round(time.Second))
|
||||
time.Sleep(t)
|
||||
err := p.renewIfNeeded()
|
||||
if err != nil {
|
||||
@@ -230,7 +228,29 @@ func (p *autoCertProvider) saveCert(cert *certificate.Resource) NestedErrorLike
|
||||
}
|
||||
|
||||
func (p *autoCertProvider) needRenewal() bool {
|
||||
return time.Now().After(p.RenewalOn())
|
||||
expired := time.Now().After(p.ShouldRenewOn())
|
||||
if expired {
|
||||
return true
|
||||
}
|
||||
if len(p.cfg.Domains) != len(p.certExpiries) {
|
||||
return true
|
||||
}
|
||||
wantedDomains := make([]string, len(p.cfg.Domains))
|
||||
certDomains := make([]string, len(p.certExpiries))
|
||||
copy(wantedDomains, p.cfg.Domains)
|
||||
i := 0
|
||||
for domain := range p.certExpiries {
|
||||
certDomains[i] = domain
|
||||
i++
|
||||
}
|
||||
slices.Sort(wantedDomains)
|
||||
slices.Sort(certDomains)
|
||||
for i, domain := range certDomains {
|
||||
if domain != wantedDomains[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *autoCertProvider) renewIfNeeded() NestedErrorLike {
|
||||
@@ -249,6 +269,7 @@ func (p *autoCertProvider) renewIfNeeded() NestedErrorLike {
|
||||
for {
|
||||
err := p.ObtainCert()
|
||||
if err == nil {
|
||||
aclog.Info("renewed certificate")
|
||||
return nil
|
||||
}
|
||||
trials++
|
||||
@@ -305,5 +326,6 @@ func setOptions[T interface{}](cfg *T, opt ProviderOptions) error {
|
||||
|
||||
var providersGenMap = map[string]ProviderGenerator{
|
||||
"cloudflare": providerGenerator(cloudflare.NewDefaultConfig, cloudflare.NewDNSProviderConfig),
|
||||
"clouddns": providerGenerator(clouddns.NewDefaultConfig, clouddns.NewDNSProviderConfig),
|
||||
"clouddns": providerGenerator(clouddns.NewDefaultConfig, clouddns.NewDNSProviderConfig),
|
||||
"duckdns": providerGenerator(duckdns.NewDefaultConfig, duckdns.NewDNSProviderConfig),
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -25,12 +25,10 @@ type Config interface {
|
||||
func NewConfig(path string) Config {
|
||||
cfg := &config{
|
||||
reader: &FileReader{Path: path},
|
||||
l: cfgl,
|
||||
}
|
||||
cfg.watcher = NewFileWatcher(
|
||||
path,
|
||||
cfg.MustReload, // OnChange
|
||||
func() { os.Exit(1) }, // OnDelete
|
||||
)
|
||||
// must init fields above before creating watcher
|
||||
cfg.watcher = cfg.NewFileWatcher()
|
||||
return cfg
|
||||
}
|
||||
|
||||
@@ -43,10 +41,7 @@ func (cfg *config) Value() configModel {
|
||||
return *cfg.m
|
||||
}
|
||||
|
||||
func (cfg *config) Load(reader ...Reader) error {
|
||||
cfg.mutex.Lock()
|
||||
defer cfg.mutex.Unlock()
|
||||
|
||||
func (cfg *config) Load() error {
|
||||
if cfg.reader == nil {
|
||||
panic("config reader not set")
|
||||
}
|
||||
@@ -68,7 +63,7 @@ func (cfg *config) Load(reader ...Reader) error {
|
||||
ne.With(err)
|
||||
}
|
||||
|
||||
pErrs := NewNestedError("errors in these providers")
|
||||
pErrs := NewNestedError("these providers have errors")
|
||||
|
||||
for name, p := range model.Providers {
|
||||
if p.Kind != ProviderKind_File {
|
||||
@@ -90,13 +85,16 @@ func (cfg *config) Load(reader ...Reader) error {
|
||||
return ne
|
||||
}
|
||||
|
||||
cfg.mutex.Lock()
|
||||
defer cfg.mutex.Unlock()
|
||||
|
||||
cfg.m = model
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *config) MustLoad() {
|
||||
if err := cfg.Load(); err != nil {
|
||||
cfgl.Fatal(err)
|
||||
cfg.l.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,7 +113,7 @@ func (cfg *config) Reload() error {
|
||||
|
||||
func (cfg *config) MustReload() {
|
||||
if err := cfg.Reload(); err != nil {
|
||||
cfgl.Fatal(err)
|
||||
cfg.l.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,7 +142,7 @@ func (cfg *config) StartProviders() {
|
||||
cfg.providerInitialized = true
|
||||
|
||||
if pErrs.HasExtras() {
|
||||
cfgl.Error(pErrs)
|
||||
cfg.l.Error(pErrs)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,6 +192,7 @@ func defaultConfig() *configModel {
|
||||
type config struct {
|
||||
m *configModel
|
||||
|
||||
l logrus.FieldLogger
|
||||
reader Reader
|
||||
watcher Watcher
|
||||
mutex sync.Mutex
|
||||
|
||||
@@ -123,7 +123,7 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
const wildcardLabelPrefix = "proxy.*."
|
||||
const wildcardAlias = "*"
|
||||
|
||||
const clientUrlFromEnv = "FROM_ENV"
|
||||
|
||||
@@ -147,18 +147,6 @@ const (
|
||||
var (
|
||||
configSchema *jsonschema.Schema
|
||||
providersSchema *jsonschema.Schema
|
||||
_ = func() *jsonschema.Compiler {
|
||||
c := jsonschema.NewCompiler()
|
||||
c.Draft = jsonschema.Draft7
|
||||
var err error
|
||||
if configSchema, err = c.Compile(configSchemaPath); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if providersSchema, err = c.Compile(providersSchemaPath); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return c
|
||||
}()
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -168,17 +156,36 @@ const (
|
||||
|
||||
const udpBufferSize = 1500
|
||||
|
||||
var isHostNetworkMode = os.Getenv("GOPROXY_HOST_NETWORK") == "1"
|
||||
var isHostNetworkMode = getEnvBool("GOPROXY_HOST_NETWORK")
|
||||
|
||||
var logLevel = func() logrus.Level {
|
||||
switch os.Getenv("GOPROXY_DEBUG") {
|
||||
case "1", "true":
|
||||
if getEnvBool("GOPROXY_DEBUG") {
|
||||
logrus.SetLevel(logrus.DebugLevel)
|
||||
}
|
||||
return logrus.GetLevel()
|
||||
}()
|
||||
|
||||
var isRunningAsService = func() bool {
|
||||
v := os.Getenv("IS_SYSTEMD")
|
||||
return v == "1"
|
||||
}()
|
||||
var isRunningAsService = getEnvBool("IS_SYSTEMD") || getEnvBool("GOPROXY_IS_SYSTEMD") // IS_SYSTEMD is deprecated
|
||||
|
||||
var noSchemaValidation = getEnvBool("GOPROXY_NO_SCHEMA_VALIDATION")
|
||||
|
||||
func getEnvBool(key string) bool {
|
||||
v := os.Getenv(key)
|
||||
return v == "1" || v == "true"
|
||||
}
|
||||
|
||||
func initSchema() {
|
||||
if noSchemaValidation {
|
||||
return
|
||||
}
|
||||
|
||||
c := jsonschema.NewCompiler()
|
||||
c.Draft = jsonschema.Draft7
|
||||
var err error
|
||||
if configSchema, err = c.Compile(configSchemaPath); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if providersSchema, err = c.Compile(providersSchemaPath); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -14,20 +15,15 @@ import (
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func (p *Provider) setConfigField(c *ProxyConfig, label string, value string, prefix string) error {
|
||||
if strings.HasPrefix(label, prefix) {
|
||||
field := strings.TrimPrefix(label, prefix)
|
||||
if err := setFieldFromSnake(c, field, value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
func setConfigField(pl *ProxyLabel, c *ProxyConfig) error {
|
||||
return setFieldFromSnake(c, pl.Field, pl.Value)
|
||||
}
|
||||
|
||||
func (p *Provider) getContainerProxyConfigs(container *types.Container, clientIP string) ProxyConfigSlice {
|
||||
func (p *Provider) getContainerProxyConfigs(container *types.Container, clientIP string) (ProxyConfigSlice, error) {
|
||||
var aliases []string
|
||||
|
||||
cfgs := make(ProxyConfigSlice, 0)
|
||||
cfgMap := make(map[string]*ProxyConfig)
|
||||
|
||||
containerName := strings.TrimPrefix(container.Names[0], "/")
|
||||
aliasesLabel, ok := container.Labels["proxy.aliases"]
|
||||
@@ -35,7 +31,8 @@ func (p *Provider) getContainerProxyConfigs(container *types.Container, clientIP
|
||||
if !ok {
|
||||
aliases = []string{containerName}
|
||||
} else {
|
||||
aliases = strings.Split(aliasesLabel, ",")
|
||||
v, _ := commaSepParser(aliasesLabel)
|
||||
aliases = v.([]string)
|
||||
}
|
||||
|
||||
if clientIP == "" && isHostNetworkMode {
|
||||
@@ -44,21 +41,42 @@ func (p *Provider) getContainerProxyConfigs(container *types.Container, clientIP
|
||||
isRemote := clientIP != ""
|
||||
|
||||
for _, alias := range aliases {
|
||||
ne := NewNestedError("invalid label config").Subjectf("container %s", containerName)
|
||||
cfgMap[alias] = &ProxyConfig{}
|
||||
}
|
||||
|
||||
l := p.l.WithField("container", containerName).WithField("alias", alias)
|
||||
config := NewProxyConfig(p)
|
||||
prefix := fmt.Sprintf("proxy.%s.", alias)
|
||||
for label, value := range container.Labels {
|
||||
err := p.setConfigField(&config, label, value, prefix)
|
||||
if err != nil {
|
||||
ne.ExtraError(NewNestedErrorFrom(err).Subjectf("alias %s", alias))
|
||||
}
|
||||
err = p.setConfigField(&config, label, value, wildcardLabelPrefix)
|
||||
if err != nil {
|
||||
ne.ExtraError(NewNestedErrorFrom(err).Subjectf("alias %s", alias))
|
||||
ne := NewNestedError("these labels have errors").Subject(containerName)
|
||||
|
||||
for label, value := range container.Labels {
|
||||
pl, err := parseProxyLabel(label, value)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errNotProxyLabel) {
|
||||
ne.ExtraError(NewNestedErrorFrom(err).Subject(label))
|
||||
}
|
||||
continue
|
||||
}
|
||||
if pl.Alias == wildcardAlias {
|
||||
for alias := range cfgMap {
|
||||
pl.Alias = alias
|
||||
err = setConfigField(pl, cfgMap[alias])
|
||||
if err != nil {
|
||||
ne.ExtraError(NewNestedErrorFrom(err).Subject(pl.Alias))
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
config, ok := cfgMap[pl.Alias]
|
||||
if !ok {
|
||||
ne.ExtraError(NewNestedError("unknown alias").Subject(pl.Alias))
|
||||
continue
|
||||
}
|
||||
err = setConfigField(pl, config)
|
||||
if err != nil {
|
||||
ne.ExtraError(NewNestedErrorFrom(err).Subject(pl.Alias))
|
||||
}
|
||||
}
|
||||
|
||||
for alias, config := range cfgMap {
|
||||
l := p.l.WithField("alias", alias)
|
||||
if config.Port == "" {
|
||||
config.Port = fmt.Sprintf("%d", selectPort(container, isRemote))
|
||||
}
|
||||
@@ -70,8 +88,6 @@ func (p *Provider) getContainerProxyConfigs(container *types.Container, clientIP
|
||||
switch {
|
||||
case strings.HasSuffix(config.Port, "443"):
|
||||
config.Scheme = "https"
|
||||
case strings.HasPrefix(container.Image, "sha256:"):
|
||||
config.Scheme = "http"
|
||||
default:
|
||||
imageName := getImageName(container)
|
||||
_, isKnownImage := ImageNamePortMapTCP[imageName]
|
||||
@@ -90,7 +106,7 @@ func (p *Provider) getContainerProxyConfigs(container *types.Container, clientIP
|
||||
var err error
|
||||
// find matching port
|
||||
srcPort := config.Port[1:]
|
||||
config.Port, err = findMatchingContainerPort(container,srcPort)
|
||||
config.Port, err = findMatchingContainerPort(container, srcPort)
|
||||
if err != nil {
|
||||
ne.ExtraError(NewNestedErrorFrom(err).Subjectf("alias %s", alias))
|
||||
}
|
||||
@@ -98,8 +114,7 @@ func (p *Provider) getContainerProxyConfigs(container *types.Container, clientIP
|
||||
config.Port = fmt.Sprintf("%s:%s", srcPort, config.Port)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
if config.Host == "" {
|
||||
switch {
|
||||
case isRemote:
|
||||
@@ -126,12 +141,15 @@ func (p *Provider) getContainerProxyConfigs(container *types.Container, clientIP
|
||||
config.Alias = alias
|
||||
|
||||
if ne.HasExtras() {
|
||||
l.Error(ne)
|
||||
continue
|
||||
}
|
||||
cfgs = append(cfgs, config)
|
||||
cfgs = append(cfgs, *config)
|
||||
}
|
||||
return cfgs
|
||||
|
||||
if ne.HasExtras() {
|
||||
return nil, ne
|
||||
}
|
||||
return cfgs, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getDockerClient() (*client.Client, error) {
|
||||
@@ -196,8 +214,19 @@ func (p *Provider) getDockerProxyConfigs() (ProxyConfigSlice, error) {
|
||||
|
||||
cfgs := make(ProxyConfigSlice, 0)
|
||||
|
||||
ne := NewNestedError("these containers have errors")
|
||||
for _, container := range containerSlice {
|
||||
cfgs = append(cfgs, p.getContainerProxyConfigs(&container, clientIP)...)
|
||||
ccfgs, err := p.getContainerProxyConfigs(&container, clientIP)
|
||||
if err != nil {
|
||||
ne.ExtraError(err)
|
||||
continue
|
||||
}
|
||||
cfgs = append(cfgs, ccfgs...)
|
||||
}
|
||||
|
||||
if ne.HasExtras() {
|
||||
// print but ignore
|
||||
p.l.Error(ne)
|
||||
}
|
||||
|
||||
return cfgs, nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -46,6 +47,10 @@ func NewNestedErrorFrom(err error) NestedErrorLike {
|
||||
if err == nil {
|
||||
panic("cannot convert nil error to NestedError")
|
||||
}
|
||||
errUnwrap := errors.Unwrap(err)
|
||||
if errUnwrap != nil {
|
||||
return NewNestedErrorFrom(errUnwrap)
|
||||
}
|
||||
return NewNestedError(err.Error())
|
||||
}
|
||||
|
||||
@@ -92,23 +97,23 @@ func (ne *NestedError) Level() int {
|
||||
return ne.level
|
||||
}
|
||||
|
||||
func (ef *NestedError) Error() string {
|
||||
func (ne *NestedError) Error() string {
|
||||
var buf strings.Builder
|
||||
ef.writeToSB(&buf, "")
|
||||
ne.writeToSB(&buf, ne.level, "")
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (ef *NestedError) HasInner() bool {
|
||||
return ef.inner != nil
|
||||
func (ne *NestedError) HasInner() bool {
|
||||
return ne.inner != nil
|
||||
}
|
||||
|
||||
func (ef *NestedError) HasExtras() bool {
|
||||
return len(ef.extras) > 0
|
||||
func (ne *NestedError) HasExtras() bool {
|
||||
return len(ne.extras) > 0
|
||||
}
|
||||
|
||||
func (ef *NestedError) With(inner error) NestedErrorLike {
|
||||
ef.Lock()
|
||||
defer ef.Unlock()
|
||||
func (ne *NestedError) With(inner error) NestedErrorLike {
|
||||
ne.Lock()
|
||||
defer ne.Unlock()
|
||||
|
||||
var in *NestedError
|
||||
|
||||
@@ -116,79 +121,75 @@ func (ef *NestedError) With(inner error) NestedErrorLike {
|
||||
case NestedErrorLike:
|
||||
in = t.copy()
|
||||
default:
|
||||
in = &NestedError{extras: []string{t.Error()}}
|
||||
in = &NestedError{message: t.Error()}
|
||||
}
|
||||
if ef.inner == nil {
|
||||
ef.inner = in
|
||||
if ne.inner == nil {
|
||||
ne.inner = in
|
||||
} else {
|
||||
ef.inner.ExtraError(in)
|
||||
ne.inner.ExtraError(in)
|
||||
}
|
||||
root := ef
|
||||
root := ne
|
||||
for root.inner != nil {
|
||||
root.inner.level = root.level + 1
|
||||
root = root.inner
|
||||
}
|
||||
return ef
|
||||
return ne
|
||||
}
|
||||
|
||||
func (ef *NestedError) addLevel(level int) NestedErrorLike {
|
||||
ef.level += level
|
||||
if ef.inner != nil {
|
||||
ef.inner.addLevel(level)
|
||||
func (ne *NestedError) addLevel(level int) NestedErrorLike {
|
||||
ne.level += level
|
||||
if ne.inner != nil {
|
||||
ne.inner.addLevel(level)
|
||||
}
|
||||
return ef
|
||||
return ne
|
||||
}
|
||||
|
||||
func (ef *NestedError) copy() *NestedError {
|
||||
func (ne *NestedError) copy() *NestedError {
|
||||
var inner *NestedError
|
||||
if ef.inner != nil {
|
||||
inner = ef.inner.copy()
|
||||
if ne.inner != nil {
|
||||
inner = ne.inner.copy()
|
||||
}
|
||||
return &NestedError{
|
||||
subject: ef.subject,
|
||||
message: ef.message,
|
||||
extras: ef.extras,
|
||||
subject: ne.subject,
|
||||
message: ne.message,
|
||||
extras: ne.extras,
|
||||
inner: inner,
|
||||
level: ef.level,
|
||||
}
|
||||
}
|
||||
|
||||
func (ef *NestedError) writeIndents(sb *strings.Builder, level int) {
|
||||
func (ne *NestedError) writeIndents(sb *strings.Builder, level int) {
|
||||
for i := 0; i < level; i++ {
|
||||
sb.WriteString(" ")
|
||||
}
|
||||
}
|
||||
|
||||
func (ef *NestedError) writeToSB(sb *strings.Builder, prefix string) {
|
||||
ef.writeIndents(sb, ef.level)
|
||||
func (ne *NestedError) writeToSB(sb *strings.Builder, level int, prefix string) {
|
||||
ne.writeIndents(sb, level)
|
||||
sb.WriteString(prefix)
|
||||
|
||||
if ef.subject != "" {
|
||||
sb.WriteRune('"')
|
||||
sb.WriteString(ef.subject)
|
||||
sb.WriteRune('"')
|
||||
if ef.message != "" {
|
||||
sb.WriteString(":\n")
|
||||
} else {
|
||||
sb.WriteRune('\n')
|
||||
if ne.subject != "" {
|
||||
sb.WriteString(ne.subject)
|
||||
if ne.message != "" {
|
||||
sb.WriteString(": ")
|
||||
}
|
||||
}
|
||||
if ef.message != "" {
|
||||
ef.writeIndents(sb, ef.level)
|
||||
sb.WriteString(ef.message)
|
||||
sb.WriteRune('\n')
|
||||
if ne.message != "" {
|
||||
sb.WriteString(ne.message)
|
||||
}
|
||||
for _, l := range ef.extras {
|
||||
l = strings.TrimSpace(l)
|
||||
if ne.HasExtras() || ne.HasInner() {
|
||||
sb.WriteString(":\n")
|
||||
}
|
||||
level += 1
|
||||
for _, l := range ne.extras {
|
||||
if l == "" {
|
||||
continue
|
||||
}
|
||||
ef.writeIndents(sb, ef.level)
|
||||
ne.writeIndents(sb, level)
|
||||
sb.WriteString("- ")
|
||||
sb.WriteString(l)
|
||||
sb.WriteRune('\n')
|
||||
}
|
||||
if ef.inner != nil {
|
||||
ef.inner.writeToSB(sb, "- ")
|
||||
if ne.inner != nil {
|
||||
ne.inner.writeToSB(sb, level, "- ")
|
||||
}
|
||||
}
|
||||
|
||||
66
src/go-proxy/error_test.go
Normal file
66
src/go-proxy/error_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func AssertEq(t *testing.T, got, want string) {
|
||||
t.Helper()
|
||||
if got != want {
|
||||
t.Errorf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorSimple(t *testing.T) {
|
||||
ne := NewNestedError("foo bar")
|
||||
AssertEq(t, ne.Error(), "foo bar")
|
||||
ne.Subject("baz")
|
||||
AssertEq(t, ne.Error(), "baz: foo bar")
|
||||
}
|
||||
|
||||
func TestErrorSubjectOnly(t *testing.T) {
|
||||
ne := NewNestedError("").Subject("bar")
|
||||
AssertEq(t, ne.Error(), "bar")
|
||||
}
|
||||
|
||||
func TestErrorExtra(t *testing.T) {
|
||||
ne := NewNestedError("foo").Extra("bar").Extra("baz")
|
||||
AssertEq(t, ne.Error(), "foo:\n - bar\n - baz\n")
|
||||
}
|
||||
|
||||
func TestErrorNested(t *testing.T) {
|
||||
inner := NewNestedError("inner").
|
||||
Extra("123").
|
||||
Extra("456")
|
||||
inner2 := NewNestedError("inner").
|
||||
Subject("2").
|
||||
Extra("456").
|
||||
Extra("789")
|
||||
inner3 := NewNestedError("inner").
|
||||
Subject("3").
|
||||
Extra("456").
|
||||
Extra("789")
|
||||
ne := NewNestedError("foo").
|
||||
Extra("bar").
|
||||
Extra("baz").
|
||||
ExtraError(inner).
|
||||
With(inner.With(inner2.With(inner3)))
|
||||
want :=
|
||||
`foo:
|
||||
- bar
|
||||
- baz
|
||||
- inner:
|
||||
- 123
|
||||
- 456
|
||||
- inner:
|
||||
- 123
|
||||
- 456
|
||||
- 2: inner:
|
||||
- 456
|
||||
- 789
|
||||
- 3: inner:
|
||||
- 456
|
||||
- 789
|
||||
`
|
||||
AssertEq(t, ne.Error(), want)
|
||||
}
|
||||
@@ -34,7 +34,7 @@ func NewHTTPRoute(config *ProxyConfig) (*HTTPRoute, error) {
|
||||
tr = transport
|
||||
}
|
||||
|
||||
proxy := NewSingleHostReverseProxy(url, tr)
|
||||
proxy := NewReverseProxy(url, tr, config)
|
||||
|
||||
route := &HTTPRoute{
|
||||
Alias: config.Alias,
|
||||
@@ -42,36 +42,35 @@ func NewHTTPRoute(config *ProxyConfig) (*HTTPRoute, error) {
|
||||
Path: config.Path,
|
||||
Proxy: proxy,
|
||||
PathMode: config.PathMode,
|
||||
l: hrlog.WithFields(logrus.Fields{
|
||||
"alias": config.Alias,
|
||||
// "path": config.Path,
|
||||
// "path_mode": config.PathMode,
|
||||
}),
|
||||
l: logrus.WithField("alias", config.Alias),
|
||||
}
|
||||
|
||||
var rewriteBegin = proxy.Rewrite
|
||||
var rewrite func(*ProxyRequest)
|
||||
var modifyResponse func(*http.Response) error
|
||||
|
||||
switch {
|
||||
case config.Path == "", config.PathMode == ProxyPathMode_Forward:
|
||||
// no path or forward path
|
||||
if config.Path == "" || config.PathMode == ProxyPathMode_Forward {
|
||||
rewrite = rewriteBegin
|
||||
case config.PathMode == ProxyPathMode_RemovedPath:
|
||||
rewrite = func(pr *ProxyRequest) {
|
||||
rewriteBegin(pr)
|
||||
pr.Out.URL.Path = strings.TrimPrefix(pr.Out.URL.Path, config.Path)
|
||||
} else {
|
||||
switch config.PathMode {
|
||||
case ProxyPathMode_RemovedPath:
|
||||
rewrite = func(pr *ProxyRequest) {
|
||||
rewriteBegin(pr)
|
||||
pr.Out.URL.Path = strings.TrimPrefix(pr.Out.URL.Path, config.Path)
|
||||
}
|
||||
case ProxyPathMode_Sub:
|
||||
rewrite = func(pr *ProxyRequest) {
|
||||
rewriteBegin(pr)
|
||||
// disable compression
|
||||
pr.Out.Header.Set("Accept-Encoding", "identity")
|
||||
// remove path prefix
|
||||
pr.Out.URL.Path = strings.TrimPrefix(pr.Out.URL.Path, config.Path)
|
||||
}
|
||||
modifyResponse = config.pathSubModResp
|
||||
default:
|
||||
return nil, NewNestedError("invalid path mode").Subject(config.PathMode)
|
||||
}
|
||||
case config.PathMode == ProxyPathMode_Sub:
|
||||
rewrite = func(pr *ProxyRequest) {
|
||||
rewriteBegin(pr)
|
||||
// disable compression
|
||||
pr.Out.Header.Set("Accept-Encoding", "identity")
|
||||
// remove path prefix
|
||||
pr.Out.URL.Path = strings.TrimPrefix(pr.Out.URL.Path, config.Path)
|
||||
}
|
||||
modifyResponse = config.pathSubModResp
|
||||
default:
|
||||
return nil, NewNestedError("invalid path mode").Subject(config.PathMode)
|
||||
}
|
||||
|
||||
if logLevel == logrus.DebugLevel {
|
||||
@@ -96,8 +95,9 @@ func NewHTTPRoute(config *ProxyConfig) (*HTTPRoute, error) {
|
||||
}
|
||||
|
||||
func (r *HTTPRoute) Start() {
|
||||
// dummy
|
||||
httpRoutes.Get(r.Alias).Add(r.Path, r)
|
||||
}
|
||||
|
||||
func (r *HTTPRoute) Stop() {
|
||||
httpRoutes.Delete(r.Alias)
|
||||
}
|
||||
|
||||
@@ -2,10 +2,9 @@ package main
|
||||
|
||||
import "github.com/sirupsen/logrus"
|
||||
|
||||
var palog = logrus.WithField("component", "panel")
|
||||
var prlog = logrus.WithField("component", "provider")
|
||||
var cfgl = logrus.WithField("component", "config")
|
||||
var hrlog = logrus.WithField("component", "http_proxy")
|
||||
var srlog = logrus.WithField("component", "stream")
|
||||
var wlog = logrus.WithField("component", "watcher")
|
||||
var aclog = logrus.WithField("component", "autocert")
|
||||
var palog = logrus.WithField("?", "panel")
|
||||
var cfgl = logrus.WithField("?", "config")
|
||||
var hrlog = logrus.WithField("?", "http")
|
||||
var srlog = logrus.WithField("?", "stream")
|
||||
var wlog = logrus.WithField("?", "watcher")
|
||||
var aclog = logrus.WithField("?", "autocert")
|
||||
@@ -43,6 +43,8 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
initSchema()
|
||||
|
||||
cfg = NewConfig(configPath)
|
||||
cfg.MustLoad()
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -10,15 +8,17 @@ type Provider struct {
|
||||
Kind string `json:"kind"` // docker, file
|
||||
Value string `json:"value"`
|
||||
|
||||
watcher Watcher
|
||||
routes map[string]Route // id -> Route
|
||||
mutex sync.Mutex
|
||||
l logrus.FieldLogger
|
||||
watcher Watcher
|
||||
routes map[string]Route // id -> Route
|
||||
l logrus.FieldLogger
|
||||
reloadReqCh chan struct{}
|
||||
}
|
||||
|
||||
// Init is called after LoadProxyConfig
|
||||
func (p *Provider) Init(name string) error {
|
||||
p.l = prlog.WithFields(logrus.Fields{"kind": p.Kind, "name": name})
|
||||
p.l = logrus.WithField("provider", name)
|
||||
p.reloadReqCh = make(chan struct{}, 1)
|
||||
|
||||
defer p.initWatcher()
|
||||
|
||||
if err := p.loadProxyConfig(); err != nil {
|
||||
@@ -40,16 +40,23 @@ func (p *Provider) StopAllRoutes() {
|
||||
}
|
||||
|
||||
func (p *Provider) ReloadRoutes() {
|
||||
p.mutex.Lock()
|
||||
defer p.mutex.Unlock()
|
||||
select {
|
||||
case p.reloadReqCh <- struct{}{}:
|
||||
defer func() {
|
||||
<-p.reloadReqCh
|
||||
}()
|
||||
|
||||
p.StopAllRoutes()
|
||||
err := p.loadProxyConfig()
|
||||
if err != nil {
|
||||
p.l.Error("failed to reload routes: ", err)
|
||||
p.StopAllRoutes()
|
||||
err := p.loadProxyConfig()
|
||||
if err != nil {
|
||||
p.l.Error("failed to reload routes: ", err)
|
||||
return
|
||||
}
|
||||
p.StartAllRoutes()
|
||||
default:
|
||||
p.l.Info("reload request already in progress")
|
||||
return
|
||||
}
|
||||
p.StartAllRoutes()
|
||||
}
|
||||
|
||||
func (p *Provider) loadProxyConfig() error {
|
||||
@@ -97,9 +104,9 @@ func (p *Provider) initWatcher() error {
|
||||
if err != nil {
|
||||
return NewNestedError("unable to create docker client").With(err)
|
||||
}
|
||||
p.watcher = NewDockerWatcher(dockerClient, p.ReloadRoutes)
|
||||
p.watcher = p.NewDockerWatcher(dockerClient)
|
||||
case ProviderKind_File:
|
||||
p.watcher = NewFileWatcher(p.GetFilePath(), p.ReloadRoutes, p.StopAllRoutes)
|
||||
p.watcher = p.NewFileWatcher()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,29 +1,26 @@
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ProxyConfig struct {
|
||||
Alias string `yaml:"-" json:"-"`
|
||||
Scheme string `yaml:"scheme" json:"scheme"`
|
||||
Host string `yaml:"host" json:"host"`
|
||||
Port string `yaml:"port" json:"port"`
|
||||
LoadBalance string `yaml:"-" json:"-"` // docker provider only
|
||||
NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify"` // http proxy only
|
||||
Path string `yaml:"path" json:"path"` // http proxy only
|
||||
PathMode string `yaml:"path_mode" json:"path_mode"` // http proxy only
|
||||
|
||||
provider *Provider
|
||||
Alias string `yaml:"-" json:"-"`
|
||||
Scheme string `yaml:"scheme" json:"scheme"`
|
||||
Host string `yaml:"host" json:"host"`
|
||||
Port string `yaml:"port" json:"port"`
|
||||
LoadBalance string `yaml:"-" json:"-"` // docker provider only
|
||||
NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify"` // http proxy only
|
||||
Path string `yaml:"path" json:"path"` // http proxy only
|
||||
PathMode string `yaml:"path_mode" json:"path_mode"` // http proxy only
|
||||
SetHeaders http.Header `yaml:"set_headers" json:"set_headers"` // http proxy only
|
||||
HideHeaders []string `yaml:"hide_headers" json:"hide_headers"` // http proxy only
|
||||
}
|
||||
|
||||
type ProxyConfigMap map[string]ProxyConfig
|
||||
type ProxyConfigSlice []ProxyConfig
|
||||
|
||||
func NewProxyConfig(provider *Provider) ProxyConfig {
|
||||
return ProxyConfig{
|
||||
provider: provider,
|
||||
}
|
||||
}
|
||||
|
||||
// used by `GetFileProxyConfigs`
|
||||
func (cfg *ProxyConfig) SetDefaults() error {
|
||||
err := NewNestedError("invalid proxy config").Subject(cfg.Alias)
|
||||
@@ -55,4 +52,4 @@ func (cfg *ProxyConfig) SetDefaults() error {
|
||||
|
||||
func (cfg *ProxyConfig) GetID() string {
|
||||
return fmt.Sprintf("%s-%s-%s-%s-%s", cfg.Alias, cfg.Scheme, cfg.Host, cfg.Port, cfg.Path)
|
||||
}
|
||||
}
|
||||
92
src/go-proxy/proxy_label.go
Normal file
92
src/go-proxy/proxy_label.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ProxyLabel struct {
|
||||
Alias string
|
||||
Field string
|
||||
Value any
|
||||
}
|
||||
|
||||
var errNotProxyLabel = errors.New("not a proxy label")
|
||||
var errInvalidSetHeaderLine = errors.New("invalid set header line")
|
||||
var errInvalidBoolean = errors.New("invalid boolean")
|
||||
|
||||
const proxyLabelNamespace = "proxy"
|
||||
|
||||
func parseProxyLabel(label string, value string) (*ProxyLabel, error) {
|
||||
ns := strings.Split(label, ".")
|
||||
var v any = value
|
||||
|
||||
if len(ns) != 3 {
|
||||
return nil, errNotProxyLabel
|
||||
}
|
||||
|
||||
if ns[0] != proxyLabelNamespace {
|
||||
return nil, errNotProxyLabel
|
||||
}
|
||||
|
||||
field := ns[2]
|
||||
|
||||
var err error
|
||||
parser, ok := valueParser[field]
|
||||
|
||||
if ok {
|
||||
v, err = parser(v.(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &ProxyLabel{
|
||||
Alias: ns[1],
|
||||
Field: field,
|
||||
Value: v,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func setHeadersParser(value string) (any, error) {
|
||||
value = strings.TrimSpace(value)
|
||||
lines := strings.Split(value, "\n")
|
||||
h := make(http.Header)
|
||||
for _, line := range lines {
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("%w: %q", errInvalidSetHeaderLine, line)
|
||||
}
|
||||
key := strings.TrimSpace(parts[0])
|
||||
val := strings.TrimSpace(parts[1])
|
||||
h.Add(key, val)
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func commaSepParser(value string) (any, error) {
|
||||
v := strings.Split(value, ",")
|
||||
for i := range v {
|
||||
v[i] = strings.TrimSpace(v[i])
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func boolParser(value string) (any, error) {
|
||||
switch strings.ToLower(value) {
|
||||
case "true", "yes", "1":
|
||||
return true, nil
|
||||
case "false", "no", "0":
|
||||
return false, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: %q", errInvalidBoolean, value)
|
||||
}
|
||||
}
|
||||
|
||||
var valueParser = map[string]func(string) (any, error){
|
||||
"set_headers": setHeadersParser,
|
||||
"hide_headers": commaSepParser,
|
||||
"no_tls_verify": boolParser,
|
||||
}
|
||||
186
src/go-proxy/proxy_label_test.go
Normal file
186
src/go-proxy/proxy_label_test.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func makeLabel(alias string, field string) string {
|
||||
return fmt.Sprintf("proxy.%s.%s", alias, field)
|
||||
}
|
||||
|
||||
func TestNotProxyLabel(t *testing.T) {
|
||||
pl, err := parseProxyLabel("foo.bar", "1234")
|
||||
if !errors.Is(err, errNotProxyLabel) {
|
||||
t.Errorf("expected err NotProxyLabel, got %v", err)
|
||||
}
|
||||
if pl != nil {
|
||||
t.Errorf("expected nil, got %v", pl)
|
||||
}
|
||||
_, err = parseProxyLabel("proxy.foo", "bar")
|
||||
if !errors.Is(err, errNotProxyLabel) {
|
||||
t.Errorf("expected err InvalidProxyLabel, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringProxyLabel(t *testing.T) {
|
||||
alias := "foo"
|
||||
field := "ip"
|
||||
v := "bar"
|
||||
pl, err := parseProxyLabel(makeLabel(alias, field), v)
|
||||
if err != nil {
|
||||
t.Errorf("expected err=nil, got %v", err)
|
||||
}
|
||||
if pl.Alias != alias {
|
||||
t.Errorf("expected alias=%s, got %s", alias, pl.Alias)
|
||||
}
|
||||
if pl.Field != field {
|
||||
t.Errorf("expected field=%s, got %s", field, pl.Field)
|
||||
}
|
||||
if pl.Value != v {
|
||||
t.Errorf("expected value=%q, got %s", v, pl.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBoolProxyLabelValid(t *testing.T) {
|
||||
alias := "foo"
|
||||
field := "no_tls_verify"
|
||||
tests := map[string]bool{
|
||||
"true": true,
|
||||
"TRUE": true,
|
||||
"yes": true,
|
||||
"1": true,
|
||||
"false": false,
|
||||
"FALSE": false,
|
||||
"no": false,
|
||||
"0": false,
|
||||
}
|
||||
|
||||
for k, v := range tests {
|
||||
pl, err := parseProxyLabel(makeLabel(alias, field), k)
|
||||
if err != nil {
|
||||
t.Errorf("expected err=nil, got %v", err)
|
||||
}
|
||||
if pl.Alias != alias {
|
||||
t.Errorf("expected alias=%s, got %s", alias, pl.Alias)
|
||||
}
|
||||
if pl.Field != field {
|
||||
t.Errorf("expected field=%s, got %s", field, pl.Field)
|
||||
}
|
||||
if pl.Value != v {
|
||||
t.Errorf("expected value=%v, got %v", v, pl.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBoolProxyLabelInvalid(t *testing.T) {
|
||||
alias := "foo"
|
||||
field := "no_tls_verify"
|
||||
_, err := parseProxyLabel(makeLabel(alias, field), "invalid")
|
||||
if !errors.Is(err, errInvalidBoolean) {
|
||||
t.Errorf("expected err InvalidProxyLabel, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderProxyLabelValid(t *testing.T) {
|
||||
alias := "foo"
|
||||
field := "set_headers"
|
||||
v := `
|
||||
X-Custom-Header1: foo
|
||||
X-Custom-Header1: bar
|
||||
X-Custom-Header2: baz
|
||||
`
|
||||
h := make(http.Header, 0)
|
||||
h.Set("X-Custom-Header1", "foo")
|
||||
h.Add("X-Custom-Header1", "bar")
|
||||
h.Set("X-Custom-Header2", "baz")
|
||||
|
||||
pl, err := parseProxyLabel(makeLabel(alias, field), v)
|
||||
if err != nil {
|
||||
t.Errorf("expected err=nil, got %v", err)
|
||||
}
|
||||
if pl.Alias != alias {
|
||||
t.Errorf("expected alias=%s, got %s", alias, pl.Alias)
|
||||
}
|
||||
if pl.Field != field {
|
||||
t.Errorf("expected field=%s, got %s", field, pl.Field)
|
||||
}
|
||||
hGot, ok := pl.Value.(http.Header)
|
||||
if !ok {
|
||||
t.Error("value is not http.Header")
|
||||
return
|
||||
}
|
||||
for k, vWant := range h {
|
||||
vGot := hGot[k]
|
||||
if !reflect.DeepEqual(vGot, vWant) {
|
||||
t.Errorf("expected %s=%q, got %q", k, vWant, vGot)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderProxyLabelInvalid(t *testing.T) {
|
||||
alias := "foo"
|
||||
field := "set_headers"
|
||||
tests := []string{
|
||||
"X-Custom-Header1 = bar",
|
||||
"X-Custom-Header1",
|
||||
}
|
||||
|
||||
for _, v := range tests {
|
||||
_, err := parseProxyLabel(makeLabel(alias, field), v)
|
||||
if !errors.Is(err, errInvalidSetHeaderLine) {
|
||||
t.Errorf("expected err InvalidProxyLabel for %q, got %v", v, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommaSepProxyLabelSingle(t *testing.T) {
|
||||
alias := "foo"
|
||||
field := "hide_headers"
|
||||
v := "X-Custom-Header1"
|
||||
pl, err := parseProxyLabel(makeLabel(alias, field), v)
|
||||
if err != nil {
|
||||
t.Errorf("expected err=nil, got %v", err)
|
||||
}
|
||||
if pl.Alias != alias {
|
||||
t.Errorf("expected alias=%s, got %s", alias, pl.Alias)
|
||||
}
|
||||
if pl.Field != field {
|
||||
t.Errorf("expected field=%s, got %s", field, pl.Field)
|
||||
}
|
||||
sGot, ok := pl.Value.([]string)
|
||||
sWant := []string{"X-Custom-Header1"}
|
||||
if !ok {
|
||||
t.Error("value is not []string")
|
||||
}
|
||||
if !reflect.DeepEqual(sGot, sWant) {
|
||||
t.Errorf("expected %q, got %q", sWant, sGot)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommaSepProxyLabelMulti(t *testing.T) {
|
||||
alias := "foo"
|
||||
field := "hide_headers"
|
||||
v := "X-Custom-Header1, X-Custom-Header2,X-Custom-Header3"
|
||||
pl, err := parseProxyLabel(makeLabel(alias, field), v)
|
||||
if err != nil {
|
||||
t.Errorf("expected err=nil, got %v", err)
|
||||
}
|
||||
if pl.Alias != alias {
|
||||
t.Errorf("expected alias=%s, got %s", alias, pl.Alias)
|
||||
}
|
||||
if pl.Field != field {
|
||||
t.Errorf("expected field=%s, got %s", field, pl.Field)
|
||||
}
|
||||
sGot, ok := pl.Value.([]string)
|
||||
sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"}
|
||||
if !ok {
|
||||
t.Error("value is not []string")
|
||||
}
|
||||
if !reflect.DeepEqual(sGot, sWant) {
|
||||
t.Errorf("expected %q, got %q", sWant, sGot)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
package main
|
||||
|
||||
// A small mod on net/http/httputils
|
||||
// A small mod on net/http/httputil/reverseproxy.go
|
||||
// that doubled the performance
|
||||
|
||||
import (
|
||||
@@ -8,14 +8,12 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
@@ -39,16 +37,16 @@ type ProxyRequest struct {
|
||||
//
|
||||
// SetURL rewrites the outbound Host header to match the target's host.
|
||||
// To preserve the inbound request's Host header (the default behavior
|
||||
// of [NewSingleHostReverseProxy]):
|
||||
// of [NewReverseProxy]):
|
||||
//
|
||||
// rewriteFunc := func(r *httputil.ProxyRequest) {
|
||||
// r.SetURL(url)
|
||||
// r.Out.Host = r.In.Host
|
||||
// }
|
||||
func (r *ProxyRequest) SetURL(target *url.URL) {
|
||||
rewriteRequestURL(r.Out, target)
|
||||
r.Out.Host = ""
|
||||
}
|
||||
// func (r *ProxyRequest) SetURL(target *url.URL) {
|
||||
// rewriteRequestURL(r.Out, target)
|
||||
// r.Out.Host = ""
|
||||
// }
|
||||
|
||||
// SetXForwarded sets the X-Forwarded-For, X-Forwarded-Host, and
|
||||
// X-Forwarded-Proto headers of the outbound request.
|
||||
@@ -132,17 +130,17 @@ type ReverseProxy struct {
|
||||
// recognizes a response as a streaming response, or
|
||||
// if its ContentLength is -1; for such responses, writes
|
||||
// are flushed to the client immediately.
|
||||
FlushInterval time.Duration
|
||||
// FlushInterval time.Duration
|
||||
|
||||
// ErrorLog specifies an optional logger for errors
|
||||
// that occur when attempting to proxy the request.
|
||||
// If nil, logging is done via the log package's standard logger.
|
||||
ErrorLog *log.Logger
|
||||
// ErrorLog *log.Logger
|
||||
|
||||
// BufferPool optionally specifies a buffer pool to
|
||||
// get byte slices for use by io.CopyBuffer when
|
||||
// copying HTTP response bodies.
|
||||
BufferPool BufferPool
|
||||
// BufferPool BufferPool
|
||||
|
||||
// ModifyResponse is an optional function that modifies the
|
||||
// Response from the backend. It is called if the backend
|
||||
@@ -203,18 +201,18 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
|
||||
return a.Path + b.Path, apath + bpath
|
||||
}
|
||||
|
||||
// NewSingleHostReverseProxy returns a new [ReverseProxy] that routes
|
||||
// NewReverseProxy returns a new [ReverseProxy] that routes
|
||||
// URLs to the scheme, host, and base path provided in target. If the
|
||||
// target's path is "/base" and the incoming request was for "/dir",
|
||||
// the target request will be for /base/dir.
|
||||
//
|
||||
// NewSingleHostReverseProxy does not rewrite the Host header.
|
||||
// NewReverseProxy does not rewrite the Host header.
|
||||
//
|
||||
// To customize the ReverseProxy behavior beyond what
|
||||
// NewSingleHostReverseProxy provides, use ReverseProxy directly
|
||||
// NewReverseProxy provides, use ReverseProxy directly
|
||||
// with a Rewrite function. The ProxyRequest SetURL method
|
||||
// may be used to route the outbound request. (Note that SetURL,
|
||||
// unlike NewSingleHostReverseProxy, rewrites the Host header
|
||||
// unlike NewReverseProxy, rewrites the Host header
|
||||
// of the outbound request by default.)
|
||||
//
|
||||
// proxy := &ReverseProxy{
|
||||
@@ -223,9 +221,34 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
|
||||
// r.Out.Host = r.In.Host // if desired
|
||||
// },
|
||||
// }
|
||||
func NewSingleHostReverseProxy(target *url.URL, transport *http.Transport) *ReverseProxy {
|
||||
func NewReverseProxy(target *url.URL, transport *http.Transport, config *ProxyConfig) *ReverseProxy {
|
||||
// check on init rather than on request
|
||||
var setHeaders = func(r *http.Request) {}
|
||||
var hideHeaders = func(r *http.Request) {}
|
||||
if len(config.SetHeaders) > 0 {
|
||||
setHeaders = func(r *http.Request) {
|
||||
h := config.SetHeaders.Clone()
|
||||
for k, vv := range h {
|
||||
if k == "Host" {
|
||||
r.Host = vv[0]
|
||||
} else {
|
||||
r.Header[k] = vv
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(config.HideHeaders) > 0 {
|
||||
hideHeaders = func(r *http.Request) {
|
||||
for _, k := range config.HideHeaders {
|
||||
r.Header.Del(k)
|
||||
}
|
||||
}
|
||||
}
|
||||
return &ReverseProxy{Rewrite: func(pr *ProxyRequest) {
|
||||
rewriteRequestURL(pr.Out, target)
|
||||
pr.SetXForwarded()
|
||||
setHeaders(pr.Out)
|
||||
hideHeaders(pr.Out)
|
||||
}, Transport: transport}
|
||||
}
|
||||
|
||||
@@ -380,7 +403,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// Strip client-provided forwarding headers.
|
||||
// The Rewrite func may use SetXForwarded to set new values
|
||||
// for these or copy the previous values from the inbound request.
|
||||
// outreq.Header.Del("Forwarded")
|
||||
outreq.Header.Del("Forwarded")
|
||||
// outreq.Header.Del("X-Forwarded-For")
|
||||
// outreq.Header.Del("X-Forwarded-Host")
|
||||
// outreq.Header.Del("X-Forwarded-Proto")
|
||||
@@ -388,29 +411,27 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// NOTE: removed
|
||||
// Remove unparsable query parameters from the outbound request.
|
||||
// outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
|
||||
|
||||
pr := &ProxyRequest{
|
||||
In: req,
|
||||
Out: outreq,
|
||||
}
|
||||
pr.SetXForwarded() // NOTE: added
|
||||
p.Rewrite(pr)
|
||||
outreq = pr.Out
|
||||
// NOTE: removed
|
||||
// } else {
|
||||
// if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||
// // If we aren't the first proxy retain prior
|
||||
// // X-Forwarded-For information as a comma+space
|
||||
// // separated list and fold multiple headers into one.
|
||||
// prior, ok := outreq.Header["X-Forwarded-For"]
|
||||
// omit := ok && prior == nil // Issue 38079: nil now means don't populate the header
|
||||
// if len(prior) > 0 {
|
||||
// clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||
// }
|
||||
// if !omit {
|
||||
// outreq.Header.Set("X-Forwarded-For", clientIP)
|
||||
// }
|
||||
// if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||
// // If we aren't the first proxy retain prior
|
||||
// // X-Forwarded-For information as a comma+space
|
||||
// // separated list and fold multiple headers into one.
|
||||
// prior, ok := outreq.Header["X-Forwarded-For"]
|
||||
// omit := ok && prior == nil // Issue 38079: nil now means don't populate the header
|
||||
// if len(prior) > 0 {
|
||||
// clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||
// }
|
||||
// if !omit {
|
||||
// outreq.Header.Set("X-Forwarded-For", clientIP)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
if _, ok := outreq.Header["User-Agent"]; !ok {
|
||||
@@ -637,11 +658,11 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// }
|
||||
|
||||
func (p *ReverseProxy) logf(format string, args ...any) {
|
||||
if p.ErrorLog != nil {
|
||||
p.ErrorLog.Printf(format, args...)
|
||||
} else {
|
||||
hrlog.Printf(format, args...)
|
||||
}
|
||||
// if p.ErrorLog != nil {
|
||||
// p.ErrorLog.Printf(format, args...)
|
||||
// } else {
|
||||
hrlog.Errorf(format, args...)
|
||||
// }
|
||||
}
|
||||
|
||||
// NOTE: removed
|
||||
102
src/go-proxy/reverse_proxy_mod_test.go
Normal file
102
src/go-proxy/reverse_proxy_mod_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var proxyCfg ProxyConfig
|
||||
var proxyUrl, _ = url.Parse("http://127.0.0.1:8181")
|
||||
var proxyServer = NewServer(ServerOptions{
|
||||
Name: "proxy",
|
||||
HTTPAddr: ":8080",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
NewReverseProxy(proxyUrl, &http.Transport{}, &proxyCfg).ServeHTTP(w, r)
|
||||
}),
|
||||
})
|
||||
|
||||
var testServer = NewServer(ServerOptions{
|
||||
Name: "test",
|
||||
HTTPAddr: ":8181",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
h := r.Header
|
||||
for k, vv := range h {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
})
|
||||
|
||||
var httpClient = http.DefaultClient
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
proxyServer.Start()
|
||||
testServer.Start()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
code := m.Run()
|
||||
proxyServer.Stop()
|
||||
testServer.Stop()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestSetHeader(t *testing.T) {
|
||||
hWant := http.Header{"X-Test": []string{"foo", "bar"}, "X-Test2": []string{"baz"}}
|
||||
proxyCfg = ProxyConfig{
|
||||
Alias: "test",
|
||||
Scheme: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: "8181",
|
||||
SetHeaders: hWant,
|
||||
}
|
||||
req, err := http.NewRequest("HEAD", "http://127.0.0.1:8080", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
hGot := resp.Header
|
||||
t.Log("headers: ", hGot)
|
||||
for k, v := range hWant {
|
||||
if !reflect.DeepEqual(hGot[k], v) {
|
||||
t.Errorf("header %s: expected %v, got %v", k, v, hGot[k])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHideHeader(t *testing.T) {
|
||||
hHide := []string{"X-Test", "X-Test2"}
|
||||
proxyCfg = ProxyConfig{
|
||||
Alias: "test",
|
||||
Scheme: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: "8181",
|
||||
HideHeaders: hHide,
|
||||
}
|
||||
req, err := http.NewRequest("HEAD", "http://127.0.0.1:8080", nil)
|
||||
for _, k := range hHide {
|
||||
req.Header.Set(k, "foo")
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
hGot := resp.Header
|
||||
t.Log("headers: ", hGot)
|
||||
for _, v := range hHide {
|
||||
_, ok := hGot[v]
|
||||
if ok {
|
||||
t.Errorf("header %s: expected hidden, got %v", v, hGot[v])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -22,7 +22,6 @@ func NewRoute(cfg *ProxyConfig) (Route, error) {
|
||||
if err != nil {
|
||||
return nil, NewNestedErrorFrom(err).Subject(cfg.Alias)
|
||||
}
|
||||
httpRoutes.Get(cfg.Alias).Add(cfg.Path, route)
|
||||
return route, nil
|
||||
}
|
||||
}
|
||||
@@ -43,4 +42,4 @@ func isStreamScheme(s string) bool {
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,10 +48,18 @@ func newStreamRouteBase(config *ProxyConfig) (*StreamRouteBase, error) {
|
||||
var srcPort, dstPort string
|
||||
var srcScheme, dstScheme string
|
||||
|
||||
l := srlog.WithFields(logrus.Fields{
|
||||
"alias": config.Alias,
|
||||
})
|
||||
portSplit := strings.Split(config.Port, ":")
|
||||
if len(portSplit) != 2 {
|
||||
cfgl.Warnf("invalid port %s, assuming it is target port", config.Port)
|
||||
srcPort = "0"
|
||||
l.Warnf(
|
||||
`%s: invalid port %s,
|
||||
assuming it is target port`,
|
||||
config.Alias,
|
||||
config.Port,
|
||||
)
|
||||
srcPort = "0" // will assign later
|
||||
dstPort = config.Port
|
||||
} else {
|
||||
srcPort = portSplit[0]
|
||||
@@ -101,11 +109,7 @@ func newStreamRouteBase(config *ProxyConfig) (*StreamRouteBase, error) {
|
||||
stopCh: make(chan struct{}, 1),
|
||||
connCh: make(chan interface{}),
|
||||
started: false,
|
||||
l: srlog.WithFields(logrus.Fields{
|
||||
"alias": config.Alias,
|
||||
// "src": fmt.Sprintf("%s://:%d", srcScheme, srcPortInt),
|
||||
// "dst": fmt.Sprintf("%s://%s:%d", dstScheme, config.Host, dstPortInt),
|
||||
}),
|
||||
l: l,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -235,4 +239,4 @@ func (route *StreamRouteBase) grHandleConnections() {
|
||||
// id -> target
|
||||
type StreamRoutes SafeMap[string, StreamRoute]
|
||||
|
||||
var streamRoutes StreamRoutes = NewSafeMapOf[StreamRoutes]()
|
||||
var streamRoutes StreamRoutes = NewSafeMapOf[StreamRoutes]()
|
||||
|
||||
@@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -212,13 +213,17 @@ func setFieldFromSnake[T interface{}, VT interface{}](obj *T, field string, valu
|
||||
field = utils.snakeToPascal(field)
|
||||
prop := reflect.ValueOf(obj).Elem().FieldByName(field)
|
||||
if prop.Kind() == 0 {
|
||||
return NewNestedError("unknown field").Subject(field)
|
||||
return errors.New("unknown field")
|
||||
}
|
||||
prop.Set(reflect.ValueOf(value))
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateYaml(schema *jsonschema.Schema, data []byte) error {
|
||||
if noSchemaValidation {
|
||||
return nil
|
||||
}
|
||||
|
||||
var i interface{}
|
||||
|
||||
err := yaml.Unmarshal(data, &i)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -22,8 +22,6 @@ type Watcher interface {
|
||||
}
|
||||
|
||||
type watcherBase struct {
|
||||
name string // for log / error output
|
||||
kind string // for log / error output
|
||||
onChange func()
|
||||
l logrus.FieldLogger
|
||||
sync.Mutex
|
||||
@@ -42,30 +40,44 @@ type dockerWatcher struct {
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func newWatcher(kind string, name string, onChange func()) *watcherBase {
|
||||
func (p *Provider) newWatcher() *watcherBase {
|
||||
return &watcherBase{
|
||||
kind: kind,
|
||||
name: name,
|
||||
onChange: onChange,
|
||||
l: wlog.WithFields(logrus.Fields{"kind": kind, "name": name}),
|
||||
}
|
||||
}
|
||||
func NewFileWatcher(p string, onChange func(), onDelete func()) Watcher {
|
||||
return &fileWatcher{
|
||||
watcherBase: newWatcher("File", path.Base(p), onChange),
|
||||
path: p,
|
||||
onDelete: onDelete,
|
||||
onChange: p.ReloadRoutes,
|
||||
l: p.l,
|
||||
}
|
||||
}
|
||||
|
||||
func NewDockerWatcher(c *client.Client, onChange func()) Watcher {
|
||||
func (p *Provider) NewFileWatcher() Watcher {
|
||||
return &fileWatcher{
|
||||
watcherBase: p.newWatcher(),
|
||||
path: p.GetFilePath(),
|
||||
onDelete: p.StopAllRoutes,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) NewDockerWatcher(c *client.Client) Watcher {
|
||||
return &dockerWatcher{
|
||||
watcherBase: newWatcher("Docker", c.DaemonHost(), onChange),
|
||||
watcherBase: p.newWatcher(),
|
||||
client: c,
|
||||
stopCh: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *config) newWatcher() *watcherBase {
|
||||
return &watcherBase{
|
||||
onChange: c.MustReload,
|
||||
l: c.l,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *config) NewFileWatcher() Watcher {
|
||||
return &fileWatcher{
|
||||
watcherBase: c.newWatcher(),
|
||||
path: c.reader.(*FileReader).Path,
|
||||
onDelete: func() { c.l.Fatal("config file deleted") },
|
||||
}
|
||||
}
|
||||
|
||||
func (w *fileWatcher) Start() {
|
||||
w.Lock()
|
||||
defer w.Unlock()
|
||||
@@ -100,7 +112,7 @@ func (w *fileWatcher) Dispose() {
|
||||
func (w *dockerWatcher) Start() {
|
||||
w.Lock()
|
||||
defer w.Unlock()
|
||||
dockerWatchMap.Set(w.name, w)
|
||||
dockerWatchMap.Set(w.client.DaemonHost(), w)
|
||||
w.wg.Add(1)
|
||||
go w.watch()
|
||||
}
|
||||
@@ -114,7 +126,7 @@ func (w *dockerWatcher) Stop() {
|
||||
close(w.stopCh)
|
||||
w.wg.Wait()
|
||||
w.stopCh = nil
|
||||
dockerWatchMap.Delete(w.name)
|
||||
dockerWatchMap.Delete(w.client.DaemonHost())
|
||||
}
|
||||
|
||||
func (w *dockerWatcher) Dispose() {
|
||||
@@ -164,10 +176,10 @@ func watchFiles() {
|
||||
}
|
||||
switch {
|
||||
case event.Has(fsnotify.Write):
|
||||
w.l.Info("file changed")
|
||||
w.l.Info("file changed: ", event.Name)
|
||||
go w.onChange()
|
||||
case event.Has(fsnotify.Remove), event.Has(fsnotify.Rename):
|
||||
w.l.Info("file renamed / deleted")
|
||||
w.l.Info("file renamed / deleted: ", event.Name)
|
||||
go w.onDelete()
|
||||
}
|
||||
case err := <-fsWatcher.Errors:
|
||||
@@ -194,16 +206,20 @@ func (w *dockerWatcher) watch() {
|
||||
case <-w.stopCh:
|
||||
return
|
||||
case msg := <-msgChan:
|
||||
w.l.Infof("container %s %s", msg.Actor.Attributes["name"], msg.Action)
|
||||
containerName := msg.Actor.Attributes["name"]
|
||||
if strings.HasPrefix(containerName, "buildx_buildkit_builder-") {
|
||||
continue
|
||||
}
|
||||
w.l.Infof("container %s %s", containerName, msg.Action)
|
||||
go w.onChange()
|
||||
case err := <-errChan:
|
||||
switch {
|
||||
case client.IsErrConnectionFailed(err):
|
||||
w.l.Error(NewNestedError("connection failed").Subject(w.name))
|
||||
w.l.Error("watcher: connection failed")
|
||||
case client.IsErrNotFound(err):
|
||||
w.l.Error(NewNestedError("endpoint not found").Subject(w.name))
|
||||
w.l.Error("watcher: endpoint not found")
|
||||
default:
|
||||
w.l.Error(NewNestedErrorFrom(err).Subject(w.name))
|
||||
w.l.Errorf("watcher: %v", err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
msgChan, errChan = listen()
|
||||
|
||||
Reference in New Issue
Block a user