diff --git a/config.example.yml b/config.example.yml index 4de8a686..a1eec288 100644 --- a/config.example.yml +++ b/config.example.yml @@ -52,8 +52,8 @@ entrypoint: # Note that HTTP/3 with proxy protocol is not supported yet. support_proxy_protocol: false - # To relay the client address to a TCP upstream, enable `relay_proxy_protocol_header: true` - # on that specific TCP route. UDP relay is not supported yet. + # To relay the client address to a TCP upstream (UDP relay is not supported yet) + relay_proxy_protocol_header: false # Below define an example of middleware config # 1. set security headers diff --git a/internal/config/state.go b/internal/config/state.go index 1b0bbbb2..7805869e 100644 --- a/internal/config/state.go +++ b/internal/config/state.go @@ -9,6 +9,8 @@ import ( "fmt" "io/fs" "iter" + "net" + "net/netip" "os" "strconv" "strings" diff --git a/internal/docker/label.go b/internal/docker/label.go index 73df326c..6f930181 100644 --- a/internal/docker/label.go +++ b/internal/docker/label.go @@ -1,8 +1,11 @@ package docker import ( + "cmp" "errors" "fmt" + "maps" + "slices" "strconv" "strings" @@ -15,13 +18,19 @@ var ErrInvalidLabel = errors.New("invalid label") const nsProxyDot = NSProxy + "." -var refPrefixes = func() []string { - prefixes := make([]string, 100) - for i := range prefixes { - prefixes[i] = nsProxyDot + "#" + strconv.Itoa(i+1) + "." +type UnexpectedTypeError struct { + Expected string + Actual any + // Message, if non-empty, is returned by Error() instead of the default "expect …, got …" form. + Message string +} + +func (e UnexpectedTypeError) Error() string { + if e.Message != "" { + return e.Message } - return prefixes -}() + return fmt.Sprintf("expect %s, got %T", e.Expected, e.Actual) +} func ParseLabels(labels map[string]string, aliases ...string) (types.LabelMap, error) { nestedMap := make(types.LabelMap) @@ -29,44 +38,125 @@ func ParseLabels(labels map[string]string, aliases ...string) (types.LabelMap, e ExpandWildcard(labels, aliases...) - for lbl, value := range labels { - parts := strings.Split(lbl, ".") - if parts[0] != NSProxy { - continue - } - if len(parts) == 1 { - errs.AddSubject(ErrInvalidLabel, lbl) - continue - } - parts = parts[1:] - currentMap := nestedMap - - for i, k := range parts { - if i == len(parts)-1 { - // Last element, set the value - currentMap[k] = value - } else { - // If the key doesn't exist, create a new map - if _, exists := currentMap[k]; !exists { - currentMap[k] = make(types.LabelMap) - } - // Move deeper into the nested map - m, ok := currentMap[k].(types.LabelMap) - if !ok && currentMap[k] != "" { - errs.AddSubject(fmt.Errorf("expect mapping, got %T", currentMap[k]), lbl) - continue - } else if !ok { - m = make(types.LabelMap) - currentMap[k] = m - } - currentMap = m - } + keys := slices.SortedFunc(maps.Keys(labels), compareLabelKeys) + for _, lbl := range keys { + if err := applyLabel(nestedMap, lbl, labels[lbl]); err != nil { + errs.AddSubject(err, lbl) } } return nestedMap, errs.Error() } +func applyLabel(dst types.LabelMap, lbl, value string) error { + parts := strings.Split(lbl, ".") + if parts[0] != NSProxy { + return nil + } + if len(parts) == 1 { + return ErrInvalidLabel + } + + currentMap := dst + for _, part := range parts[1 : len(parts)-1] { + nextMap, err := descendLabelMap(currentMap, part) + if err != nil { + return err + } + currentMap = nextMap + } + + return setLabelValue(currentMap, parts[len(parts)-1], value) +} + +func descendLabelMap(currentMap types.LabelMap, key string) (types.LabelMap, error) { + if next, ok := currentMap[key]; ok { + switch typed := next.(type) { + case types.LabelMap: + return typed, nil + case string: + objectValue, isObject := parseLabelObject(typed) + if !isObject { + return nil, UnexpectedTypeError{Expected: "mapping", Actual: next} + } + currentMap[key] = objectValue + return objectValue, nil + default: + return nil, UnexpectedTypeError{Expected: "mapping", Actual: next} + } + } + + nextMap := make(types.LabelMap) + currentMap[key] = nextMap + return nextMap, nil +} + +func setLabelValue(currentMap types.LabelMap, key, value string) error { + existing, ok := currentMap[key].(types.LabelMap) + if !ok { + currentMap[key] = value + return nil + } + + objectValue, isObject := parseLabelObject(value) + if !isObject { + return UnexpectedTypeError{Expected: "mapping", Actual: value} + } + return mergeLabelMaps(existing, objectValue) +} + +func parseLabelObject(value string) (types.LabelMap, bool) { + if value == "" { + return make(types.LabelMap), true + } + + objectValue := make(types.LabelMap) + if err := yaml.Unmarshal([]byte(strings.ReplaceAll(value, "\t", " ")), &objectValue); err != nil { + return nil, false + } + return objectValue, true +} + +func mergeLabelMaps(dst, src types.LabelMap) error { + for key, srcValue := range src { + existingValue, exists := dst[key] + if !exists { + dst[key] = srcValue + continue + } + + existingMap, existingIsMap := existingValue.(types.LabelMap) + srcMap, srcIsMap := srcValue.(types.LabelMap) + if existingIsMap && srcIsMap { + if err := mergeLabelMaps(existingMap, srcMap); err != nil { + return err + } + continue + } + if existingIsMap { + return UnexpectedTypeError{Expected: "mapping", Actual: srcValue} + } + if srcIsMap { + return UnexpectedTypeError{ + Expected: "scalar", + Actual: srcValue, + Message: fmt.Sprintf( + "cannot merge mapping into existing scalar; merge source is %T", + srcValue, + ), + } + } + } + return nil +} + +func compareLabelKeys(a, b string) int { + if parts := cmp.Compare(strings.Count(a, "."), strings.Count(b, ".")); parts != 0 { + return parts + } + return cmp.Compare(a, b) +} + func ExpandWildcard(labels map[string]string, aliases ...string) { aliasSet := make(map[string]int, len(aliases)) for i, alias := range aliases { @@ -77,12 +167,10 @@ func ExpandWildcard(labels map[string]string, aliases ...string) { // First pass: collect wildcards and discover aliases for lbl, value := range labels { - if !strings.HasPrefix(lbl, nsProxyDot) { + alias, suffix, ok := splitAliasLabel(lbl) + if !ok { continue } - // lbl is "proxy.X..." where X is alias or wildcard - rest := lbl[len(nsProxyDot):] // "X..." or "X.suffix" - alias, suffix, _ := strings.Cut(rest, ".") if alias == WildcardAlias { delete(labels, lbl) if suffix == "" || strings.Count(value, "\n") > 1 { @@ -108,15 +196,10 @@ func ExpandWildcard(labels map[string]string, aliases ...string) { // Second pass: convert explicit labels to #N format for lbl, value := range labels { - if !strings.HasPrefix(lbl, nsProxyDot) { + alias, suffix, ok := splitAliasLabel(lbl) + if !ok || suffix == "" || alias == "" || alias[0] == '#' { continue } - rest := lbl[len(nsProxyDot):] - alias, suffix, ok := strings.Cut(rest, ".") - if !ok || alias == "" || alias[0] == '#' { - continue - } - idx, known := aliasSet[alias] if !known { continue @@ -124,24 +207,33 @@ func ExpandWildcard(labels map[string]string, aliases ...string) { delete(labels, lbl) if _, overridden := wildcardLabels[suffix]; !overridden { - labels[refPrefixes[idx]+suffix] = value + labels[refPrefix(idx)+suffix] = value } } // Expand wildcards for all aliases for suffix, value := range wildcardLabels { for _, idx := range aliasSet { - labels[refPrefixes[idx]+suffix] = value + labels[refPrefix(idx)+suffix] = value } } } +func splitAliasLabel(lbl string) (alias, suffix string, ok bool) { + rest, ok := strings.CutPrefix(lbl, nsProxyDot) + if !ok { + return "", "", false + } + alias, suffix, _ = strings.Cut(rest, ".") + return alias, suffix, true +} + // expandYamlWildcard parses a YAML document in value, flattens it to dot-notated keys and adds the // results into dest map where each key is the flattened suffix and the value is the scalar string // representation. The provided YAML is expected to be a mapping. func expandYamlWildcard(value string, dest map[string]string) { // replace tab indentation with spaces to make YAML parser happy - yamlStr := strings.ReplaceAll(value, "\t", " ") + yamlStr := strings.ReplaceAll(value, "\t", " ") raw := make(map[string]any) if err := yaml.Unmarshal([]byte(yamlStr), &raw); err != nil { @@ -152,59 +244,53 @@ func expandYamlWildcard(value string, dest map[string]string) { flattenMap("", raw, dest) } +// refPrefix returns the prefix for a reference to the Nth alias. +func refPrefix(n int) string { + return nsProxyDot + "#" + strconv.Itoa(n+1) + "." +} + // flattenMap converts nested maps into a flat map with dot-delimited keys. func flattenMap(prefix string, src map[string]any, dest map[string]string) { for k, v := range src { - key := k - if prefix != "" { - key = prefix + "." + k - } - switch vv := v.(type) { - case map[string]any: - flattenMap(key, vv, dest) - case map[any]any: - flattenMapAny(key, vv, dest) - case string: - dest[key] = vv - case int: - dest[key] = strconv.Itoa(vv) - case bool: - dest[key] = strconv.FormatBool(vv) - case float64: - dest[key] = strconv.FormatFloat(vv, 'f', -1, 64) - default: - dest[key] = fmt.Sprint(v) - } + flattenValue(joinLabelKey(prefix, k), v, dest) } } func flattenMapAny(prefix string, src map[any]any, dest map[string]string) { for k, v := range src { - var key string - switch kk := k.(type) { - case string: - key = kk - default: - key = fmt.Sprint(k) - } - if prefix != "" { - key = prefix + "." + key - } - switch vv := v.(type) { - case map[string]any: - flattenMap(key, vv, dest) - case map[any]any: - flattenMapAny(key, vv, dest) - case string: - dest[key] = vv - case int: - dest[key] = strconv.Itoa(vv) - case bool: - dest[key] = strconv.FormatBool(vv) - case float64: - dest[key] = strconv.FormatFloat(vv, 'f', -1, 64) - default: - dest[key] = fmt.Sprint(v) - } + flattenValue(joinLabelKey(prefix, stringifyLabelKey(k)), v, dest) } } + +func flattenValue(key string, value any, dest map[string]string) { + switch typed := value.(type) { + case map[string]any: + flattenMap(key, typed, dest) + case map[any]any: + flattenMapAny(key, typed, dest) + case string: + dest[key] = typed + case int: + dest[key] = strconv.Itoa(typed) + case bool: + dest[key] = strconv.FormatBool(typed) + case float64: + dest[key] = strconv.FormatFloat(typed, 'f', -1, 64) + default: + dest[key] = fmt.Sprint(value) + } +} + +func joinLabelKey(prefix, key string) string { + if prefix == "" { + return key + } + return prefix + "." + key +} + +func stringifyLabelKey(key any) string { + if typed, ok := key.(string); ok { + return typed + } + return fmt.Sprint(key) +} diff --git a/internal/docker/label_internal_test.go b/internal/docker/label_internal_test.go new file mode 100644 index 00000000..28e5ed08 --- /dev/null +++ b/internal/docker/label_internal_test.go @@ -0,0 +1,310 @@ +package docker + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/internal/types" +) + +func TestParseLabelsIgnoresNonProxyAndRejectsInvalidRoot(t *testing.T) { + parsed, err := ParseLabels(map[string]string{ + "other.label": "value", + "proxy": "invalid", + }) + + require.ErrorIs(t, err, ErrInvalidLabel) + require.Empty(t, parsed) +} + +func TestParseLabelsPromotesEmptyStringIntoNestedObject(t *testing.T) { + parsed, err := ParseLabels(map[string]string{ + "proxy.a.b": "", + "proxy.a.b.c": "value", + }) + + require.NoError(t, err) + require.Equal(t, types.LabelMap{ + "a": types.LabelMap{ + "b": types.LabelMap{ + "c": "value", + }, + }, + }, parsed) +} + +func TestParseLabelsMergesObjectIntoExistingMap(t *testing.T) { + parsed, err := ParseLabels(map[string]string{ + "proxy.a.b": "c: generic\nd: merged", + "proxy.a.b.c": "specific", + }) + + require.NoError(t, err) + require.Equal(t, types.LabelMap{ + "a": types.LabelMap{ + "b": types.LabelMap{ + "c": "specific", + "d": "merged", + }, + }, + }, parsed) +} + +func TestParseLabelsRejectsInvalidObjectMergeValue(t *testing.T) { + parsed, err := ParseLabels(map[string]string{ + "proxy.a.b": "- invalid", + "proxy.a.b.c": "specific", + }) + + require.ErrorContains(t, err, "proxy.a.b.c") + require.ErrorContains(t, err, "expect mapping, got string") + require.Equal(t, types.LabelMap{ + "a": types.LabelMap{ + "b": "- invalid", + }, + }, parsed) +} + +func TestParseLabelsRejectsSpecificFieldOverrideOfNestedObjectField(t *testing.T) { + parsed, err := ParseLabels(map[string]string{ + "proxy.a.b": "c:\n nested: value", + "proxy.a.b.c": "specific", + }) + + require.ErrorContains(t, err, "proxy.a.b.c") + require.ErrorContains(t, err, "expect mapping, got string") + require.Equal(t, types.LabelMap{ + "a": types.LabelMap{ + "b": types.LabelMap{ + "c": types.LabelMap{ + "nested": "value", + }, + }, + }, + }, parsed) +} + +func TestParseLabelsMergesIntoExistingNestedMap(t *testing.T) { + parsed, err := ParseLabels(map[string]string{ + "proxy.a.b": "c:\n nested:\n allow: true", + "proxy.a.b.c": "nested:\n deny: true", + }) + + require.NoError(t, err) + require.Equal(t, types.LabelMap{ + "a": types.LabelMap{ + "b": types.LabelMap{ + "c": types.LabelMap{ + "nested": types.LabelMap{ + "allow": true, + "deny": true, + }, + }, + }, + }, + }, parsed) +} + +func TestParseLabelsRejectsInvalidNestedObjectMergeValue(t *testing.T) { + parsed, err := ParseLabels(map[string]string{ + "proxy.a.b": "c:\n nested: value", + "proxy.a.b.c": "- invalid", + }) + + require.ErrorContains(t, err, "proxy.a.b.c") + require.ErrorContains(t, err, "expect mapping, got string") + require.Equal(t, types.LabelMap{ + "a": types.LabelMap{ + "b": types.LabelMap{ + "c": types.LabelMap{ + "nested": "value", + }, + }, + }, + }, parsed) +} + +func TestParseLabelsRejectsConflictingNestedObjectMerge(t *testing.T) { + parsed, err := ParseLabels(map[string]string{ + "proxy.a.b": "c:\n nested:\n allow: true", + "proxy.a.b.c": "nested: blocked", + }) + + require.ErrorContains(t, err, "proxy.a.b.c") + require.ErrorContains(t, err, "expect mapping, got string") + require.Equal(t, types.LabelMap{ + "a": types.LabelMap{ + "b": types.LabelMap{ + "c": types.LabelMap{ + "nested": types.LabelMap{ + "allow": true, + }, + }, + }, + }, + }, parsed) +} + +func TestParseLabelsRejectsNestedFieldInsideScalarObjectMember(t *testing.T) { + parsed, err := ParseLabels(map[string]string{ + "proxy.a.b": "c: 1", + "proxy.a.b.c.d": "value", + }) + + require.ErrorContains(t, err, "proxy.a.b.c.d") + require.ErrorContains(t, err, "expect mapping, got uint64") + require.Equal(t, types.LabelMap{ + "a": types.LabelMap{ + "b": types.LabelMap{ + "c": uint64(1), + }, + }, + }, parsed) +} + +func TestParseLabelObject(t *testing.T) { + t.Run("empty string becomes empty map", func(t *testing.T) { + parsed, ok := parseLabelObject("") + require.True(t, ok) + require.Empty(t, parsed) + }) + + t.Run("yaml object parses", func(t *testing.T) { + parsed, ok := parseLabelObject("nested:\n\tvalue: true") + require.True(t, ok) + require.Equal(t, types.LabelMap{ + "nested": types.LabelMap{ + "value": true, + }, + }, parsed) + }) + + t.Run("non-object yaml is rejected", func(t *testing.T) { + parsed, ok := parseLabelObject("- item") + require.False(t, ok) + require.Nil(t, parsed) + }) +} + +func TestMergeLabelMaps(t *testing.T) { + t.Run("recursively merges nested maps and preserves specific scalar overrides", func(t *testing.T) { + dst := types.LabelMap{ + "allowed_groups": []any{"specific"}, + "bypass": types.LabelMap{ + "path": "/private", + }, + } + src := types.LabelMap{ + "allowed_groups": []any{"generic"}, + "bypass": types.LabelMap{ + "methods": "GET", + }, + "priority": 5, + } + + err := mergeLabelMaps(dst, src) + require.NoError(t, err) + require.Equal(t, types.LabelMap{ + "allowed_groups": []any{"specific"}, + "bypass": types.LabelMap{ + "path": "/private", + "methods": "GET", + }, + "priority": 5, + }, dst) + }) + + t.Run("rejects map receiving scalar", func(t *testing.T) { + err := mergeLabelMaps(types.LabelMap{ + "bypass": types.LabelMap{"path": "/private"}, + }, types.LabelMap{ + "bypass": "skip", + }) + + require.ErrorContains(t, err, "expect mapping") + }) + + t.Run("rejects scalar receiving map", func(t *testing.T) { + err := mergeLabelMaps(types.LabelMap{ + "bypass": "skip", + }, types.LabelMap{ + "bypass": types.LabelMap{"path": "/private"}, + }) + + require.ErrorContains(t, err, "cannot merge mapping into existing scalar") + }) + + t.Run("rejects nested recursive map conflicts", func(t *testing.T) { + err := mergeLabelMaps(types.LabelMap{ + "outer": types.LabelMap{ + "nested": types.LabelMap{"allow": true}, + }, + }, types.LabelMap{ + "outer": types.LabelMap{ + "nested": "blocked", + }, + }) + + require.ErrorContains(t, err, "expect mapping") + }) +} + +func TestCompareLabelKeys(t *testing.T) { + require.Less(t, compareLabelKeys("proxy.a", "proxy.a.b"), 0) + require.Less(t, compareLabelKeys("proxy.a.a", "proxy.a.b"), 0) + require.Greater(t, compareLabelKeys("proxy.a.c", "proxy.a.b"), 0) +} + +func TestFlattenMapAny(t *testing.T) { + dest := make(map[string]string) + + flattenMapAny("", map[any]any{ + "nested": map[any]any{ + "string": "value", + "int": 7, + "bool": true, + "float": 1.5, + 9: "numeric-key", + "map": map[string]any{ + "child": "value", + }, + }, + "list": []int{1, 2}, + }, dest) + + require.Equal(t, map[string]string{ + "nested.string": "value", + "nested.int": "7", + "nested.bool": "true", + "nested.float": "1.5", + "nested.9": "numeric-key", + "nested.map.child": "value", + "list": "[1 2]", + }, dest) +} + +func TestFlattenMap(t *testing.T) { + dest := make(map[string]string) + + flattenMap("", map[string]any{ + "nested": map[string]any{ + "string": "value", + "mapany": map[any]any{ + "child": "nested-value", + }, + "int": 7, + "bool": true, + "float": 1.5, + }, + "list": []int{1, 2}, + }, dest) + + require.Equal(t, map[string]string{ + "nested.string": "value", + "nested.mapany.child": "nested-value", + "nested.int": "7", + "nested.bool": "true", + "nested.float": "1.5", + "list": "[1 2]", + }, dest) +} diff --git a/internal/docker/label_test.go b/internal/docker/label_test.go index 6a7a3848..b8aa8580 100644 --- a/internal/docker/label_test.go +++ b/internal/docker/label_test.go @@ -1,6 +1,7 @@ package docker_test import ( + "fmt" "testing" "github.com/stretchr/testify/require" @@ -242,6 +243,14 @@ port: 8080`[1:] }) } +func requireMap(t *testing.T, value any) map[string]any { + t.Helper() + + m, ok := value.(map[string]any) + require.True(t, ok, "expected map[string]any, got %T", value) + return m +} + func BenchmarkParseLabels(b *testing.B) { m := map[string]string{ "proxy.a.host": "localhost", @@ -253,3 +262,39 @@ func BenchmarkParseLabels(b *testing.B) { _, _ = docker.ParseLabels(m, "a", "b") } } + +func TestParseLabelsMixedObjectAndFlatFields(t *testing.T) { + for i := range 100 { + labels := map[string]string{ + "proxy.universal.middlewares.oidc": "allowed_groups: [everyone]", + "proxy.universal.middlewares.oidc.bypass": "- path glob(/geheimenvan/*)", + } + + parsed, err := docker.ParseLabels(labels) + require.NoError(t, err, fmt.Sprintf("iteration %d", i)) + + universal := requireMap(t, parsed["universal"]) + middlewares := requireMap(t, universal["middlewares"]) + oidc := requireMap(t, middlewares["oidc"]) + + require.Equal(t, []any{"everyone"}, oidc["allowed_groups"]) + require.Equal(t, "- path glob(/geheimenvan/*)", oidc["bypass"]) + } +} + +func TestParseLabelsRejectsScalarAndNestedObjectConflict(t *testing.T) { + for i := range 100 { + parsed, err := docker.ParseLabels(map[string]string{ + "proxy.universal.middlewares.oidc": "bypass: skip", + "proxy.universal.middlewares.oidc.bypass.path": "/geheimenvan", + }) + + require.ErrorContains(t, err, "proxy.universal.middlewares.oidc.bypass.path") + require.ErrorContains(t, err, "expect mapping, got string") + + universal := requireMap(t, parsed["universal"]) + middlewares := requireMap(t, universal["middlewares"]) + oidc := requireMap(t, middlewares["oidc"]) + require.Equal(t, "skip", oidc["bypass"], "iteration %d", i) + } +}