mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-24 17:28:31 +02:00
fixed loadbalancer with idlewatcher, fixed reload issue
This commit is contained in:
@@ -66,7 +66,7 @@ func SetFindMuxDomains(domains []string) {
|
||||
}
|
||||
}
|
||||
|
||||
func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.NestedError) {
|
||||
func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.Error) {
|
||||
var trans *http.Transport
|
||||
|
||||
if entry.NoTLSVerify {
|
||||
@@ -97,7 +97,7 @@ func (r *HTTPRoute) String() string {
|
||||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError {
|
||||
func (r *HTTPRoute) Start(providerSubtask task.Task) E.Error {
|
||||
if entry.ShouldNotServe(r) {
|
||||
providerSubtask.Finish("should not serve")
|
||||
return nil
|
||||
@@ -151,7 +151,7 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError {
|
||||
r.addToLoadBalancer()
|
||||
} else {
|
||||
httpRoutes.Store(string(r.Alias), r)
|
||||
r.task.OnComplete("stop rp", func() {
|
||||
r.task.OnFinished("remove from route table", func() {
|
||||
httpRoutes.Delete(string(r.Alias))
|
||||
})
|
||||
}
|
||||
@@ -160,7 +160,7 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError {
|
||||
}
|
||||
|
||||
// Finish implements task.TaskFinisher.
|
||||
func (r *HTTPRoute) Finish(reason string) {
|
||||
func (r *HTTPRoute) Finish(reason any) {
|
||||
r.task.Finish(reason)
|
||||
}
|
||||
|
||||
@@ -175,8 +175,8 @@ func (r *HTTPRoute) addToLoadBalancer() {
|
||||
}
|
||||
} else {
|
||||
lb = loadbalancer.New(r.LoadBalance)
|
||||
lbTask := r.task.Parent().Subtask("loadbalancer %s", r.LoadBalance.Link)
|
||||
lbTask.OnComplete("remove lb from routes", func() {
|
||||
lbTask := r.task.Parent().Subtask("loadbalancer " + r.LoadBalance.Link)
|
||||
lbTask.OnCancel("remove lb from routes", func() {
|
||||
httpRoutes.Delete(r.LoadBalance.Link)
|
||||
})
|
||||
lb.Start(lbTask)
|
||||
@@ -194,9 +194,9 @@ func (r *HTTPRoute) addToLoadBalancer() {
|
||||
httpRoutes.Store(r.LoadBalance.Link, linked)
|
||||
}
|
||||
r.loadBalancer = lb
|
||||
r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon)
|
||||
r.server = loadbalancer.NewServer(r.task.String(), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon)
|
||||
lb.AddServer(r.server)
|
||||
r.task.OnComplete("remove server from lb", func() {
|
||||
r.task.OnCancel("remove server from lb", func() {
|
||||
lb.RemoveServer(r.server)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ var (
|
||||
AliasRefRegexOld = regexp.MustCompile(`\$\d+`)
|
||||
)
|
||||
|
||||
func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImpl, E.NestedError) {
|
||||
func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImpl, E.Error) {
|
||||
if dockerHost == common.DockerHostFromEnv {
|
||||
dockerHost = common.GetEnv("DOCKER_HOST", client.DefaultDockerHost)
|
||||
}
|
||||
@@ -40,18 +40,18 @@ func (p *DockerProvider) NewWatcher() W.Watcher {
|
||||
return W.NewDockerWatcher(p.dockerHost)
|
||||
}
|
||||
|
||||
func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
|
||||
func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.Error) {
|
||||
routes = R.NewRoutes()
|
||||
entries := entry.NewProxyEntries()
|
||||
|
||||
info, err := D.GetClientInfo(p.dockerHost, true)
|
||||
containers, err := D.ListContainers(p.dockerHost)
|
||||
if err != nil {
|
||||
return routes, E.FailWith("connect to docker", err)
|
||||
return routes, err
|
||||
}
|
||||
|
||||
errors := E.NewBuilder("errors in docker labels")
|
||||
|
||||
for _, c := range info.Containers {
|
||||
for _, c := range containers {
|
||||
container := D.FromDocker(&c, p.dockerHost)
|
||||
if container.IsExcluded {
|
||||
continue
|
||||
@@ -70,10 +70,6 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
|
||||
})
|
||||
}
|
||||
|
||||
entries.RangeAll(func(_ string, e *entry.RawEntry) {
|
||||
e.Container.DockerHost = p.dockerHost
|
||||
})
|
||||
|
||||
routes, err = R.FromEntries(entries)
|
||||
errors.Add(err)
|
||||
|
||||
@@ -89,7 +85,7 @@ func (p *DockerProvider) shouldIgnore(container *D.Container) bool {
|
||||
|
||||
// Returns a list of proxy entries for a container.
|
||||
// Always non-nil.
|
||||
func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries entry.RawEntries, _ E.NestedError) {
|
||||
func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries entry.RawEntries, _ E.Error) {
|
||||
entries = entry.NewProxyEntries()
|
||||
|
||||
if p.shouldIgnore(container) {
|
||||
@@ -117,7 +113,7 @@ func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (ent
|
||||
return entries, errors.Build().Subject(container.ContainerName)
|
||||
}
|
||||
|
||||
func (p *DockerProvider) applyLabel(container *D.Container, entries entry.RawEntries, key, val string) (res E.NestedError) {
|
||||
func (p *DockerProvider) applyLabel(container *D.Container, entries entry.RawEntries, key, val string) (res E.Error) {
|
||||
b := E.NewBuilder("errors in label %s", key)
|
||||
defer b.To(&res)
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
"github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/watcher"
|
||||
@@ -32,31 +34,52 @@ func (handler *EventHandler) Handle(parent task.Task, events []watcher.Event) {
|
||||
return
|
||||
}
|
||||
|
||||
oldRoutes.RangeAll(func(k string, v *route.Route) {
|
||||
if !newRoutes.Has(k) {
|
||||
handler.Remove(v)
|
||||
if common.IsDebug {
|
||||
eventsLog := E.NewBuilder("events")
|
||||
for _, event := range events {
|
||||
eventsLog.Addf("event %s, actor: name=%s, id=%s", event.Action, event.ActorName, event.ActorID)
|
||||
}
|
||||
handler.provider.l.Debug(eventsLog.String())
|
||||
oldRoutesLog := E.NewBuilder("old routes")
|
||||
oldRoutes.RangeAll(func(k string, r *route.Route) {
|
||||
oldRoutesLog.Addf(k)
|
||||
})
|
||||
handler.provider.l.Debug(oldRoutesLog.String())
|
||||
newRoutesLog := E.NewBuilder("new routes")
|
||||
newRoutes.RangeAll(func(k string, r *route.Route) {
|
||||
newRoutesLog.Addf(k)
|
||||
})
|
||||
handler.provider.l.Debug(newRoutesLog.String())
|
||||
}
|
||||
|
||||
oldRoutes.RangeAll(func(k string, oldr *route.Route) {
|
||||
newr, ok := newRoutes.Load(k)
|
||||
if !ok {
|
||||
handler.Remove(oldr)
|
||||
} else if handler.matchAny(events, newr) {
|
||||
handler.Update(parent, oldr, newr)
|
||||
} else if entry.ShouldNotServe(newr) {
|
||||
handler.Remove(oldr)
|
||||
}
|
||||
})
|
||||
newRoutes.RangeAll(func(k string, newr *route.Route) {
|
||||
if oldRoutes.Has(k) {
|
||||
for _, ev := range events {
|
||||
if handler.match(ev, newr) {
|
||||
old, ok := oldRoutes.Load(k)
|
||||
if !ok { // should not happen
|
||||
panic("race condition")
|
||||
}
|
||||
handler.Update(parent, old, newr)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if !(oldRoutes.Has(k) || entry.ShouldNotServe(newr)) {
|
||||
handler.Add(parent, newr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (handler *EventHandler) matchAny(events []watcher.Event, route *route.Route) bool {
|
||||
for _, event := range events {
|
||||
if handler.match(event, route) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool {
|
||||
switch handler.provider.t {
|
||||
switch handler.provider.GetType() {
|
||||
case ProviderTypeDocker:
|
||||
return route.Entry.Container.ContainerID == event.ActorID ||
|
||||
route.Entry.Container.ContainerName == event.ActorName
|
||||
@@ -70,14 +93,15 @@ func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool
|
||||
func (handler *EventHandler) Add(parent task.Task, route *route.Route) {
|
||||
err := handler.provider.startRoute(parent, route)
|
||||
if err != nil {
|
||||
handler.errs.Add(err)
|
||||
handler.errs.Add(E.FailWith("add "+route.Entry.Alias, err))
|
||||
} else {
|
||||
handler.added = append(handler.added, route.Entry.Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func (handler *EventHandler) Remove(route *route.Route) {
|
||||
route.Finish("route removal")
|
||||
route.Finish("route removed")
|
||||
handler.provider.routes.Delete(route.Entry.Alias)
|
||||
handler.removed = append(handler.removed, route.Entry.Alias)
|
||||
}
|
||||
|
||||
@@ -85,7 +109,7 @@ func (handler *EventHandler) Update(parent task.Task, oldRoute *route.Route, new
|
||||
oldRoute.Finish("route update")
|
||||
err := handler.provider.startRoute(parent, newRoute)
|
||||
if err != nil {
|
||||
handler.errs.Add(err)
|
||||
handler.errs.Add(E.FailWith("update "+newRoute.Entry.Alias, err))
|
||||
} else {
|
||||
handler.updated = append(handler.updated, newRoute.Entry.Alias)
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ type FileProvider struct {
|
||||
path string
|
||||
}
|
||||
|
||||
func FileProviderImpl(filename string) (ProviderImpl, E.NestedError) {
|
||||
func FileProviderImpl(filename string) (ProviderImpl, E.Error) {
|
||||
impl := &FileProvider{
|
||||
fileName: filename,
|
||||
path: path.Join(common.ConfigBasePath, filename),
|
||||
@@ -34,7 +34,7 @@ func FileProviderImpl(filename string) (ProviderImpl, E.NestedError) {
|
||||
}
|
||||
}
|
||||
|
||||
func Validate(data []byte) E.NestedError {
|
||||
func Validate(data []byte) E.Error {
|
||||
return U.ValidateYaml(U.GetSchema(common.FileProviderSchemaPath), data)
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func (p FileProvider) String() string {
|
||||
return p.fileName
|
||||
}
|
||||
|
||||
func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) {
|
||||
func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.Error) {
|
||||
routes = R.NewRoutes()
|
||||
|
||||
b := E.NewBuilder("validation failure")
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
R "github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
W "github.com/yusing/go-proxy/internal/watcher"
|
||||
@@ -29,7 +28,7 @@ type (
|
||||
ProviderImpl interface {
|
||||
fmt.Stringer
|
||||
NewWatcher() W.Watcher
|
||||
LoadRoutesImpl() (R.Routes, E.NestedError)
|
||||
LoadRoutesImpl() (R.Routes, E.Error)
|
||||
}
|
||||
ProviderType string
|
||||
ProviderStats struct {
|
||||
@@ -43,7 +42,7 @@ const (
|
||||
ProviderTypeDocker ProviderType = "docker"
|
||||
ProviderTypeFile ProviderType = "file"
|
||||
|
||||
providerEventFlushInterval = 500 * time.Millisecond
|
||||
providerEventFlushInterval = 300 * time.Millisecond
|
||||
)
|
||||
|
||||
func newProvider(name string, t ProviderType) *Provider {
|
||||
@@ -56,7 +55,7 @@ func newProvider(name string, t ProviderType) *Provider {
|
||||
return p
|
||||
}
|
||||
|
||||
func NewFileProvider(filename string) (p *Provider, err E.NestedError) {
|
||||
func NewFileProvider(filename string) (p *Provider, err E.Error) {
|
||||
name := path.Base(filename)
|
||||
if name == "" {
|
||||
return nil, E.Invalid("file name", "empty")
|
||||
@@ -70,7 +69,7 @@ func NewFileProvider(filename string) (p *Provider, err E.NestedError) {
|
||||
return
|
||||
}
|
||||
|
||||
func NewDockerProvider(name string, dockerHost string) (p *Provider, err E.NestedError) {
|
||||
func NewDockerProvider(name string, dockerHost string) (p *Provider, err E.Error) {
|
||||
if name == "" {
|
||||
return nil, E.Invalid("provider name", "empty")
|
||||
}
|
||||
@@ -101,18 +100,16 @@ func (p *Provider) MarshalText() ([]byte, error) {
|
||||
return []byte(p.String()), nil
|
||||
}
|
||||
|
||||
func (p *Provider) startRoute(parent task.Task, r *R.Route) E.NestedError {
|
||||
if entry.UseLoadBalance(r) {
|
||||
r.Entry.Alias = p.String() + "/" + r.Entry.Alias
|
||||
}
|
||||
subtask := parent.Subtask(r.Entry.Alias)
|
||||
func (p *Provider) startRoute(parent task.Task, r *R.Route) E.Error {
|
||||
subtask := parent.Subtask(p.String() + "/" + r.Entry.Alias)
|
||||
err := r.Start(subtask)
|
||||
if err != nil {
|
||||
p.routes.Delete(r.Entry.Alias)
|
||||
subtask.Finish(err.String()) // just to ensure
|
||||
subtask.Finish(err) // just to ensure
|
||||
return err
|
||||
} else {
|
||||
subtask.OnComplete("del from provider", func() {
|
||||
p.routes.Store(r.Entry.Alias, r)
|
||||
subtask.OnFinished("del from provider", func() {
|
||||
p.routes.Delete(r.Entry.Alias)
|
||||
})
|
||||
}
|
||||
@@ -120,7 +117,7 @@ func (p *Provider) startRoute(parent task.Task, r *R.Route) E.NestedError {
|
||||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (p *Provider) Start(configSubtask task.Task) (res E.NestedError) {
|
||||
func (p *Provider) Start(configSubtask task.Task) (res E.Error) {
|
||||
errors := E.NewBuilder("errors starting routes")
|
||||
defer errors.To(&res)
|
||||
|
||||
@@ -141,7 +138,7 @@ func (p *Provider) Start(configSubtask task.Task) (res E.NestedError) {
|
||||
handler.Log()
|
||||
flushTask.Finish("events flushed")
|
||||
},
|
||||
func(err E.NestedError) {
|
||||
func(err E.Error) {
|
||||
p.l.Error(err)
|
||||
},
|
||||
)
|
||||
@@ -157,8 +154,8 @@ func (p *Provider) GetRoute(alias string) (*R.Route, bool) {
|
||||
return p.routes.Load(alias)
|
||||
}
|
||||
|
||||
func (p *Provider) LoadRoutes() E.NestedError {
|
||||
var err E.NestedError
|
||||
func (p *Provider) LoadRoutes() E.Error {
|
||||
var err E.Error
|
||||
p.routes, err = p.LoadRoutesImpl()
|
||||
if p.routes.Size() > 0 {
|
||||
return err
|
||||
|
||||
94
internal/route/raw.go
Normal file
94
internal/route/raw.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
T "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type (
|
||||
RawStream struct {
|
||||
*StreamRoute
|
||||
|
||||
listener net.Listener
|
||||
targetAddr net.Addr
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
streamBufferSize = 8192
|
||||
streamDialTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
func NewRawStreamRoute(base *StreamRoute) *RawStream {
|
||||
return &RawStream{
|
||||
StreamRoute: base,
|
||||
}
|
||||
}
|
||||
|
||||
func (route *RawStream) Setup() error {
|
||||
var lcfg net.ListenConfig
|
||||
var err error
|
||||
|
||||
switch route.Scheme.ListeningScheme {
|
||||
case "tcp":
|
||||
route.targetAddr, err = net.ResolveTCPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tcpListener, err := lcfg.Listen(route.task.Context(), "tcp", fmt.Sprintf(":%v", route.Port.ListeningPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Port.ListeningPort = T.Port(tcpListener.Addr().(*net.TCPAddr).Port)
|
||||
route.listener = tcpListener
|
||||
case "udp":
|
||||
route.targetAddr, err = net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
udpListener, err := lcfg.ListenPacket(route.task.Context(), "udp", fmt.Sprintf(":%v", route.Port.ListeningPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Port.ListeningPort = T.Port(udpListener.LocalAddr().(*net.UDPAddr).Port)
|
||||
route.listener = newUDPListenerAdaptor(route.task.Context(), udpListener)
|
||||
default:
|
||||
return errors.New("invalid listening scheme: " + string(route.Scheme.ListeningScheme))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (route *RawStream) Accept() (net.Conn, error) {
|
||||
if route.listener == nil {
|
||||
return nil, errors.New("listener not yet set up")
|
||||
}
|
||||
return route.listener.Accept()
|
||||
}
|
||||
|
||||
func (route *RawStream) Handle(c net.Conn) error {
|
||||
clientConn := c.(net.Conn)
|
||||
|
||||
defer clientConn.Close()
|
||||
route.task.OnCancel("close conn", func() { clientConn.Close() })
|
||||
|
||||
dialer := &net.Dialer{Timeout: streamDialTimeout}
|
||||
|
||||
serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
|
||||
serverConn, err := dialer.DialContext(route.task.Context(), string(route.Scheme.ProxyScheme), serverAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn)
|
||||
return pipe.Start()
|
||||
}
|
||||
|
||||
func (route *RawStream) Close() error {
|
||||
return route.listener.Close()
|
||||
}
|
||||
@@ -44,7 +44,7 @@ func (rt *Route) Container() *docker.Container {
|
||||
return rt.Entry.Container
|
||||
}
|
||||
|
||||
func NewRoute(raw *entry.RawEntry) (*Route, E.NestedError) {
|
||||
func NewRoute(raw *entry.RawEntry) (*Route, E.Error) {
|
||||
en, err := entry.ValidateEntry(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -73,7 +73,7 @@ func NewRoute(raw *entry.RawEntry) (*Route, E.NestedError) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func FromEntries(entries entry.RawEntries) (Routes, E.NestedError) {
|
||||
func FromEntries(entries entry.RawEntries) (Routes, E.Error) {
|
||||
b := E.NewBuilder("errors in routes")
|
||||
|
||||
routes := NewRoutes()
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
stdNet "net"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -37,7 +36,7 @@ func GetStreamProxies() F.Map[string, *StreamRoute] {
|
||||
return streamRoutes
|
||||
}
|
||||
|
||||
func NewStreamRoute(entry *entry.StreamEntry) (impl, E.NestedError) {
|
||||
func NewStreamRoute(entry *entry.StreamEntry) (impl, E.Error) {
|
||||
// TODO: support non-coherent scheme
|
||||
if !entry.Scheme.IsCoherent() {
|
||||
return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme))
|
||||
@@ -48,16 +47,12 @@ func NewStreamRoute(entry *entry.StreamEntry) (impl, E.NestedError) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *StreamRoute) Finish(reason string) {
|
||||
r.task.Finish(reason)
|
||||
}
|
||||
|
||||
func (r *StreamRoute) String() string {
|
||||
return fmt.Sprintf("stream %s", r.Alias)
|
||||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
|
||||
func (r *StreamRoute) Start(providerSubtask task.Task) E.Error {
|
||||
if entry.ShouldNotServe(r) {
|
||||
providerSubtask.Finish("should not serve")
|
||||
return nil
|
||||
@@ -71,11 +66,13 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
|
||||
r.HealthCheck.Disable = true
|
||||
}
|
||||
|
||||
if r.Scheme.ListeningScheme.IsTCP() {
|
||||
r.Stream = NewTCPRoute(r)
|
||||
} else {
|
||||
r.Stream = NewUDPRoute(r)
|
||||
}
|
||||
// if r.Scheme.ListeningScheme.IsTCP() {
|
||||
// r.Stream = NewTCPRoute(r)
|
||||
// } else {
|
||||
// r.Stream = NewUDPRoute(r)
|
||||
// }
|
||||
r.task = providerSubtask
|
||||
r.Stream = NewRawStreamRoute(r)
|
||||
r.l = logrus.WithField("route", r.Stream.String())
|
||||
|
||||
switch {
|
||||
@@ -83,6 +80,7 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
|
||||
wakerTask := providerSubtask.Parent().Subtask("waker for " + string(r.Alias))
|
||||
waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.Stream)
|
||||
if err != nil {
|
||||
r.task.Finish(err)
|
||||
return err
|
||||
}
|
||||
r.Stream = waker
|
||||
@@ -90,24 +88,41 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
|
||||
case entry.UseHealthCheck(r):
|
||||
r.HealthMon = health.NewRawHealthMonitor(r.TargetURL(), r.HealthCheck)
|
||||
}
|
||||
r.task = providerSubtask
|
||||
r.task.OnComplete("stop stream", r.CloseListeners)
|
||||
|
||||
if err := r.Setup(); err != nil {
|
||||
r.task.Finish(err)
|
||||
return E.FailWith("setup", err)
|
||||
}
|
||||
r.l.Infof("listening on port %d", r.Port.ListeningPort)
|
||||
|
||||
go r.acceptConnections()
|
||||
r.task.OnFinished("close stream", func() {
|
||||
if err := r.Close(); err != nil {
|
||||
r.l.Error("close stream error: ", err)
|
||||
}
|
||||
})
|
||||
r.task.OnFinished("remove from route table", func() {
|
||||
streamRoutes.Delete(string(r.Alias))
|
||||
})
|
||||
|
||||
r.l.Infof("listening on %s port %d", r.Scheme.ListeningScheme, r.Port.ListeningPort)
|
||||
|
||||
if r.HealthMon != nil {
|
||||
r.HealthMon.Start(r.task.Subtask("health monitor"))
|
||||
if err := r.HealthMon.Start(r.task.Subtask("health monitor")); err != nil {
|
||||
logrus.Warn("health monitor error: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
go r.acceptConnections()
|
||||
streamRoutes.Store(string(r.Alias), r)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *StreamRoute) Finish(reason any) {
|
||||
r.task.Finish(reason)
|
||||
}
|
||||
|
||||
func (r *StreamRoute) acceptConnections() {
|
||||
defer r.task.Finish("listener closed")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.task.Context().Done():
|
||||
@@ -117,24 +132,17 @@ func (r *StreamRoute) acceptConnections() {
|
||||
if err != nil {
|
||||
select {
|
||||
case <-r.task.Context().Done():
|
||||
return
|
||||
default:
|
||||
var nErr *stdNet.OpError
|
||||
ok := errors.As(err, &nErr)
|
||||
if !(ok && nErr.Timeout()) {
|
||||
r.l.Error("accept connection error: ", err)
|
||||
r.task.Finish(err.Error())
|
||||
return
|
||||
}
|
||||
continue
|
||||
r.l.Error("accept connection error: ", err)
|
||||
r.task.Finish(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
connTask := r.task.Subtask("%s connection from %s", conn.RemoteAddr().Network(), conn.RemoteAddr().String())
|
||||
connTask := r.task.Subtask(fmt.Sprintf("connection from %s", conn.RemoteAddr()))
|
||||
go func() {
|
||||
err := r.Handle(conn)
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
r.l.Error(err)
|
||||
connTask.Finish(err.Error())
|
||||
} else {
|
||||
connTask.Finish("connection closed")
|
||||
}
|
||||
|
||||
@@ -1,71 +1,68 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
// import (
|
||||
// "context"
|
||||
// "fmt"
|
||||
// "net"
|
||||
// "time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
T "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
// "github.com/yusing/go-proxy/internal/net/types"
|
||||
// T "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
// U "github.com/yusing/go-proxy/internal/utils"
|
||||
// F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
// )
|
||||
|
||||
const tcpDialTimeout = 5 * time.Second
|
||||
// const tcpDialTimeout = 5 * time.Second
|
||||
|
||||
type (
|
||||
TCPConnMap = F.Map[net.Conn, struct{}]
|
||||
TCPRoute struct {
|
||||
*StreamRoute
|
||||
listener *net.TCPListener
|
||||
}
|
||||
)
|
||||
// type (
|
||||
// TCPConnMap = F.Map[net.Conn, struct{}]
|
||||
// TCPRoute struct {
|
||||
// *StreamRoute
|
||||
// listener *net.TCPListener
|
||||
// }
|
||||
// )
|
||||
|
||||
func NewTCPRoute(base *StreamRoute) *TCPRoute {
|
||||
return &TCPRoute{StreamRoute: base}
|
||||
}
|
||||
// func NewTCPRoute(base *StreamRoute) *TCPRoute {
|
||||
// return &TCPRoute{StreamRoute: base}
|
||||
// }
|
||||
|
||||
func (route *TCPRoute) Setup() error {
|
||||
in, err := net.Listen("tcp", fmt.Sprintf(":%v", route.Port.ListeningPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//! this read the allocated port from original ':0'
|
||||
route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port)
|
||||
route.listener = in.(*net.TCPListener)
|
||||
return nil
|
||||
}
|
||||
// func (route *TCPRoute) Setup() error {
|
||||
// var cfg net.ListenConfig
|
||||
// in, err := cfg.Listen(route.task.Context(), "tcp", fmt.Sprintf(":%v", route.Port.ListeningPort))
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// //! this read the allocated port from original ':0'
|
||||
// route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port)
|
||||
// route.listener = in.(*net.TCPListener)
|
||||
// return nil
|
||||
// }
|
||||
|
||||
func (route *TCPRoute) Accept() (types.StreamConn, error) {
|
||||
route.listener.SetDeadline(time.Now().Add(time.Second))
|
||||
return route.listener.Accept()
|
||||
}
|
||||
// func (route *TCPRoute) Accept() (types.StreamConn, error) {
|
||||
// return route.listener.Accept()
|
||||
// }
|
||||
|
||||
func (route *TCPRoute) Handle(c types.StreamConn) error {
|
||||
clientConn := c.(net.Conn)
|
||||
// func (route *TCPRoute) Handle(c types.StreamConn) error {
|
||||
// clientConn := c.(net.Conn)
|
||||
|
||||
defer clientConn.Close()
|
||||
route.task.OnComplete("close conn", func() { clientConn.Close() })
|
||||
// defer clientConn.Close()
|
||||
// route.task.OnCancel("close conn", func() { clientConn.Close() })
|
||||
|
||||
ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout)
|
||||
// ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout)
|
||||
|
||||
serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
|
||||
dialer := &net.Dialer{}
|
||||
// serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
|
||||
// dialer := &net.Dialer{}
|
||||
|
||||
serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr)
|
||||
// cancel()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn)
|
||||
return pipe.Start()
|
||||
}
|
||||
// pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn)
|
||||
// return pipe.Start()
|
||||
// }
|
||||
|
||||
func (route *TCPRoute) CloseListeners() {
|
||||
if route.listener == nil {
|
||||
return
|
||||
}
|
||||
route.listener.Close()
|
||||
}
|
||||
// func (route *TCPRoute) Close() error {
|
||||
// return route.listener.Close()
|
||||
// }
|
||||
|
||||
@@ -1,145 +1,149 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
// import (
|
||||
// "errors"
|
||||
// "fmt"
|
||||
// "io"
|
||||
// "net"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
T "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
// "github.com/yusing/go-proxy/internal/net/types"
|
||||
// T "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
// U "github.com/yusing/go-proxy/internal/utils"
|
||||
// F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
// )
|
||||
|
||||
type (
|
||||
UDPRoute struct {
|
||||
*StreamRoute
|
||||
// type (
|
||||
// UDPRoute struct {
|
||||
// *StreamRoute
|
||||
|
||||
connMap UDPConnMap
|
||||
// connMap UDPConnMap
|
||||
|
||||
listeningConn *net.UDPConn
|
||||
targetAddr *net.UDPAddr
|
||||
}
|
||||
UDPConn struct {
|
||||
key string
|
||||
src *net.UDPConn
|
||||
dst *net.UDPConn
|
||||
U.BidirectionalPipe
|
||||
}
|
||||
UDPConnMap = F.Map[string, *UDPConn]
|
||||
)
|
||||
// listeningConn net.PacketConn
|
||||
// targetAddr *net.UDPAddr
|
||||
// }
|
||||
// UDPConn struct {
|
||||
// key string
|
||||
// src net.Conn
|
||||
// dst net.Conn
|
||||
// U.BidirectionalPipe
|
||||
// }
|
||||
// UDPConnMap = F.Map[string, *UDPConn]
|
||||
// )
|
||||
|
||||
var NewUDPConnMap = F.NewMap[UDPConnMap]
|
||||
// var NewUDPConnMap = F.NewMap[UDPConnMap]
|
||||
|
||||
const udpBufferSize = 8192
|
||||
// const udpBufferSize = 8192
|
||||
|
||||
func NewUDPRoute(base *StreamRoute) *UDPRoute {
|
||||
return &UDPRoute{
|
||||
StreamRoute: base,
|
||||
connMap: NewUDPConnMap(),
|
||||
}
|
||||
}
|
||||
// func NewUDPRoute(base *StreamRoute) *UDPRoute {
|
||||
// return &UDPRoute{
|
||||
// StreamRoute: base,
|
||||
// connMap: NewUDPConnMap(),
|
||||
// }
|
||||
// }
|
||||
|
||||
func (route *UDPRoute) Setup() error {
|
||||
laddr, err := net.ResolveUDPAddr(string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ListeningPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
source, err := net.ListenUDP(string(route.Scheme.ListeningScheme), laddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
raddr, err := net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
|
||||
if err != nil {
|
||||
source.Close()
|
||||
return err
|
||||
}
|
||||
// func (route *UDPRoute) Setup() error {
|
||||
// var cfg net.ListenConfig
|
||||
// source, err := cfg.ListenPacket(route.task.Context(), string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ListeningPort))
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// raddr, err := net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
|
||||
// if err != nil {
|
||||
// source.Close()
|
||||
// return err
|
||||
// }
|
||||
|
||||
//! this read the allocated listeningPort from original ':0'
|
||||
route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port)
|
||||
// //! this read the allocated listeningPort from original ':0'
|
||||
// route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port)
|
||||
|
||||
route.listeningConn = source
|
||||
route.targetAddr = raddr
|
||||
// route.listeningConn = source
|
||||
// route.targetAddr = raddr
|
||||
|
||||
return nil
|
||||
}
|
||||
// return nil
|
||||
// }
|
||||
|
||||
func (route *UDPRoute) Accept() (types.StreamConn, error) {
|
||||
in := route.listeningConn
|
||||
// func (route *UDPRoute) Accept() (types.StreamConn, error) {
|
||||
// in := route.listeningConn
|
||||
|
||||
buffer := make([]byte, udpBufferSize)
|
||||
route.listeningConn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
nRead, srcAddr, err := in.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// buffer := make([]byte, udpBufferSize)
|
||||
// nRead, srcAddr, err := in.ReadFrom(buffer)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
if nRead == 0 {
|
||||
return nil, io.ErrShortBuffer
|
||||
}
|
||||
// if nRead == 0 {
|
||||
// return nil, io.ErrShortBuffer
|
||||
// }
|
||||
|
||||
key := srcAddr.String()
|
||||
conn, ok := route.connMap.Load(key)
|
||||
// key := srcAddr.String()
|
||||
// conn, ok := route.connMap.Load(key)
|
||||
|
||||
if !ok {
|
||||
srcConn, err := net.DialUDP("udp", nil, srcAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dstConn, err := net.DialUDP("udp", nil, route.targetAddr)
|
||||
if err != nil {
|
||||
srcConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
conn = &UDPConn{
|
||||
key,
|
||||
srcConn,
|
||||
dstConn,
|
||||
U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}),
|
||||
}
|
||||
route.connMap.Store(key, conn)
|
||||
}
|
||||
// if !ok {
|
||||
// srcConn, err := net.Dial(srcAddr.Network(), srcAddr.String())
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// dstConn, err := net.Dial(route.targetAddr.Network(), route.targetAddr.String())
|
||||
// if err != nil {
|
||||
// srcConn.Close()
|
||||
// return nil, err
|
||||
// }
|
||||
// conn = &UDPConn{
|
||||
// key,
|
||||
// srcConn,
|
||||
// dstConn,
|
||||
// U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}),
|
||||
// }
|
||||
// route.connMap.Store(key, conn)
|
||||
// }
|
||||
|
||||
_, err = conn.dst.Write(buffer[:nRead])
|
||||
return conn, err
|
||||
}
|
||||
// _, err = conn.dst.Write(buffer[:nRead])
|
||||
// return conn, err
|
||||
// }
|
||||
|
||||
func (route *UDPRoute) Handle(c types.StreamConn) error {
|
||||
conn := c.(*UDPConn)
|
||||
err := conn.Start()
|
||||
route.connMap.Delete(conn.key)
|
||||
return err
|
||||
}
|
||||
// func (route *UDPRoute) Handle(c types.StreamConn) error {
|
||||
// switch c := c.(type) {
|
||||
// case *UDPConn:
|
||||
// err := c.Start()
|
||||
// route.connMap.Delete(c.key)
|
||||
// c.Close()
|
||||
// return err
|
||||
// case *net.TCPConn:
|
||||
// in := route.listeningConn
|
||||
// srcConn, err := net.DialTCP("tcp", nil, c.RemoteAddr().(*net.TCPAddr))
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// err = U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, c}, sourceRWCloser{in, srcConn}).Start()
|
||||
// c.Close()
|
||||
// return err
|
||||
// }
|
||||
// return fmt.Errorf("unknown conn type: %T", c)
|
||||
// }
|
||||
|
||||
func (route *UDPRoute) CloseListeners() {
|
||||
if route.listeningConn != nil {
|
||||
route.listeningConn.Close()
|
||||
}
|
||||
route.connMap.RangeAllParallel(func(_ string, conn *UDPConn) {
|
||||
if err := conn.Close(); err != nil {
|
||||
route.l.Errorf("error closing conn: %s", err)
|
||||
}
|
||||
})
|
||||
route.connMap.Clear()
|
||||
}
|
||||
// func (route *UDPRoute) Close() error {
|
||||
// route.connMap.RangeAllParallel(func(k string, v *UDPConn) {
|
||||
// v.Close()
|
||||
// })
|
||||
// route.connMap.Clear()
|
||||
// return route.listeningConn.Close()
|
||||
// }
|
||||
|
||||
// Close implements types.StreamConn
|
||||
func (conn *UDPConn) Close() error {
|
||||
return errors.Join(conn.src.Close(), conn.dst.Close())
|
||||
}
|
||||
// // Close implements types.StreamConn
|
||||
// func (conn *UDPConn) Close() error {
|
||||
// return errors.Join(conn.src.Close(), conn.dst.Close())
|
||||
// }
|
||||
|
||||
// RemoteAddr implements types.StreamConn
|
||||
func (conn *UDPConn) RemoteAddr() net.Addr {
|
||||
return conn.src.RemoteAddr()
|
||||
}
|
||||
// // RemoteAddr implements types.StreamConn
|
||||
// func (conn *UDPConn) RemoteAddr() net.Addr {
|
||||
// return conn.src.RemoteAddr()
|
||||
// }
|
||||
|
||||
type sourceRWCloser struct {
|
||||
server *net.UDPConn
|
||||
*net.UDPConn
|
||||
}
|
||||
// type sourceRWCloser struct {
|
||||
// server net.PacketConn
|
||||
// net.Conn
|
||||
// }
|
||||
|
||||
func (w sourceRWCloser) Write(p []byte) (int, error) {
|
||||
return w.server.WriteToUDP(p, w.RemoteAddr().(*net.UDPAddr)) // TODO: support non udp
|
||||
}
|
||||
// func (w sourceRWCloser) Write(p []byte) (int, error) {
|
||||
// return w.server.WriteTo(p, w.RemoteAddr().(*net.UDPAddr))
|
||||
// }
|
||||
|
||||
73
internal/route/udp_listener.go
Normal file
73
internal/route/udp_listener.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
type (
|
||||
UDPListener struct {
|
||||
ctx context.Context
|
||||
listener net.PacketConn
|
||||
connMap UDPConnMap
|
||||
mu sync.Mutex
|
||||
}
|
||||
UDPConnMap = F.Map[string, net.Conn]
|
||||
)
|
||||
|
||||
var NewUDPConnMap = F.NewMap[UDPConnMap]
|
||||
|
||||
func newUDPListenerAdaptor(ctx context.Context, listener net.PacketConn) net.Listener {
|
||||
return &UDPListener{
|
||||
ctx: ctx,
|
||||
listener: listener,
|
||||
connMap: NewUDPConnMap(),
|
||||
}
|
||||
}
|
||||
|
||||
// Addr implements net.Listener.
|
||||
func (route *UDPListener) Addr() net.Addr {
|
||||
return route.listener.LocalAddr()
|
||||
}
|
||||
|
||||
func (udpl *UDPListener) Accept() (net.Conn, error) {
|
||||
in := udpl.listener
|
||||
|
||||
buffer := make([]byte, streamBufferSize)
|
||||
nRead, srcAddr, err := in.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if nRead == 0 {
|
||||
return nil, io.ErrShortBuffer
|
||||
}
|
||||
|
||||
udpl.mu.Lock()
|
||||
defer udpl.mu.Unlock()
|
||||
|
||||
key := srcAddr.String()
|
||||
conn, ok := udpl.connMap.Load(key)
|
||||
if !ok {
|
||||
dialer := &net.Dialer{Timeout: streamDialTimeout}
|
||||
srcConn, err := dialer.DialContext(udpl.ctx, srcAddr.Network(), srcAddr.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udpl.connMap.Store(key, srcConn)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Close implements net.Listener.
|
||||
func (route *UDPListener) Close() error {
|
||||
route.connMap.RangeAllParallel(func(key string, conn net.Conn) {
|
||||
conn.Close()
|
||||
})
|
||||
route.connMap.Clear()
|
||||
return route.listener.Close()
|
||||
}
|
||||
Reference in New Issue
Block a user