restructured the project to comply community guideline, for others check release note

This commit is contained in:
yusing
2024-09-28 09:51:34 +08:00
parent 4120fd8d1c
commit 90487bfde6
134 changed files with 1110 additions and 625 deletions

142
internal/proxy/entry.go Normal file
View File

@@ -0,0 +1,142 @@
package proxy
import (
"fmt"
"net/url"
"time"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
M "github.com/yusing/go-proxy/internal/models"
T "github.com/yusing/go-proxy/internal/proxy/fields"
)
type (
ReverseProxyEntry struct { // real model after validation
Alias T.Alias
Scheme T.Scheme
URL *url.URL
NoTLSVerify bool
PathPatterns T.PathPatterns
Middlewares D.NestedLabelMap
/* Docker only */
IdleTimeout time.Duration
WakeTimeout time.Duration
StopMethod T.StopMethod
StopTimeout int
StopSignal T.Signal
DockerHost string
ContainerName string
ContainerRunning bool
}
StreamEntry struct {
Alias T.Alias `json:"alias"`
Scheme T.StreamScheme `json:"scheme"`
Host T.Host `json:"host"`
Port T.StreamPort `json:"port"`
}
)
func (rp *ReverseProxyEntry) UseIdleWatcher() bool {
return rp.IdleTimeout > 0 && rp.DockerHost != ""
}
func ValidateEntry(m *M.RawEntry) (any, E.NestedError) {
if !m.FillMissingFields() {
return nil, E.Missing("fields")
}
scheme, err := T.NewScheme(m.Scheme)
if err.HasError() {
return nil, err
}
var entry any
e := E.NewBuilder("error validating entry")
if scheme.IsStream() {
entry = validateStreamEntry(m, e)
} else {
entry = validateRPEntry(m, scheme, e)
}
if err := e.Build(); err.HasError() {
return nil, err
}
return entry, nil
}
func validateRPEntry(m *M.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry {
var stopTimeOut time.Duration
host, err := T.ValidateHost(m.Host)
b.Add(err)
port, err := T.ValidatePort(m.Port)
b.Add(err)
pathPatterns, err := T.ValidatePathPatterns(m.PathPatterns)
b.Add(err)
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
b.Add(err)
idleTimeout, err := T.ValidateDurationPostitive(m.IdleTimeout)
b.Add(err)
wakeTimeout, err := T.ValidateDurationPostitive(m.WakeTimeout)
b.Add(err)
stopMethod, err := T.ValidateStopMethod(m.StopMethod)
b.Add(err)
if stopMethod == T.StopMethodStop {
stopTimeOut, err = T.ValidateDurationPostitive(m.StopTimeout)
b.Add(err)
}
stopSignal, err := T.ValidateSignal(m.StopSignal)
b.Add(err)
if err.HasError() {
return nil
}
return &ReverseProxyEntry{
Alias: T.NewAlias(m.Alias),
Scheme: s,
URL: url,
NoTLSVerify: m.NoTLSVerify,
PathPatterns: pathPatterns,
Middlewares: m.Middlewares,
IdleTimeout: idleTimeout,
WakeTimeout: wakeTimeout,
StopMethod: stopMethod,
StopTimeout: int(stopTimeOut.Seconds()), // docker api takes integer seconds for timeout argument
StopSignal: stopSignal,
DockerHost: m.DockerHost,
ContainerName: m.ContainerName,
ContainerRunning: m.Running,
}
}
func validateStreamEntry(m *M.RawEntry, b E.Builder) *StreamEntry {
host, err := T.ValidateHost(m.Host)
b.Add(err)
port, err := T.ValidateStreamPort(m.Port)
b.Add(err)
scheme, err := T.ValidateStreamScheme(m.Scheme)
b.Add(err)
if b.HasError() {
return nil
}
return &StreamEntry{
Alias: T.NewAlias(m.Alias),
Scheme: *scheme,
Host: host,
Port: port,
}
}

View File

@@ -0,0 +1,6 @@
package fields
type (
Alias string
NewAlias = Alias
)

View File

@@ -0,0 +1,19 @@
package fields
import (
"net/http"
"strings"
E "github.com/yusing/go-proxy/internal/error"
)
func ValidateHTTPHeaders(headers map[string]string) (http.Header, E.NestedError) {
h := make(http.Header)
for k, v := range headers {
vSplit := strings.Split(v, ",")
for _, header := range vSplit {
h.Add(k, strings.TrimSpace(header))
}
}
return h, nil
}

View File

@@ -0,0 +1,12 @@
package fields
import (
E "github.com/yusing/go-proxy/internal/error"
)
type Host string
type Subdomain = Alias
func ValidateHost[String ~string](s String) (Host, E.NestedError) {
return Host(s), nil
}

View File

@@ -0,0 +1,24 @@
package fields
import (
E "github.com/yusing/go-proxy/internal/error"
)
type PathMode string
func NewPathMode(pm string) (PathMode, E.NestedError) {
switch pm {
case "", "forward":
return PathMode(pm), nil
default:
return "", E.Invalid("path mode", pm)
}
}
func (p PathMode) IsRemove() bool {
return p == ""
}
func (p PathMode) IsForward() bool {
return p == "forward"
}

View File

@@ -0,0 +1,37 @@
package fields
import (
"regexp"
E "github.com/yusing/go-proxy/internal/error"
)
type PathPattern string
type PathPatterns = []PathPattern
func NewPathPattern(s string) (PathPattern, E.NestedError) {
if len(s) == 0 {
return "", E.Invalid("path", "must not be empty")
}
if !pathPattern.MatchString(s) {
return "", E.Invalid("path pattern", s)
}
return PathPattern(s), nil
}
func ValidatePathPatterns(s []string) (PathPatterns, E.NestedError) {
if len(s) == 0 {
return []PathPattern{"/"}, nil
}
pp := make(PathPatterns, len(s))
for i, v := range s {
if pattern, err := NewPathPattern(v); err.HasError() {
return nil, err
} else {
pp[i] = pattern
}
}
return pp, nil
}
var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`)

View File

@@ -0,0 +1,47 @@
package fields
import (
"testing"
E "github.com/yusing/go-proxy/internal/error"
U "github.com/yusing/go-proxy/internal/utils/testing"
)
var validPatterns = []string{
"/",
"/index.html",
"/somepage/",
"/drive/abc.mp4",
"/{$}",
"/some-page/{$}",
"GET /",
"GET /static/{$}",
"GET /drive/abc.mp4",
"GET /drive/abc.mp4/{$}",
"POST /auth",
"DELETE /user/",
"PUT /storage/id/",
}
var invalidPatterns = []string{
"/$",
"/{$}{$}",
"/{$}/{$}",
"/index.html$",
"get /",
"GET/",
"GET /$",
"GET /drive/{$}/abc.mp4/",
"OPTION /config/{$}/abc.conf/{$}",
}
func TestPathPatternRegex(t *testing.T) {
for _, pattern := range validPatterns {
_, err := NewPathPattern(pattern)
U.ExpectNoError(t, err.Error())
}
for _, pattern := range invalidPatterns {
_, err := NewPathPattern(pattern)
U.ExpectError2(t, pattern, E.ErrInvalid, err.Error())
}
}

View File

@@ -0,0 +1,41 @@
package fields
import (
"strconv"
E "github.com/yusing/go-proxy/internal/error"
)
type Port int
func ValidatePort[String ~string](v String) (Port, E.NestedError) {
p, err := strconv.Atoi(string(v))
if err != nil {
return ErrPort, E.Invalid("port number", v).With(err)
}
return ValidatePortInt(p)
}
func ValidatePortInt[Int int | uint16](v Int) (Port, E.NestedError) {
p := Port(v)
if !p.inBound() {
return ErrPort, E.OutOfRange("port", p)
}
return p, nil
}
func (p Port) inBound() bool {
return p >= MinPort && p <= MaxPort
}
func (p Port) String() string {
return strconv.Itoa(int(p))
}
const (
MinPort = 0
MaxPort = 65535
ErrPort = Port(-1)
NoPort = Port(-1)
ZeroPort = Port(0)
)

View File

@@ -0,0 +1,21 @@
package fields
import (
E "github.com/yusing/go-proxy/internal/error"
)
type Scheme string
func NewScheme[String ~string](s String) (Scheme, E.NestedError) {
switch s {
case "http", "https", "tcp", "udp":
return Scheme(s), nil
}
return "", E.Invalid("scheme", s)
}
func (s Scheme) IsHTTP() bool { return s == "http" }
func (s Scheme) IsHTTPS() bool { return s == "https" }
func (s Scheme) IsTCP() bool { return s == "tcp" }
func (s Scheme) IsUDP() bool { return s == "udp" }
func (s Scheme) IsStream() bool { return s.IsTCP() || s.IsUDP() }

View File

@@ -0,0 +1,17 @@
package fields
import (
E "github.com/yusing/go-proxy/internal/error"
)
type Signal string
func ValidateSignal(s string) (Signal, E.NestedError) {
switch s {
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
"INT", "TERM", "HUP", "QUIT":
return Signal(s), nil
}
return "", E.Invalid("signal", s)
}

View File

@@ -0,0 +1,23 @@
package fields
import (
E "github.com/yusing/go-proxy/internal/error"
)
type StopMethod string
const (
StopMethodPause StopMethod = "pause"
StopMethodStop StopMethod = "stop"
StopMethodKill StopMethod = "kill"
)
func ValidateStopMethod(s string) (StopMethod, E.NestedError) {
sm := StopMethod(s)
switch sm {
case StopMethodPause, StopMethodStop, StopMethodKill:
return sm, nil
default:
return "", E.Invalid("stop_method", sm)
}
}

View File

@@ -0,0 +1,56 @@
package fields
import (
"strings"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
)
type StreamPort struct {
ListeningPort Port `json:"listening"`
ProxyPort Port `json:"proxy"`
}
func ValidateStreamPort(p string) (StreamPort, E.NestedError) {
split := strings.Split(p, ":")
switch len(split) {
case 1:
split = []string{"0", split[0]}
case 2:
break
default:
return ErrStreamPort, E.Invalid("stream port", p).With("too many colons")
}
listeningPort, err := ValidatePort(split[0])
if err != nil {
return ErrStreamPort, err.Subject("listening port")
}
proxyPort, err := ValidatePort(split[1])
if err.Is(E.ErrOutOfRange) {
return ErrStreamPort, err.Subject("proxy port")
} else if proxyPort == 0 {
return ErrStreamPort, E.Invalid("proxy port", p)
} else if err != nil {
proxyPort, err = parseNameToPort(split[1])
if err != nil {
return ErrStreamPort, E.Invalid("proxy port", proxyPort)
}
}
return StreamPort{listeningPort, proxyPort}, nil
}
func parseNameToPort(name string) (Port, E.NestedError) {
port, ok := common.ServiceNamePortMapTCP[name]
if !ok {
return ErrPort, E.Invalid("service", name)
}
return Port(port), nil
}
var ErrStreamPort = StreamPort{ErrPort, ErrPort}

View File

@@ -0,0 +1,48 @@
package fields
import (
"testing"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
var validPorts = []string{
"1234:5678",
"0:2345",
"2345",
"1234:postgres",
}
var invalidPorts = []string{
"",
"123:",
"0:",
":1234",
"1234:1234:1234",
"qwerty",
"asdfgh:asdfgh",
"1234:asdfgh",
}
var outOfRangePorts = []string{
"-1:1234",
"1234:-1",
"65536",
"0:65536",
}
func TestStreamPort(t *testing.T) {
for _, port := range validPorts {
_, err := ValidateStreamPort(port)
ExpectNoError(t, err.Error())
}
for _, port := range invalidPorts {
_, err := ValidateStreamPort(port)
ExpectError2(t, port, E.ErrInvalid, err.Error())
}
for _, port := range outOfRangePorts {
_, err := ValidateStreamPort(port)
ExpectError2(t, port, E.ErrOutOfRange, err.Error())
}
}

View File

@@ -0,0 +1,43 @@
package fields
import (
"fmt"
"strings"
E "github.com/yusing/go-proxy/internal/error"
)
type StreamScheme struct {
ListeningScheme Scheme `json:"listening"`
ProxyScheme Scheme `json:"proxy"`
}
func ValidateStreamScheme(s string) (ss *StreamScheme, err E.NestedError) {
ss = &StreamScheme{}
parts := strings.Split(s, ":")
if len(parts) == 1 {
parts = []string{s, s}
} else if len(parts) != 2 {
return nil, E.Invalid("stream scheme", s)
}
ss.ListeningScheme, err = NewScheme(parts[0])
if err.HasError() {
return nil, err
}
ss.ProxyScheme, err = NewScheme(parts[1])
if err.HasError() {
return nil, err
}
return ss, nil
}
func (s StreamScheme) String() string {
return fmt.Sprintf("%s:%s", s.ListeningScheme, s.ProxyScheme)
}
// IsCoherent checks if the ListeningScheme and ProxyScheme of the StreamScheme are equal.
//
// It returns a boolean value indicating whether the ListeningScheme and ProxyScheme are equal.
func (s StreamScheme) IsCoherent() bool {
return s.ListeningScheme == s.ProxyScheme
}

View File

@@ -0,0 +1,18 @@
package fields
import (
"time"
E "github.com/yusing/go-proxy/internal/error"
)
func ValidateDurationPostitive(value string) (time.Duration, E.NestedError) {
d, err := time.ParseDuration(value)
if err != nil {
return 0, E.Invalid("duration", value)
}
if d < 0 {
return 0, E.Invalid("duration", "negative value")
}
return d, nil
}

206
internal/proxy/provider/docker.go Executable file
View File

@@ -0,0 +1,206 @@
package provider
import (
"fmt"
"regexp"
"strconv"
"strings"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
M "github.com/yusing/go-proxy/internal/models"
R "github.com/yusing/go-proxy/internal/route"
W "github.com/yusing/go-proxy/internal/watcher"
)
type DockerProvider struct {
dockerHost, hostname string
}
var AliasRefRegex = regexp.MustCompile(`#\d+`)
var AliasRefRegexOld = regexp.MustCompile(`\$\d+`)
func DockerProviderImpl(dockerHost string) (ProviderImpl, E.NestedError) {
hostname, err := D.ParseDockerHostname(dockerHost)
if err.HasError() {
return nil, err
}
return &DockerProvider{dockerHost: dockerHost, hostname: hostname}, nil
}
func (p *DockerProvider) String() string {
return fmt.Sprintf("docker:%s", p.dockerHost)
}
func (p *DockerProvider) NewWatcher() W.Watcher {
return W.NewDockerWatcher(p.dockerHost)
}
func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
routes = R.NewRoutes()
entries := M.NewProxyEntries()
info, err := D.GetClientInfo(p.dockerHost, true)
if err.HasError() {
return routes, E.FailWith("connect to docker", err)
}
errors := E.NewBuilder("errors when parse docker labels")
for _, c := range info.Containers {
container := D.FromDocker(&c, p.dockerHost)
if container.IsExcluded {
continue
}
newEntries, err := p.entriesFromContainerLabels(container)
if err.HasError() {
errors.Add(err)
}
// although err is not nil
// there may be some valid entries in `en`
dups := entries.MergeFrom(newEntries)
// add the duplicate proxy entries to the error
dups.RangeAll(func(k string, v *M.RawEntry) {
errors.Addf("duplicate alias %s", k)
})
}
entries.RangeAll(func(_ string, e *M.RawEntry) {
e.DockerHost = p.dockerHost
})
routes, err = R.FromEntries(entries)
errors.Add(err)
return routes, errors.Build()
}
func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) {
b := E.NewBuilder("event %s error", event)
defer b.To(&res.err)
routes.RangeAll(func(k string, v R.Route) {
if v.Entry().ContainerName == event.ActorName {
b.Add(v.Stop())
routes.Delete(k)
res.nRemoved++
}
})
client, err := D.ConnectClient(p.dockerHost)
if err.HasError() {
b.Add(E.FailWith("connect to docker", err))
return
}
defer client.Close()
cont, err := client.Inspect(event.ActorID)
if err.HasError() {
b.Add(E.FailWith("inspect container", err))
return
}
entries, err := p.entriesFromContainerLabels(cont)
b.Add(err)
entries.RangeAll(func(alias string, entry *M.RawEntry) {
if routes.Has(alias) {
b.Add(E.Duplicated("alias", alias))
} else {
if route, err := R.NewRoute(entry); err.HasError() {
b.Add(err)
} else {
routes.Store(alias, route)
b.Add(route.Start())
res.nAdded++
}
}
})
return
}
// Returns a list of proxy entries for a container.
// Always non-nil
func (p *DockerProvider) entriesFromContainerLabels(container D.Container) (M.RawEntries, E.NestedError) {
entries := M.NewProxyEntries()
if container.IsExcluded {
return entries, nil
}
// init entries map for all aliases
for _, a := range container.Aliases {
entries.Store(a, &M.RawEntry{
Alias: a,
Host: p.hostname,
ProxyProperties: container.ProxyProperties,
})
}
errors := E.NewBuilder("failed to apply label")
for key, val := range container.Labels {
errors.Add(p.applyLabel(container, entries, key, val))
}
// remove all entries that failed to fill in missing fields
entries.RemoveAll(func(re *M.RawEntry) bool {
return !re.FillMissingFields()
})
return entries, errors.Build().Subject(container.ContainerName)
}
func (p *DockerProvider) applyLabel(container D.Container, entries M.RawEntries, key, val string) (res E.NestedError) {
b := E.NewBuilder("errors in label %s", key)
defer b.To(&res)
refErr := E.NewBuilder("errors parsing alias references")
replaceIndexRef := func(ref string) string {
index, err := strconv.Atoi(ref[1:])
if err != nil {
refErr.Add(E.Invalid("integer", ref))
return ref
}
if index < 1 || index > len(container.Aliases) {
refErr.Add(E.OutOfRange("index", ref))
return ref
}
return container.Aliases[index-1]
}
lbl, err := D.ParseLabel(key, val)
if err.HasError() {
b.Add(err.Subject(key))
}
if lbl.Namespace != D.NSProxy {
return
}
if lbl.Target == D.WildcardAlias {
// apply label for all aliases
entries.RangeAll(func(a string, e *M.RawEntry) {
if err = D.ApplyLabel(e, lbl); err.HasError() {
b.Add(err.Subjectf("alias %s", lbl.Target))
}
})
} else {
lbl.Target = AliasRefRegex.ReplaceAllStringFunc(lbl.Target, replaceIndexRef)
lbl.Target = AliasRefRegexOld.ReplaceAllStringFunc(lbl.Target, func(s string) string {
logrus.Warnf("%q should now be %q, old syntax will be removed in a future version", lbl, strings.ReplaceAll(lbl.String(), "$", "#"))
return replaceIndexRef(s)
})
if refErr.HasError() {
b.Add(refErr.Build())
return
}
config, ok := entries.Load(lbl.Target)
if !ok {
b.Add(E.NotExist("alias", lbl.Target))
return
}
if err = D.ApplyLabel(config, lbl); err.HasError() {
b.Add(err.Subjectf("alias %s", lbl.Target))
}
}
return
}

View File

@@ -0,0 +1,307 @@
package provider
import (
"strings"
"testing"
"github.com/docker/docker/api/types"
"github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
P "github.com/yusing/go-proxy/internal/proxy"
T "github.com/yusing/go-proxy/internal/proxy/fields"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
var dummyNames = []string{"/a"}
func TestApplyLabelFieldValidity(t *testing.T) {
pathPatterns := `
- /
- POST /upload/{$}
- GET /static
`[1:]
pathPatternsExpect := []string{
"/",
"POST /upload/{$}",
"GET /static",
}
middlewaresExpect := D.NestedLabelMap{
"middleware1": {
"prop1": "value1",
"prop2": "value2",
},
"middleware2": {
"prop3": "value3",
"prop4": "value4",
},
}
var p DockerProvider
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a,b",
D.LabelIdleTimeout: common.IdleTimeoutDefault,
D.LabelStopMethod: common.StopMethodDefault,
D.LabelStopSignal: "SIGTERM",
D.LabelStopTimeout: common.StopTimeoutDefault,
D.LabelWakeTimeout: common.WakeTimeoutDefault,
"proxy.*.no_tls_verify": "true",
"proxy.*.scheme": "https",
"proxy.*.host": "app",
"proxy.*.port": "4567",
"proxy.a.no_tls_verify": "true",
"proxy.a.path_patterns": pathPatterns,
"proxy.a.middlewares.middleware1.prop1": "value1",
"proxy.a.middlewares.middleware1.prop2": "value2",
"proxy.a.middlewares.middleware2.prop3": "value3",
"proxy.a.middlewares.middleware2.prop4": "value4",
},
Ports: []types.Port{
{Type: "tcp", PrivatePort: 4567, PublicPort: 8888},
}}, ""))
ExpectNoError(t, err.Error())
a, ok := entries.Load("a")
ExpectTrue(t, ok)
b, ok := entries.Load("b")
ExpectTrue(t, ok)
ExpectEqual(t, a.Scheme, "https")
ExpectEqual(t, b.Scheme, "https")
ExpectEqual(t, a.Host, "app")
ExpectEqual(t, b.Host, "app")
ExpectEqual(t, a.Port, "8888")
ExpectEqual(t, b.Port, "8888")
ExpectTrue(t, a.NoTLSVerify)
ExpectTrue(t, b.NoTLSVerify)
ExpectDeepEqual(t, a.PathPatterns, pathPatternsExpect)
ExpectEqual(t, len(b.PathPatterns), 0)
ExpectDeepEqual(t, a.Middlewares, middlewaresExpect)
ExpectEqual(t, len(b.Middlewares), 0)
ExpectEqual(t, a.IdleTimeout, common.IdleTimeoutDefault)
ExpectEqual(t, b.IdleTimeout, common.IdleTimeoutDefault)
ExpectEqual(t, a.StopTimeout, common.StopTimeoutDefault)
ExpectEqual(t, b.StopTimeout, common.StopTimeoutDefault)
ExpectEqual(t, a.StopMethod, common.StopMethodDefault)
ExpectEqual(t, b.StopMethod, common.StopMethodDefault)
ExpectEqual(t, a.WakeTimeout, common.WakeTimeoutDefault)
ExpectEqual(t, b.WakeTimeout, common.WakeTimeoutDefault)
ExpectEqual(t, a.StopSignal, "SIGTERM")
ExpectEqual(t, b.StopSignal, "SIGTERM")
}
func TestApplyLabel(t *testing.T) {
var p DockerProvider
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a,b,c",
"proxy.a.no_tls_verify": "true",
"proxy.a.port": "3333",
"proxy.b.port": "1234",
"proxy.c.scheme": "https",
},
Ports: []types.Port{
{Type: "tcp", PrivatePort: 3333, PublicPort: 1111},
{Type: "tcp", PrivatePort: 4444, PublicPort: 1234},
}}, "",
))
a, ok := entries.Load("a")
ExpectTrue(t, ok)
b, ok := entries.Load("b")
ExpectTrue(t, ok)
c, ok := entries.Load("c")
ExpectTrue(t, ok)
ExpectNoError(t, err.Error())
ExpectEqual(t, a.Scheme, "http")
ExpectEqual(t, a.Port, "1111")
ExpectEqual(t, a.NoTLSVerify, true)
ExpectEqual(t, b.Scheme, "http")
ExpectEqual(t, b.Port, "1234")
ExpectEqual(t, c.Scheme, "https")
// map does not necessary follow the order above
ExpectEqualAny(t, c.Port, []string{"1111", "1234"})
}
func TestApplyLabelWithRef(t *testing.T) {
var p DockerProvider
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a,b,c",
"proxy.#1.host": "localhost",
"proxy.#1.port": "4444",
"proxy.#2.port": "9999",
"proxy.#3.port": "1111",
"proxy.#3.scheme": "https",
},
Ports: []types.Port{
{Type: "tcp", PrivatePort: 3333, PublicPort: 9999},
{Type: "tcp", PrivatePort: 4444, PublicPort: 5555},
{Type: "tcp", PrivatePort: 1111, PublicPort: 2222},
}}, ""))
a, ok := entries.Load("a")
ExpectTrue(t, ok)
b, ok := entries.Load("b")
ExpectTrue(t, ok)
c, ok := entries.Load("c")
ExpectTrue(t, ok)
ExpectNoError(t, err.Error())
ExpectEqual(t, a.Scheme, "http")
ExpectEqual(t, a.Host, "localhost")
ExpectEqual(t, a.Port, "5555")
ExpectEqual(t, b.Port, "9999")
ExpectEqual(t, c.Scheme, "https")
ExpectEqual(t, c.Port, "2222")
}
func TestApplyLabelWithRefIndexError(t *testing.T) {
var p DockerProvider
var c = D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a,b",
"proxy.#1.host": "localhost",
"proxy.#4.scheme": "https",
}}, "")
_, err := p.entriesFromContainerLabels(c)
ExpectError(t, E.ErrOutOfRange, err.Error())
ExpectTrue(t, strings.Contains(err.String(), "index out of range"))
_, err = p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a,b",
"proxy.#0.host": "localhost",
}}, ""))
ExpectError(t, E.ErrOutOfRange, err.Error())
ExpectTrue(t, strings.Contains(err.String(), "index out of range"))
}
func TestStreamDefaultValues(t *testing.T) {
var p DockerProvider
var c = D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a",
"proxy.*.no_tls_verify": "true",
},
Ports: []types.Port{
{Type: "udp", PrivatePort: 1234, PublicPort: 5678},
}}, "",
)
entries, err := p.entriesFromContainerLabels(c)
ExpectNoError(t, err.Error())
raw, ok := entries.Load("a")
ExpectTrue(t, ok)
entry, err := P.ValidateEntry(raw)
ExpectNoError(t, err.Error())
a := ExpectType[*P.StreamEntry](t, entry)
ExpectEqual(t, a.Scheme.ListeningScheme, T.Scheme("udp"))
ExpectEqual(t, a.Scheme.ProxyScheme, T.Scheme("udp"))
ExpectEqual(t, a.Port.ListeningPort, 0)
ExpectEqual(t, a.Port.ProxyPort, 5678)
}
func TestExplicitExclude(t *testing.T) {
var p DockerProvider
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a",
D.LabelExclude: "true",
"proxy.a.no_tls_verify": "true",
}}, ""))
ExpectNoError(t, err.Error())
_, ok := entries.Load("a")
ExpectFalse(t, ok)
}
func TestImplicitExclude(t *testing.T) {
var p DockerProvider
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a",
"proxy.a.no_tls_verify": "true",
},
State: "running",
}, ""))
ExpectNoError(t, err.Error())
_, ok := entries.Load("a")
ExpectFalse(t, ok)
}
func TestImplicitExcludeNoExposedPort(t *testing.T) {
var p DockerProvider
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Image: "redis",
Names: []string{"redis"},
Ports: []types.Port{
{Type: "tcp", PrivatePort: 6379, PublicPort: 0}, // not exposed
},
State: "running",
}, ""))
ExpectNoError(t, err.Error())
_, ok := entries.Load("redis")
ExpectFalse(t, ok)
}
func TestNotExcludeSpecifiedPort(t *testing.T) {
var p DockerProvider
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Image: "redis",
Names: []string{"redis"},
Ports: []types.Port{
{Type: "tcp", PrivatePort: 6379, PublicPort: 0}, // not exposed
},
Labels: map[string]string{
"proxy.redis.port": "6379:6379", // but specified in label
},
}, ""))
ExpectNoError(t, err.Error())
_, ok := entries.Load("redis")
ExpectTrue(t, ok)
}
func TestNotExcludeNonExposedPortHostNetwork(t *testing.T) {
var p DockerProvider
cont := &types.Container{
Image: "redis",
Names: []string{"redis"},
Ports: []types.Port{
{Type: "tcp", PrivatePort: 6379, PublicPort: 0}, // not exposed
},
Labels: map[string]string{
"proxy.redis.port": "6379:6379",
},
}
cont.HostConfig.NetworkMode = "host"
entries, err := p.entriesFromContainerLabels(D.FromDocker(cont, ""))
ExpectNoError(t, err.Error())
_, ok := entries.Load("redis")
ExpectTrue(t, ok)
}

View File

@@ -0,0 +1,98 @@
package provider
import (
"errors"
"os"
"path"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
M "github.com/yusing/go-proxy/internal/models"
R "github.com/yusing/go-proxy/internal/route"
U "github.com/yusing/go-proxy/internal/utils"
W "github.com/yusing/go-proxy/internal/watcher"
)
type FileProvider struct {
fileName string
path string
}
func FileProviderImpl(filename string) (ProviderImpl, E.NestedError) {
impl := &FileProvider{
fileName: filename,
path: path.Join(common.ConfigBasePath, filename),
}
_, err := os.Stat(impl.path)
switch {
case err == nil:
return impl, nil
case errors.Is(err, os.ErrNotExist):
return nil, E.NotExist("file", impl.path)
default:
return nil, E.UnexpectedError(err)
}
}
func Validate(data []byte) E.NestedError {
return U.ValidateYaml(U.GetSchema(common.FileProviderSchemaPath), data)
}
func (p FileProvider) String() string {
return p.fileName
}
func (p FileProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) {
b := E.NewBuilder("event %s error", event)
defer b.To(&res.err)
newRoutes, err := p.LoadRoutesImpl()
if err.HasError() {
b.Add(err)
return
}
routes.RangeAll(func(_ string, v R.Route) {
b.Add(v.Stop())
})
routes.Clear()
newRoutes.RangeAll(func(_ string, v R.Route) {
b.Add(v.Start())
})
routes.MergeFrom(newRoutes)
return
}
func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) {
routes = R.NewRoutes()
b := E.NewBuilder("file %q validation failure", p.fileName)
defer b.To(&res)
entries := M.NewProxyEntries()
data, err := E.Check(os.ReadFile(p.path))
if err.HasError() {
b.Add(E.FailWith("read file", err))
return
}
if !common.NoSchemaValidation {
if err = Validate(data); err.HasError() {
b.Add(err)
return
}
}
if err = entries.UnmarshalFromYAML(data); err.HasError() {
b.Add(err)
return
}
return R.FromEntries(entries)
}
func (p *FileProvider) NewWatcher() W.Watcher {
return W.NewConfigFileWatcher(p.fileName)
}

View File

@@ -0,0 +1,178 @@
package provider
import (
"context"
"path"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error"
R "github.com/yusing/go-proxy/internal/route"
W "github.com/yusing/go-proxy/internal/watcher"
)
type (
Provider struct {
ProviderImpl `json:"-"`
name string
t ProviderType
routes R.Routes
watcher W.Watcher
watcherCtx context.Context
watcherCancel context.CancelFunc
l *logrus.Entry
}
ProviderImpl interface {
NewWatcher() W.Watcher
// even returns error, routes must be non-nil
LoadRoutesImpl() (R.Routes, E.NestedError)
OnEvent(event W.Event, routes R.Routes) EventResult
String() string
}
ProviderType string
EventResult struct {
nRemoved int
nAdded int
err E.NestedError
}
)
const (
ProviderTypeDocker ProviderType = "docker"
ProviderTypeFile ProviderType = "file"
)
func newProvider(name string, t ProviderType) *Provider {
p := &Provider{
name: name,
t: t,
routes: R.NewRoutes(),
}
p.l = logrus.WithField("provider", p)
return p
}
func NewFileProvider(filename string) (p *Provider, err E.NestedError) {
name := path.Base(filename)
p = newProvider(name, ProviderTypeFile)
p.ProviderImpl, err = FileProviderImpl(filename)
if err != nil {
return nil, err
}
p.watcher = p.NewWatcher()
return
}
func NewDockerProvider(name string, dockerHost string) (p *Provider, err E.NestedError) {
p = newProvider(name, ProviderTypeDocker)
p.ProviderImpl, err = DockerProviderImpl(dockerHost)
if err != nil {
return nil, err
}
p.watcher = p.NewWatcher()
return
}
func (p *Provider) GetName() string {
return p.name
}
func (p *Provider) GetType() ProviderType {
return p.t
}
// to work with json marshaller
func (p *Provider) MarshalText() ([]byte, error) {
return []byte(p.String()), nil
}
func (p *Provider) StartAllRoutes() (res E.NestedError) {
errors := E.NewBuilder("errors in routes")
defer errors.To(&res)
// start watcher no matter load success or not
go p.watchEvents()
nStarted := 0
nFailed := 0
p.routes.RangeAll(func(alias string, r R.Route) {
if err := r.Start(); err.HasError() {
errors.Add(err.Subject(r))
nFailed++
} else {
nStarted++
}
})
p.l.Debugf("%d routes started, %d failed", nStarted, nFailed)
return
}
func (p *Provider) StopAllRoutes() (res E.NestedError) {
if p.watcherCancel != nil {
p.watcherCancel()
p.watcherCancel = nil
}
errors := E.NewBuilder("errors stopping routes for provider %q", p.name)
defer errors.To(&res)
nStopped := 0
nFailed := 0
p.routes.RangeAll(func(alias string, r R.Route) {
if err := r.Stop(); err.HasError() {
errors.Add(err.Subject(r))
nFailed++
} else {
nStopped++
}
})
p.l.Debugf("%d routes stopped, %d failed", nStopped, nFailed)
return
}
func (p *Provider) RangeRoutes(do func(string, R.Route)) {
p.routes.RangeAll(do)
}
func (p *Provider) GetRoute(alias string) (R.Route, bool) {
return p.routes.Load(alias)
}
func (p *Provider) LoadRoutes() E.NestedError {
var err E.NestedError
p.routes, err = p.LoadRoutesImpl()
if p.routes.Size() > 0 {
p.l.Infof("loaded %d routes", p.routes.Size())
return err
}
return E.FailWith("loading routes", err)
}
func (p *Provider) watchEvents() {
p.watcherCtx, p.watcherCancel = context.WithCancel(context.Background())
events, errs := p.watcher.Events(p.watcherCtx)
l := p.l.WithField("module", "watcher")
for {
select {
case <-p.watcherCtx.Done():
return
case event := <-events:
res := p.OnEvent(event, p.routes)
l.Infof("%s event %q", event.Type, event)
l.Infof("%d route added, %d routes removed", res.nAdded, res.nRemoved)
if res.err.HasError() {
l.Error(res.err)
}
case err := <-errs:
if err == nil || err.Is(context.Canceled) {
continue
}
l.Errorf("watcher error: %s", err)
}
}
}