diff --git a/agent/pkg/agent/config.go b/agent/pkg/agent/config.go index d7128f39..404554e4 100644 --- a/agent/pkg/agent/config.go +++ b/agent/pkg/agent/config.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" "crypto/x509" - "encoding/json" "encoding/pem" "errors" "fmt" @@ -16,6 +15,7 @@ import ( "strings" "time" + "github.com/bytedance/sonic" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/yusing/godoxy/agent/pkg/agent/common" @@ -150,7 +150,7 @@ func (cfg *AgentConfig) InitWithCerts(ctx context.Context, ca, crt, key []byte) // test stream server connection const fakeAddress = "localhost:8080" // it won't be used, just for testing // test TCP stream support - err := agentstream.TCPHealthCheck(cfg.Addr, cfg.caCert, cfg.clientCert) + err := agentstream.TCPHealthCheck(ctx, cfg.Addr, cfg.caCert, cfg.clientCert) if err != nil { streamUnsupportedErrs.Addf("failed to connect to stream server via TCP: %w", err) } else { @@ -158,7 +158,7 @@ func (cfg *AgentConfig) InitWithCerts(ctx context.Context, ca, crt, key []byte) } // test UDP stream support - err = agentstream.UDPHealthCheck(cfg.Addr, cfg.caCert, cfg.clientCert) + err = agentstream.UDPHealthCheck(ctx, cfg.Addr, cfg.caCert, cfg.clientCert) if err != nil { streamUnsupportedErrs.Addf("failed to connect to stream server via UDP: %w", err) } else { @@ -313,8 +313,18 @@ func (cfg *AgentConfig) do(ctx context.Context, method, endpoint string, body io if err != nil { return nil, err } + + timeout := 5 * time.Second + if deadline, ok := ctx.Deadline(); ok { + remaining := time.Until(deadline) + if remaining > 0 { + timeout = remaining + } + } + client := http.Client{ Transport: cfg.Transport(), + Timeout: timeout, } return client.Do(req) } @@ -356,7 +366,7 @@ func (cfg *AgentConfig) fetchJSON(ctx context.Context, endpoint string, out any) return resp.StatusCode, nil } - err = json.Unmarshal(data, out) + err = sonic.Unmarshal(data, out) if err != nil { return 0, err } diff --git a/agent/pkg/agent/stream/tcp_client.go b/agent/pkg/agent/stream/tcp_client.go index 3a9db398..75dad5c3 100644 --- a/agent/pkg/agent/stream/tcp_client.go +++ b/agent/pkg/agent/stream/tcp_client.go @@ -1,6 +1,7 @@ package stream import ( + "context" "crypto/tls" "crypto/x509" "net" @@ -34,13 +35,13 @@ func NewTCPClient(serverAddr, targetAddress string, caCert *x509.Certificate, cl return nil, err } - return newTCPClientWIthHeader(serverAddr, header, caCert, clientCert) + return newTCPClientWIthHeader(context.Background(), serverAddr, header, caCert, clientCert) } -func TCPHealthCheck(serverAddr string, caCert *x509.Certificate, clientCert *tls.Certificate) error { +func TCPHealthCheck(ctx context.Context, serverAddr string, caCert *x509.Certificate, clientCert *tls.Certificate) error { header := NewStreamHealthCheckHeader() - conn, err := newTCPClientWIthHeader(serverAddr, header, caCert, clientCert) + conn, err := newTCPClientWIthHeader(ctx, serverAddr, header, caCert, clientCert) if err != nil { return err } @@ -49,7 +50,7 @@ func TCPHealthCheck(serverAddr string, caCert *x509.Certificate, clientCert *tls return nil } -func newTCPClientWIthHeader(serverAddr string, header *StreamRequestHeader, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) { +func newTCPClientWIthHeader(ctx context.Context, serverAddr string, header *StreamRequestHeader, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) { // Setup TLS configuration caCertPool := x509.NewCertPool() caCertPool.AddCert(caCert) @@ -62,17 +63,43 @@ func newTCPClientWIthHeader(serverAddr string, header *StreamRequestHeader, caCe ServerName: common.CertsDNSName, } + dialer := &net.Dialer{ + Timeout: dialTimeout, + } + tlsDialer := &tls.Dialer{ + NetDialer: dialer, + Config: tlsConfig, + } + // Establish TLS connection - conn, err := tls.DialWithDialer(&net.Dialer{Timeout: dialTimeout}, "tcp", serverAddr, tlsConfig) + conn, err := tlsDialer.DialContext(ctx, "tcp", serverAddr) if err != nil { return nil, err } + + deadline, hasDeadline := ctx.Deadline() + if hasDeadline { + err := conn.SetWriteDeadline(deadline) + if err != nil { + _ = conn.Close() + return nil, err + } + } // Send the stream header once as a handshake. if _, err := conn.Write(header.Bytes()); err != nil { _ = conn.Close() return nil, err } + if hasDeadline { + // reset write deadline + err = conn.SetWriteDeadline(time.Time{}) + if err != nil { + _ = conn.Close() + return nil, err + } + } + return &TCPClient{ conn: conn, }, nil diff --git a/agent/pkg/agent/stream/tests/healthcheck_test.go b/agent/pkg/agent/stream/tests/healthcheck_test.go index 320e29a6..282d4caf 100644 --- a/agent/pkg/agent/stream/tests/healthcheck_test.go +++ b/agent/pkg/agent/stream/tests/healthcheck_test.go @@ -12,7 +12,7 @@ func TestTCPHealthCheck(t *testing.T) { srv := startTCPServer(t, certs) - err := stream.TCPHealthCheck(srv.Addr.String(), certs.CaCert, certs.ClientCert) + err := stream.TCPHealthCheck(t.Context(), srv.Addr.String(), certs.CaCert, certs.ClientCert) require.NoError(t, err, "health check") } @@ -21,6 +21,6 @@ func TestUDPHealthCheck(t *testing.T) { srv := startUDPServer(t, certs) - err := stream.UDPHealthCheck(srv.Addr.String(), certs.CaCert, certs.ClientCert) + err := stream.UDPHealthCheck(t.Context(), srv.Addr.String(), certs.CaCert, certs.ClientCert) require.NoError(t, err, "health check") } diff --git a/agent/pkg/agent/stream/udp_client.go b/agent/pkg/agent/stream/udp_client.go index 4d372be8..24941991 100644 --- a/agent/pkg/agent/stream/udp_client.go +++ b/agent/pkg/agent/stream/udp_client.go @@ -1,6 +1,7 @@ package stream import ( + "context" "crypto/tls" "crypto/x509" "net" @@ -35,10 +36,10 @@ func NewUDPClient(serverAddr, targetAddress string, caCert *x509.Certificate, cl return nil, err } - return newUDPClientWIthHeader(serverAddr, header, caCert, clientCert) + return newUDPClientWIthHeader(context.Background(), serverAddr, header, caCert, clientCert) } -func newUDPClientWIthHeader(serverAddr string, header *StreamRequestHeader, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) { +func newUDPClientWIthHeader(ctx context.Context, serverAddr string, header *StreamRequestHeader, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) { // Setup DTLS configuration caCertPool := x509.NewCertPool() caCertPool.AddCert(caCert) @@ -62,21 +63,40 @@ func newUDPClientWIthHeader(serverAddr string, header *StreamRequestHeader, caCe if err != nil { return nil, err } + + deadline, hasDeadline := ctx.Deadline() + if hasDeadline { + err := conn.SetWriteDeadline(deadline) + if err != nil { + _ = conn.Close() + return nil, err + } + } + // Send the stream header once as a handshake. if _, err := conn.Write(header.Bytes()); err != nil { _ = conn.Close() return nil, err } + if hasDeadline { + // reset write deadline + err = conn.SetWriteDeadline(time.Time{}) + if err != nil { + _ = conn.Close() + return nil, err + } + } + return &UDPClient{ conn: conn, }, nil } -func UDPHealthCheck(serverAddr string, caCert *x509.Certificate, clientCert *tls.Certificate) error { +func UDPHealthCheck(ctx context.Context, serverAddr string, caCert *x509.Certificate, clientCert *tls.Certificate) error { header := NewStreamHealthCheckHeader() - conn, err := newUDPClientWIthHeader(serverAddr, header, caCert, clientCert) + conn, err := newUDPClientWIthHeader(ctx, serverAddr, header, caCert, clientCert) if err != nil { return err } diff --git a/internal/agentpool/agent.go b/internal/agentpool/agent.go index 59fe1e77..b85aeb0d 100644 --- a/internal/agentpool/agent.go +++ b/internal/agentpool/agent.go @@ -27,6 +27,7 @@ func newAgent(cfg *agent.AgentConfig) *Agent { AgentConfig: cfg, httpClient: &http.Client{ Transport: transport, + Timeout: 5 * time.Second, }, fasthttpHcClient: &fasthttp.Client{ DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) { diff --git a/internal/health/check/http.go b/internal/health/check/http.go index fb946131..5ee11688 100644 --- a/internal/health/check/http.go +++ b/internal/health/check/http.go @@ -76,8 +76,11 @@ func H2C(ctx context.Context, url *url.URL, method, path string, timeout time.Du setCommonHeaders(req.Header.Set) + client := *h2cClient + client.Timeout = timeout + start := time.Now() - resp, err := h2cClient.Do(req) + resp, err := client.Do(req) lat := time.Since(start) if resp != nil { diff --git a/internal/proxmox/config.go b/internal/proxmox/config.go index 35d54445..54269adf 100644 --- a/internal/proxmox/config.go +++ b/internal/proxmox/config.go @@ -162,4 +162,4 @@ func (c *Config) refreshSessionLoop(ctx context.Context) { } } } -} \ No newline at end of file +}