fixed loadbalancer with idlewatcher, fixed reload issue

This commit is contained in:
yusing
2024-10-20 09:46:02 +08:00
parent 01ffe0d97c
commit a278711421
78 changed files with 906 additions and 609 deletions

View File

@@ -66,7 +66,7 @@ func SetFindMuxDomains(domains []string) {
}
}
func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.NestedError) {
func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.Error) {
var trans *http.Transport
if entry.NoTLSVerify {
@@ -97,7 +97,7 @@ func (r *HTTPRoute) String() string {
}
// Start implements task.TaskStarter.
func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError {
func (r *HTTPRoute) Start(providerSubtask task.Task) E.Error {
if entry.ShouldNotServe(r) {
providerSubtask.Finish("should not serve")
return nil
@@ -151,7 +151,7 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError {
r.addToLoadBalancer()
} else {
httpRoutes.Store(string(r.Alias), r)
r.task.OnComplete("stop rp", func() {
r.task.OnFinished("remove from route table", func() {
httpRoutes.Delete(string(r.Alias))
})
}
@@ -160,7 +160,7 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError {
}
// Finish implements task.TaskFinisher.
func (r *HTTPRoute) Finish(reason string) {
func (r *HTTPRoute) Finish(reason any) {
r.task.Finish(reason)
}
@@ -175,8 +175,8 @@ func (r *HTTPRoute) addToLoadBalancer() {
}
} else {
lb = loadbalancer.New(r.LoadBalance)
lbTask := r.task.Parent().Subtask("loadbalancer %s", r.LoadBalance.Link)
lbTask.OnComplete("remove lb from routes", func() {
lbTask := r.task.Parent().Subtask("loadbalancer " + r.LoadBalance.Link)
lbTask.OnCancel("remove lb from routes", func() {
httpRoutes.Delete(r.LoadBalance.Link)
})
lb.Start(lbTask)
@@ -194,9 +194,9 @@ func (r *HTTPRoute) addToLoadBalancer() {
httpRoutes.Store(r.LoadBalance.Link, linked)
}
r.loadBalancer = lb
r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon)
r.server = loadbalancer.NewServer(r.task.String(), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon)
lb.AddServer(r.server)
r.task.OnComplete("remove server from lb", func() {
r.task.OnCancel("remove server from lb", func() {
lb.RemoveServer(r.server)
})
}

View File

@@ -25,7 +25,7 @@ var (
AliasRefRegexOld = regexp.MustCompile(`\$\d+`)
)
func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImpl, E.NestedError) {
func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImpl, E.Error) {
if dockerHost == common.DockerHostFromEnv {
dockerHost = common.GetEnv("DOCKER_HOST", client.DefaultDockerHost)
}
@@ -40,18 +40,18 @@ func (p *DockerProvider) NewWatcher() W.Watcher {
return W.NewDockerWatcher(p.dockerHost)
}
func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.Error) {
routes = R.NewRoutes()
entries := entry.NewProxyEntries()
info, err := D.GetClientInfo(p.dockerHost, true)
containers, err := D.ListContainers(p.dockerHost)
if err != nil {
return routes, E.FailWith("connect to docker", err)
return routes, err
}
errors := E.NewBuilder("errors in docker labels")
for _, c := range info.Containers {
for _, c := range containers {
container := D.FromDocker(&c, p.dockerHost)
if container.IsExcluded {
continue
@@ -70,10 +70,6 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
})
}
entries.RangeAll(func(_ string, e *entry.RawEntry) {
e.Container.DockerHost = p.dockerHost
})
routes, err = R.FromEntries(entries)
errors.Add(err)
@@ -89,7 +85,7 @@ func (p *DockerProvider) shouldIgnore(container *D.Container) bool {
// Returns a list of proxy entries for a container.
// Always non-nil.
func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries entry.RawEntries, _ E.NestedError) {
func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries entry.RawEntries, _ E.Error) {
entries = entry.NewProxyEntries()
if p.shouldIgnore(container) {
@@ -117,7 +113,7 @@ func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (ent
return entries, errors.Build().Subject(container.ContainerName)
}
func (p *DockerProvider) applyLabel(container *D.Container, entries entry.RawEntries, key, val string) (res E.NestedError) {
func (p *DockerProvider) applyLabel(container *D.Container, entries entry.RawEntries, key, val string) (res E.Error) {
b := E.NewBuilder("errors in label %s", key)
defer b.To(&res)

View File

@@ -1,7 +1,9 @@
package provider
import (
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher"
@@ -32,31 +34,52 @@ func (handler *EventHandler) Handle(parent task.Task, events []watcher.Event) {
return
}
oldRoutes.RangeAll(func(k string, v *route.Route) {
if !newRoutes.Has(k) {
handler.Remove(v)
if common.IsDebug {
eventsLog := E.NewBuilder("events")
for _, event := range events {
eventsLog.Addf("event %s, actor: name=%s, id=%s", event.Action, event.ActorName, event.ActorID)
}
handler.provider.l.Debug(eventsLog.String())
oldRoutesLog := E.NewBuilder("old routes")
oldRoutes.RangeAll(func(k string, r *route.Route) {
oldRoutesLog.Addf(k)
})
handler.provider.l.Debug(oldRoutesLog.String())
newRoutesLog := E.NewBuilder("new routes")
newRoutes.RangeAll(func(k string, r *route.Route) {
newRoutesLog.Addf(k)
})
handler.provider.l.Debug(newRoutesLog.String())
}
oldRoutes.RangeAll(func(k string, oldr *route.Route) {
newr, ok := newRoutes.Load(k)
if !ok {
handler.Remove(oldr)
} else if handler.matchAny(events, newr) {
handler.Update(parent, oldr, newr)
} else if entry.ShouldNotServe(newr) {
handler.Remove(oldr)
}
})
newRoutes.RangeAll(func(k string, newr *route.Route) {
if oldRoutes.Has(k) {
for _, ev := range events {
if handler.match(ev, newr) {
old, ok := oldRoutes.Load(k)
if !ok { // should not happen
panic("race condition")
}
handler.Update(parent, old, newr)
return
}
}
} else {
if !(oldRoutes.Has(k) || entry.ShouldNotServe(newr)) {
handler.Add(parent, newr)
}
})
}
func (handler *EventHandler) matchAny(events []watcher.Event, route *route.Route) bool {
for _, event := range events {
if handler.match(event, route) {
return true
}
}
return false
}
func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool {
switch handler.provider.t {
switch handler.provider.GetType() {
case ProviderTypeDocker:
return route.Entry.Container.ContainerID == event.ActorID ||
route.Entry.Container.ContainerName == event.ActorName
@@ -70,14 +93,15 @@ func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool
func (handler *EventHandler) Add(parent task.Task, route *route.Route) {
err := handler.provider.startRoute(parent, route)
if err != nil {
handler.errs.Add(err)
handler.errs.Add(E.FailWith("add "+route.Entry.Alias, err))
} else {
handler.added = append(handler.added, route.Entry.Alias)
}
}
func (handler *EventHandler) Remove(route *route.Route) {
route.Finish("route removal")
route.Finish("route removed")
handler.provider.routes.Delete(route.Entry.Alias)
handler.removed = append(handler.removed, route.Entry.Alias)
}
@@ -85,7 +109,7 @@ func (handler *EventHandler) Update(parent task.Task, oldRoute *route.Route, new
oldRoute.Finish("route update")
err := handler.provider.startRoute(parent, newRoute)
if err != nil {
handler.errs.Add(err)
handler.errs.Add(E.FailWith("update "+newRoute.Entry.Alias, err))
} else {
handler.updated = append(handler.updated, newRoute.Entry.Alias)
}

View File

@@ -18,7 +18,7 @@ type FileProvider struct {
path string
}
func FileProviderImpl(filename string) (ProviderImpl, E.NestedError) {
func FileProviderImpl(filename string) (ProviderImpl, E.Error) {
impl := &FileProvider{
fileName: filename,
path: path.Join(common.ConfigBasePath, filename),
@@ -34,7 +34,7 @@ func FileProviderImpl(filename string) (ProviderImpl, E.NestedError) {
}
}
func Validate(data []byte) E.NestedError {
func Validate(data []byte) E.Error {
return U.ValidateYaml(U.GetSchema(common.FileProviderSchemaPath), data)
}
@@ -42,7 +42,7 @@ func (p FileProvider) String() string {
return p.fileName
}
func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) {
func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.Error) {
routes = R.NewRoutes()
b := E.NewBuilder("validation failure")

View File

@@ -7,7 +7,6 @@ import (
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/proxy/entry"
R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/task"
W "github.com/yusing/go-proxy/internal/watcher"
@@ -29,7 +28,7 @@ type (
ProviderImpl interface {
fmt.Stringer
NewWatcher() W.Watcher
LoadRoutesImpl() (R.Routes, E.NestedError)
LoadRoutesImpl() (R.Routes, E.Error)
}
ProviderType string
ProviderStats struct {
@@ -43,7 +42,7 @@ const (
ProviderTypeDocker ProviderType = "docker"
ProviderTypeFile ProviderType = "file"
providerEventFlushInterval = 500 * time.Millisecond
providerEventFlushInterval = 300 * time.Millisecond
)
func newProvider(name string, t ProviderType) *Provider {
@@ -56,7 +55,7 @@ func newProvider(name string, t ProviderType) *Provider {
return p
}
func NewFileProvider(filename string) (p *Provider, err E.NestedError) {
func NewFileProvider(filename string) (p *Provider, err E.Error) {
name := path.Base(filename)
if name == "" {
return nil, E.Invalid("file name", "empty")
@@ -70,7 +69,7 @@ func NewFileProvider(filename string) (p *Provider, err E.NestedError) {
return
}
func NewDockerProvider(name string, dockerHost string) (p *Provider, err E.NestedError) {
func NewDockerProvider(name string, dockerHost string) (p *Provider, err E.Error) {
if name == "" {
return nil, E.Invalid("provider name", "empty")
}
@@ -101,18 +100,16 @@ func (p *Provider) MarshalText() ([]byte, error) {
return []byte(p.String()), nil
}
func (p *Provider) startRoute(parent task.Task, r *R.Route) E.NestedError {
if entry.UseLoadBalance(r) {
r.Entry.Alias = p.String() + "/" + r.Entry.Alias
}
subtask := parent.Subtask(r.Entry.Alias)
func (p *Provider) startRoute(parent task.Task, r *R.Route) E.Error {
subtask := parent.Subtask(p.String() + "/" + r.Entry.Alias)
err := r.Start(subtask)
if err != nil {
p.routes.Delete(r.Entry.Alias)
subtask.Finish(err.String()) // just to ensure
subtask.Finish(err) // just to ensure
return err
} else {
subtask.OnComplete("del from provider", func() {
p.routes.Store(r.Entry.Alias, r)
subtask.OnFinished("del from provider", func() {
p.routes.Delete(r.Entry.Alias)
})
}
@@ -120,7 +117,7 @@ func (p *Provider) startRoute(parent task.Task, r *R.Route) E.NestedError {
}
// Start implements task.TaskStarter.
func (p *Provider) Start(configSubtask task.Task) (res E.NestedError) {
func (p *Provider) Start(configSubtask task.Task) (res E.Error) {
errors := E.NewBuilder("errors starting routes")
defer errors.To(&res)
@@ -141,7 +138,7 @@ func (p *Provider) Start(configSubtask task.Task) (res E.NestedError) {
handler.Log()
flushTask.Finish("events flushed")
},
func(err E.NestedError) {
func(err E.Error) {
p.l.Error(err)
},
)
@@ -157,8 +154,8 @@ func (p *Provider) GetRoute(alias string) (*R.Route, bool) {
return p.routes.Load(alias)
}
func (p *Provider) LoadRoutes() E.NestedError {
var err E.NestedError
func (p *Provider) LoadRoutes() E.Error {
var err E.Error
p.routes, err = p.LoadRoutesImpl()
if p.routes.Size() > 0 {
return err

94
internal/route/raw.go Normal file
View File

@@ -0,0 +1,94 @@
package route
import (
"errors"
"fmt"
"net"
"time"
T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils"
)
type (
RawStream struct {
*StreamRoute
listener net.Listener
targetAddr net.Addr
}
)
const (
streamBufferSize = 8192
streamDialTimeout = 5 * time.Second
)
func NewRawStreamRoute(base *StreamRoute) *RawStream {
return &RawStream{
StreamRoute: base,
}
}
func (route *RawStream) Setup() error {
var lcfg net.ListenConfig
var err error
switch route.Scheme.ListeningScheme {
case "tcp":
route.targetAddr, err = net.ResolveTCPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
if err != nil {
return err
}
tcpListener, err := lcfg.Listen(route.task.Context(), "tcp", fmt.Sprintf(":%v", route.Port.ListeningPort))
if err != nil {
return err
}
route.Port.ListeningPort = T.Port(tcpListener.Addr().(*net.TCPAddr).Port)
route.listener = tcpListener
case "udp":
route.targetAddr, err = net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
if err != nil {
return err
}
udpListener, err := lcfg.ListenPacket(route.task.Context(), "udp", fmt.Sprintf(":%v", route.Port.ListeningPort))
if err != nil {
return err
}
route.Port.ListeningPort = T.Port(udpListener.LocalAddr().(*net.UDPAddr).Port)
route.listener = newUDPListenerAdaptor(route.task.Context(), udpListener)
default:
return errors.New("invalid listening scheme: " + string(route.Scheme.ListeningScheme))
}
return nil
}
func (route *RawStream) Accept() (net.Conn, error) {
if route.listener == nil {
return nil, errors.New("listener not yet set up")
}
return route.listener.Accept()
}
func (route *RawStream) Handle(c net.Conn) error {
clientConn := c.(net.Conn)
defer clientConn.Close()
route.task.OnCancel("close conn", func() { clientConn.Close() })
dialer := &net.Dialer{Timeout: streamDialTimeout}
serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
serverConn, err := dialer.DialContext(route.task.Context(), string(route.Scheme.ProxyScheme), serverAddr)
if err != nil {
return err
}
pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn)
return pipe.Start()
}
func (route *RawStream) Close() error {
return route.listener.Close()
}

View File

@@ -44,7 +44,7 @@ func (rt *Route) Container() *docker.Container {
return rt.Entry.Container
}
func NewRoute(raw *entry.RawEntry) (*Route, E.NestedError) {
func NewRoute(raw *entry.RawEntry) (*Route, E.Error) {
en, err := entry.ValidateEntry(raw)
if err != nil {
return nil, err
@@ -73,7 +73,7 @@ func NewRoute(raw *entry.RawEntry) (*Route, E.NestedError) {
}, nil
}
func FromEntries(entries entry.RawEntries) (Routes, E.NestedError) {
func FromEntries(entries entry.RawEntries) (Routes, E.Error) {
b := E.NewBuilder("errors in routes")
routes := NewRoutes()

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
stdNet "net"
"sync"
"github.com/sirupsen/logrus"
@@ -37,7 +36,7 @@ func GetStreamProxies() F.Map[string, *StreamRoute] {
return streamRoutes
}
func NewStreamRoute(entry *entry.StreamEntry) (impl, E.NestedError) {
func NewStreamRoute(entry *entry.StreamEntry) (impl, E.Error) {
// TODO: support non-coherent scheme
if !entry.Scheme.IsCoherent() {
return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme))
@@ -48,16 +47,12 @@ func NewStreamRoute(entry *entry.StreamEntry) (impl, E.NestedError) {
}, nil
}
func (r *StreamRoute) Finish(reason string) {
r.task.Finish(reason)
}
func (r *StreamRoute) String() string {
return fmt.Sprintf("stream %s", r.Alias)
}
// Start implements task.TaskStarter.
func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
func (r *StreamRoute) Start(providerSubtask task.Task) E.Error {
if entry.ShouldNotServe(r) {
providerSubtask.Finish("should not serve")
return nil
@@ -71,11 +66,13 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
r.HealthCheck.Disable = true
}
if r.Scheme.ListeningScheme.IsTCP() {
r.Stream = NewTCPRoute(r)
} else {
r.Stream = NewUDPRoute(r)
}
// if r.Scheme.ListeningScheme.IsTCP() {
// r.Stream = NewTCPRoute(r)
// } else {
// r.Stream = NewUDPRoute(r)
// }
r.task = providerSubtask
r.Stream = NewRawStreamRoute(r)
r.l = logrus.WithField("route", r.Stream.String())
switch {
@@ -83,6 +80,7 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
wakerTask := providerSubtask.Parent().Subtask("waker for " + string(r.Alias))
waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.Stream)
if err != nil {
r.task.Finish(err)
return err
}
r.Stream = waker
@@ -90,24 +88,41 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
case entry.UseHealthCheck(r):
r.HealthMon = health.NewRawHealthMonitor(r.TargetURL(), r.HealthCheck)
}
r.task = providerSubtask
r.task.OnComplete("stop stream", r.CloseListeners)
if err := r.Setup(); err != nil {
r.task.Finish(err)
return E.FailWith("setup", err)
}
r.l.Infof("listening on port %d", r.Port.ListeningPort)
go r.acceptConnections()
r.task.OnFinished("close stream", func() {
if err := r.Close(); err != nil {
r.l.Error("close stream error: ", err)
}
})
r.task.OnFinished("remove from route table", func() {
streamRoutes.Delete(string(r.Alias))
})
r.l.Infof("listening on %s port %d", r.Scheme.ListeningScheme, r.Port.ListeningPort)
if r.HealthMon != nil {
r.HealthMon.Start(r.task.Subtask("health monitor"))
if err := r.HealthMon.Start(r.task.Subtask("health monitor")); err != nil {
logrus.Warn("health monitor error: ", err)
}
}
go r.acceptConnections()
streamRoutes.Store(string(r.Alias), r)
return nil
}
func (r *StreamRoute) Finish(reason any) {
r.task.Finish(reason)
}
func (r *StreamRoute) acceptConnections() {
defer r.task.Finish("listener closed")
for {
select {
case <-r.task.Context().Done():
@@ -117,24 +132,17 @@ func (r *StreamRoute) acceptConnections() {
if err != nil {
select {
case <-r.task.Context().Done():
return
default:
var nErr *stdNet.OpError
ok := errors.As(err, &nErr)
if !(ok && nErr.Timeout()) {
r.l.Error("accept connection error: ", err)
r.task.Finish(err.Error())
return
}
continue
r.l.Error("accept connection error: ", err)
r.task.Finish(err)
}
return
}
connTask := r.task.Subtask("%s connection from %s", conn.RemoteAddr().Network(), conn.RemoteAddr().String())
connTask := r.task.Subtask(fmt.Sprintf("connection from %s", conn.RemoteAddr()))
go func() {
err := r.Handle(conn)
if err != nil && !errors.Is(err, context.Canceled) {
r.l.Error(err)
connTask.Finish(err.Error())
} else {
connTask.Finish("connection closed")
}

View File

@@ -1,71 +1,68 @@
package route
import (
"context"
"fmt"
"net"
"time"
// import (
// "context"
// "fmt"
// "net"
// "time"
"github.com/yusing/go-proxy/internal/net/types"
T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
// "github.com/yusing/go-proxy/internal/net/types"
// T "github.com/yusing/go-proxy/internal/proxy/fields"
// U "github.com/yusing/go-proxy/internal/utils"
// F "github.com/yusing/go-proxy/internal/utils/functional"
// )
const tcpDialTimeout = 5 * time.Second
// const tcpDialTimeout = 5 * time.Second
type (
TCPConnMap = F.Map[net.Conn, struct{}]
TCPRoute struct {
*StreamRoute
listener *net.TCPListener
}
)
// type (
// TCPConnMap = F.Map[net.Conn, struct{}]
// TCPRoute struct {
// *StreamRoute
// listener *net.TCPListener
// }
// )
func NewTCPRoute(base *StreamRoute) *TCPRoute {
return &TCPRoute{StreamRoute: base}
}
// func NewTCPRoute(base *StreamRoute) *TCPRoute {
// return &TCPRoute{StreamRoute: base}
// }
func (route *TCPRoute) Setup() error {
in, err := net.Listen("tcp", fmt.Sprintf(":%v", route.Port.ListeningPort))
if err != nil {
return err
}
//! this read the allocated port from original ':0'
route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port)
route.listener = in.(*net.TCPListener)
return nil
}
// func (route *TCPRoute) Setup() error {
// var cfg net.ListenConfig
// in, err := cfg.Listen(route.task.Context(), "tcp", fmt.Sprintf(":%v", route.Port.ListeningPort))
// if err != nil {
// return err
// }
// //! this read the allocated port from original ':0'
// route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port)
// route.listener = in.(*net.TCPListener)
// return nil
// }
func (route *TCPRoute) Accept() (types.StreamConn, error) {
route.listener.SetDeadline(time.Now().Add(time.Second))
return route.listener.Accept()
}
// func (route *TCPRoute) Accept() (types.StreamConn, error) {
// return route.listener.Accept()
// }
func (route *TCPRoute) Handle(c types.StreamConn) error {
clientConn := c.(net.Conn)
// func (route *TCPRoute) Handle(c types.StreamConn) error {
// clientConn := c.(net.Conn)
defer clientConn.Close()
route.task.OnComplete("close conn", func() { clientConn.Close() })
// defer clientConn.Close()
// route.task.OnCancel("close conn", func() { clientConn.Close() })
ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout)
// ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout)
serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
dialer := &net.Dialer{}
// serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
// dialer := &net.Dialer{}
serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr)
cancel()
if err != nil {
return err
}
// serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr)
// cancel()
// if err != nil {
// return err
// }
pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn)
return pipe.Start()
}
// pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn)
// return pipe.Start()
// }
func (route *TCPRoute) CloseListeners() {
if route.listener == nil {
return
}
route.listener.Close()
}
// func (route *TCPRoute) Close() error {
// return route.listener.Close()
// }

View File

@@ -1,145 +1,149 @@
package route
import (
"errors"
"fmt"
"io"
"net"
"time"
// import (
// "errors"
// "fmt"
// "io"
// "net"
"github.com/yusing/go-proxy/internal/net/types"
T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
// "github.com/yusing/go-proxy/internal/net/types"
// T "github.com/yusing/go-proxy/internal/proxy/fields"
// U "github.com/yusing/go-proxy/internal/utils"
// F "github.com/yusing/go-proxy/internal/utils/functional"
// )
type (
UDPRoute struct {
*StreamRoute
// type (
// UDPRoute struct {
// *StreamRoute
connMap UDPConnMap
// connMap UDPConnMap
listeningConn *net.UDPConn
targetAddr *net.UDPAddr
}
UDPConn struct {
key string
src *net.UDPConn
dst *net.UDPConn
U.BidirectionalPipe
}
UDPConnMap = F.Map[string, *UDPConn]
)
// listeningConn net.PacketConn
// targetAddr *net.UDPAddr
// }
// UDPConn struct {
// key string
// src net.Conn
// dst net.Conn
// U.BidirectionalPipe
// }
// UDPConnMap = F.Map[string, *UDPConn]
// )
var NewUDPConnMap = F.NewMap[UDPConnMap]
// var NewUDPConnMap = F.NewMap[UDPConnMap]
const udpBufferSize = 8192
// const udpBufferSize = 8192
func NewUDPRoute(base *StreamRoute) *UDPRoute {
return &UDPRoute{
StreamRoute: base,
connMap: NewUDPConnMap(),
}
}
// func NewUDPRoute(base *StreamRoute) *UDPRoute {
// return &UDPRoute{
// StreamRoute: base,
// connMap: NewUDPConnMap(),
// }
// }
func (route *UDPRoute) Setup() error {
laddr, err := net.ResolveUDPAddr(string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ListeningPort))
if err != nil {
return err
}
source, err := net.ListenUDP(string(route.Scheme.ListeningScheme), laddr)
if err != nil {
return err
}
raddr, err := net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
if err != nil {
source.Close()
return err
}
// func (route *UDPRoute) Setup() error {
// var cfg net.ListenConfig
// source, err := cfg.ListenPacket(route.task.Context(), string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ListeningPort))
// if err != nil {
// return err
// }
// raddr, err := net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
// if err != nil {
// source.Close()
// return err
// }
//! this read the allocated listeningPort from original ':0'
route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port)
// //! this read the allocated listeningPort from original ':0'
// route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port)
route.listeningConn = source
route.targetAddr = raddr
// route.listeningConn = source
// route.targetAddr = raddr
return nil
}
// return nil
// }
func (route *UDPRoute) Accept() (types.StreamConn, error) {
in := route.listeningConn
// func (route *UDPRoute) Accept() (types.StreamConn, error) {
// in := route.listeningConn
buffer := make([]byte, udpBufferSize)
route.listeningConn.SetReadDeadline(time.Now().Add(time.Second))
nRead, srcAddr, err := in.ReadFromUDP(buffer)
if err != nil {
return nil, err
}
// buffer := make([]byte, udpBufferSize)
// nRead, srcAddr, err := in.ReadFrom(buffer)
// if err != nil {
// return nil, err
// }
if nRead == 0 {
return nil, io.ErrShortBuffer
}
// if nRead == 0 {
// return nil, io.ErrShortBuffer
// }
key := srcAddr.String()
conn, ok := route.connMap.Load(key)
// key := srcAddr.String()
// conn, ok := route.connMap.Load(key)
if !ok {
srcConn, err := net.DialUDP("udp", nil, srcAddr)
if err != nil {
return nil, err
}
dstConn, err := net.DialUDP("udp", nil, route.targetAddr)
if err != nil {
srcConn.Close()
return nil, err
}
conn = &UDPConn{
key,
srcConn,
dstConn,
U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}),
}
route.connMap.Store(key, conn)
}
// if !ok {
// srcConn, err := net.Dial(srcAddr.Network(), srcAddr.String())
// if err != nil {
// return nil, err
// }
// dstConn, err := net.Dial(route.targetAddr.Network(), route.targetAddr.String())
// if err != nil {
// srcConn.Close()
// return nil, err
// }
// conn = &UDPConn{
// key,
// srcConn,
// dstConn,
// U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}),
// }
// route.connMap.Store(key, conn)
// }
_, err = conn.dst.Write(buffer[:nRead])
return conn, err
}
// _, err = conn.dst.Write(buffer[:nRead])
// return conn, err
// }
func (route *UDPRoute) Handle(c types.StreamConn) error {
conn := c.(*UDPConn)
err := conn.Start()
route.connMap.Delete(conn.key)
return err
}
// func (route *UDPRoute) Handle(c types.StreamConn) error {
// switch c := c.(type) {
// case *UDPConn:
// err := c.Start()
// route.connMap.Delete(c.key)
// c.Close()
// return err
// case *net.TCPConn:
// in := route.listeningConn
// srcConn, err := net.DialTCP("tcp", nil, c.RemoteAddr().(*net.TCPAddr))
// if err != nil {
// return err
// }
// err = U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, c}, sourceRWCloser{in, srcConn}).Start()
// c.Close()
// return err
// }
// return fmt.Errorf("unknown conn type: %T", c)
// }
func (route *UDPRoute) CloseListeners() {
if route.listeningConn != nil {
route.listeningConn.Close()
}
route.connMap.RangeAllParallel(func(_ string, conn *UDPConn) {
if err := conn.Close(); err != nil {
route.l.Errorf("error closing conn: %s", err)
}
})
route.connMap.Clear()
}
// func (route *UDPRoute) Close() error {
// route.connMap.RangeAllParallel(func(k string, v *UDPConn) {
// v.Close()
// })
// route.connMap.Clear()
// return route.listeningConn.Close()
// }
// Close implements types.StreamConn
func (conn *UDPConn) Close() error {
return errors.Join(conn.src.Close(), conn.dst.Close())
}
// // Close implements types.StreamConn
// func (conn *UDPConn) Close() error {
// return errors.Join(conn.src.Close(), conn.dst.Close())
// }
// RemoteAddr implements types.StreamConn
func (conn *UDPConn) RemoteAddr() net.Addr {
return conn.src.RemoteAddr()
}
// // RemoteAddr implements types.StreamConn
// func (conn *UDPConn) RemoteAddr() net.Addr {
// return conn.src.RemoteAddr()
// }
type sourceRWCloser struct {
server *net.UDPConn
*net.UDPConn
}
// type sourceRWCloser struct {
// server net.PacketConn
// net.Conn
// }
func (w sourceRWCloser) Write(p []byte) (int, error) {
return w.server.WriteToUDP(p, w.RemoteAddr().(*net.UDPAddr)) // TODO: support non udp
}
// func (w sourceRWCloser) Write(p []byte) (int, error) {
// return w.server.WriteTo(p, w.RemoteAddr().(*net.UDPAddr))
// }

View File

@@ -0,0 +1,73 @@
package route
import (
"context"
"io"
"net"
"sync"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
UDPListener struct {
ctx context.Context
listener net.PacketConn
connMap UDPConnMap
mu sync.Mutex
}
UDPConnMap = F.Map[string, net.Conn]
)
var NewUDPConnMap = F.NewMap[UDPConnMap]
func newUDPListenerAdaptor(ctx context.Context, listener net.PacketConn) net.Listener {
return &UDPListener{
ctx: ctx,
listener: listener,
connMap: NewUDPConnMap(),
}
}
// Addr implements net.Listener.
func (route *UDPListener) Addr() net.Addr {
return route.listener.LocalAddr()
}
func (udpl *UDPListener) Accept() (net.Conn, error) {
in := udpl.listener
buffer := make([]byte, streamBufferSize)
nRead, srcAddr, err := in.ReadFrom(buffer)
if err != nil {
return nil, err
}
if nRead == 0 {
return nil, io.ErrShortBuffer
}
udpl.mu.Lock()
defer udpl.mu.Unlock()
key := srcAddr.String()
conn, ok := udpl.connMap.Load(key)
if !ok {
dialer := &net.Dialer{Timeout: streamDialTimeout}
srcConn, err := dialer.DialContext(udpl.ctx, srcAddr.Network(), srcAddr.String())
if err != nil {
return nil, err
}
udpl.connMap.Store(key, srcConn)
}
return conn, nil
}
// Close implements net.Listener.
func (route *UDPListener) Close() error {
route.connMap.RangeAllParallel(func(key string, conn net.Conn) {
conn.Close()
})
route.connMap.Clear()
return route.listener.Close()
}