fix stream task stuck on reload and udp mutex not unlocked properly

This commit is contained in:
yusing
2025-01-05 03:26:31 +08:00
parent e04080bf1c
commit 5e2ce9e1e6
3 changed files with 43 additions and 28 deletions

View File

@@ -43,16 +43,19 @@ func (stream *Stream) Setup() error {
var lcfg net.ListenConfig
var err error
ctx := stream.task.Context()
switch stream.Scheme.ListeningScheme {
case "tcp":
stream.targetAddr, err = net.ResolveTCPAddr("tcp", stream.URL.Host)
if err != nil {
return err
}
tcpListener, err := lcfg.Listen(stream.task.Context(), "tcp", stream.ListenURL.Host)
tcpListener, err := lcfg.Listen(ctx, "tcp", stream.ListenURL.Host)
if err != nil {
return err
}
// in case ListeningPort was zero, get the actual port
stream.Port.ListeningPort = T.Port(tcpListener.Addr().(*net.TCPAddr).Port)
stream.listener = types.NetListener(tcpListener)
case "udp":
@@ -60,7 +63,7 @@ func (stream *Stream) Setup() error {
if err != nil {
return err
}
udpListener, err := lcfg.ListenPacket(stream.task.Context(), "udp", stream.ListenURL.Host)
udpListener, err := lcfg.ListenPacket(ctx, "udp", stream.ListenURL.Host)
if err != nil {
return err
}
@@ -70,7 +73,7 @@ func (stream *Stream) Setup() error {
return errors.New("udp listener is not *net.UDPConn")
}
stream.Port.ListeningPort = T.Port(udpConn.LocalAddr().(*net.UDPAddr).Port)
stream.listener = NewUDPForwarder(stream.task.Context(), udpConn, stream.targetAddr)
stream.listener = NewUDPForwarder(ctx, udpConn, stream.targetAddr)
default:
panic("should not reach here")
}
@@ -78,11 +81,24 @@ func (stream *Stream) Setup() error {
return nil
}
func (stream *Stream) Accept() (types.StreamConn, error) {
func (stream *Stream) Accept() (conn types.StreamConn, err error) {
if stream.listener == nil {
return nil, errors.New("listener is nil")
}
return stream.listener.Accept()
// prevent Accept from blocking forever
done := make(chan struct{})
go func() {
conn, err = stream.listener.Accept()
close(done)
}()
select {
case <-stream.task.Context().Done():
stream.Close()
return nil, stream.task.Context().Err()
case <-done:
return conn, nil
}
}
func (stream *Stream) Handle(conn types.StreamConn) error {
@@ -95,14 +111,13 @@ func (stream *Stream) Handle(conn types.StreamConn) error {
return fmt.Errorf("unexpected listener type: %T", stream)
}
case io.ReadWriteCloser:
stream.task.OnCancel("close_conn", func() { conn.Close() })
dialer := &net.Dialer{Timeout: streamDialTimeout}
dstConn, err := dialer.DialContext(stream.task.Context(), stream.targetAddr.Network(), stream.targetAddr.String())
if err != nil {
return err
}
defer dstConn.Close()
defer conn.Close()
pipe := U.NewBidirectionalPipe(stream.task.Context(), conn, dstConn)
return pipe.Start()
default: