diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 84e3c8d9..d35cc0b9 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -130,3 +130,18 @@ jobs: run: | docker tag ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.meta.outputs.version }} ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest docker push ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest + scan: + runs-on: ubuntu-latest + needs: + - merge + steps: + - name: Scan Image with Trivy + uses: aquasecurity/trivy-action@0.20.0 + with: + image-ref: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest + format: "sarif" + output: "trivy-results.sarif" + - name: Upload Trivy SARIF Report + uses: github/codeql-action/upload-sarif@v2 + with: + sarif_file: "trivy-results.sarif" diff --git a/Dockerfile b/Dockerfile index 7a6c7878..83d966e2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,15 +9,15 @@ COPY src/go.mod src/go.sum ./ # Utilize build cache RUN --mount=type=cache,target="/go/pkg/mod" \ - go mod download + go mod graph | awk '{if ($1 !~ "@") print $2}' | xargs go get -# Now copy the remaining files -COPY src/ ./ +ENV GOCACHE=/root/.cache/go-build # Build the application with better caching RUN --mount=type=cache,target="/go/pkg/mod" \ --mount=type=cache,target="/root/.cache/go-build" \ - CGO_ENABLED=0 GOOS=linux go build -pgo=auto -o go-proxy ./ + --mount=type=bind,src=src,dst=/src \ + CGO_ENABLED=0 GOOS=linux go build -ldflags '-w -s' -pgo=auto -o /go-proxy . # Stage 2: Final image FROM scratch @@ -28,7 +28,7 @@ LABEL maintainer="yusing@6uo.me" COPY --from=builder /usr/share/zoneinfo /usr/share/zoneinfo # copy binary -COPY --from=builder /src/go-proxy /app/ +COPY --from=builder /go-proxy /app/ # copy schema directory COPY schema/ /app/schema/ diff --git a/Makefile b/Makefile index 0b628531..11e77848 100755 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: all build up quick-restart restart logs get udp-server +.PHONY: all setup build test up restart logs get debug run archive repush rapid-crash debug-list-containers all: debug @@ -9,7 +9,8 @@ setup: build: mkdir -p bin - CGO_ENABLED=0 GOOS=linux go build -pgo=auto -o bin/go-proxy github.com/yusing/go-proxy + CGO_ENABLED=0 GOOS=linux \ + go build -ldflags '${BUILD_FLAG}' -pgo=auto -o bin/go-proxy github.com/yusing/go-proxy test: go test ./src/... @@ -29,6 +30,9 @@ get: debug: make build && sudo GOPROXY_DEBUG=1 bin/go-proxy +run: + BUILD_FLAG="-s -w" make build && sudo bin/go-proxy + archive: git archive HEAD -o ../go-proxy-$$(date +"%Y%m%d%H%M").zip @@ -44,4 +48,4 @@ rapid-crash: sudo docker rm -f test_crash debug-list-containers: - bash -c 'echo -e "GET /containers/json HTTP/1.0\r\n" | sudo netcat -U /var/run/docker.sock | tail -n +9 | jq' \ No newline at end of file + bash -c 'echo -e "GET /containers/json HTTP/1.0\r\n" | sudo netcat -U /var/run/docker.sock | tail -n +9 | jq' diff --git a/config.example.yml b/config.example.yml index fa3955a2..8bd09e3f 100644 --- a/config.example.yml +++ b/config.example.yml @@ -1,36 +1,64 @@ # Autocert (choose one below and uncomment to enable) - +# # 1. use existing cert +# # autocert: # provider: local -# cert_path: certs/cert.crt # optional, uncomment only if you need to change it -# key_path: certs/priv.key # optional, uncomment only if you need to change it - +# +# cert_path: certs/cert.crt # optional, uncomment only if you need to change it +# key_path: certs/priv.key # optional, uncomment only if you need to change it +# # 2. cloudflare +# # autocert: # provider: cloudflare -# email: # ACME Email -# domains: # a list of domains for cert registration -# - x.y.z +# email: abc@gmail.com # ACME Email +# domains: # a list of domains for cert registration +# - "*.y.z" # remember to use double quotes to surround wildcard domain # options: -# auth_token: c1234565789-abcdefghijklmnopqrst # your zone API token - +# auth_token: c1234565789-abcdefghijklmnopqrst # your zone API token +# # 3. other providers, check docs/dns_providers.md for more providers: + # include files are standalone yaml files under `config/` directory + # # include: - # - providers.yml # config/providers.yml - # # add some more below if you want - # - file1.yml # config/file_1.yml + # - file1.yml # - file2.yml + docker: - # for value format, see https://docs.docker.com/reference/cli/dockerd/ - # $DOCKER_HOST implies unix:///var/run/docker.sock by default + # $DOCKER_HOST implies environment variable `DOCKER_HOST` or unix:///var/run/docker.sock by default local: $DOCKER_HOST + # add more docker providers if needed + # for value format, see https://docs.docker.com/reference/cli/dockerd/ + # # remote-1: tcp://10.0.2.1:2375 # remote-2: ssh://root:1234@10.0.2.2 -# Fixed options (optional, non hot-reloadable) +# if match_domains not defined +# any host = alias+[any domain] will match +# i.e. https://app1.y.z will match alias app1 for any domain y.z +# but https://app1.node1.y.z will only match alias "app.node1" +# +# if match_domains defined +# only host = alias+[one of match_domains] will match +# i.e. match_domains = [node1.my.app, my.site] +# https://app1.my.app, https://app1.my.net, etc. will not match even if app1 exists +# only https://*.node1.my.app and https://*.my.site will match +# +# +# match_domains: +# - my.site +# - node1.my.app +# Below are fixed options (non hot-reloadable) + +# timeout for shutdown (in seconds) +# # timeout_shutdown: 5 -# redirect_to_https: false # redirect http requests to https (if enabled) + +# global setting redirect http requests to https (if https available, otherwise this will be ignored) +# proxy..middlewares.redirect_http will override this +# +# redirect_to_https: false diff --git a/docs/docker.md b/docs/docker.md index b8285f97..8304b686 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -74,7 +74,7 @@ | `proxy.stop_timeout` | time to wait for stop command | | `10s` | `number[unit]...` | | `proxy.stop_signal` | signal sent to container for `stop` and `kill` methods | | docker's default | `SIGINT`, `SIGTERM`, `SIGHUP`, `SIGQUIT` and those without **SIG** prefix | | `proxy..` | set field for specific alias | `proxy.gitlab-ssh.scheme` | N/A | N/A | -| `proxy.$.` | set field for specific alias at index (starting from **1**) | `proxy.$3.port` | N/A | N/A | +| `proxy.#.` | set field for specific alias at index (starting from **1**) | `proxy.#3.port` | N/A | N/A | | `proxy.*.` | set field for all aliases | `proxy.*.set_headers` | N/A | N/A | ### Fields diff --git a/schema/config.schema.json b/schema/config.schema.json index 96f9115b..5af8576f 100644 --- a/schema/config.schema.json +++ b/schema/config.schema.json @@ -37,7 +37,13 @@ "title": "DNS Challenge Provider", "default": "local", "type": "string", - "enum": ["local", "cloudflare", "clouddns", "duckdns", "ovh"] + "enum": [ + "local", + "cloudflare", + "clouddns", + "duckdns", + "ovh" + ] }, "options": { "title": "Provider specific options", @@ -56,7 +62,12 @@ } }, "then": { - "required": ["email", "domains", "provider", "options"] + "required": [ + "email", + "domains", + "provider", + "options" + ] } }, { @@ -70,7 +81,9 @@ "then": { "properties": { "options": { - "required": ["auth_token"], + "required": [ + "auth_token" + ], "additionalProperties": false, "properties": { "auth_token": { @@ -93,7 +106,11 @@ "then": { "properties": { "options": { - "required": ["client_id", "email", "password"], + "required": [ + "client_id", + "email", + "password" + ], "additionalProperties": false, "properties": { "client_id": { @@ -124,7 +141,9 @@ "then": { "properties": { "options": { - "required": ["token"], + "required": [ + "token" + ], "additionalProperties": false, "properties": { "token": { @@ -147,14 +166,21 @@ "then": { "properties": { "options": { - "required": ["application_secret", "consumer_key"], + "required": [ + "application_secret", + "consumer_key" + ], "additionalProperties": false, "oneOf": [ { - "required": ["application_key"] + "required": [ + "application_key" + ] }, { - "required": ["oauth2_config"] + "required": [ + "oauth2_config" + ] } ], "properties": { @@ -205,7 +231,10 @@ "type": "string" } }, - "required": ["client_id", "client_secret"] + "required": [ + "client_id", + "client_secret" + ] } } } @@ -268,6 +297,14 @@ } } }, + "match_domains": { + "title": "Domains to match", + "type": "array", + "items": { + "type": "string" + }, + "minItems": 1 + }, "timeout_shutdown": { "title": "Shutdown timeout (in seconds)", "type": "integer", @@ -279,5 +316,7 @@ } }, "additionalProperties": false, - "required": ["providers"] -} + "required": [ + "providers" + ] +} \ No newline at end of file diff --git a/src/autocert/provider.go b/src/autocert/provider.go index 3feb30ef..97fb8025 100644 --- a/src/autocert/provider.go +++ b/src/autocert/provider.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "crypto/x509" "os" + "path" "reflect" "sort" "time" @@ -59,8 +60,7 @@ func (p *Provider) ObtainCert() (res E.NestedError) { defer b.To(&res) if p.cfg.Provider == ProviderLocal { - b.Addf("provider is set to %q", ProviderLocal).WithSeverity(E.SeverityWarning) - return + return nil } if p.client == nil { @@ -191,7 +191,19 @@ func (p *Provider) registerACME() E.NestedError { } func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError { - err := os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw------- + //* This should have been done in setup + //* but double check is always a good choice + _, err := os.Stat(path.Dir(p.cfg.CertPath)) + if err != nil { + if os.IsNotExist(err) { + if err = os.MkdirAll(path.Dir(p.cfg.CertPath), 0o755); err != nil { + return E.FailWith("create cert directory", err) + } + } else { + return E.FailWith("stat cert directory", err) + } + } + err = os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw------- if err != nil { return E.FailWith("write key file", err) } @@ -227,6 +239,10 @@ func (p *Provider) certState() CertState { } func (p *Provider) renewIfNeeded() E.NestedError { + if p.cfg.Provider == ProviderLocal { + return nil + } + switch p.certState() { case CertStateExpired: logger.Info("certs expired, renewing") diff --git a/src/autocert/setup.go b/src/autocert/setup.go index 72fe3f44..e8d1ca2c 100644 --- a/src/autocert/setup.go +++ b/src/autocert/setup.go @@ -14,7 +14,7 @@ func (p *Provider) Setup(ctx context.Context) (err E.NestedError) { } logger.Debug("obtaining cert due to error loading cert") if err = p.ObtainCert(); err != nil { - return err.Warn() + return err } } diff --git a/src/common/http.go b/src/common/http.go new file mode 100644 index 00000000..fff08ad0 --- /dev/null +++ b/src/common/http.go @@ -0,0 +1,25 @@ +package common + +import ( + "crypto/tls" + "net" + "net/http" + "time" +) + +var ( + defaultDialer = net.Dialer{ + Timeout: 60 * time.Second, + KeepAlive: 60 * time.Second, + } + DefaultTransport = &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: defaultDialer.DialContext, + MaxIdleConnsPerHost: 1000, + } + DefaultTransportNoTLS = func() *http.Transport { + var clone = DefaultTransport.Clone() + clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + return clone + }() +) diff --git a/src/config/config.go b/src/config/config.go index e4864aed..ef7b1ca4 100644 --- a/src/config/config.go +++ b/src/config/config.go @@ -31,25 +31,48 @@ type Config struct { reloadReq chan struct{} } -func Load() (*Config, E.NestedError) { - cfg := &Config{ +var instance *Config + +func GetConfig() *Config { + return instance +} + +func Load() E.NestedError { + if instance != nil { + return nil + } + instance = &Config{ + value: M.DefaultConfig(), proxyProviders: F.NewMapOf[string, *PR.Provider](), l: logrus.WithField("module", "config"), watcher: W.NewFileWatcher(common.ConfigFileName), reloadReq: make(chan struct{}, 1), } - return cfg, cfg.load() + return instance.load() } func Validate(data []byte) E.NestedError { return U.ValidateYaml(U.GetSchema(common.ConfigSchemaPath), data) } +func MatchDomains() []string { + if instance == nil { + logrus.Panic("config has not been loaded, please check if there is any errors") + } + return instance.value.MatchDomains +} + func (cfg *Config) Value() M.Config { + if cfg == nil { + logrus.Panic("config has not been loaded, please check if there is any errors") + } return *cfg.value } func (cfg *Config) GetAutoCertProvider() *autocert.Provider { + if instance == nil { + logrus.Panic("config has not been loaded, please check if there is any errors") + } return cfg.autocertProvider } @@ -61,13 +84,11 @@ func (cfg *Config) Dispose() { cfg.stopProviders() } -func (cfg *Config) Reload() E.NestedError { +func (cfg *Config) Reload() (err E.NestedError) { cfg.stopProviders() - if err := cfg.load(); err.HasError() { - return err - } + err = cfg.load() cfg.StartProxyProviders() - return nil + return } func (cfg *Config) StartProxyProviders() { @@ -126,28 +147,28 @@ func (cfg *Config) load() (res E.NestedError) { data, err := E.Check(os.ReadFile(common.ConfigPath)) if err.HasError() { b.Add(E.FailWith("read config", err)) - return + logrus.Fatal(b.Build()) } if !common.NoSchemaValidation { if err = Validate(data); err.HasError() { b.Add(E.FailWith("schema validation", err)) - return + logrus.Fatal(b.Build()) } } model := M.DefaultConfig() if err := E.From(yaml.Unmarshal(data, model)); err.HasError() { b.Add(E.FailWith("parse config", err)) - return + logrus.Fatal(b.Build()) } // errors are non fatal below - b.WithSeverity(E.SeverityWarning) b.Add(cfg.initAutoCert(&model.AutoCert)) b.Add(cfg.loadProviders(&model.Providers)) cfg.value = model + R.SetFindMuxDomains(model.MatchDomains) return } diff --git a/src/docker/label.go b/src/docker/label.go index 57746c00..edac8b9f 100644 --- a/src/docker/label.go +++ b/src/docker/label.go @@ -1,17 +1,37 @@ package docker import ( + "reflect" "strings" E "github.com/yusing/go-proxy/error" U "github.com/yusing/go-proxy/utils" + F "github.com/yusing/go-proxy/utils/functional" ) -type Label struct { - Namespace string - Target string - Attribute string - Value any +/* +Formats: + - namespace.attribute + - namespace.target.attribute + - namespace.target.attribute.namespace2.attribute +*/ +type ( + Label struct { + Namespace string + Target string + Attribute string + Value any + } + NestedLabelMap map[string]U.SerializedObject + ValueParser func(string) (any, E.NestedError) + ValueParserMap map[string]ValueParser +) + +func (l *Label) String() string { + if l.Attribute == "" { + return l.Namespace + "." + l.Target + } + return l.Namespace + "." + l.Target + "." + l.Attribute } // Apply applies the value of a Label to the corresponding field in the given object. @@ -23,12 +43,40 @@ type Label struct { // Returns: // - error: an error if the field does not exist. func ApplyLabel[T any](obj *T, l *Label) E.NestedError { - return U.Deserialize(map[string]any{l.Attribute: l.Value}, obj) + if obj == nil { + return E.Invalid("nil object", l) + } + switch nestedLabel := l.Value.(type) { + case *Label: + var field reflect.Value + objType := reflect.TypeFor[T]() + for i := 0; i < reflect.TypeFor[T]().NumField(); i++ { + if objType.Field(i).Tag.Get("yaml") == l.Attribute { + field = reflect.ValueOf(obj).Elem().Field(i) + break + } + } + if !field.IsValid() { + return E.NotExist("field", l.Attribute) + } + dst, ok := field.Interface().(NestedLabelMap) + if !ok { + return E.Invalid("type", field.Type()) + } + if dst == nil { + field.Set(reflect.MakeMap(reflect.TypeFor[NestedLabelMap]())) + dst = field.Interface().(NestedLabelMap) + } + if dst[nestedLabel.Namespace] == nil { + dst[nestedLabel.Namespace] = make(U.SerializedObject) + } + dst[nestedLabel.Namespace][nestedLabel.Attribute] = nestedLabel.Value + return nil + default: + return U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj) + } } -type ValueParser func(string) (any, E.NestedError) -type ValueParserMap map[string]ValueParser - func ParseLabel(label string, value string) (*Label, E.NestedError) { parts := strings.Split(label, ".") @@ -45,14 +93,22 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) { Value: value, } - if len(parts) == 3 { - l.Attribute = parts[2] - } else { + switch len(parts) { + case 2: l.Attribute = l.Target + case 3: + l.Attribute = parts[2] + default: + l.Attribute = parts[2] + nestedLabel, err := ParseLabel(strings.Join(parts[3:], "."), value) + if err.HasError() { + return nil, err + } + l.Value = nestedLabel } // find if namespace has value parser - pm, ok := labelValueParserMap[l.Namespace] + pm, ok := valueParserMap.Load(l.Namespace) if !ok { return l, nil } @@ -64,15 +120,28 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) { // try to parse value v, err := p(value) if err.HasError() { - return nil, err + return nil, err.Subject(label) } l.Value = v return l, nil } func RegisterNamespace(namespace string, pm ValueParserMap) { - labelValueParserMap[namespace] = pm + valueParserMap.Store(namespace, pm) +} + +func GetRegisteredNamespaces() map[string][]string { + r := make(map[string][]string) + + valueParserMap.RangeAll(func(ns string, vpm ValueParserMap) { + r[ns] = make([]string, 0, len(vpm)) + for attr := range vpm { + r[ns] = append(r[ns], attr) + } + }) + + return r } // namespace:target.attribute -> func(string) (any, error) -var labelValueParserMap = make(map[string]ValueParserMap) +var valueParserMap = F.NewMapOf[string, ValueParserMap]() diff --git a/src/docker/label_parser.go b/src/docker/label_parser.go index b89be9fd..def4b076 100644 --- a/src/docker/label_parser.go +++ b/src/docker/label_parser.go @@ -7,7 +7,27 @@ import ( "gopkg.in/yaml.v3" ) -func yamlListParser(value string) (any, E.NestedError) { +const ( + NSProxy = "proxy" + ProxyAttributePathPatterns = "path_patterns" + ProxyAttributeNoTLSVerify = "no_tls_verify" + ProxyAttributeMiddlewares = "middlewares" +) + +var _ = func() int { + RegisterNamespace(NSProxy, ValueParserMap{ + ProxyAttributePathPatterns: YamlStringListParser, + ProxyAttributeNoTLSVerify: BoolParser, + }) + return 0 +}() + +func YamlStringListParser(value string) (any, E.NestedError) { + /* + - foo + - bar + - baz + */ value = strings.TrimSpace(value) if value == "" { return []string{}, nil @@ -17,27 +37,36 @@ func yamlListParser(value string) (any, E.NestedError) { return data, err } -func yamlStringMappingParser(value string) (any, E.NestedError) { - value = strings.TrimSpace(value) - lines := strings.Split(value, "\n") - h := make(map[string]string) - for _, line := range lines { - parts := strings.SplitN(line, ":", 2) - if len(parts) != 2 { - return nil, E.Invalid("set header statement", line) - } - key := strings.TrimSpace(parts[0]) - val := strings.TrimSpace(parts[1]) - if existing, ok := h[key]; ok { - h[key] = existing + ", " + val - } else { - h[key] = val +func YamlLikeMappingParser(allowDuplicate bool) func(string) (any, E.NestedError) { + return func(value string) (any, E.NestedError) { + /* + foo: bar + boo: baz + */ + value = strings.TrimSpace(value) + lines := strings.Split(value, "\n") + h := make(map[string]string) + for _, line := range lines { + parts := strings.SplitN(line, ":", 2) + if len(parts) != 2 { + return nil, E.Invalid("syntax", line) + } + key := strings.TrimSpace(parts[0]) + val := strings.TrimSpace(parts[1]) + if existing, ok := h[key]; ok { + if !allowDuplicate { + return nil, E.Duplicated("key", key) + } + h[key] = existing + ", " + val + } else { + h[key] = val + } } + return h, nil } - return h, nil } -func boolParser(value string) (any, E.NestedError) { +func BoolParser(value string) (any, E.NestedError) { switch strings.ToLower(value) { case "true", "yes", "1": return true, nil @@ -47,15 +76,3 @@ func boolParser(value string) (any, E.NestedError) { return nil, E.Invalid("boolean value", value) } } - -const NSProxy = "proxy" - -var _ = func() int { - RegisterNamespace(NSProxy, ValueParserMap{ - "path_patterns": yamlListParser, - "set_headers": yamlStringMappingParser, - "hide_headers": yamlListParser, - "no_tls_verify": boolParser, - }) - return 0 -}() diff --git a/src/docker/label_parser_test.go b/src/docker/label_parser_test.go index 57faf939..4596f902 100644 --- a/src/docker/label_parser_test.go +++ b/src/docker/label_parser_test.go @@ -2,8 +2,6 @@ package docker import ( "fmt" - "reflect" - "strings" "testing" E "github.com/yusing/go-proxy/error" @@ -14,21 +12,16 @@ func makeLabel(namespace string, alias string, field string) string { return fmt.Sprintf("%s.%s.%s", namespace, alias, field) } -func TestHomePageLabel(t *testing.T) { +func TestParseLabel(t *testing.T) { alias := "foo" field := "ip" v := "bar" pl, err := ParseLabel(makeLabel(NSHomePage, alias, field), v) ExpectNoError(t, err.Error()) - if pl.Target != alias { - t.Errorf("Expected alias=%s, got %s", alias, pl.Target) - } - if pl.Attribute != field { - t.Errorf("Expected field=%s, got %s", field, pl.Target) - } - if pl.Value != v { - t.Errorf("Expected value=%q, got %s", v, pl.Value) - } + ExpectEqual(t, pl.Namespace, NSHomePage) + ExpectEqual(t, pl.Target, alias) + ExpectEqual(t, pl.Attribute, field) + ExpectEqual(t, pl.Value.(string), v) } func TestStringProxyLabel(t *testing.T) { @@ -51,90 +44,63 @@ func TestBoolProxyLabelValid(t *testing.T) { } for k, v := range tests { - pl, err := ParseLabel(makeLabel(NSProxy, "foo", "no_tls_verify"), k) + pl, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeNoTLSVerify), k) ExpectNoError(t, err.Error()) ExpectEqual(t, pl.Value.(bool), v) } } func TestBoolProxyLabelInvalid(t *testing.T) { - alias := "foo" - field := "no_tls_verify" - _, err := ParseLabel(makeLabel(NSProxy, alias, field), "invalid") + _, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeNoTLSVerify), "invalid") if !err.Is(E.ErrInvalid) { t.Errorf("Expected err InvalidProxyLabel, got %s", err.Error()) } } -func TestSetHeaderProxyLabelValid(t *testing.T) { - v := ` -X-Custom-Header1: foo, bar -X-Custom-Header1: baz -X-Custom-Header2: boo` - v = strings.TrimPrefix(v, "\n") - h := map[string]string{ - "X-Custom-Header1": "foo, bar, baz", - "X-Custom-Header2": "boo", - } +// func TestSetHeaderProxyLabelValid(t *testing.T) { +// v := ` +// X-Custom-Header1: foo, bar +// X-Custom-Header1: baz +// X-Custom-Header2: boo` +// v = strings.TrimPrefix(v, "\n") +// h := map[string]string{ +// "X-Custom-Header1": "foo, bar, baz", +// "X-Custom-Header2": "boo", +// } - pl, err := ParseLabel(makeLabel(NSProxy, "foo", "set_headers"), v) - ExpectNoError(t, err.Error()) - hGot := ExpectType[map[string]string](t, pl.Value) - if hGot != nil && !reflect.DeepEqual(h, hGot) { - t.Errorf("Expected %v, got %v", h, hGot) - } +// pl, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeSetHeaders), v) +// ExpectNoError(t, err.Error()) +// hGot := ExpectType[map[string]string](t, pl.Value) +// ExpectFalse(t, hGot == nil) +// ExpectDeepEqual(t, h, hGot) +// } -} +// func TestSetHeaderProxyLabelInvalid(t *testing.T) { +// tests := []string{ +// "X-Custom-Header1 = bar", +// "X-Custom-Header1", +// "- X-Custom-Header1", +// } -func TestSetHeaderProxyLabelInvalid(t *testing.T) { - tests := []string{ - "X-Custom-Header1 = bar", - "X-Custom-Header1", - "- X-Custom-Header1", - } - - for _, v := range tests { - _, err := ParseLabel(makeLabel(NSProxy, "foo", "set_headers"), v) - if !err.Is(E.ErrInvalid) { - t.Errorf("Expected invalid err for %q, got %s", v, err.Error()) - } - } -} - -func TestHideHeadersProxyLabel(t *testing.T) { - v := ` -- X-Custom-Header1 -- X-Custom-Header2 -- X-Custom-Header3 -` - v = strings.TrimPrefix(v, "\n") - pl, err := ParseLabel(makeLabel(NSProxy, "foo", "hide_headers"), v) - ExpectNoError(t, err.Error()) - sGot := ExpectType[[]string](t, pl.Value) - sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"} - if sGot != nil { - ExpectDeepEqual(t, sGot, sWant) - } -} - -// func TestCommaSepProxyLabelSingle(t *testing.T) { -// v := "a" -// pl, err := ParseLabel("proxy.aliases", v) -// ExpectNoError(t, err) -// sGot := ExpectType[[]string](t, pl.Value) -// sWant := []string{"a"} -// if sGot != nil { -// ExpectEqual(t, sGot, sWant) +// for _, v := range tests { +// _, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeSetHeaders), v) +// if !err.Is(E.ErrInvalid) { +// t.Errorf("Expected invalid err for %q, got %s", v, err.Error()) +// } // } // } -// func TestCommaSepProxyLabelMulti(t *testing.T) { -// v := "X-Custom-Header1, X-Custom-Header2,X-Custom-Header3" -// pl, err := ParseLabel("proxy.aliases", v) -// ExpectNoError(t, err) +// func TestHideHeadersProxyLabel(t *testing.T) { +// v := ` +// - X-Custom-Header1 +// - X-Custom-Header2 +// - X-Custom-Header3 +// ` +// v = strings.TrimPrefix(v, "\n") +// pl, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeHideHeaders), v) +// ExpectNoError(t, err.Error()) // sGot := ExpectType[[]string](t, pl.Value) // sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"} -// if sGot != nil { -// ExpectEqual(t, sGot, sWant) -// } +// ExpectFalse(t, sGot == nil) +// ExpectDeepEqual(t, sGot, sWant) // } diff --git a/src/docker/label_test.go b/src/docker/label_test.go new file mode 100644 index 00000000..045001ab --- /dev/null +++ b/src/docker/label_test.go @@ -0,0 +1,85 @@ +package docker + +import ( + "fmt" + "testing" + + U "github.com/yusing/go-proxy/utils" + . "github.com/yusing/go-proxy/utils/testing" +) + +func TestNestedLabel(t *testing.T) { + mName := "middleware1" + mAttr := "prop1" + v := "value1" + pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s.%s", ProxyAttributeMiddlewares, mName, mAttr)), v) + ExpectNoError(t, err.Error()) + sGot := ExpectType[*Label](t, pl.Value) + ExpectFalse(t, sGot == nil) + ExpectEqual(t, sGot.Namespace, mName) + ExpectEqual(t, sGot.Attribute, mAttr) +} + +func TestApplyNestedLabel(t *testing.T) { + entry := new(struct { + Middlewares NestedLabelMap `yaml:"middlewares"` + }) + mName := "middleware1" + mAttr := "prop1" + v := "value1" + pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s.%s", ProxyAttributeMiddlewares, mName, mAttr)), v) + ExpectNoError(t, err.Error()) + err = ApplyLabel(entry, pl) + ExpectNoError(t, err.Error()) + middleware1, ok := entry.Middlewares[mName] + ExpectTrue(t, ok) + got := ExpectType[string](t, middleware1[mAttr]) + ExpectEqual(t, got, v) +} + +func TestApplyNestedLabelExisting(t *testing.T) { + mName := "middleware1" + mAttr := "prop1" + v := "value1" + + checkAttr := "prop2" + checkV := "value2" + entry := new(struct { + Middlewares NestedLabelMap `yaml:"middlewares"` + }) + entry.Middlewares = make(NestedLabelMap) + entry.Middlewares[mName] = make(U.SerializedObject) + entry.Middlewares[mName][checkAttr] = checkV + + pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s.%s", ProxyAttributeMiddlewares, mName, mAttr)), v) + ExpectNoError(t, err.Error()) + err = ApplyLabel(entry, pl) + ExpectNoError(t, err.Error()) + middleware1, ok := entry.Middlewares[mName] + ExpectTrue(t, ok) + got := ExpectType[string](t, middleware1[mAttr]) + ExpectEqual(t, got, v) + + // check if prop2 is affected + ExpectFalse(t, middleware1[checkAttr] == nil) + got = ExpectType[string](t, middleware1[checkAttr]) + ExpectEqual(t, got, checkV) +} + +func TestApplyNestedLabelNoAttr(t *testing.T) { + mName := "middleware1" + v := "value1" + + entry := new(struct { + Middlewares NestedLabelMap `yaml:"middlewares"` + }) + entry.Middlewares = make(NestedLabelMap) + entry.Middlewares[mName] = make(U.SerializedObject) + + pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", ProxyAttributeMiddlewares, mName)), v) + ExpectNoError(t, err.Error()) + err = ApplyLabel(entry, pl) + ExpectNoError(t, err.Error()) + _, ok := entry.Middlewares[mName] + ExpectTrue(t, ok) +} diff --git a/src/error/builder.go b/src/error/builder.go index 1a44a091..6bc24429 100644 --- a/src/error/builder.go +++ b/src/error/builder.go @@ -10,9 +10,8 @@ type Builder struct { } type builder struct { - message string - errors []NestedError - severity Severity + message string + errors []NestedError sync.Mutex } @@ -40,11 +39,6 @@ func (b Builder) Addf(format string, args ...any) Builder { return b.Add(errorf(format, args...)) } -func (b Builder) WithSeverity(s Severity) Builder { - b.severity = s - return b -} - // Build builds a NestedError based on the errors collected in the Builder. // // If there are no errors in the Builder, it returns a Nil() NestedError. @@ -58,7 +52,7 @@ func (b Builder) Build() NestedError { } else if len(b.errors) == 1 { return b.errors[0] } - return Join(b.message, b.errors...).Severity(b.severity) + return Join(b.message, b.errors...) } func (b Builder) To(ptr *NestedError) { diff --git a/src/error/error.go b/src/error/error.go index c18386d0..fc2b246d 100644 --- a/src/error/error.go +++ b/src/error/error.go @@ -9,17 +9,10 @@ import ( type ( NestedError = *nestedError nestedError struct { - subject string - err error - extras []nestedError - severity Severity + subject string + err error + extras []nestedError } - Severity uint8 -) - -const ( - SeverityWarning Severity = iota - SeverityFatal ) func From(err error) NestedError { @@ -164,22 +157,6 @@ func (ne NestedError) Subjectf(format string, args ...any) NestedError { return ne } -func (ne NestedError) Severity(s Severity) NestedError { - if ne == nil { - return ne - } - ne.severity = s - return ne -} - -func (ne NestedError) Warn() NestedError { - if ne == nil { - return ne - } - ne.severity = SeverityWarning - return ne -} - func (ne NestedError) NoError() bool { return ne == nil } @@ -188,14 +165,6 @@ func (ne NestedError) HasError() bool { return ne != nil } -func (ne NestedError) IsFatal() bool { - return ne != nil && ne.severity == SeverityFatal -} - -func (ne NestedError) IsWarning() bool { - return ne != nil && ne.severity == SeverityWarning -} - func errorf(format string, args ...any) NestedError { return From(fmt.Errorf(format, args...)) } diff --git a/src/error/error_test.go b/src/error/error_test.go index 7e465af0..b46cb18c 100644 --- a/src/error/error_test.go +++ b/src/error/error_test.go @@ -31,11 +31,11 @@ func TestErrorNestedIs(t *testing.T) { err = Failure("some reason") ExpectTrue(t, err.Is(ErrFailure)) - ExpectFalse(t, err.Is(ErrAlreadyExist)) + ExpectFalse(t, err.Is(ErrDuplicated)) - err.With(AlreadyExist("something", "")) + err.With(Duplicated("something", "")) ExpectTrue(t, err.Is(ErrFailure)) - ExpectTrue(t, err.Is(ErrAlreadyExist)) + ExpectTrue(t, err.Is(ErrDuplicated)) ExpectFalse(t, err.Is(ErrInvalid)) } diff --git a/src/error/errors.go b/src/error/errors.go index c3bccb8e..aa0aa4ba 100644 --- a/src/error/errors.go +++ b/src/error/errors.go @@ -5,14 +5,14 @@ import ( ) var ( - ErrFailure = stderrors.New("failed") - ErrInvalid = stderrors.New("invalid") - ErrUnsupported = stderrors.New("unsupported") - ErrUnexpected = stderrors.New("unexpected") - ErrNotExists = stderrors.New("does not exist") - ErrMissing = stderrors.New("missing") - ErrAlreadyExist = stderrors.New("already exist") - ErrOutOfRange = stderrors.New("out of range") + ErrFailure = stderrors.New("failed") + ErrInvalid = stderrors.New("invalid") + ErrUnsupported = stderrors.New("unsupported") + ErrUnexpected = stderrors.New("unexpected") + ErrNotExists = stderrors.New("does not exist") + ErrMissing = stderrors.New("missing") + ErrDuplicated = stderrors.New("duplicated") + ErrOutOfRange = stderrors.New("out of range") ) const fmtSubjectWhat = "%w %v: %q" @@ -53,8 +53,8 @@ func Missing(subject any) NestedError { return errorf("%w %v", ErrMissing, subject) } -func AlreadyExist(subject, what any) NestedError { - return errorf("%v %w: %v", subject, ErrAlreadyExist, what) +func Duplicated(subject, what any) NestedError { + return errorf("%w %v: %v", ErrDuplicated, subject, what) } func OutOfRange(subject string, value any) NestedError { diff --git a/src/go.mod b/src/go.mod index 5d7980d9..c273d718 100644 --- a/src/go.mod +++ b/src/go.mod @@ -17,7 +17,7 @@ require ( require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect - github.com/cloudflare/cloudflare-go v0.104.0 // indirect + github.com/cloudflare/cloudflare-go v0.105.0 // indirect github.com/containerd/log v0.1.0 // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/go-connections v0.5.0 // indirect diff --git a/src/go.sum b/src/go.sum index e4b70538..2e7394cc 100644 --- a/src/go.sum +++ b/src/go.sum @@ -4,8 +4,8 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= -github.com/cloudflare/cloudflare-go v0.104.0 h1:R/lB0dZupaZbOgibAH/BRrkFbZ6Acn/WsKg2iX2xXuY= -github.com/cloudflare/cloudflare-go v0.104.0/go.mod h1:pfUQ4PIG4ISI0/Mmc21Bp86UnFU0ktmPf3iTgbSL+cM= +github.com/cloudflare/cloudflare-go v0.105.0 h1:yu2IatITLZ4dw7/byzRrlE5DfUvtub0k9CHZ5zBlj90= +github.com/cloudflare/cloudflare-go v0.105.0/go.mod h1:pfUQ4PIG4ISI0/Mmc21Bp86UnFU0ktmPf3iTgbSL+cM= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/src/http/modify_response_writer.go b/src/http/modify_response_writer.go new file mode 100644 index 00000000..8ba0d72c --- /dev/null +++ b/src/http/modify_response_writer.go @@ -0,0 +1,96 @@ +// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/response_modifier.go) +// Copyright (c) 2020-2024 Traefik Labs + +package http + +import ( + "bufio" + "fmt" + "net" + "net/http" +) + +type ModifyResponseFunc func(*http.Response) error +type ModifyResponseWriter struct { + w http.ResponseWriter + r *http.Request + + headerSent bool + code int + + modifier ModifyResponseFunc + modified bool + modifierErr error +} + +func NewModifyResponseWriter(w http.ResponseWriter, r *http.Request, f ModifyResponseFunc) *ModifyResponseWriter { + return &ModifyResponseWriter{ + w: w, + r: r, + modifier: f, + code: http.StatusOK, + } +} + +func (w *ModifyResponseWriter) WriteHeader(code int) { + if w.headerSent { + return + } + + if code >= http.StatusContinue && code < http.StatusOK { + w.w.WriteHeader(code) + } + + defer func() { + w.headerSent = true + w.code = code + }() + + if w.modifier == nil || w.modified { + w.w.WriteHeader(code) + return + } + + resp := http.Response{ + Header: w.w.Header(), + Request: w.r, + } + + if err := w.modifier(&resp); err != nil { + w.modifierErr = err + logger.Errorf("error modifying response: %s", err) + w.w.WriteHeader(http.StatusInternalServerError) + return + } + + w.modified = true + w.w.WriteHeader(code) +} + +func (w *ModifyResponseWriter) Header() http.Header { + return w.w.Header() +} + +func (w *ModifyResponseWriter) Write(b []byte) (int, error) { + w.WriteHeader(w.code) + if w.modifierErr != nil { + return 0, w.modifierErr + } + return w.w.Write(b) +} + +// Hijack hijacks the connection. +func (w *ModifyResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if h, ok := w.w.(http.Hijacker); ok { + return h.Hijack() + } + + return nil, nil, fmt.Errorf("not a hijacker: %T", w.w) +} + +// Flush sends any buffered data to the client. +func (w *ModifyResponseWriter) Flush() { + if flusher, ok := w.w.(http.Flusher); ok { + flusher.Flush() + } +} diff --git a/src/proxy/reverse_proxy_mod.go b/src/http/reverse_proxy_mod.go similarity index 86% rename from src/proxy/reverse_proxy_mod.go rename to src/http/reverse_proxy_mod.go index 6e75e57f..43810a3a 100644 --- a/src/proxy/reverse_proxy_mod.go +++ b/src/http/reverse_proxy_mod.go @@ -1,7 +1,13 @@ -package proxy +// Copyright 2011 The Go Authors. +// Modified from the Go project under the a BSD-style License (https://cs.opensource.google/go/go/+/refs/tags/go1.23.1:src/net/http/httputil/reverseproxy.go) +// https://cs.opensource.google/go/go/+/master:LICENSE -// A small mod on net/http/httputil/reverseproxy.go -// that doubled the performance +package http + +// This is a small mod on net/http/httputil/reverseproxy.go +// that boosts performance in some cases +// and compatible to other modules of this project +// Copyright (c) 2024 yusing import ( "context" @@ -52,6 +58,21 @@ type ProxyRequest struct { // r.SetXForwarded() // } func (r *ProxyRequest) SetXForwarded() { + clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr) + if err == nil { + r.Out.Header.Set("X-Forwarded-For", clientIP) + } else { + r.Out.Header.Del("X-Forwarded-For") + } + r.Out.Header.Set("X-Forwarded-Host", r.In.Host) + if r.In.TLS == nil { + r.Out.Header.Set("X-Forwarded-Proto", "http") + } else { + r.Out.Header.Set("X-Forwarded-Proto", "https") + } +} + +func (r *ProxyRequest) AddXForwarded() { clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr) if err == nil { prior := r.Out.Header["X-Forwarded-For"] @@ -104,28 +125,6 @@ type ReverseProxy struct { // If nil, http.DefaultTransport is used. Transport http.RoundTripper - // FlushInterval specifies the flush interval - // to flush to the client while copying the - // response body. - // If zero, no periodic flushing is done. - // A negative value means to flush immediately - // after each write to the client. - // The FlushInterval is ignored when ReverseProxy - // recognizes a response as a streaming response, or - // if its ContentLength is -1; for such responses, writes - // are flushed to the client immediately. - // FlushInterval time.Duration - - // ErrorLog specifies an optional logger for errors - // that occur when attempting to proxy the request. - // If nil, logging is done via the log package's standard logger. - // ErrorLog *log.Logger - - // BufferPool optionally specifies a buffer pool to - // get byte slices for use by io.CopyBuffer when - // copying HTTP response bodies. - // BufferPool BufferPool - // ModifyResponse is an optional function that modifies the // Response from the backend. It is called if the backend // returns a response at all, with any HTTP status code. @@ -208,36 +207,11 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) { // }, // } // -// TODO: headers in ModifyResponse -func NewReverseProxy(target *url.URL, transport http.RoundTripper, entry *ReverseProxyEntry) *ReverseProxy { - // check on init rather than on request - var setHeaders = func(r *http.Request) {} - var hideHeaders = func(r *http.Request) {} - if len(entry.SetHeaders) > 0 { - setHeaders = func(r *http.Request) { - h := entry.SetHeaders.Clone() - for k, vv := range h { - if k == "Host" { - r.Host = vv[0] - } else { - r.Header[k] = vv - } - } - } - } - if len(entry.HideHeaders) > 0 { - hideHeaders = func(r *http.Request) { - for _, k := range entry.HideHeaders { - r.Header.Del(k) - } - } - } + +func NewReverseProxy(target *url.URL, transport http.RoundTripper) *ReverseProxy { rp := &ReverseProxy{ Rewrite: func(pr *ProxyRequest) { rewriteRequestURL(pr.Out, target) - // pr.SetXForwarded() - setHeaders(pr.Out) - hideHeaders(pr.Out) }, Transport: transport, } rp.ServeHTTP = rp.serveHTTP @@ -256,6 +230,23 @@ func rewriteRequestURL(req *http.Request, target *url.URL) { } } +// Hop-by-hop headers. These are removed when sent to the backend. +// As of RFC 7230, hop-by-hop headers are required to appear in the +// Connection header field. These are the headers defined by the +// obsoleted RFC 2616 (section 13.5.1) and are used for backward +// compatibility. +var hopHeaders = []string{ + "Connection", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 + "Transfer-Encoding", + "Upgrade", +} + func copyHeader(dst, src http.Header) { for k, vv := range src { for _, v := range vv { @@ -331,12 +322,14 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { outreq.Close = false - reqUpType := upgradeType(outreq.Header) + reqUpType := UpgradeType(outreq.Header) if !IsPrint(reqUpType) { p.errorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType)) return } + RemoveHopByHopHeaders(outreq.Header) + // Issue 21096: tell backend applications that care about trailer support // that we support trailers. (We do, but we don't go out of our way to // advertise that unless the incoming client request thought it was worth @@ -458,16 +451,34 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { } } -func upgradeType(h http.Header) string { +func UpgradeType(h http.Header) string { if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") { return "" } return h.Get("Upgrade") } +// RemoveHopByHopHeaders removes hop-by-hop headers. +func RemoveHopByHopHeaders(h http.Header) { + // RFC 7230, section 6.1: Remove headers listed in the "Connection" header. + for _, f := range h["Connection"] { + for _, sf := range strings.Split(f, ",") { + if sf = textproto.TrimString(sf); sf != "" { + h.Del(sf) + } + } + } + // RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers. + // This behavior is superseded by the RFC 7230 Connection header, but + // preserve it for backwards compatibility. + for _, f := range hopHeaders { + h.Del(f) + } +} + func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { - reqUpType := upgradeType(req.Header) - resUpType := upgradeType(res.Header) + reqUpType := UpgradeType(req.Header) + resUpType := UpgradeType(res.Header) if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller. p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType)) } diff --git a/src/main.go b/src/main.go index 3dd7d06b..79de91fc 100755 --- a/src/main.go +++ b/src/main.go @@ -47,11 +47,10 @@ func main() { logrus.SetOutput(io.Discard) } else { logrus.SetFormatter(&logrus.TextFormatter{ - DisableSorting: true, - DisableLevelTruncation: true, - FullTimestamp: true, - ForceColors: true, - TimestampFormat: "01-02 15:04:05", + DisableSorting: true, + FullTimestamp: true, + ForceColors: true, + TimestampFormat: "01-02 15:04:05", }) } @@ -76,10 +75,11 @@ func main() { return } - cfg, err := config.Load() - if err.IsFatal() { - log.Fatal(err) + err := config.Load() + if err != nil { + logrus.Warn(err) } + cfg := config.GetConfig() switch args.Command { case common.CommandListConfigs: @@ -96,6 +96,10 @@ func main() { return } + if common.IsDebug { + printJSON(docker.GetRegisteredNamespaces()) + } + cfg.StartProxyProviders() if err.HasError() { @@ -116,10 +120,7 @@ func main() { if autocert != nil { ctx, cancel := context.WithCancel(context.Background()) - if err = autocert.Setup(ctx); err != nil && err.IsWarning() { - cancel() - l.Warn(err) - } else if err.IsFatal() { + if err = autocert.Setup(ctx); err != nil { l.Fatal(err) } else { onShutdown.Add(cancel) @@ -192,7 +193,7 @@ func funcName(f func()) string { } func printJSON(obj any) { - j, err := E.Check(json.Marshal(obj)) + j, err := E.Check(json.MarshalIndent(obj, "", " ")) if err.HasError() { logrus.Fatal(err) } diff --git a/src/models/config.go b/src/models/config.go index 5fb99a2c..aca18e9c 100644 --- a/src/models/config.go +++ b/src/models/config.go @@ -3,6 +3,7 @@ package model type Config struct { Providers ProxyProviders `yaml:",flow" json:"providers"` AutoCert AutoCertConfig `yaml:",flow" json:"autocert"` + MatchDomains []string `yaml:"match_domains" json:"match_domains"` TimeoutShutdown int `yaml:"timeout_shutdown" json:"timeout_shutdown"` RedirectToHTTPS bool `yaml:"redirect_to_https" json:"redirect_to_https"` } @@ -11,6 +12,6 @@ func DefaultConfig() *Config { return &Config{ Providers: ProxyProviders{}, TimeoutShutdown: 3, - RedirectToHTTPS: true, + RedirectToHTTPS: false, } } diff --git a/src/models/raw_entry.go b/src/models/raw_entry.go index 2fa54be8..e5971c2a 100644 --- a/src/models/raw_entry.go +++ b/src/models/raw_entry.go @@ -14,14 +14,13 @@ type ( RawEntry struct { // raw entry object before validation // loaded from docker labels or yaml file - Alias string `yaml:"-" json:"-"` - Scheme string `yaml:"scheme" json:"scheme"` - Host string `yaml:"host" json:"host"` - Port string `yaml:"port" json:"port"` - NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify"` // https proxy only - PathPatterns []string `yaml:"path_patterns" json:"path_patterns"` // http(s) proxy only - SetHeaders map[string]string `yaml:"set_headers" json:"set_headers"` // http(s) proxy only - HideHeaders []string `yaml:"hide_headers" json:"hide_headers"` // http(s) proxy only + Alias string `yaml:"-" json:"-"` + Scheme string `yaml:"scheme" json:"scheme"` + Host string `yaml:"host" json:"host"` + Port string `yaml:"port" json:"port"` + NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify"` // https proxy only + PathPatterns []string `yaml:"path_patterns" json:"path_patterns"` // http(s) proxy only + Middlewares D.NestedLabelMap `yaml:"middlewares" json:"middlewares"` /* Docker only */ *D.ProxyProperties `yaml:"-" json:"proxy_properties"` @@ -44,12 +43,16 @@ func (e *RawEntry) FillMissingFields() bool { if pp == "" { pp = strconv.Itoa(port) } - e.Scheme = "tcp" + if e.Scheme == "" { + e.Scheme = "tcp" + } } else if port, ok := ImageNamePortMap[e.ImageName]; ok { if pp == "" { pp = strconv.Itoa(port) } - e.Scheme = "http" + if e.Scheme == "" { + e.Scheme = "http" + } } else if pp == "" && e.Scheme == "https" { pp = "443" } else if pp == "" { diff --git a/src/proxy/entry.go b/src/proxy/entry.go index 80873b44..c752e796 100644 --- a/src/proxy/entry.go +++ b/src/proxy/entry.go @@ -2,10 +2,10 @@ package proxy import ( "fmt" - "net/http" "net/url" "time" + D "github.com/yusing/go-proxy/docker" E "github.com/yusing/go-proxy/error" M "github.com/yusing/go-proxy/models" T "github.com/yusing/go-proxy/proxy/fields" @@ -18,8 +18,7 @@ type ( URL *url.URL NoTLSVerify bool PathPatterns T.PathPatterns - SetHeaders http.Header - HideHeaders []string + Middlewares D.NestedLabelMap /* Docker only */ IdleTimeout time.Duration @@ -78,9 +77,6 @@ func validateRPEntry(m *M.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry pathPatterns, err := T.ValidatePathPatterns(m.PathPatterns) b.Add(err) - setHeaders, err := T.ValidateHTTPHeaders(m.SetHeaders) - b.Add(err) - url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port))) b.Add(err) @@ -111,8 +107,7 @@ func validateRPEntry(m *M.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry URL: url, NoTLSVerify: m.NoTLSVerify, PathPatterns: pathPatterns, - SetHeaders: setHeaders, - HideHeaders: m.HideHeaders, + Middlewares: m.Middlewares, IdleTimeout: idleTimeout, WakeTimeout: wakeTimeout, StopMethod: stopMethod, diff --git a/src/proxy/provider/docker.go b/src/proxy/provider/docker.go index d75891c5..d16f764d 100755 --- a/src/proxy/provider/docker.go +++ b/src/proxy/provider/docker.go @@ -4,7 +4,9 @@ import ( "fmt" "regexp" "strconv" + "strings" + "github.com/sirupsen/logrus" D "github.com/yusing/go-proxy/docker" E "github.com/yusing/go-proxy/error" M "github.com/yusing/go-proxy/models" @@ -17,6 +19,7 @@ type DockerProvider struct { } var AliasRefRegex = regexp.MustCompile(`#\d+`) +var AliasRefRegexOld = regexp.MustCompile(`\$\d+`) func DockerProviderImpl(dockerHost string) (ProviderImpl, E.NestedError) { hostname, err := D.ParseDockerHostname(dockerHost) @@ -152,6 +155,20 @@ func (p *DockerProvider) applyLabel(container D.Container, entries M.RawEntries, b := E.NewBuilder("errors in label %s", key) defer b.To(&res) + refErr := E.NewBuilder("errors parsing alias references") + replaceIndexRef := func(ref string) string { + index, err := strconv.Atoi(ref[1:]) + if err != nil { + refErr.Add(E.Invalid("integer", ref)) + return ref + } + if index < 1 || index > len(container.Aliases) { + refErr.Add(E.OutOfRange("index", ref)) + return ref + } + return container.Aliases[index-1] + } + lbl, err := D.ParseLabel(key, val) if err.HasError() { b.Add(err.Subject(key)) @@ -163,22 +180,14 @@ func (p *DockerProvider) applyLabel(container D.Container, entries M.RawEntries, // apply label for all aliases entries.RangeAll(func(a string, e *M.RawEntry) { if err = D.ApplyLabel(e, lbl); err.HasError() { - b.Add(err.Subject(lbl.Target)) + b.Add(err.Subjectf("alias %s", lbl.Target)) } }) } else { - refErr := E.NewBuilder("errors parsing alias references") - lbl.Target = AliasRefRegex.ReplaceAllStringFunc(lbl.Target, func(ref string) string { - index, err := strconv.Atoi(ref[1:]) - if err != nil { - refErr.Add(E.Invalid("integer", ref)) - return ref - } - if index < 1 || index > len(container.Aliases) { - refErr.Add(E.OutOfRange("index", ref)) - return ref - } - return container.Aliases[index-1] + lbl.Target = AliasRefRegex.ReplaceAllStringFunc(lbl.Target, replaceIndexRef) + lbl.Target = AliasRefRegexOld.ReplaceAllStringFunc(lbl.Target, func(s string) string { + logrus.Warnf("%q should now be %q, old syntax will be removed in a future version", lbl, strings.ReplaceAll(lbl.String(), "$", "#")) + return replaceIndexRef(s) }) if refErr.HasError() { b.Add(refErr.Build()) @@ -190,7 +199,7 @@ func (p *DockerProvider) applyLabel(container D.Container, entries M.RawEntries, return } if err = D.ApplyLabel(config, lbl); err.HasError() { - b.Add(err.Subject(lbl.Target)) + b.Add(err.Subjectf("alias %s", lbl.Target)) } } return diff --git a/src/proxy/provider/docker_test.go b/src/proxy/provider/docker_test.go index a9c285ac..9e0185ae 100644 --- a/src/proxy/provider/docker_test.go +++ b/src/proxy/provider/docker_test.go @@ -132,7 +132,8 @@ func TestApplyLabel(t *testing.T) { ExpectEqual(t, b.Scheme, "http") ExpectEqual(t, b.Port, "1234") ExpectEqual(t, c.Scheme, "https") - ExpectEqual(t, c.Port, "1111") + // map does not necessary follow the order above + ExpectEqualAny(t, c.Port, []string{"1111", "1234"}) } func TestApplyLabelWithRef(t *testing.T) { @@ -142,9 +143,9 @@ func TestApplyLabelWithRef(t *testing.T) { Labels: map[string]string{ D.LabelAliases: "a,b,c", "proxy.#1.host": "localhost", - "proxy.*.port": "1111", "proxy.#1.port": "4444", "proxy.#2.port": "9999", + "proxy.#3.port": "1111", "proxy.#3.scheme": "https", }, Ports: []types.Port{ diff --git a/src/route/http.go b/src/route/http.go index b8533764..5c38d694 100755 --- a/src/route/http.go +++ b/src/route/http.go @@ -1,20 +1,20 @@ package route import ( - "crypto/tls" - "net" "sync" - "time" "net/http" "net/url" "strings" "github.com/sirupsen/logrus" + "github.com/yusing/go-proxy/common" "github.com/yusing/go-proxy/docker/idlewatcher" E "github.com/yusing/go-proxy/error" + . "github.com/yusing/go-proxy/http" P "github.com/yusing/go-proxy/proxy" PT "github.com/yusing/go-proxy/proxy/fields" + "github.com/yusing/go-proxy/route/middleware" F "github.com/yusing/go-proxy/utils/functional" ) @@ -26,7 +26,7 @@ type ( entry *P.ReverseProxyEntry mux *http.ServeMux - handler *P.ReverseProxy + handler *ReverseProxy regIdleWatcher func() E.NestedError unregIdleWatcher func() @@ -36,18 +36,41 @@ type ( SubdomainKey = PT.Alias ) +var ( + findMuxFunc = findMuxAnyDomain + + httpRoutes = F.NewMapOf[SubdomainKey, *HTTPRoute]() + httpRoutesMu sync.Mutex + globalMux = http.NewServeMux() // TODO: support regex subdomain matching +) + +func SetFindMuxDomains(domains []string) { + if len(domains) == 0 { + findMuxFunc = findMuxAnyDomain + } else { + findMuxFunc = findMuxByDomain(domains) + } +} + func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) { var trans *http.Transport var regIdleWatcher func() E.NestedError var unregIdleWatcher func() if entry.NoTLSVerify { - trans = transportNoTLS.Clone() + trans = common.DefaultTransportNoTLS.Clone() } else { - trans = transport.Clone() + trans = common.DefaultTransport.Clone() } - rp := P.NewReverseProxy(entry.URL, trans, entry) + rp := NewReverseProxy(entry.URL, trans) + + if len(entry.Middlewares) > 0 { + err := middleware.PatchReverseProxy(rp, entry.Middlewares) + if err != nil { + return nil, err + } + } if entry.UseIdleWatcher() { // allow time for response header up to `WakeTimeout` @@ -74,7 +97,7 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) { _, exists := httpRoutes.Load(entry.Alias) if exists { - return nil, E.AlreadyExist("HTTPRoute alias", entry.Alias) + return nil, E.Duplicated("HTTPRoute alias", entry.Alias) } r := &HTTPRoute{ @@ -94,11 +117,16 @@ func (r *HTTPRoute) String() string { } func (r *HTTPRoute) Start() E.NestedError { + if r.mux != nil { + return nil + } + httpRoutesMu.Lock() defer httpRoutesMu.Unlock() if r.regIdleWatcher != nil { if err := r.regIdleWatcher(); err.HasError() { + r.unregIdleWatcher = nil return err } } @@ -113,6 +141,10 @@ func (r *HTTPRoute) Start() E.NestedError { } func (r *HTTPRoute) Stop() E.NestedError { + if r.mux == nil { + return nil + } + httpRoutesMu.Lock() defer httpRoutesMu.Unlock() @@ -135,7 +167,7 @@ func (u *URL) MarshalText() (text []byte, err error) { } func ProxyHandler(w http.ResponseWriter, r *http.Request) { - mux, err := findMux(r.Host) + mux, err := findMuxFunc(r.Host) if err != nil { err = E.Failure("request"). Subjectf("%s %s%s", r.Method, r.Host, r.URL.Path). @@ -147,7 +179,7 @@ func ProxyHandler(w http.ResponseWriter, r *http.Request) { mux.ServeHTTP(w, r) } -func findMux(host string) (*http.ServeMux, E.NestedError) { +func findMuxAnyDomain(host string) (*http.ServeMux, E.NestedError) { hostSplit := strings.Split(host, ".") n := len(hostSplit) if n <= 2 { @@ -160,23 +192,21 @@ func findMux(host string) (*http.ServeMux, E.NestedError) { return nil, E.NotExist("route", sd) } -var ( - defaultDialer = net.Dialer{ - Timeout: 60 * time.Second, - KeepAlive: 60 * time.Second, +func findMuxByDomain(domains []string) func(host string) (*http.ServeMux, E.NestedError) { + return func(host string) (*http.ServeMux, E.NestedError) { + var subdomain string + for _, domain := range domains { + subdomain = strings.TrimSuffix(subdomain, domain) + if subdomain != domain { + break + } + } + if subdomain == "" { // not matched + return nil, E.Invalid("host", host) + } + if r, ok := httpRoutes.Load(PT.Alias(subdomain)); ok { + return r.mux, nil + } + return nil, E.NotExist("route", subdomain) } - transport = &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: defaultDialer.DialContext, - MaxIdleConnsPerHost: 1000, - } - transportNoTLS = func() *http.Transport { - var clone = transport.Clone() - clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - return clone - }() - - httpRoutes = F.NewMapOf[SubdomainKey, *HTTPRoute]() - httpRoutesMu sync.Mutex - globalMux = http.NewServeMux() -) +} diff --git a/src/route/middleware/add_x_forwarded.go b/src/route/middleware/add_x_forwarded.go deleted file mode 100644 index bc8a25d6..00000000 --- a/src/route/middleware/add_x_forwarded.go +++ /dev/null @@ -1,7 +0,0 @@ -package middleware - -var AddXForwarded = &Middleware{ - rewrite: func(r *ProxyRequest) { - r.SetXForwarded() - }, -} diff --git a/src/route/middleware/forward_auth.go b/src/route/middleware/forward_auth.go new file mode 100644 index 00000000..0276da08 --- /dev/null +++ b/src/route/middleware/forward_auth.go @@ -0,0 +1,249 @@ +// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/auth/forward.go) +// Copyright (c) 2020-2024 Traefik Labs +// Copyright (c) 2024 yusing + +package middleware + +import ( + "io" + "net" + "net/http" + "net/url" + "slices" + "strings" + "time" + + "github.com/sirupsen/logrus" + "github.com/yusing/go-proxy/common" + D "github.com/yusing/go-proxy/docker" + E "github.com/yusing/go-proxy/error" + gpHTTP "github.com/yusing/go-proxy/http" + U "github.com/yusing/go-proxy/utils" +) + +type ( + forwardAuth struct { + *forwardAuthOpts + m *Middleware + client http.Client + } + forwardAuthOpts struct { + Address string + TrustForwardHeader bool + AuthResponseHeaders []string + AddAuthCookiesToResponse []string + } +) + +const ( + xForwardedFor = "X-Forwarded-For" + xForwardedMethod = "X-Forwarded-Method" + xForwardedHost = "X-Forwarded-Host" + xForwardedProto = "X-Forwarded-Proto" + xForwardedURI = "X-Forwarded-Uri" + xForwardedPort = "X-Forwarded-Port" +) + +var ForwardAuth = newForwardAuth() +var faLogger = logrus.WithField("middleware", "ForwardAuth") + +func newForwardAuth() (fa *forwardAuth) { + fa = new(forwardAuth) + fa.m = new(Middleware) + fa.m.labelParserMap = D.ValueParserMap{ + "trust_forward_header": D.BoolParser, + "auth_response_headers": D.YamlStringListParser, + "add_auth_cookies_to_response": D.YamlStringListParser, + } + fa.m.withOptions = func(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) { + tr, ok := rp.Transport.(*http.Transport) + if ok { + tr = tr.Clone() + } else { + tr = common.DefaultTransport.Clone() + } + + faWithOpts := new(forwardAuth) + faWithOpts.forwardAuthOpts = new(forwardAuthOpts) + faWithOpts.client = http.Client{ + CheckRedirect: func(r *Request, via []*Request) error { + return http.ErrUseLastResponse + }, + Timeout: 30 * time.Second, + Transport: tr, + } + faWithOpts.m = &Middleware{ + impl: faWithOpts, + before: fa.forward, + } + + err := U.Deserialize(optsRaw, faWithOpts.forwardAuthOpts) + if err != nil { + return nil, E.FailWith("set options", err) + } + _, err = E.Check(url.Parse(faWithOpts.Address)) + if err != nil { + return nil, E.Invalid("address", faWithOpts.Address) + } + return faWithOpts.m, nil + } + return +} + +func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request) { + removeHop(req.Header) + + faReq, err := http.NewRequestWithContext( + req.Context(), + http.MethodGet, + fa.Address, + nil, + ) + if err != nil { + faLogger.Debugf("new request err to %s: %s", fa.Address, err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + copyHeader(faReq.Header, req.Header) + removeHop(faReq.Header) + + filterHeaders(faReq.Header, fa.AuthResponseHeaders) + fa.setAuthHeaders(req, faReq) + + faResp, err := fa.client.Do(faReq) + if err != nil { + faLogger.Debugf("failed to call %s: %s", fa.Address, err) + w.WriteHeader(http.StatusInternalServerError) + return + } + defer faResp.Body.Close() + + body, err := io.ReadAll(faResp.Body) + if err != nil { + faLogger.Debugf("failed to read response body from %s: %s", fa.Address, err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices { + copyHeader(w.Header(), faResp.Header) + removeHop(w.Header()) + + redirectURL, err := faResp.Location() + if err != nil { + faLogger.Debugf("failed to get location from %s: %s", fa.Address, err) + w.WriteHeader(http.StatusInternalServerError) + return + } else if redirectURL.String() != "" { + w.Header().Set("Location", redirectURL.String()) + } + + w.WriteHeader(faResp.StatusCode) + + if _, err = w.Write(body); err != nil { + faLogger.Debugf("failed to write response body from %s: %s", fa.Address, err) + } + return + } + + for _, key := range fa.AuthResponseHeaders { + key := http.CanonicalHeaderKey(key) + req.Header.Del(key) + if len(faResp.Header[key]) > 0 { + req.Header[key] = append([]string(nil), faResp.Header[key]...) + } + } + + req.RequestURI = req.URL.RequestURI() + + authCookies := faResp.Cookies() + + if len(authCookies) == 0 { + next.ServeHTTP(w, req) + return + } + + next.ServeHTTP(gpHTTP.NewModifyResponseWriter(w, req, func(resp *Response) error { + fa.setAuthCookies(resp, authCookies) + return nil + }), req) +} + +func (fa *forwardAuth) setAuthCookies(resp *Response, authCookies []*Cookie) { + if len(fa.AddAuthCookiesToResponse) == 0 { + return + } + + cookies := resp.Cookies() + resp.Header.Del("Set-Cookie") + + for _, cookie := range cookies { + if !slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) { + // this cookie is not an auth cookie, so add it back + resp.Header.Add("Set-Cookie", cookie.String()) + } + } + + for _, cookie := range authCookies { + if slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) { + // this cookie is an auth cookie, so add to resp + resp.Header.Add("Set-Cookie", cookie.String()) + } + } +} + +func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) { + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + if fa.TrustForwardHeader { + if prior, ok := req.Header[xForwardedFor]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + } + faReq.Header.Set(xForwardedFor, clientIP) + } + + xMethod := req.Header.Get(xForwardedMethod) + switch { + case xMethod != "" && fa.TrustForwardHeader: + faReq.Header.Set(xForwardedMethod, xMethod) + case req.Method != "": + faReq.Header.Set(xForwardedMethod, req.Method) + default: + faReq.Header.Del(xForwardedMethod) + } + + xfp := req.Header.Get(xForwardedProto) + switch { + case xfp != "" && fa.TrustForwardHeader: + faReq.Header.Set(xForwardedProto, xfp) + case req.TLS != nil: + faReq.Header.Set(xForwardedProto, "https") + default: + faReq.Header.Set(xForwardedProto, "http") + } + + if xfp := req.Header.Get(xForwardedPort); xfp != "" && fa.TrustForwardHeader { + faReq.Header.Set(xForwardedPort, xfp) + } + + xfh := req.Header.Get(xForwardedHost) + switch { + case xfh != "" && fa.TrustForwardHeader: + faReq.Header.Set(xForwardedHost, xfh) + case req.Host != "": + faReq.Header.Set(xForwardedHost, req.Host) + default: + faReq.Header.Del(xForwardedHost) + } + + xfURI := req.Header.Get(xForwardedURI) + switch { + case xfURI != "" && fa.TrustForwardHeader: + faReq.Header.Set(xForwardedURI, xfURI) + case req.URL.RequestURI() != "": + faReq.Header.Set(xForwardedURI, req.URL.RequestURI()) + default: + faReq.Header.Del(xForwardedURI) + } +} diff --git a/src/route/middleware/headers.go b/src/route/middleware/headers.go new file mode 100644 index 00000000..b12b134a --- /dev/null +++ b/src/route/middleware/headers.go @@ -0,0 +1,44 @@ +package middleware + +import ( + "net/http" + "slices" + + gpHTTP "github.com/yusing/go-proxy/http" +) + +func removeHop(h Header) { + reqUpType := gpHTTP.UpgradeType(h) + gpHTTP.RemoveHopByHopHeaders(h) + + if reqUpType != "" { + h.Set("Connection", "Upgrade") + h.Set("Upgrade", reqUpType) + } else { + h.Del("Connection") + } +} + +func copyHeader(dst, src Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +func filterHeaders(h Header, allowed []string) { + if allowed == nil { + return + } + + for i := range allowed { + allowed[i] = http.CanonicalHeaderKey(allowed[i]) + } + + for key := range h { + if !slices.Contains(allowed, key) { + h.Del(key) + } + } +} diff --git a/src/route/middleware/middleware.go b/src/route/middleware/middleware.go index 5aedd922..fae2d0c6 100644 --- a/src/route/middleware/middleware.go +++ b/src/route/middleware/middleware.go @@ -3,33 +3,42 @@ package middleware import ( "net/http" + D "github.com/yusing/go-proxy/docker" E "github.com/yusing/go-proxy/error" - P "github.com/yusing/go-proxy/proxy" + gpHTTP "github.com/yusing/go-proxy/http" ) type ( - ReverseProxy = P.ReverseProxy - ProxyRequest = P.ProxyRequest + Error = E.NestedError + + ReverseProxy = gpHTTP.ReverseProxy + ProxyRequest = gpHTTP.ProxyRequest Request = http.Request Response = http.Response ResponseWriter = http.ResponseWriter + Header = http.Header + Cookie = http.Cookie - BeforeFunc func(w ResponseWriter, r *Request) (continue_ bool) + BeforeFunc func(next http.Handler, w ResponseWriter, r *Request) RewriteFunc func(req *ProxyRequest) ModifyResponseFunc func(res *Response) error + CloneWithOptFunc func(opts OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) - MiddlewareOptionsRaw map[string]string - MiddlewareOptions map[string]interface{} + OptionsRaw = map[string]any + Options any Middleware struct { name string - before BeforeFunc - rewrite RewriteFunc - modifyResponse ModifyResponseFunc + before BeforeFunc // runs before ReverseProxy.ServeHTTP + rewrite RewriteFunc // runs after ReverseProxy.Rewrite + modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse - options MiddlewareOptions - validateOptions func(opts MiddlewareOptionsRaw) (MiddlewareOptions, E.NestedError) + transport http.RoundTripper + + withOptions CloneWithOptFunc + labelParserMap D.ValueParserMap + impl any } ) @@ -41,41 +50,32 @@ func (m *Middleware) String() string { return m.name } -func (m *Middleware) WithOptions(optsRaw MiddlewareOptionsRaw) (*Middleware, E.NestedError) { - if len(optsRaw) == 0 { - return m, nil - } - - var opts MiddlewareOptions - var err E.NestedError - - if m.validateOptions != nil { - if opts, err = m.validateOptions(optsRaw); err != nil { +func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) { + if len(optsRaw) != 0 && m.withOptions != nil { + if mWithOpt, err := m.withOptions(optsRaw, rp); err != nil { return nil, err + } else { + return mWithOpt, nil } } - return &Middleware{ - name: m.name, - before: m.before, - rewrite: m.rewrite, - modifyResponse: m.modifyResponse, - options: opts, - }, nil + // WithOptionsClone is called only once + // set withOptions and labelParser will not be used after that + return &Middleware{m.name, m.before, m.rewrite, m.modifyResponse, m.transport, nil, nil, m.impl}, nil } -// TODO: check conflict -func PatchReverseProxy(rp ReverseProxy, middlewares map[string]MiddlewareOptionsRaw) (out ReverseProxy, err E.NestedError) { - out = rp - +// TODO: check conflict or duplicates +func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res E.NestedError) { befores := make([]BeforeFunc, 0, len(middlewares)) rewrites := make([]RewriteFunc, 0, len(middlewares)) modifyResponses := make([]ModifyResponseFunc, 0, len(middlewares)) invalidM := E.NewBuilder("invalid middlewares") invalidOpts := E.NewBuilder("invalid options") - defer invalidM.Add(invalidOpts.Build()) - defer invalidM.To(&err) + defer func() { + invalidM.Add(invalidOpts.Build()) + invalidM.To(&res) + }() for name, opts := range middlewares { m, ok := Get(name) @@ -83,7 +83,8 @@ func PatchReverseProxy(rp ReverseProxy, middlewares map[string]MiddlewareOptions invalidM.Addf("%s", name) continue } - m, err = m.WithOptions(opts) + + m, err := m.WithOptionsClone(opts, rp) if err != nil { invalidOpts.Add(err.Subject(name)) continue @@ -103,25 +104,37 @@ func PatchReverseProxy(rp ReverseProxy, middlewares map[string]MiddlewareOptions return } - if len(befores) > 0 { - rp.ServeHTTP = func(w ResponseWriter, r *Request) { - for _, before := range befores { - if !before(w, r) { - return - } + origServeHTTP := rp.ServeHTTP + for i, before := range befores { + if i < len(befores)-1 { + rp.ServeHTTP = func(w ResponseWriter, r *Request) { + before(rp.ServeHTTP, w, r) + } + } else { + rp.ServeHTTP = func(w ResponseWriter, r *Request) { + before(origServeHTTP, w, r) } - rp.ServeHTTP(w, r) } } + if len(rewrites) > 0 { + origRewrite := rp.Rewrite rp.Rewrite = func(req *ProxyRequest) { + if origRewrite != nil { + origRewrite(req) + } for _, rewrite := range rewrites { rewrite(req) } } } + if len(modifyResponses) > 0 { + origModifyResponse := rp.ModifyResponse rp.ModifyResponse = func(res *Response) error { + if origModifyResponse != nil { + return origModifyResponse(res) + } for _, modifyResponse := range modifyResponses { if err := modifyResponse(res); err != nil { return err diff --git a/src/route/middleware/middlewares.go b/src/route/middleware/middlewares.go index 3c32ffd7..ab21daa5 100644 --- a/src/route/middleware/middlewares.go +++ b/src/route/middleware/middlewares.go @@ -3,14 +3,11 @@ package middleware import ( "fmt" "strings" + + D "github.com/yusing/go-proxy/docker" ) -var middlewares = map[string]*Middleware{ - "set_x_forwarded": SetXForwarded, // nginx - "add_x_forwarded": AddXForwarded, // nginx - "trust_forward_header": AddXForwarded, // traefik alias - "redirect_http": RedirectHTTP, -} +var middlewares map[string]*Middleware func Get(name string) (middleware *Middleware, ok bool) { middleware, ok = middlewares[name] @@ -18,10 +15,23 @@ func Get(name string) (middleware *Middleware, ok bool) { } // initialize middleware names -var _ = func() (_ bool) { +func init() { + middlewares = map[string]*Middleware{ + "set_x_forwarded": SetXForwarded, + "add_x_forwarded": AddXForwarded, + "redirect_http": RedirectHTTP, + "forward_auth": ForwardAuth.m, + "modify_response": ModifyResponse.m, + "modify_request": ModifyRequest.m, + } names := make(map[*Middleware][]string) for name, m := range middlewares { names[m] = append(names[m], name) + // register middleware name to docker label parsr + // in order to parse middleware_name.option=value into correct type + if m.labelParserMap != nil { + D.RegisterNamespace(name, m.labelParserMap) + } } for m, names := range names { if len(names) > 1 { @@ -30,5 +40,4 @@ var _ = func() (_ bool) { m.name = names[0] } } - return -}() +} diff --git a/src/route/middleware/modify_request.go b/src/route/middleware/modify_request.go new file mode 100644 index 00000000..28994959 --- /dev/null +++ b/src/route/middleware/modify_request.go @@ -0,0 +1,58 @@ +package middleware + +import ( + D "github.com/yusing/go-proxy/docker" + E "github.com/yusing/go-proxy/error" + U "github.com/yusing/go-proxy/utils" +) + +type ( + modifyRequest struct { + *modifyRequestOpts + m *Middleware + } + // order: set_headers -> add_headers -> hide_headers + modifyRequestOpts struct { + SetHeaders map[string]string + AddHeaders map[string]string + HideHeaders []string + } +) + +var ModifyRequest = newModifyRequest() + +func newModifyRequest() (mr *modifyRequest) { + mr = new(modifyRequest) + mr.m = new(Middleware) + mr.m.labelParserMap = D.ValueParserMap{ + "set_headers": D.YamlLikeMappingParser(true), + "add_headers": D.YamlLikeMappingParser(true), + "hide_headers": D.YamlStringListParser, + } + mr.m.withOptions = func(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) { + mrWithOpts := new(modifyRequest) + mrWithOpts.m = &Middleware{ + impl: mrWithOpts, + rewrite: mrWithOpts.modifyRequest, + } + mrWithOpts.modifyRequestOpts = new(modifyRequestOpts) + err := U.Deserialize(optsRaw, mrWithOpts.modifyRequestOpts) + if err != nil { + return nil, E.FailWith("set options", err) + } + return mrWithOpts.m, nil + } + return +} + +func (mr *modifyRequest) modifyRequest(req *ProxyRequest) { + for k, v := range mr.SetHeaders { + req.Out.Header.Set(k, v) + } + for k, v := range mr.AddHeaders { + req.Out.Header.Add(k, v) + } + for _, k := range mr.HideHeaders { + req.Out.Header.Del(k) + } +} diff --git a/src/route/middleware/modify_request_test.go b/src/route/middleware/modify_request_test.go new file mode 100644 index 00000000..c925464f --- /dev/null +++ b/src/route/middleware/modify_request_test.go @@ -0,0 +1,34 @@ +package middleware + +import ( + "slices" + "testing" + + . "github.com/yusing/go-proxy/utils/testing" +) + +func TestSetModifyRequest(t *testing.T) { + opts := OptionsRaw{ + "set_headers": map[string]string{"User-Agent": "go-proxy/v0.5.0"}, + "add_headers": map[string]string{"Accept-Encoding": "test-value"}, + "hide_headers": []string{"Accept"}, + } + + t.Run("set_options", func(t *testing.T) { + mr, err := ModifyRequest.m.WithOptionsClone(opts, nil) + ExpectNoError(t, err.Error()) + ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string)) + ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string)) + ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string)) + }) + + t.Run("request_headers", func(t *testing.T) { + result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{ + middlewareOpt: opts, + }) + ExpectNoError(t, err.Error()) + ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0") + ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value")) + ExpectEqual(t, result.RequestHeaders.Get("Accept"), "") + }) +} diff --git a/src/route/middleware/modify_response.go b/src/route/middleware/modify_response.go new file mode 100644 index 00000000..1509b82a --- /dev/null +++ b/src/route/middleware/modify_response.go @@ -0,0 +1,61 @@ +package middleware + +import ( + "net/http" + + D "github.com/yusing/go-proxy/docker" + E "github.com/yusing/go-proxy/error" + U "github.com/yusing/go-proxy/utils" +) + +type ( + modifyResponse struct { + *modifyResponseOpts + m *Middleware + } + // order: set_headers -> add_headers -> hide_headers + modifyResponseOpts struct { + SetHeaders map[string]string + AddHeaders map[string]string + HideHeaders []string + } +) + +var ModifyResponse = newModifyResponse() + +func newModifyResponse() (mr *modifyResponse) { + mr = new(modifyResponse) + mr.m = new(Middleware) + mr.m.labelParserMap = D.ValueParserMap{ + "set_headers": D.YamlLikeMappingParser(true), + "add_headers": D.YamlLikeMappingParser(true), + "hide_headers": D.YamlStringListParser, + } + mr.m.withOptions = func(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) { + mrWithOpts := new(modifyResponse) + mrWithOpts.m = &Middleware{ + impl: mrWithOpts, + modifyResponse: mrWithOpts.modifyResponse, + } + mrWithOpts.modifyResponseOpts = new(modifyResponseOpts) + err := U.Deserialize(optsRaw, mrWithOpts.modifyResponseOpts) + if err != nil { + return nil, E.FailWith("set options", err) + } + return mrWithOpts.m, nil + } + return +} + +func (mr *modifyResponse) modifyResponse(resp *http.Response) error { + for k, v := range mr.SetHeaders { + resp.Header.Set(k, v) + } + for k, v := range mr.AddHeaders { + resp.Header.Add(k, v) + } + for _, k := range mr.HideHeaders { + resp.Header.Del(k) + } + return nil +} diff --git a/src/route/middleware/modify_response_test.go b/src/route/middleware/modify_response_test.go new file mode 100644 index 00000000..c63a1eb8 --- /dev/null +++ b/src/route/middleware/modify_response_test.go @@ -0,0 +1,35 @@ +package middleware + +import ( + "slices" + "testing" + + . "github.com/yusing/go-proxy/utils/testing" +) + +func TestSetModifyResponse(t *testing.T) { + opts := OptionsRaw{ + "set_headers": map[string]string{"User-Agent": "go-proxy/v0.5.0"}, + "add_headers": map[string]string{"Accept-Encoding": "test-value"}, + "hide_headers": []string{"Accept"}, + } + + t.Run("set_options", func(t *testing.T) { + mr, err := ModifyResponse.m.WithOptionsClone(opts, nil) + ExpectNoError(t, err.Error()) + ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string)) + ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string)) + ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string)) + }) + + t.Run("request_headers", func(t *testing.T) { + result, err := newMiddlewareTest(ModifyResponse.m, &testArgs{ + middlewareOpt: opts, + }) + ExpectNoError(t, err.Error()) + ExpectEqual(t, result.ResponseHeaders.Get("User-Agent"), "go-proxy/v0.5.0") + t.Log(result.ResponseHeaders.Get("Accept-Encoding")) + ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value")) + ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "") + }) +} diff --git a/src/route/middleware/redirect_http.go b/src/route/middleware/redirect_http.go index 613a2b0a..e50685ed 100644 --- a/src/route/middleware/redirect_http.go +++ b/src/route/middleware/redirect_http.go @@ -7,14 +7,13 @@ import ( ) var RedirectHTTP = &Middleware{ - before: func(w ResponseWriter, r *Request) (continue_ bool) { + before: func(next http.Handler, w ResponseWriter, r *Request) { if r.TLS == nil { r.URL.Scheme = "https" - r.URL.Host = r.URL.Hostname() + common.ProxyHTTPSPort + r.URL.Host = r.URL.Hostname() + ":" + common.ProxyHTTPSPort http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect) - } else { - continue_ = true + return } - return + next.ServeHTTP(w, r) }, } diff --git a/src/route/middleware/redirect_http_test.go b/src/route/middleware/redirect_http_test.go new file mode 100644 index 00000000..18b77b1b --- /dev/null +++ b/src/route/middleware/redirect_http_test.go @@ -0,0 +1,26 @@ +package middleware + +import ( + "net/http" + "testing" + + "github.com/yusing/go-proxy/common" + . "github.com/yusing/go-proxy/utils/testing" +) + +func TestRedirectToHTTPs(t *testing.T) { + result, err := newMiddlewareTest(RedirectHTTP, &testArgs{ + scheme: "http", + }) + ExpectNoError(t, err.Error()) + ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect) + ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://"+testHost+":"+common.ProxyHTTPSPort) +} + +func TestNoRedirect(t *testing.T) { + result, err := newMiddlewareTest(RedirectHTTP, &testArgs{ + scheme: "https", + }) + ExpectNoError(t, err.Error()) + ExpectEqual(t, result.ResponseStatus, http.StatusOK) +} diff --git a/src/route/middleware/set_x_forwarded.go b/src/route/middleware/set_x_forwarded.go deleted file mode 100644 index 2cec4aa2..00000000 --- a/src/route/middleware/set_x_forwarded.go +++ /dev/null @@ -1,10 +0,0 @@ -package middleware - -var SetXForwarded = &Middleware{ - rewrite: func(r *ProxyRequest) { - r.Out.Header.Del("X-Forwarded-For") - r.Out.Header.Del("X-Forwarded-Host") - r.Out.Header.Del("X-Forwarded-Proto") - r.SetXForwarded() - }, -} diff --git a/src/route/middleware/test_data/sample_headers.json b/src/route/middleware/test_data/sample_headers.json new file mode 100644 index 00000000..e1276484 --- /dev/null +++ b/src/route/middleware/test_data/sample_headers.json @@ -0,0 +1,17 @@ +{ + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", + "Accept-Encoding": "gzip, deflate, br, zstd", + "Accept-Language": "en,zh-HK;q=0.9,zh-TW;q=0.8,zh-CN;q=0.7,zh;q=0.6", + "Dnt": "1", + "Host": "localhost", + "Priority": "u=0, i", + "Sec-Ch-Ua": "\"Chromium\";v=\"129\", \"Not=A?Brand\";v=\"8\"", + "Sec-Ch-Ua-Mobile": "?0", + "Sec-Ch-Ua-Platform": "\"Windows\"", + "Sec-Fetch-Dest": "document", + "Sec-Fetch-Mode": "navigate", + "Sec-Fetch-Site": "none", + "Sec-Fetch-User": "?1", + "Upgrade-Insecure-Requests": "1", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36" +} \ No newline at end of file diff --git a/src/route/middleware/test_utils.go b/src/route/middleware/test_utils.go new file mode 100644 index 00000000..f19dbd25 --- /dev/null +++ b/src/route/middleware/test_utils.go @@ -0,0 +1,125 @@ +package middleware + +import ( + "bytes" + _ "embed" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + + E "github.com/yusing/go-proxy/error" + gpHTTP "github.com/yusing/go-proxy/http" +) + +//go:embed test_data/sample_headers.json +var testHeadersRaw []byte +var testHeaders http.Header + +const testHost = "example.com" + +func init() { + tmp := map[string]string{} + err := json.Unmarshal(testHeadersRaw, &tmp) + if err != nil { + panic(err) + } + testHeaders = http.Header{} + for k, v := range tmp { + testHeaders.Set(k, v) + } +} + +type requestHeaderRecorder struct { + parent http.RoundTripper + reqHeaders http.Header +} + +func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, error) { + rt.reqHeaders = req.Header + if rt.parent != nil { + return rt.parent.RoundTrip(req) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: testHeaders, + Body: io.NopCloser(bytes.NewBufferString("OK")), + Request: req, + }, nil +} + +type TestResult struct { + RequestHeaders http.Header + ResponseHeaders http.Header + ResponseStatus int + Data []byte +} + +type testArgs struct { + middlewareOpt OptionsRaw + proxyURL string + body []byte + scheme string +} + +func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) { + var body io.Reader + var rt = new(requestHeaderRecorder) + var proxyURL *url.URL + var requestTarget string + var err error + + if args == nil { + args = new(testArgs) + } + + if args.body != nil { + body = bytes.NewReader(args.body) + } + + if args.scheme == "" || args.scheme == "http" { + requestTarget = "http://" + testHost + } else if args.scheme == "https" { + requestTarget = "https://" + testHost + } else { + panic("typo?") + } + + req := httptest.NewRequest(http.MethodGet, requestTarget, body) + w := httptest.NewRecorder() + + if args.scheme == "https" && req.TLS == nil { + panic("bug occurred") + } + + if args.proxyURL != "" { + proxyURL, err = url.Parse(args.proxyURL) + if err != nil { + return nil, E.From(err) + } + rt.parent = http.DefaultTransport + } else { + proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect + } + rp := gpHTTP.NewReverseProxy(proxyURL, rt) + setOptErr := PatchReverseProxy(rp, map[string]OptionsRaw{ + middleware.name: args.middlewareOpt, + }) + if setOptErr != nil { + return nil, setOptErr + } + rp.ServeHTTP(w, req) + resp := w.Result() + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, E.From(err) + } + return &TestResult{ + RequestHeaders: rt.reqHeaders, + ResponseHeaders: resp.Header, + ResponseStatus: resp.StatusCode, + Data: data, + }, nil +} diff --git a/src/route/middleware/x_forwarded.go b/src/route/middleware/x_forwarded.go new file mode 100644 index 00000000..902fb8e4 --- /dev/null +++ b/src/route/middleware/x_forwarded.go @@ -0,0 +1,9 @@ +package middleware + +var AddXForwarded = &Middleware{ + rewrite: (*ProxyRequest).AddXForwarded, +} + +var SetXForwarded = &Middleware{ + rewrite: (*ProxyRequest).SetXForwarded, +} diff --git a/src/server/server.go b/src/server/server.go index d054fe54..57276dd1 100644 --- a/src/server/server.go +++ b/src/server/server.go @@ -84,6 +84,11 @@ func NewServer(opt Options) (s *server) { } } +// Start will start the http and https servers. +// +// If both are not set, this does nothing. +// +// Start() is non-blocking func (s *server) Start() { if s.http == nil && s.https == nil { return diff --git a/src/utils/serialization.go b/src/utils/serialization.go index 28582d93..c4cd3108 100644 --- a/src/utils/serialization.go +++ b/src/utils/serialization.go @@ -106,7 +106,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) { return result, nil } -func Deserialize(src map[string]any, target any) E.NestedError { +func Deserialize(src SerializedObject, target any) E.NestedError { // convert data fields to lower no-snake // convert target fields to lower // then check if the field of data is in the target @@ -117,6 +117,10 @@ func Deserialize(src map[string]any, target any) E.NestedError { snakeCaseField := strings.ToLower(field.Name) mapping[snakeCaseField] = field.Name } + tValue := reflect.ValueOf(target) + if tValue.IsZero() { + return E.Invalid("value", "nil") + } for k, v := range src { kCleaned := toLowerNoSnake(k) if fieldName, ok := mapping[kCleaned]; ok { @@ -150,13 +154,13 @@ func Deserialize(src map[string]any, target any) E.NestedError { } prop.Set(propNew) default: - return E.Unsupported("field", k).Extraf("type=%s", propType) + return E.Invalid("conversion", k).Extraf("from %s to %s", vType, propType) } } else { - return E.Unsupported("field", k).Extraf("type=%s", propType) + return E.Unsupported("field", k).Extraf("type %s is not settable", propType) } } else { - return E.Failure("unknown field").With(k) + return E.Unexpected("field", k) } } diff --git a/src/utils/testing/testing.go b/src/utils/testing/testing.go index 23f8242b..fb6e861b 100644 --- a/src/utils/testing/testing.go +++ b/src/utils/testing/testing.go @@ -38,6 +38,17 @@ func ExpectEqual[T comparable](t *testing.T, got T, want T) { } } +func ExpectEqualAny[T comparable](t *testing.T, got T, wants []T) { + t.Helper() + for _, want := range wants { + if got == want { + return + } + } + t.Errorf("expected any of:\n%v, got\n%v", wants, got) + t.FailNow() +} + func ExpectDeepEqual[T any](t *testing.T, got T, want T) { t.Helper() if !reflect.DeepEqual(got, want) {