fixes, meaningful error messages and new features

This commit is contained in:
yusing
2024-03-27 06:30:47 +00:00
parent 539ef911de
commit 90f4aac946
50 changed files with 2079 additions and 885 deletions

View File

@@ -1,7 +1,6 @@
package main
import (
"errors"
"fmt"
"strconv"
"strings"
@@ -11,16 +10,18 @@ import (
"github.com/sirupsen/logrus"
)
type StreamImpl interface {
Setup() error
Accept() (interface{}, error)
Handle(interface{}) error
CloseListeners()
}
type StreamRoute interface {
Route
ListeningUrl() string
TargetUrl() string
Logger() logrus.FieldLogger
closeListeners()
closeChannel()
unmarkPort()
wait()
}
type StreamRouteBase struct {
@@ -32,10 +33,14 @@ type StreamRouteBase struct {
TargetHost string
TargetPort int
id string
wg sync.WaitGroup
stopChann chan struct{}
l logrus.FieldLogger
id string
wg sync.WaitGroup
stopCh chan struct{}
connCh chan interface{}
started bool
l logrus.FieldLogger
StreamImpl
}
func newStreamRouteBase(config *ProxyConfig) (*StreamRouteBase, error) {
@@ -45,14 +50,14 @@ func newStreamRouteBase(config *ProxyConfig) (*StreamRouteBase, error) {
var srcScheme string
var dstScheme string
port_split := strings.Split(config.Port, ":")
if len(port_split) != 2 {
portSplit := strings.Split(config.Port, ":")
if len(portSplit) != 2 {
cfgl.Warnf("invalid port %s, assuming it is target port", config.Port)
srcPort = "0"
dstPort = config.Port
} else {
srcPort = port_split[0]
dstPort = port_split[1]
srcPort = portSplit[0]
dstPort = portSplit[1]
}
if port, hasName := NamePortMap[dstPort]; hasName {
@@ -61,25 +66,20 @@ func newStreamRouteBase(config *ProxyConfig) (*StreamRouteBase, error) {
srcPortInt, err := strconv.Atoi(srcPort)
if err != nil {
return nil, fmt.Errorf(
"invalid stream source port %s, ignoring", srcPort,
)
return nil, NewNestedError("invalid stream source port").Subject(srcPort)
}
utils.markPortInUse(srcPortInt)
dstPortInt, err := strconv.Atoi(dstPort)
if err != nil {
return nil, fmt.Errorf(
"invalid stream target port %s, ignoring", dstPort,
)
return nil, NewNestedError("invalid stream target port").Subject(dstPort)
}
scheme_split := strings.Split(config.Scheme, ":")
if len(scheme_split) == 2 {
srcScheme = scheme_split[0]
dstScheme = scheme_split[1]
schemeSplit := strings.Split(config.Scheme, ":")
if len(schemeSplit) == 2 {
srcScheme = schemeSplit[0]
dstScheme = schemeSplit[1]
} else {
srcScheme = config.Scheme
dstScheme = config.Scheme
@@ -94,9 +94,11 @@ func newStreamRouteBase(config *ProxyConfig) (*StreamRouteBase, error) {
TargetHost: config.Host,
TargetPort: dstPortInt,
id: config.GetID(),
wg: sync.WaitGroup{},
stopChann: make(chan struct{}, 1),
id: config.GetID(),
wg: sync.WaitGroup{},
stopCh: make(chan struct{}, 1),
connCh: make(chan interface{}),
started: false,
l: srlog.WithFields(logrus.Fields{
"alias": config.Alias,
"src": fmt.Sprintf("%s://:%d", srcScheme, srcPortInt),
@@ -112,7 +114,7 @@ func NewStreamRoute(config *ProxyConfig) (StreamRoute, error) {
case StreamType_UDP:
return NewUDPRoute(config)
default:
return nil, errors.New("unknown stream type")
return nil, NewNestedError("invalid stream type").Subject(config.Scheme)
}
}
@@ -128,7 +130,45 @@ func (route *StreamRouteBase) Logger() logrus.FieldLogger {
return route.l
}
func (route *StreamRouteBase) setupListen() {
func (route *StreamRouteBase) Start() {
route.ensurePort()
if err := route.Setup(); err != nil {
route.l.Errorf("failed to setup: %v", err)
return
}
route.started = true
route.wg.Add(2)
go route.grAcceptConnections()
go route.grHandleConnections()
}
func (route *StreamRouteBase) Stop() {
if !route.started {
return
}
l := route.Logger()
l.Debug("stopping listening")
close(route.stopCh)
route.CloseListeners()
done := make(chan struct{}, 1)
go func() {
route.wg.Wait()
close(done)
}()
select {
case <-done:
l.Info("stopped listening")
case <-time.After(streamStopListenTimeout):
l.Error("timed out waiting for connections")
}
utils.unmarkPortInUse(route.ListeningPort)
streamRoutes.Delete(route.id)
}
func (route *StreamRouteBase) ensurePort() {
if route.ListeningPort == 0 {
freePort, err := utils.findUseFreePort(20000)
if err != nil {
@@ -142,40 +182,43 @@ func (route *StreamRouteBase) setupListen() {
route.l.Info("listening on ", route.ListeningUrl())
}
func (route *StreamRouteBase) wait() {
route.wg.Wait()
}
func (route *StreamRouteBase) grAcceptConnections() {
defer route.wg.Done()
func (route *StreamRouteBase) closeChannel() {
close(route.stopChann)
}
func (route *StreamRouteBase) unmarkPort() {
utils.unmarkPortInUse(route.ListeningPort)
}
func stopListening(route StreamRoute) {
l := route.Logger()
l.Debug("stopping listening")
// close channel -> wait -> close listeners
route.closeChannel()
done := make(chan struct{})
go func() {
route.wait()
close(done)
route.unmarkPort()
}()
select {
case <-done:
l.Info("stopped listening")
case <-time.After(StreamStopListenTimeout):
l.Error("timed out waiting for connections")
for {
select {
case <-route.stopCh:
return
default:
conn, err := route.Accept()
if err != nil {
select {
case <-route.stopCh:
return
default:
route.l.Error(err)
continue
}
}
route.connCh <- conn
}
}
}
func (route *StreamRouteBase) grHandleConnections() {
defer route.wg.Done()
for {
select {
case <-route.stopCh:
return
case conn := <-route.connCh:
go func() {
err := route.Handle(conn)
if err != nil {
route.l.Error(err)
}
}()
}
}
route.closeListeners()
}