diff --git a/agent/pkg/agent/stream/tcp_server.go b/agent/pkg/agent/stream/tcp_server.go index 9432746c..a9cae161 100644 --- a/agent/pkg/agent/stream/tcp_server.go +++ b/agent/pkg/agent/stream/tcp_server.go @@ -7,6 +7,7 @@ import ( "errors" "io" "net" + "time" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -147,9 +148,11 @@ func (s *TCPServer) handle(conn net.Conn) { func (s *TCPServer) redirect(conn net.Conn) (net.Conn, error) { // Read the stream header once as a handshake. var headerBuf [headerSize]byte + _ = conn.SetReadDeadline(time.Now().Add(dialTimeout)) if _, err := io.ReadFull(conn, headerBuf[:]); err != nil { return nil, err } + _ = conn.SetReadDeadline(time.Time{}) header := ToHeader(&headerBuf) if !header.Validate() { diff --git a/agent/pkg/agent/stream/udp_server.go b/agent/pkg/agent/stream/udp_server.go index 0ceb7124..bea30929 100644 --- a/agent/pkg/agent/stream/udp_server.go +++ b/agent/pkg/agent/stream/udp_server.go @@ -102,10 +102,13 @@ func (s *UDPServer) handleDTLSConnection(clientConn net.Conn) { // Read the stream header once as a handshake. var headerBuf [headerSize]byte + _ = clientConn.SetReadDeadline(time.Now().Add(dialTimeout)) if _, err := io.ReadFull(clientConn, headerBuf[:]); err != nil { s.logger(clientConn).Err(err).Msg("failed to read stream header") return } + _ = clientConn.SetReadDeadline(time.Time{}) + header := ToHeader(&headerBuf) if !header.Validate() { s.logger(clientConn).Error().Bytes("header", headerBuf[:]).Msg("invalid stream header received")