refactor: propagate context and standardize HTTP client timeouts

Add context parameter to TCP/UDP stream health checks and client constructors
for proper cancellation and deadline propagation. Switch from encoding/json
to sonic for faster JSON unmarshaling.

Standardize HTTP client timeouts to 5 seconds
across agent pool and health check.
This commit is contained in:
yusing
2026-01-30 00:23:03 +08:00
parent 0f13004ad6
commit 6528fb0a8d
7 changed files with 78 additions and 17 deletions

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
@@ -16,6 +15,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/bytedance/sonic"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/yusing/godoxy/agent/pkg/agent/common" "github.com/yusing/godoxy/agent/pkg/agent/common"
@@ -150,7 +150,7 @@ func (cfg *AgentConfig) InitWithCerts(ctx context.Context, ca, crt, key []byte)
// test stream server connection // test stream server connection
const fakeAddress = "localhost:8080" // it won't be used, just for testing const fakeAddress = "localhost:8080" // it won't be used, just for testing
// test TCP stream support // test TCP stream support
err := agentstream.TCPHealthCheck(cfg.Addr, cfg.caCert, cfg.clientCert) err := agentstream.TCPHealthCheck(ctx, cfg.Addr, cfg.caCert, cfg.clientCert)
if err != nil { if err != nil {
streamUnsupportedErrs.Addf("failed to connect to stream server via TCP: %w", err) streamUnsupportedErrs.Addf("failed to connect to stream server via TCP: %w", err)
} else { } else {
@@ -158,7 +158,7 @@ func (cfg *AgentConfig) InitWithCerts(ctx context.Context, ca, crt, key []byte)
} }
// test UDP stream support // test UDP stream support
err = agentstream.UDPHealthCheck(cfg.Addr, cfg.caCert, cfg.clientCert) err = agentstream.UDPHealthCheck(ctx, cfg.Addr, cfg.caCert, cfg.clientCert)
if err != nil { if err != nil {
streamUnsupportedErrs.Addf("failed to connect to stream server via UDP: %w", err) streamUnsupportedErrs.Addf("failed to connect to stream server via UDP: %w", err)
} else { } else {
@@ -313,8 +313,18 @@ func (cfg *AgentConfig) do(ctx context.Context, method, endpoint string, body io
if err != nil { if err != nil {
return nil, err return nil, err
} }
timeout := 5 * time.Second
if deadline, ok := ctx.Deadline(); ok {
remaining := time.Until(deadline)
if remaining > 0 {
timeout = remaining
}
}
client := http.Client{ client := http.Client{
Transport: cfg.Transport(), Transport: cfg.Transport(),
Timeout: timeout,
} }
return client.Do(req) return client.Do(req)
} }
@@ -356,7 +366,7 @@ func (cfg *AgentConfig) fetchJSON(ctx context.Context, endpoint string, out any)
return resp.StatusCode, nil return resp.StatusCode, nil
} }
err = json.Unmarshal(data, out) err = sonic.Unmarshal(data, out)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@@ -1,6 +1,7 @@
package stream package stream
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"net" "net"
@@ -34,13 +35,13 @@ func NewTCPClient(serverAddr, targetAddress string, caCert *x509.Certificate, cl
return nil, err return nil, err
} }
return newTCPClientWIthHeader(serverAddr, header, caCert, clientCert) return newTCPClientWIthHeader(context.Background(), serverAddr, header, caCert, clientCert)
} }
func TCPHealthCheck(serverAddr string, caCert *x509.Certificate, clientCert *tls.Certificate) error { func TCPHealthCheck(ctx context.Context, serverAddr string, caCert *x509.Certificate, clientCert *tls.Certificate) error {
header := NewStreamHealthCheckHeader() header := NewStreamHealthCheckHeader()
conn, err := newTCPClientWIthHeader(serverAddr, header, caCert, clientCert) conn, err := newTCPClientWIthHeader(ctx, serverAddr, header, caCert, clientCert)
if err != nil { if err != nil {
return err return err
} }
@@ -49,7 +50,7 @@ func TCPHealthCheck(serverAddr string, caCert *x509.Certificate, clientCert *tls
return nil return nil
} }
func newTCPClientWIthHeader(serverAddr string, header *StreamRequestHeader, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) { func newTCPClientWIthHeader(ctx context.Context, serverAddr string, header *StreamRequestHeader, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) {
// Setup TLS configuration // Setup TLS configuration
caCertPool := x509.NewCertPool() caCertPool := x509.NewCertPool()
caCertPool.AddCert(caCert) caCertPool.AddCert(caCert)
@@ -62,17 +63,43 @@ func newTCPClientWIthHeader(serverAddr string, header *StreamRequestHeader, caCe
ServerName: common.CertsDNSName, ServerName: common.CertsDNSName,
} }
dialer := &net.Dialer{
Timeout: dialTimeout,
}
tlsDialer := &tls.Dialer{
NetDialer: dialer,
Config: tlsConfig,
}
// Establish TLS connection // Establish TLS connection
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: dialTimeout}, "tcp", serverAddr, tlsConfig) conn, err := tlsDialer.DialContext(ctx, "tcp", serverAddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
deadline, hasDeadline := ctx.Deadline()
if hasDeadline {
err := conn.SetWriteDeadline(deadline)
if err != nil {
_ = conn.Close()
return nil, err
}
}
// Send the stream header once as a handshake. // Send the stream header once as a handshake.
if _, err := conn.Write(header.Bytes()); err != nil { if _, err := conn.Write(header.Bytes()); err != nil {
_ = conn.Close() _ = conn.Close()
return nil, err return nil, err
} }
if hasDeadline {
// reset write deadline
err = conn.SetWriteDeadline(time.Time{})
if err != nil {
_ = conn.Close()
return nil, err
}
}
return &TCPClient{ return &TCPClient{
conn: conn, conn: conn,
}, nil }, nil

View File

@@ -12,7 +12,7 @@ func TestTCPHealthCheck(t *testing.T) {
srv := startTCPServer(t, certs) srv := startTCPServer(t, certs)
err := stream.TCPHealthCheck(srv.Addr.String(), certs.CaCert, certs.ClientCert) err := stream.TCPHealthCheck(t.Context(), srv.Addr.String(), certs.CaCert, certs.ClientCert)
require.NoError(t, err, "health check") require.NoError(t, err, "health check")
} }
@@ -21,6 +21,6 @@ func TestUDPHealthCheck(t *testing.T) {
srv := startUDPServer(t, certs) srv := startUDPServer(t, certs)
err := stream.UDPHealthCheck(srv.Addr.String(), certs.CaCert, certs.ClientCert) err := stream.UDPHealthCheck(t.Context(), srv.Addr.String(), certs.CaCert, certs.ClientCert)
require.NoError(t, err, "health check") require.NoError(t, err, "health check")
} }

View File

@@ -1,6 +1,7 @@
package stream package stream
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"net" "net"
@@ -35,10 +36,10 @@ func NewUDPClient(serverAddr, targetAddress string, caCert *x509.Certificate, cl
return nil, err return nil, err
} }
return newUDPClientWIthHeader(serverAddr, header, caCert, clientCert) return newUDPClientWIthHeader(context.Background(), serverAddr, header, caCert, clientCert)
} }
func newUDPClientWIthHeader(serverAddr string, header *StreamRequestHeader, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) { func newUDPClientWIthHeader(ctx context.Context, serverAddr string, header *StreamRequestHeader, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) {
// Setup DTLS configuration // Setup DTLS configuration
caCertPool := x509.NewCertPool() caCertPool := x509.NewCertPool()
caCertPool.AddCert(caCert) caCertPool.AddCert(caCert)
@@ -62,21 +63,40 @@ func newUDPClientWIthHeader(serverAddr string, header *StreamRequestHeader, caCe
if err != nil { if err != nil {
return nil, err return nil, err
} }
deadline, hasDeadline := ctx.Deadline()
if hasDeadline {
err := conn.SetWriteDeadline(deadline)
if err != nil {
_ = conn.Close()
return nil, err
}
}
// Send the stream header once as a handshake. // Send the stream header once as a handshake.
if _, err := conn.Write(header.Bytes()); err != nil { if _, err := conn.Write(header.Bytes()); err != nil {
_ = conn.Close() _ = conn.Close()
return nil, err return nil, err
} }
if hasDeadline {
// reset write deadline
err = conn.SetWriteDeadline(time.Time{})
if err != nil {
_ = conn.Close()
return nil, err
}
}
return &UDPClient{ return &UDPClient{
conn: conn, conn: conn,
}, nil }, nil
} }
func UDPHealthCheck(serverAddr string, caCert *x509.Certificate, clientCert *tls.Certificate) error { func UDPHealthCheck(ctx context.Context, serverAddr string, caCert *x509.Certificate, clientCert *tls.Certificate) error {
header := NewStreamHealthCheckHeader() header := NewStreamHealthCheckHeader()
conn, err := newUDPClientWIthHeader(serverAddr, header, caCert, clientCert) conn, err := newUDPClientWIthHeader(ctx, serverAddr, header, caCert, clientCert)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -27,6 +27,7 @@ func newAgent(cfg *agent.AgentConfig) *Agent {
AgentConfig: cfg, AgentConfig: cfg,
httpClient: &http.Client{ httpClient: &http.Client{
Transport: transport, Transport: transport,
Timeout: 5 * time.Second,
}, },
fasthttpHcClient: &fasthttp.Client{ fasthttpHcClient: &fasthttp.Client{
DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) { DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) {

View File

@@ -76,8 +76,11 @@ func H2C(ctx context.Context, url *url.URL, method, path string, timeout time.Du
setCommonHeaders(req.Header.Set) setCommonHeaders(req.Header.Set)
client := *h2cClient
client.Timeout = timeout
start := time.Now() start := time.Now()
resp, err := h2cClient.Do(req) resp, err := client.Do(req)
lat := time.Since(start) lat := time.Since(start)
if resp != nil { if resp != nil {

View File

@@ -162,4 +162,4 @@ func (c *Config) refreshSessionLoop(ctx context.Context) {
} }
} }
} }
} }