mirror of
https://github.com/juanfont/headscale.git
synced 2026-04-19 15:21:35 +02:00
lint and leftover
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
committed by
Kristoffer Dalby
parent
39443184d6
commit
233dffc186
@@ -146,12 +146,12 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
|
||||
policyChanged, err := app.state.DeleteNode(node)
|
||||
if err != nil {
|
||||
log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to delete ephemeral node")
|
||||
log.Error().Err(err).Uint64("node.id", ni.Uint64()).Str("node.name", node.Hostname()).Msg("Ephemeral node deletion failed")
|
||||
return
|
||||
}
|
||||
|
||||
app.Change(policyChanged)
|
||||
log.Debug().Uint64("node.id", ni.Uint64()).Msgf("deleted ephemeral node")
|
||||
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Str("node.name", node.Hostname()).Msg("Ephemeral node deleted because garbage collection timeout reached")
|
||||
})
|
||||
app.ephemeralGC = ephemeralGC
|
||||
|
||||
@@ -384,53 +384,49 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
||||
writer.WriteHeader(http.StatusUnauthorized)
|
||||
_, err := writer.Write([]byte("Unauthorized"))
|
||||
Msg("HTTP authentication invoked")
|
||||
|
||||
authHeader := req.Header.Get("Authorization")
|
||||
|
||||
if !strings.HasPrefix(authHeader, AuthPrefix) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
||||
writer.WriteHeader(http.StatusUnauthorized)
|
||||
_, err := writer.Write([]byte("Unauthorized"))
|
||||
return err
|
||||
}
|
||||
|
||||
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg("failed to validate token")
|
||||
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
_, err := writer.Write([]byte("Unauthorized"))
|
||||
return err
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
if !valid {
|
||||
log.Info().
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg("invalid token")
|
||||
|
||||
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
|
||||
if err != nil {
|
||||
writer.WriteHeader(http.StatusUnauthorized)
|
||||
_, err := writer.Write([]byte("Unauthorized"))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}(); err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg("failed to validate token")
|
||||
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
_, err := writer.Write([]byte("Unauthorized"))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !valid {
|
||||
log.Info().
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg("invalid token")
|
||||
|
||||
writer.WriteHeader(http.StatusUnauthorized)
|
||||
_, err := writer.Write([]byte("Unauthorized"))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
}
|
||||
|
||||
Msg("Failed to write HTTP response")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -260,7 +260,7 @@ func NewHeadscaleDatabase(
|
||||
log.Error().Err(err).Msg("Error creating route")
|
||||
} else {
|
||||
log.Info().
|
||||
Uint64("node_id", route.NodeID).
|
||||
Uint64("node.id", route.NodeID).
|
||||
Str("prefix", prefix.String()).
|
||||
Msg("Route migrated")
|
||||
}
|
||||
@@ -870,23 +870,23 @@ AND auth_key_id NOT IN (
|
||||
// Copy data directly using SQL
|
||||
dataCopySQL := []string{
|
||||
`INSERT INTO users (id, name, display_name, email, provider_identifier, provider, profile_pic_url, created_at, updated_at, deleted_at)
|
||||
SELECT id, name, display_name, email, provider_identifier, provider, profile_pic_url, created_at, updated_at, deleted_at
|
||||
SELECT id, name, display_name, email, provider_identifier, provider, profile_pic_url, created_at, updated_at, deleted_at
|
||||
FROM users_old`,
|
||||
|
||||
`INSERT INTO pre_auth_keys (id, key, user_id, reusable, ephemeral, used, tags, expiration, created_at)
|
||||
SELECT id, key, user_id, reusable, ephemeral, used, tags, expiration, created_at
|
||||
SELECT id, key, user_id, reusable, ephemeral, used, tags, expiration, created_at
|
||||
FROM pre_auth_keys_old`,
|
||||
|
||||
`INSERT INTO api_keys (id, prefix, hash, expiration, last_seen, created_at)
|
||||
SELECT id, prefix, hash, expiration, last_seen, created_at
|
||||
SELECT id, prefix, hash, expiration, last_seen, created_at
|
||||
FROM api_keys_old`,
|
||||
|
||||
`INSERT INTO nodes (id, machine_key, node_key, disco_key, endpoints, host_info, ipv4, ipv6, hostname, given_name, user_id, register_method, forced_tags, auth_key_id, last_seen, expiry, approved_routes, created_at, updated_at, deleted_at)
|
||||
SELECT id, machine_key, node_key, disco_key, endpoints, host_info, ipv4, ipv6, hostname, given_name, user_id, register_method, forced_tags, auth_key_id, last_seen, expiry, approved_routes, created_at, updated_at, deleted_at
|
||||
SELECT id, machine_key, node_key, disco_key, endpoints, host_info, ipv4, ipv6, hostname, given_name, user_id, register_method, forced_tags, auth_key_id, last_seen, expiry, approved_routes, created_at, updated_at, deleted_at
|
||||
FROM nodes_old`,
|
||||
|
||||
`INSERT INTO policies (id, data, created_at, updated_at, deleted_at)
|
||||
SELECT id, data, created_at, updated_at, deleted_at
|
||||
SELECT id, data, created_at, updated_at, deleted_at
|
||||
FROM policies_old`,
|
||||
}
|
||||
|
||||
@@ -1131,7 +1131,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
||||
}
|
||||
|
||||
for _, migrationID := range migrationIDs {
|
||||
log.Trace().Str("migration_id", migrationID).Msg("Running migration")
|
||||
log.Trace().Caller().Str("migration_id", migrationID).Msg("Running migration")
|
||||
needsFKDisabled := migrationsRequiringFKDisabled[migrationID]
|
||||
|
||||
if needsFKDisabled {
|
||||
|
||||
@@ -275,7 +275,7 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
||||
return errors.New("backfilling IPs: ip allocator was nil")
|
||||
}
|
||||
|
||||
log.Trace().Msgf("starting to backfill IPs")
|
||||
log.Trace().Caller().Msgf("starting to backfill IPs")
|
||||
|
||||
nodes, err := ListNodes(tx)
|
||||
if err != nil {
|
||||
@@ -283,7 +283,7 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
log.Trace().Uint64("node.id", node.ID.Uint64()).Msg("checking if need backfill")
|
||||
log.Trace().Caller().Uint64("node.id", node.ID.Uint64()).Str("node.name", node.Hostname).Msg("IP backfill check started because node found in database")
|
||||
|
||||
changed := false
|
||||
// IPv4 prefix is set, but node ip is missing, alloc
|
||||
|
||||
@@ -34,9 +34,6 @@ var (
|
||||
"node not found in registration cache",
|
||||
)
|
||||
ErrCouldNotConvertNodeInterface = errors.New("failed to convert node interface")
|
||||
ErrDifferentRegisteredUser = errors.New(
|
||||
"node was previously registered with a different user",
|
||||
)
|
||||
)
|
||||
|
||||
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/arl/statsviz"
|
||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"tailscale.com/tsweb"
|
||||
@@ -239,6 +240,34 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
w.Write(resJSON)
|
||||
}))
|
||||
|
||||
// Batcher endpoint
|
||||
debug.Handle("batcher", "Batcher connected nodes", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check Accept header to determine response format
|
||||
acceptHeader := r.Header.Get("Accept")
|
||||
wantsJSON := strings.Contains(acceptHeader, "application/json")
|
||||
|
||||
if wantsJSON {
|
||||
batcherInfo := h.debugBatcherJSON()
|
||||
|
||||
batcherJSON, err := json.MarshalIndent(batcherInfo, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(batcherJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
batcherInfo := h.debugBatcher()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(batcherInfo))
|
||||
}
|
||||
}))
|
||||
|
||||
err := statsviz.Register(debugMux)
|
||||
if err == nil {
|
||||
debug.URL("/debug/statsviz", "Statsviz (visualise go metrics)")
|
||||
@@ -256,3 +285,124 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
|
||||
return debugHTTPServer
|
||||
}
|
||||
|
||||
// debugBatcher returns debug information about the batcher's connected nodes.
|
||||
func (h *Headscale) debugBatcher() string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("=== Batcher Connected Nodes ===\n\n")
|
||||
|
||||
totalNodes := 0
|
||||
connectedCount := 0
|
||||
|
||||
// Collect nodes and sort them by ID
|
||||
type nodeStatus struct {
|
||||
id types.NodeID
|
||||
connected bool
|
||||
activeConnections int
|
||||
}
|
||||
|
||||
var nodes []nodeStatus
|
||||
|
||||
// Try to get detailed debug info if we have a LockFreeBatcher
|
||||
if batcher, ok := h.mapBatcher.(*mapper.LockFreeBatcher); ok {
|
||||
debugInfo := batcher.Debug()
|
||||
for nodeID, info := range debugInfo {
|
||||
nodes = append(nodes, nodeStatus{
|
||||
id: nodeID,
|
||||
connected: info.Connected,
|
||||
activeConnections: info.ActiveConnections,
|
||||
})
|
||||
totalNodes++
|
||||
if info.Connected {
|
||||
connectedCount++
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback to basic connection info
|
||||
connectedMap := h.mapBatcher.ConnectedMap()
|
||||
connectedMap.Range(func(nodeID types.NodeID, connected bool) bool {
|
||||
nodes = append(nodes, nodeStatus{
|
||||
id: nodeID,
|
||||
connected: connected,
|
||||
activeConnections: 0,
|
||||
})
|
||||
totalNodes++
|
||||
if connected {
|
||||
connectedCount++
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Sort by node ID
|
||||
for i := 0; i < len(nodes); i++ {
|
||||
for j := i + 1; j < len(nodes); j++ {
|
||||
if nodes[i].id > nodes[j].id {
|
||||
nodes[i], nodes[j] = nodes[j], nodes[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Output sorted nodes
|
||||
for _, node := range nodes {
|
||||
status := "disconnected"
|
||||
if node.connected {
|
||||
status = "connected"
|
||||
}
|
||||
|
||||
if node.activeConnections > 0 {
|
||||
sb.WriteString(fmt.Sprintf("Node %d:\t%s (%d connections)\n", node.id, status, node.activeConnections))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf("Node %d:\t%s\n", node.id, status))
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf("\nSummary: %d connected, %d total\n", connectedCount, totalNodes))
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// DebugBatcherInfo represents batcher connection information in a structured format.
|
||||
type DebugBatcherInfo struct {
|
||||
ConnectedNodes map[string]DebugBatcherNodeInfo `json:"connected_nodes"` // NodeID -> node connection info
|
||||
TotalNodes int `json:"total_nodes"`
|
||||
}
|
||||
|
||||
// DebugBatcherNodeInfo represents connection information for a single node.
|
||||
type DebugBatcherNodeInfo struct {
|
||||
Connected bool `json:"connected"`
|
||||
ActiveConnections int `json:"active_connections"`
|
||||
}
|
||||
|
||||
// debugBatcherJSON returns structured debug information about the batcher's connected nodes.
|
||||
func (h *Headscale) debugBatcherJSON() DebugBatcherInfo {
|
||||
info := DebugBatcherInfo{
|
||||
ConnectedNodes: make(map[string]DebugBatcherNodeInfo),
|
||||
TotalNodes: 0,
|
||||
}
|
||||
|
||||
// Try to get detailed debug info if we have a LockFreeBatcher
|
||||
if batcher, ok := h.mapBatcher.(*mapper.LockFreeBatcher); ok {
|
||||
debugInfo := batcher.Debug()
|
||||
for nodeID, debugData := range debugInfo {
|
||||
info.ConnectedNodes[fmt.Sprintf("%d", nodeID)] = DebugBatcherNodeInfo{
|
||||
Connected: debugData.Connected,
|
||||
ActiveConnections: debugData.ActiveConnections,
|
||||
}
|
||||
info.TotalNodes++
|
||||
}
|
||||
} else {
|
||||
// Fallback to basic connection info
|
||||
connectedMap := h.mapBatcher.ConnectedMap()
|
||||
connectedMap.Range(func(nodeID types.NodeID, connected bool) bool {
|
||||
info.ConnectedNodes[fmt.Sprintf("%d", nodeID)] = DebugBatcherNodeInfo{
|
||||
Connected: connected,
|
||||
ActiveConnections: 0,
|
||||
}
|
||||
info.TotalNodes++
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
@@ -161,7 +161,7 @@ func (d *DERPServer) DERPHandler(
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
Msg("Failed to write HTTP response")
|
||||
}
|
||||
|
||||
return
|
||||
@@ -199,7 +199,7 @@ func (d *DERPServer) serveWebsocket(writer http.ResponseWriter, req *http.Reques
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
Msg("Failed to write HTTP response")
|
||||
}
|
||||
|
||||
return
|
||||
@@ -229,7 +229,7 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
Msg("Failed to write HTTP response")
|
||||
}
|
||||
|
||||
return
|
||||
@@ -245,7 +245,7 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
Msg("Failed to write HTTP response")
|
||||
}
|
||||
|
||||
return
|
||||
@@ -284,7 +284,7 @@ func DERPProbeHandler(
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
Msg("Failed to write HTTP response")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -330,7 +330,7 @@ func DERPBootstrapDNSHandler(
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
Msg("Failed to write HTTP response")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,6 +237,7 @@ func (api headscaleV1APIServer) RegisterNode(
|
||||
request *v1.RegisterNodeRequest,
|
||||
) (*v1.RegisterNodeResponse, error) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("user", request.GetUser()).
|
||||
Str("registration_id", request.GetKey()).
|
||||
Msg("Registering node")
|
||||
@@ -525,7 +526,7 @@ func (api headscaleV1APIServer) BackfillNodeIPs(
|
||||
ctx context.Context,
|
||||
request *v1.BackfillNodeIPsRequest,
|
||||
) (*v1.BackfillNodeIPsResponse, error) {
|
||||
log.Trace().Msg("Backfill called")
|
||||
log.Trace().Caller().Msg("Backfill called")
|
||||
|
||||
if !request.Confirmed {
|
||||
return nil, errors.New("not confirmed, aborting")
|
||||
@@ -709,6 +710,10 @@ func (api headscaleV1APIServer) SetPolicy(
|
||||
UpdatedAt: timestamppb.New(updated.UpdatedAt),
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Msg("gRPC SetPolicy completed successfully because response prepared")
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
@@ -731,7 +736,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
||||
Caller().
|
||||
Interface("route-prefix", routes).
|
||||
Interface("route-str", request.GetRoutes()).
|
||||
Msg("")
|
||||
Msg("Creating routes for node")
|
||||
|
||||
hostinfo := tailcfg.Hostinfo{
|
||||
RoutableIPs: routes,
|
||||
@@ -760,6 +765,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Str("registration_id", registrationId.String()).
|
||||
Msg("adding debug machine via CLI, appending to registration cache")
|
||||
|
||||
|
||||
@@ -197,7 +197,7 @@ func (h *Headscale) RobotsHandler(
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
Msg("Failed to write HTTP response")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
@@ -23,7 +24,7 @@ type Batcher interface {
|
||||
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool
|
||||
IsConnected(id types.NodeID) bool
|
||||
ConnectedMap() *xsync.Map[types.NodeID, bool]
|
||||
AddWork(c change.ChangeSet)
|
||||
AddWork(c ...change.ChangeSet)
|
||||
MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error)
|
||||
DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error)
|
||||
}
|
||||
@@ -36,7 +37,7 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB
|
||||
|
||||
// The size of this channel is arbitrary chosen, the sizing should be revisited.
|
||||
workCh: make(chan work, workers*200),
|
||||
nodes: xsync.NewMap[types.NodeID, *nodeConn](),
|
||||
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
|
||||
connected: xsync.NewMap[types.NodeID, *time.Time](),
|
||||
pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](),
|
||||
}
|
||||
@@ -47,6 +48,7 @@ func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher {
|
||||
m := newMapper(cfg, state)
|
||||
b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m)
|
||||
m.batcher = b
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -72,8 +74,10 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
|
||||
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
|
||||
}
|
||||
|
||||
var mapResp *tailcfg.MapResponse
|
||||
var err error
|
||||
var (
|
||||
mapResp *tailcfg.MapResponse
|
||||
err error
|
||||
)
|
||||
|
||||
switch c.Change {
|
||||
case change.DERP:
|
||||
@@ -84,10 +88,21 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
|
||||
// TODO(kradalby): This can potentially be a peer update of the old and new subnet router.
|
||||
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||
} else {
|
||||
// CRITICAL FIX: Read actual online status from NodeStore when available,
|
||||
// fall back to deriving from change type for unit tests or when NodeStore is empty
|
||||
var onlineStatus bool
|
||||
if node, found := mapper.state.GetNodeByID(c.NodeID); found && node.IsOnline().Valid() {
|
||||
// Use actual NodeStore status when available (production case)
|
||||
onlineStatus = node.IsOnline().Get()
|
||||
} else {
|
||||
// Fall back to deriving from change type (unit test case or initial setup)
|
||||
onlineStatus = c.Change == change.NodeCameOnline
|
||||
}
|
||||
|
||||
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: c.NodeID.NodeID(),
|
||||
Online: ptr.To(c.Change == change.NodeCameOnline),
|
||||
Online: ptr.To(onlineStatus),
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -125,7 +140,12 @@ func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) err
|
||||
}
|
||||
|
||||
nodeID := nc.nodeID()
|
||||
data, err := generateMapResponse(nodeID, nc.version(), mapper, c)
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("change.type", c.Change.String()).Msg("Node change processing started because change notification received")
|
||||
|
||||
var data *tailcfg.MapResponse
|
||||
var err error
|
||||
data, err = generateMapResponse(nodeID, nc.version(), mapper, c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
|
||||
}
|
||||
@@ -136,7 +156,8 @@ func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) err
|
||||
}
|
||||
|
||||
// Send the map response
|
||||
if err := nc.send(data); err != nil {
|
||||
err = nc.send(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package mapper
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -57,16 +58,21 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
||||
version: version,
|
||||
created: now,
|
||||
}
|
||||
// Initialize last used timestamp
|
||||
newEntry.lastUsed.Store(now.Unix())
|
||||
|
||||
// Only after validation succeeds, create or update node connection
|
||||
newConn := newNodeConn(id, c, version, b.mapper)
|
||||
// Get or create multiChannelNodeConn - this reuses existing offline nodes for rapid reconnection
|
||||
nodeConn, loaded := b.nodes.LoadOrStore(id, newMultiChannelNodeConn(id, b.mapper))
|
||||
|
||||
if !loaded {
|
||||
b.totalNodes.Add(1)
|
||||
conn = newConn
|
||||
}
|
||||
|
||||
b.connected.Store(id, nil) // nil = connected
|
||||
// Add connection to the list (lock-free)
|
||||
nodeConn.addConnection(newEntry)
|
||||
|
||||
// Use the worker pool for controlled concurrency instead of direct generation
|
||||
initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id))
|
||||
|
||||
if err != nil {
|
||||
log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed")
|
||||
@@ -87,6 +93,16 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
||||
return fmt.Errorf("failed to send initial map to node %d: timeout", id)
|
||||
}
|
||||
|
||||
// Update connection status
|
||||
b.connected.Store(id, nil) // nil = connected
|
||||
|
||||
// Node will automatically receive updates through the normal flow
|
||||
// The initial full map already contains all current state
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("total.duration", time.Since(addNodeStart)).
|
||||
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
||||
Msg("Node connection established in batcher because AddNode completed successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -101,10 +117,11 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
|
||||
return false
|
||||
}
|
||||
|
||||
// Mark the connection as closed to prevent further sends
|
||||
if connData := existing.connData.Load(); connData != nil {
|
||||
connData.closed.Store(true)
|
||||
}
|
||||
// Remove specific connection
|
||||
removed := nodeConn.removeConnectionByChannel(c)
|
||||
if !removed {
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode: channel not found because connection already removed or invalid")
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if node has any remaining active connections
|
||||
@@ -115,18 +132,17 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
|
||||
return true // Node still has active connections
|
||||
}
|
||||
|
||||
// Remove node and mark disconnected atomically
|
||||
b.nodes.Delete(id)
|
||||
// No active connections - keep the node entry alive for rapid reconnections
|
||||
// The node will get a fresh full map when it reconnects
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("Node disconnected from batcher because all connections removed, keeping entry for rapid reconnection")
|
||||
b.connected.Store(id, ptr.To(time.Now()))
|
||||
b.totalNodes.Add(-1)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AddWork queues a change to be processed by the batcher.
|
||||
// Critical changes are processed immediately, while others are batched for efficiency.
|
||||
func (b *LockFreeBatcher) AddWork(c change.ChangeSet) {
|
||||
b.addWork(c)
|
||||
func (b *LockFreeBatcher) AddWork(c ...change.ChangeSet) {
|
||||
b.addWork(c...)
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Start() {
|
||||
@@ -137,23 +153,36 @@ func (b *LockFreeBatcher) Start() {
|
||||
func (b *LockFreeBatcher) Close() {
|
||||
if b.cancel != nil {
|
||||
b.cancel()
|
||||
b.cancel = nil // Prevent multiple calls
|
||||
}
|
||||
|
||||
// Only close workCh once
|
||||
select {
|
||||
case <-b.workCh:
|
||||
// Channel is already closed
|
||||
default:
|
||||
close(b.workCh)
|
||||
}
|
||||
close(b.workCh)
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) doWork() {
|
||||
log.Debug().Msg("batcher doWork loop started")
|
||||
defer log.Debug().Msg("batcher doWork loop stopped")
|
||||
|
||||
for i := range b.workers {
|
||||
go b.worker(i + 1)
|
||||
}
|
||||
|
||||
// Create a cleanup ticker for removing truly disconnected nodes
|
||||
cleanupTicker := time.NewTicker(5 * time.Minute)
|
||||
defer cleanupTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-b.tick.C:
|
||||
// Process batched changes
|
||||
b.processBatchedChanges()
|
||||
case <-cleanupTicker.C:
|
||||
// Clean up nodes that have been offline for too long
|
||||
b.cleanupOfflineNodes()
|
||||
case <-b.ctx.Done():
|
||||
return
|
||||
}
|
||||
@@ -161,8 +190,6 @@ func (b *LockFreeBatcher) doWork() {
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) worker(workerID int) {
|
||||
log.Debug().Int("workerID", workerID).Msg("batcher worker started")
|
||||
defer log.Debug().Int("workerID", workerID).Msg("batcher worker stopped")
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -171,7 +198,6 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
b.workProcessed.Add(1)
|
||||
|
||||
// If the resultCh is set, it means that this is a work request
|
||||
@@ -181,7 +207,9 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
if w.resultCh != nil {
|
||||
var result workResult
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
result.mapResponse, result.err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
|
||||
var err error
|
||||
result.mapResponse, err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
|
||||
result.err = err
|
||||
if result.err != nil {
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(result.err).
|
||||
@@ -192,6 +220,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
}
|
||||
} else {
|
||||
result.err = fmt.Errorf("node %d not found", w.nodeID)
|
||||
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(result.err).
|
||||
Int("workerID", workerID).
|
||||
@@ -260,19 +289,22 @@ func (b *LockFreeBatcher) addToBatch(c ...change.ChangeSet) {
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
all, self := change.SplitAllAndSelf(c)
|
||||
|
||||
for _, changeSet := range self {
|
||||
changes, _ := b.pendingChanges.LoadOrStore(changeSet.NodeID, []change.ChangeSet{})
|
||||
changes = append(changes, changeSet)
|
||||
b.pendingChanges.Store(changeSet.NodeID, changes)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
|
||||
if c.NodeID == nodeID && !c.AlsoSelf() {
|
||||
return true
|
||||
}
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
|
||||
rel := change.RemoveUpdatesForSelf(nodeID, all)
|
||||
|
||||
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
|
||||
changes = append(changes, c)
|
||||
changes = append(changes, rel...)
|
||||
b.pendingChanges.Store(nodeID, changes)
|
||||
|
||||
return true
|
||||
@@ -303,7 +335,44 @@ func (b *LockFreeBatcher) processBatchedChanges() {
|
||||
})
|
||||
}
|
||||
|
||||
// IsConnected is lock-free read.
|
||||
// cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks.
|
||||
func (b *LockFreeBatcher) cleanupOfflineNodes() {
|
||||
cleanupThreshold := 15 * time.Minute
|
||||
now := time.Now()
|
||||
|
||||
var nodesToCleanup []types.NodeID
|
||||
|
||||
// Find nodes that have been offline for too long
|
||||
b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool {
|
||||
if disconnectTime != nil && now.Sub(*disconnectTime) > cleanupThreshold {
|
||||
// Double-check the node doesn't have active connections
|
||||
if nodeConn, exists := b.nodes.Load(nodeID); exists {
|
||||
if !nodeConn.hasActiveConnections() {
|
||||
nodesToCleanup = append(nodesToCleanup, nodeID)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Clean up the identified nodes
|
||||
for _, nodeID := range nodesToCleanup {
|
||||
log.Info().Uint64("node.id", nodeID.Uint64()).
|
||||
Dur("offline_duration", cleanupThreshold).
|
||||
Msg("Cleaning up node that has been offline for too long")
|
||||
|
||||
b.nodes.Delete(nodeID)
|
||||
b.connected.Delete(nodeID)
|
||||
b.totalNodes.Add(-1)
|
||||
}
|
||||
|
||||
if len(nodesToCleanup) > 0 {
|
||||
log.Info().Int("cleaned_nodes", len(nodesToCleanup)).
|
||||
Msg("Completed cleanup of long-offline nodes")
|
||||
}
|
||||
}
|
||||
|
||||
// IsConnected is lock-free read that checks if a node has any active connections.
|
||||
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
||||
// First check if we have active connections for this node
|
||||
if nodeConn, exists := b.nodes.Load(id); exists {
|
||||
@@ -373,89 +442,234 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.Change
|
||||
}
|
||||
}
|
||||
|
||||
// connectionData holds the channel and connection parameters.
|
||||
type connectionData struct {
|
||||
c chan<- *tailcfg.MapResponse
|
||||
version tailcfg.CapabilityVersion
|
||||
closed atomic.Bool // Track if this connection has been closed
|
||||
// connectionEntry represents a single connection to a node.
|
||||
type connectionEntry struct {
|
||||
id string // unique connection ID
|
||||
c chan<- *tailcfg.MapResponse
|
||||
version tailcfg.CapabilityVersion
|
||||
created time.Time
|
||||
lastUsed atomic.Int64 // Unix timestamp of last successful send
|
||||
}
|
||||
|
||||
// nodeConn described the node connection and its associated data.
|
||||
type nodeConn struct {
|
||||
// multiChannelNodeConn manages multiple concurrent connections for a single node.
|
||||
type multiChannelNodeConn struct {
|
||||
id types.NodeID
|
||||
mapper *mapper
|
||||
|
||||
// Atomic pointer to connection data - allows lock-free updates
|
||||
connData atomic.Pointer[connectionData]
|
||||
mutex sync.RWMutex
|
||||
connections []*connectionEntry
|
||||
|
||||
updateCount atomic.Int64
|
||||
}
|
||||
|
||||
func newNodeConn(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, mapper *mapper) *nodeConn {
|
||||
nc := &nodeConn{
|
||||
// generateConnectionID generates a unique connection identifier.
|
||||
func generateConnectionID() string {
|
||||
bytes := make([]byte, 8)
|
||||
rand.Read(bytes)
|
||||
return fmt.Sprintf("%x", bytes)
|
||||
}
|
||||
|
||||
// newMultiChannelNodeConn creates a new multi-channel node connection.
|
||||
func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn {
|
||||
return &multiChannelNodeConn{
|
||||
id: id,
|
||||
mapper: mapper,
|
||||
}
|
||||
|
||||
// Initialize connection data
|
||||
data := &connectionData{
|
||||
c: c,
|
||||
version: version,
|
||||
}
|
||||
nc.connData.Store(data)
|
||||
|
||||
return nc
|
||||
}
|
||||
|
||||
// updateConnection atomically updates connection parameters.
|
||||
func (nc *nodeConn) updateConnection(c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) {
|
||||
newData := &connectionData{
|
||||
c: c,
|
||||
version: version,
|
||||
}
|
||||
nc.connData.Store(newData)
|
||||
// addConnection adds a new connection.
|
||||
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
|
||||
mutexWaitStart := time.Now()
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
|
||||
Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
|
||||
|
||||
mc.mutex.Lock()
|
||||
mutexWaitDur := time.Since(mutexWaitStart)
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
mc.connections = append(mc.connections, entry)
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
|
||||
Int("total_connections", len(mc.connections)).
|
||||
Dur("mutex_wait_time", mutexWaitDur).
|
||||
Msg("Successfully added connection after mutex wait")
|
||||
}
|
||||
|
||||
// matchesChannel checks if the given channel matches current connection.
|
||||
func (nc *nodeConn) matchesChannel(c chan<- *tailcfg.MapResponse) bool {
|
||||
data := nc.connData.Load()
|
||||
if data == nil {
|
||||
return false
|
||||
// removeConnectionByChannel removes a connection by matching channel pointer.
|
||||
func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapResponse) bool {
|
||||
mc.mutex.Lock()
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
for i, entry := range mc.connections {
|
||||
if entry.c == c {
|
||||
// Remove this connection
|
||||
mc.connections = append(mc.connections[:i], mc.connections[i+1:]...)
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", c)).
|
||||
Int("remaining_connections", len(mc.connections)).
|
||||
Msg("Successfully removed connection")
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Compare channel pointers directly
|
||||
return data.c == c
|
||||
return false
|
||||
}
|
||||
|
||||
// compressAndVersion atomically reads connection settings.
|
||||
func (nc *nodeConn) version() tailcfg.CapabilityVersion {
|
||||
data := nc.connData.Load()
|
||||
if data == nil {
|
||||
// hasActiveConnections checks if the node has any active connections.
|
||||
func (mc *multiChannelNodeConn) hasActiveConnections() bool {
|
||||
mc.mutex.RLock()
|
||||
defer mc.mutex.RUnlock()
|
||||
|
||||
return len(mc.connections) > 0
|
||||
}
|
||||
|
||||
// getActiveConnectionCount returns the number of active connections.
|
||||
func (mc *multiChannelNodeConn) getActiveConnectionCount() int {
|
||||
mc.mutex.RLock()
|
||||
defer mc.mutex.RUnlock()
|
||||
|
||||
return len(mc.connections)
|
||||
}
|
||||
|
||||
// send broadcasts data to all active connections for the node.
|
||||
func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
||||
mc.mutex.Lock()
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
if len(mc.connections) == 0 {
|
||||
// During rapid reconnection, nodes may temporarily have no active connections
|
||||
// This is not an error - the node will receive a full map when it reconnects
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
|
||||
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
|
||||
return nil // Return success instead of error
|
||||
}
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
|
||||
Int("total_connections", len(mc.connections)).
|
||||
Msg("send: broadcasting to all connections")
|
||||
|
||||
var lastErr error
|
||||
successCount := 0
|
||||
var failedConnections []int // Track failed connections for removal
|
||||
|
||||
// Send to all connections
|
||||
for i, conn := range mc.connections {
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
||||
Str("conn.id", conn.id).Int("connection_index", i).
|
||||
Msg("send: attempting to send to connection")
|
||||
|
||||
if err := conn.send(data); err != nil {
|
||||
lastErr = err
|
||||
failedConnections = append(failedConnections, i)
|
||||
log.Warn().Err(err).
|
||||
Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
||||
Str("conn.id", conn.id).Int("connection_index", i).
|
||||
Msg("send: connection send failed")
|
||||
} else {
|
||||
successCount++
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
||||
Str("conn.id", conn.id).Int("connection_index", i).
|
||||
Msg("send: successfully sent to connection")
|
||||
}
|
||||
}
|
||||
|
||||
// Remove failed connections (in reverse order to maintain indices)
|
||||
for i := len(failedConnections) - 1; i >= 0; i-- {
|
||||
idx := failedConnections[i]
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
|
||||
Str("conn.id", mc.connections[idx].id).
|
||||
Msg("send: removing failed connection")
|
||||
mc.connections = append(mc.connections[:idx], mc.connections[idx+1:]...)
|
||||
}
|
||||
|
||||
mc.updateCount.Add(1)
|
||||
|
||||
log.Info().Uint64("node.id", mc.id.Uint64()).
|
||||
Int("successful_sends", successCount).
|
||||
Int("failed_connections", len(failedConnections)).
|
||||
Int("remaining_connections", len(mc.connections)).
|
||||
Msg("send: completed broadcast")
|
||||
|
||||
// Success if at least one send succeeded
|
||||
if successCount > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("node %d: all connections failed, last error: %w", mc.id, lastErr)
|
||||
}
|
||||
|
||||
// send sends data to a single connection entry with timeout-based stale connection detection.
|
||||
func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
|
||||
// Use a short timeout to detect stale connections where the client isn't reading the channel.
|
||||
// This is critical for detecting Docker containers that are forcefully terminated
|
||||
// but still have channels that appear open.
|
||||
select {
|
||||
case entry.c <- data:
|
||||
// Update last used timestamp on successful send
|
||||
entry.lastUsed.Store(time.Now().Unix())
|
||||
return nil
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// Connection is likely stale - client isn't reading from channel
|
||||
// This catches the case where Docker containers are killed but channels remain open
|
||||
return fmt.Errorf("connection %s: timeout sending to channel (likely stale connection)", entry.id)
|
||||
}
|
||||
}
|
||||
|
||||
// nodeID returns the node ID.
|
||||
func (mc *multiChannelNodeConn) nodeID() types.NodeID {
|
||||
return mc.id
|
||||
}
|
||||
|
||||
// version returns the capability version from the first active connection.
|
||||
// All connections for a node should have the same version in practice.
|
||||
func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion {
|
||||
mc.mutex.RLock()
|
||||
defer mc.mutex.RUnlock()
|
||||
|
||||
if len(mc.connections) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return data.version
|
||||
return mc.connections[0].version
|
||||
}
|
||||
|
||||
func (nc *nodeConn) nodeID() types.NodeID {
|
||||
return nc.id
|
||||
// change applies a change to all active connections for the node.
|
||||
func (mc *multiChannelNodeConn) change(c change.ChangeSet) error {
|
||||
return handleNodeChange(mc, mc.mapper, c)
|
||||
}
|
||||
|
||||
func (nc *nodeConn) change(c change.ChangeSet) error {
|
||||
return handleNodeChange(nc, nc.mapper, c)
|
||||
// DebugNodeInfo contains debug information about a node's connections.
|
||||
type DebugNodeInfo struct {
|
||||
Connected bool `json:"connected"`
|
||||
ActiveConnections int `json:"active_connections"`
|
||||
}
|
||||
|
||||
// send sends data to the node's channel.
|
||||
// The node will pick it up and send it to the HTTP handler.
|
||||
func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
|
||||
connData := nc.connData.Load()
|
||||
if connData == nil {
|
||||
return fmt.Errorf("node %d: no connection data", nc.id)
|
||||
}
|
||||
// Debug returns a pre-baked map of node debug information for the debug interface.
|
||||
func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
|
||||
result := make(map[types.NodeID]DebugNodeInfo)
|
||||
|
||||
// Check if connection has been closed
|
||||
if connData.closed.Load() {
|
||||
return fmt.Errorf("node %d: connection closed", nc.id)
|
||||
}
|
||||
// Get all nodes with their connection status using immediate connection logic
|
||||
// (no grace period) for debug purposes
|
||||
b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool {
|
||||
nodeConn.mutex.RLock()
|
||||
activeConnCount := len(nodeConn.connections)
|
||||
nodeConn.mutex.RUnlock()
|
||||
|
||||
// Use immediate connection status: if active connections exist, node is connected
|
||||
// If not, check the connected map for nil (connected) vs timestamp (disconnected)
|
||||
connected := false
|
||||
if activeConnCount > 0 {
|
||||
connected = true
|
||||
} else {
|
||||
// Check connected map for immediate status
|
||||
if val, ok := b.connected.Load(id); ok && val == nil {
|
||||
connected = true
|
||||
}
|
||||
}
|
||||
|
||||
result[id] = DebugNodeInfo{
|
||||
Connected: connected,
|
||||
ActiveConnections: activeConnCount,
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Add all entries from the connected map to capture both connected and disconnected nodes
|
||||
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
||||
|
||||
@@ -209,6 +209,7 @@ func setupBatcherWithTestData(
|
||||
|
||||
// Create test users and nodes in the database
|
||||
users := database.CreateUsersForTest(userCount, "testuser")
|
||||
|
||||
allNodes := make([]node, 0, userCount*nodesPerUser)
|
||||
for _, user := range users {
|
||||
dbNodes := database.CreateRegisteredNodesForTest(user, nodesPerUser, "node")
|
||||
@@ -353,6 +354,7 @@ func assertOnlineMapResponse(t *testing.T, resp *tailcfg.MapResponse, expected b
|
||||
if len(resp.PeersChangedPatch) > 0 {
|
||||
require.Len(t, resp.PeersChangedPatch, 1)
|
||||
assert.Equal(t, expected, *resp.PeersChangedPatch[0].Online)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -412,6 +414,7 @@ func (n *node) start() {
|
||||
n.maxPeersCount = info.PeerCount
|
||||
}
|
||||
}
|
||||
|
||||
if info.IsPatch {
|
||||
atomic.AddInt64(&n.patchCount, 1)
|
||||
// For patches, we track how many patch items
|
||||
@@ -550,6 +553,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
// Reduce verbose application logging for cleaner test output
|
||||
originalLevel := zerolog.GlobalLevel()
|
||||
defer zerolog.SetGlobalLevel(originalLevel)
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.ErrorLevel)
|
||||
|
||||
// Test cases: different node counts to stress test the all-to-all connectivity
|
||||
@@ -618,6 +622,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
|
||||
// Join all nodes as fast as possible
|
||||
t.Logf("Joining %d nodes as fast as possible...", len(allNodes))
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
@@ -693,6 +698,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
if stats.MaxPeersSeen > maxPeersGlobal {
|
||||
maxPeersGlobal = stats.MaxPeersSeen
|
||||
}
|
||||
|
||||
if stats.MaxPeersSeen < minPeersSeen {
|
||||
minPeersSeen = stats.MaxPeersSeen
|
||||
}
|
||||
@@ -730,9 +736,11 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
// Show sample of node details
|
||||
if len(nodeDetails) > 0 {
|
||||
t.Logf(" Node sample:")
|
||||
|
||||
for _, detail := range nodeDetails[:min(5, len(nodeDetails))] {
|
||||
t.Logf(" %s", detail)
|
||||
}
|
||||
|
||||
if len(nodeDetails) > 5 {
|
||||
t.Logf(" ... (%d more nodes)", len(nodeDetails)-5)
|
||||
}
|
||||
@@ -754,6 +762,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
// Show details of failed nodes for debugging
|
||||
if len(nodeDetails) > 5 {
|
||||
t.Logf("Failed nodes details:")
|
||||
|
||||
for _, detail := range nodeDetails[5:] {
|
||||
if !strings.Contains(detail, fmt.Sprintf("max %d peers", expectedPeers)) {
|
||||
t.Logf(" %s", detail)
|
||||
@@ -875,6 +884,7 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
|
||||
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) {
|
||||
count := 0
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
@@ -1026,10 +1036,12 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
||||
// Collect updates with timeout
|
||||
updateCount := 0
|
||||
timeout := time.After(200 * time.Millisecond)
|
||||
|
||||
for {
|
||||
select {
|
||||
case data := <-ch:
|
||||
updateCount++
|
||||
|
||||
receivedUpdates = append(receivedUpdates, data)
|
||||
|
||||
// Validate update content
|
||||
@@ -1058,6 +1070,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
||||
|
||||
// Validate that all updates have valid content
|
||||
validUpdates := 0
|
||||
|
||||
for _, data := range receivedUpdates {
|
||||
if data != nil {
|
||||
if valid, _ := validateUpdateContent(data); valid {
|
||||
@@ -1095,16 +1108,22 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
|
||||
batcher := testData.Batcher
|
||||
testNode := testData.Nodes[0]
|
||||
var channelIssues int
|
||||
var mutex sync.Mutex
|
||||
|
||||
var (
|
||||
channelIssues int
|
||||
mutex sync.Mutex
|
||||
)
|
||||
|
||||
// Run rapid connect/disconnect cycles with real updates to test channel closing
|
||||
|
||||
for i := range 100 {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// First connection
|
||||
ch1 := make(chan *tailcfg.MapResponse, 1)
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
@@ -1118,17 +1137,22 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
|
||||
// Rapid second connection - should replace ch1
|
||||
ch2 := make(chan *tailcfg.MapResponse, 1)
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
time.Sleep(1 * time.Microsecond)
|
||||
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
||||
}()
|
||||
|
||||
// Remove second connection
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
time.Sleep(2 * time.Microsecond)
|
||||
batcher.RemoveNode(testNode.n.ID, ch2)
|
||||
}()
|
||||
@@ -1143,7 +1167,9 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
case <-time.After(1 * time.Millisecond):
|
||||
// If no data received, increment issues counter
|
||||
mutex.Lock()
|
||||
|
||||
channelIssues++
|
||||
|
||||
mutex.Unlock()
|
||||
}
|
||||
|
||||
@@ -1185,18 +1211,24 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
|
||||
batcher := testData.Batcher
|
||||
testNode := testData.Nodes[0]
|
||||
var panics int
|
||||
var channelErrors int
|
||||
var invalidData int
|
||||
var mutex sync.Mutex
|
||||
|
||||
var (
|
||||
panics int
|
||||
channelErrors int
|
||||
invalidData int
|
||||
mutex sync.Mutex
|
||||
)
|
||||
|
||||
// Test rapid connect/disconnect with work generation
|
||||
|
||||
for i := range 50 {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
mutex.Lock()
|
||||
|
||||
panics++
|
||||
|
||||
mutex.Unlock()
|
||||
t.Logf("Panic caught: %v", r)
|
||||
}
|
||||
@@ -1213,7 +1245,9 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
mutex.Lock()
|
||||
|
||||
channelErrors++
|
||||
|
||||
mutex.Unlock()
|
||||
t.Logf("Channel consumer panic: %v", r)
|
||||
}
|
||||
@@ -1229,7 +1263,9 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
// Validate the data we received
|
||||
if valid, reason := validateUpdateContent(data); !valid {
|
||||
mutex.Lock()
|
||||
|
||||
invalidData++
|
||||
|
||||
mutex.Unlock()
|
||||
t.Logf("Invalid data received: %s", reason)
|
||||
}
|
||||
@@ -1268,9 +1304,11 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
if panics > 0 {
|
||||
t.Errorf("Worker channel safety failed with %d panics", panics)
|
||||
}
|
||||
|
||||
if channelErrors > 0 {
|
||||
t.Errorf("Channel handling failed with %d channel errors", channelErrors)
|
||||
}
|
||||
|
||||
if invalidData > 0 {
|
||||
t.Errorf("Data validation failed with %d invalid data packets", invalidData)
|
||||
}
|
||||
@@ -1342,15 +1380,19 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
// Use remaining nodes for connection churn testing
|
||||
churningNodes := allNodes[len(allNodes)/2:]
|
||||
churningChannels := make(map[types.NodeID]chan *tailcfg.MapResponse)
|
||||
|
||||
var churningChannelsMutex sync.Mutex // Protect concurrent map access
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
numCycles := 10 // Reduced for simpler test
|
||||
panicCount := 0
|
||||
|
||||
var panicMutex sync.Mutex
|
||||
|
||||
// Track deadlock with timeout
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
||||
@@ -1364,16 +1406,22 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicMutex.Lock()
|
||||
|
||||
panicCount++
|
||||
|
||||
panicMutex.Unlock()
|
||||
t.Logf("Panic in churning connect: %v", r)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE)
|
||||
|
||||
churningChannelsMutex.Lock()
|
||||
|
||||
churningChannels[nodeID] = ch
|
||||
|
||||
churningChannelsMutex.Unlock()
|
||||
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
@@ -1400,17 +1448,23 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicMutex.Lock()
|
||||
|
||||
panicCount++
|
||||
|
||||
panicMutex.Unlock()
|
||||
t.Logf("Panic in churning disconnect: %v", r)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
time.Sleep(time.Duration(i%5) * time.Millisecond)
|
||||
churningChannelsMutex.Lock()
|
||||
|
||||
ch, exists := churningChannels[nodeID]
|
||||
|
||||
churningChannelsMutex.Unlock()
|
||||
|
||||
if exists {
|
||||
batcher.RemoveNode(nodeID, ch)
|
||||
}
|
||||
@@ -1422,10 +1476,12 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
// DERP changes
|
||||
batcher.AddWork(change.DERPSet)
|
||||
}
|
||||
|
||||
if i%5 == 0 {
|
||||
// Full updates using real node data
|
||||
batcher.AddWork(change.FullSet)
|
||||
}
|
||||
|
||||
if i%7 == 0 && len(allNodes) > 0 {
|
||||
// Node-specific changes using real nodes
|
||||
node := allNodes[i%len(allNodes)]
|
||||
@@ -1453,7 +1509,9 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
|
||||
// Validate results
|
||||
panicMutex.Lock()
|
||||
|
||||
finalPanicCount := panicCount
|
||||
|
||||
panicMutex.Unlock()
|
||||
|
||||
allStats := tracker.getAllStats()
|
||||
@@ -1536,6 +1594,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
// Reduce verbose application logging for cleaner test output
|
||||
originalLevel := zerolog.GlobalLevel()
|
||||
defer zerolog.SetGlobalLevel(originalLevel)
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.ErrorLevel)
|
||||
|
||||
// Full test matrix for scalability testing
|
||||
@@ -1624,6 +1683,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
batcher := testData.Batcher
|
||||
allNodes := testData.Nodes
|
||||
|
||||
t.Logf("[%d/%d] SCALABILITY TEST: %s", i+1, len(testCases), tc.description)
|
||||
t.Logf(
|
||||
" Cycles: %d, Buffer Size: %d, Chaos Type: %s",
|
||||
@@ -1660,12 +1720,16 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
// Connect all nodes first so they can see each other as peers
|
||||
connectedNodes := make(map[types.NodeID]bool)
|
||||
|
||||
var connectedNodesMutex sync.RWMutex
|
||||
|
||||
for i := range testNodes {
|
||||
node := &testNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
connectedNodesMutex.Lock()
|
||||
|
||||
connectedNodes[node.n.ID] = true
|
||||
|
||||
connectedNodesMutex.Unlock()
|
||||
}
|
||||
|
||||
@@ -1676,6 +1740,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
t.Logf(
|
||||
@@ -1697,14 +1762,17 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
// For chaos testing, only disconnect/reconnect a subset of nodes
|
||||
// This ensures some nodes stay connected to continue receiving updates
|
||||
startIdx := cycle % len(testNodes)
|
||||
|
||||
endIdx := startIdx + len(testNodes)/4
|
||||
if endIdx > len(testNodes) {
|
||||
endIdx = len(testNodes)
|
||||
}
|
||||
|
||||
if startIdx >= endIdx {
|
||||
startIdx = 0
|
||||
endIdx = min(len(testNodes)/4, len(testNodes))
|
||||
}
|
||||
|
||||
chaosNodes := testNodes[startIdx:endIdx]
|
||||
if len(chaosNodes) == 0 {
|
||||
chaosNodes = testNodes[:min(1, len(testNodes))] // At least one node for chaos
|
||||
@@ -1722,17 +1790,22 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
if r := recover(); r != nil {
|
||||
atomic.AddInt64(&panicCount, 1)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
connectedNodesMutex.RLock()
|
||||
|
||||
isConnected := connectedNodes[nodeID]
|
||||
|
||||
connectedNodesMutex.RUnlock()
|
||||
|
||||
if isConnected {
|
||||
batcher.RemoveNode(nodeID, channel)
|
||||
connectedNodesMutex.Lock()
|
||||
|
||||
connectedNodes[nodeID] = false
|
||||
|
||||
connectedNodesMutex.Unlock()
|
||||
}
|
||||
}(
|
||||
@@ -1746,6 +1819,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
if r := recover(); r != nil {
|
||||
atomic.AddInt64(&panicCount, 1)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
@@ -1757,7 +1831,9 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
tailcfg.CapabilityVersion(100),
|
||||
)
|
||||
connectedNodesMutex.Lock()
|
||||
|
||||
connectedNodes[nodeID] = true
|
||||
|
||||
connectedNodesMutex.Unlock()
|
||||
|
||||
// Add work to create load
|
||||
@@ -1776,11 +1852,13 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
updateCount := min(tc.nodeCount/5, 20) // Scale updates with node count
|
||||
for i := range updateCount {
|
||||
wg.Add(1)
|
||||
|
||||
go func(index int) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
atomic.AddInt64(&panicCount, 1)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
@@ -1823,11 +1901,14 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
deadlockDetected = true
|
||||
// Collect diagnostic information
|
||||
allStats := tracker.getAllStats()
|
||||
|
||||
totalUpdates := 0
|
||||
for _, stats := range allStats {
|
||||
totalUpdates += stats.TotalUpdates
|
||||
}
|
||||
|
||||
interimPanics := atomic.LoadInt64(&panicCount)
|
||||
|
||||
t.Logf("TIMEOUT DIAGNOSIS: Test timed out after %v", TEST_TIMEOUT)
|
||||
t.Logf(
|
||||
" Progress at timeout: %d total updates, %d panics",
|
||||
@@ -1873,6 +1954,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
stats := node.cleanup()
|
||||
totalUpdates += stats.TotalUpdates
|
||||
totalPatches += stats.PatchUpdates
|
||||
|
||||
totalFull += stats.FullUpdates
|
||||
if stats.MaxPeersSeen > maxPeersGlobal {
|
||||
maxPeersGlobal = stats.MaxPeersSeen
|
||||
@@ -1910,10 +1992,12 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
// Legacy tracker comparison (optional)
|
||||
allStats := tracker.getAllStats()
|
||||
|
||||
legacyTotalUpdates := 0
|
||||
for _, stats := range allStats {
|
||||
legacyTotalUpdates += stats.TotalUpdates
|
||||
}
|
||||
|
||||
if legacyTotalUpdates != int(totalUpdates) {
|
||||
t.Logf(
|
||||
"Note: Legacy tracker mismatch - legacy: %d, new: %d",
|
||||
@@ -1926,6 +2010,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
// Validation based on expectation
|
||||
testPassed := true
|
||||
|
||||
if tc.expectBreak {
|
||||
// For tests expected to break, we're mainly checking that we don't crash
|
||||
if finalPanicCount > 0 {
|
||||
@@ -1947,14 +2032,19 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
// For tests expected to pass, validate proper operation
|
||||
if finalPanicCount > 0 {
|
||||
t.Errorf("Scalability test failed with %d panics", finalPanicCount)
|
||||
|
||||
testPassed = false
|
||||
}
|
||||
|
||||
if deadlockDetected {
|
||||
t.Errorf("Deadlock detected at %d nodes (should handle this load)", len(testNodes))
|
||||
|
||||
testPassed = false
|
||||
}
|
||||
|
||||
if totalUpdates == 0 {
|
||||
t.Error("No updates received - system may be completely stalled")
|
||||
|
||||
testPassed = false
|
||||
}
|
||||
}
|
||||
@@ -2020,6 +2110,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
// Read all available updates for each node
|
||||
for i := range allNodes {
|
||||
nodeUpdates := 0
|
||||
|
||||
t.Logf("Reading updates for node %d:", i)
|
||||
|
||||
// Read up to 10 updates per node or until timeout/no more data
|
||||
@@ -2056,6 +2147,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
|
||||
if len(data.Peers) > 0 {
|
||||
t.Logf(" Full peer list with %d peers", len(data.Peers))
|
||||
|
||||
for j, peer := range data.Peers[:min(3, len(data.Peers))] {
|
||||
t.Logf(
|
||||
" Peer %d: NodeID=%d, Online=%v",
|
||||
@@ -2065,8 +2157,10 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if len(data.PeersChangedPatch) > 0 {
|
||||
t.Logf(" Patch update with %d changes", len(data.PeersChangedPatch))
|
||||
|
||||
for j, patch := range data.PeersChangedPatch[:min(3, len(data.PeersChangedPatch))] {
|
||||
t.Logf(
|
||||
" Patch %d: NodeID=%d, Online=%v",
|
||||
@@ -2080,6 +2174,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Node %d received %d updates", i, nodeUpdates)
|
||||
}
|
||||
|
||||
@@ -2095,71 +2190,132 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestBatcherWorkQueueTracing traces exactly what happens to change.FullSet work items.
|
||||
func TestBatcherWorkQueueTracing(t *testing.T) {
|
||||
// TestBatcherRapidReconnection reproduces the issue where nodes connecting with the same ID
|
||||
// at the same time cause /debug/batcher to show nodes as disconnected when they should be connected.
|
||||
// This specifically tests the multi-channel batcher implementation issue.
|
||||
func TestBatcherRapidReconnection(t *testing.T) {
|
||||
for _, batcherFunc := range allBatcherFunctions {
|
||||
t.Run(batcherFunc.name, func(t *testing.T) {
|
||||
testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 3, 10)
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
allNodes := testData.Nodes
|
||||
|
||||
t.Logf("=== RAPID RECONNECTION TEST ===")
|
||||
t.Logf("Testing rapid connect/disconnect with %d nodes", len(allNodes))
|
||||
|
||||
// Phase 1: Connect all nodes initially
|
||||
t.Logf("Phase 1: Connecting all nodes...")
|
||||
for i, node := range allNodes {
|
||||
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add node %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond) // Let connections settle
|
||||
|
||||
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
|
||||
t.Logf("Phase 2: Rapid disconnect all nodes...")
|
||||
for i, node := range allNodes {
|
||||
removed := batcher.RemoveNode(node.n.ID, node.ch)
|
||||
t.Logf("Node %d RemoveNode result: %t", i, removed)
|
||||
}
|
||||
|
||||
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up)
|
||||
t.Logf("Phase 3: Rapid reconnect with new channels...")
|
||||
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
|
||||
for i, node := range allNodes {
|
||||
newChannels[i] = make(chan *tailcfg.MapResponse, 10)
|
||||
err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Errorf("Failed to reconnect node %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond) // Let reconnections settle
|
||||
|
||||
// Phase 4: Check debug status - THIS IS WHERE THE BUG SHOULD APPEAR
|
||||
t.Logf("Phase 4: Checking debug status...")
|
||||
|
||||
if debugBatcher, ok := batcher.(interface {
|
||||
Debug() map[types.NodeID]any
|
||||
}); ok {
|
||||
debugInfo := debugBatcher.Debug()
|
||||
disconnectedCount := 0
|
||||
|
||||
for i, node := range allNodes {
|
||||
if info, exists := debugInfo[node.n.ID]; exists {
|
||||
t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info)
|
||||
|
||||
// Check if the debug info shows the node as connected
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||
disconnectedCount++
|
||||
t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
disconnectedCount++
|
||||
t.Logf("Node %d missing from debug info entirely", i)
|
||||
}
|
||||
|
||||
// Also check IsConnected method
|
||||
if !batcher.IsConnected(node.n.ID) {
|
||||
t.Logf("Node %d IsConnected() returns false", i)
|
||||
}
|
||||
}
|
||||
|
||||
if disconnectedCount > 0 {
|
||||
t.Logf("ISSUE REPRODUCED: %d/%d nodes show as disconnected in debug", disconnectedCount, len(allNodes))
|
||||
// This is expected behavior for multi-channel batcher according to user
|
||||
// "it has never worked with the multi"
|
||||
} else {
|
||||
t.Logf("All nodes show as connected - working correctly")
|
||||
}
|
||||
} else {
|
||||
t.Logf("Batcher does not implement Debug() method")
|
||||
}
|
||||
|
||||
// Phase 5: Test if "disconnected" nodes can actually receive updates
|
||||
t.Logf("Phase 5: Testing if nodes can receive updates despite debug status...")
|
||||
|
||||
// Send a change that should reach all nodes
|
||||
batcher.AddWork(change.DERPChange())
|
||||
|
||||
receivedCount := 0
|
||||
timeout := time.After(500 * time.Millisecond)
|
||||
|
||||
for i := 0; i < len(allNodes); i++ {
|
||||
select {
|
||||
case update := <-newChannels[i]:
|
||||
if update != nil {
|
||||
receivedCount++
|
||||
t.Logf("Node %d received update successfully", i)
|
||||
}
|
||||
case <-timeout:
|
||||
t.Logf("Node %d timed out waiting for update", i)
|
||||
goto done
|
||||
}
|
||||
}
|
||||
|
||||
done:
|
||||
t.Logf("Update delivery test: %d/%d nodes received updates", receivedCount, len(allNodes))
|
||||
|
||||
if receivedCount < len(allNodes) {
|
||||
t.Logf("Some nodes failed to receive updates - confirming the issue")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatcherMultiConnection(t *testing.T) {
|
||||
for _, batcherFunc := range allBatcherFunctions {
|
||||
t.Run(batcherFunc.name, func(t *testing.T) {
|
||||
testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 2, 10)
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
nodes := testData.Nodes
|
||||
|
||||
t.Logf("=== WORK QUEUE TRACING TEST ===")
|
||||
|
||||
time.Sleep(100 * time.Millisecond) // Let connections settle
|
||||
|
||||
// Wait for initial NodeCameOnline to be processed
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Drain any initial updates
|
||||
drainedCount := 0
|
||||
for {
|
||||
select {
|
||||
case <-nodes[0].ch:
|
||||
drainedCount++
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
goto drained
|
||||
}
|
||||
}
|
||||
drained:
|
||||
t.Logf("Drained %d initial updates", drainedCount)
|
||||
|
||||
// Now send a single FullSet update and trace it closely
|
||||
t.Logf("Sending change.FullSet work item...")
|
||||
batcher.AddWork(change.FullSet)
|
||||
|
||||
// Give short time for processing
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Check if any update was received
|
||||
select {
|
||||
case data := <-nodes[0].ch:
|
||||
t.Logf("SUCCESS: Received update after FullSet!")
|
||||
|
||||
if data != nil {
|
||||
// Detailed analysis of the response - data is already a MapResponse
|
||||
t.Logf("Response details:")
|
||||
t.Logf(" Peers: %d", len(data.Peers))
|
||||
t.Logf(" PeersChangedPatch: %d", len(data.PeersChangedPatch))
|
||||
t.Logf(" PeersChanged: %d", len(data.PeersChanged))
|
||||
t.Logf(" PeersRemoved: %d", len(data.PeersRemoved))
|
||||
t.Logf(" DERPMap: %v", data.DERPMap != nil)
|
||||
t.Logf(" KeepAlive: %v", data.KeepAlive)
|
||||
t.Logf(" Node: %v", data.Node != nil)
|
||||
|
||||
if len(data.Peers) > 0 {
|
||||
t.Logf("SUCCESS: Full peer list received with %d peers", len(data.Peers))
|
||||
} else if len(data.PeersChangedPatch) > 0 {
|
||||
t.Errorf("ERROR: Received patch update instead of full update!")
|
||||
} else if data.DERPMap != nil {
|
||||
t.Logf("Received DERP map update")
|
||||
} else if data.Node != nil {
|
||||
t.Logf("Received self node update")
|
||||
} else {
|
||||
t.Errorf("ERROR: Received unknown update type!")
|
||||
}
|
||||
|
||||
batcher := testData.Batcher
|
||||
node1 := testData.Nodes[0]
|
||||
node2 := testData.Nodes[1]
|
||||
@@ -2328,12 +2484,53 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.Errorf("Response data is nil")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Errorf("CRITICAL: No update received after FullSet within 2 seconds!")
|
||||
t.Errorf("This indicates FullSet work items are not being processed at all")
|
||||
}
|
||||
|
||||
// Send another update and verify remaining connections still work
|
||||
clearChannel(node1.ch)
|
||||
clearChannel(thirdChannel)
|
||||
|
||||
testChangeSet2 := change.ChangeSet{
|
||||
NodeID: node2.n.ID,
|
||||
Change: change.NodeNewOrUpdate,
|
||||
SelfUpdateOnly: false,
|
||||
}
|
||||
|
||||
batcher.AddWork(testChangeSet2)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify remaining connections still receive updates
|
||||
remaining1Received := false
|
||||
remaining3Received := false
|
||||
|
||||
select {
|
||||
case mapResp := <-node1.ch:
|
||||
remaining1Received = (mapResp != nil)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Errorf("Node1 connection 1 did not receive update after removal")
|
||||
}
|
||||
|
||||
select {
|
||||
case mapResp := <-thirdChannel:
|
||||
remaining3Received = (mapResp != nil)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Errorf("Node1 connection 3 did not receive update after removal")
|
||||
}
|
||||
|
||||
if remaining1Received && remaining3Received {
|
||||
t.Logf("SUCCESS: Remaining connections still receive updates after removal")
|
||||
} else {
|
||||
t.Errorf("FAILURE: Remaining connections failed to receive updates - conn1: %t, conn3: %t",
|
||||
remaining1Received, remaining3Received)
|
||||
}
|
||||
|
||||
// Verify second channel no longer receives updates (should be closed/removed)
|
||||
select {
|
||||
case <-secondChannel:
|
||||
t.Errorf("Removed connection still received update - this should not happen")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Logf("SUCCESS: Removed connection correctly no longer receives updates")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,6 +20,8 @@ type MapResponseBuilder struct {
|
||||
nodeID types.NodeID
|
||||
capVer tailcfg.CapabilityVersion
|
||||
errs []error
|
||||
|
||||
debugType debugType
|
||||
}
|
||||
|
||||
type debugType string
|
||||
|
||||
@@ -139,11 +139,11 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||
func (m *mapper) fullMapResponse(
|
||||
nodeID types.NodeID,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
messages ...string,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
peers := m.state.ListPeers(nodeID)
|
||||
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(fullResponseDebug).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithSelfNode().
|
||||
WithDERPMap().
|
||||
@@ -162,6 +162,7 @@ func (m *mapper) derpMapResponse(
|
||||
nodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(derpResponseDebug).
|
||||
WithDERPMap().
|
||||
Build()
|
||||
}
|
||||
@@ -173,6 +174,7 @@ func (m *mapper) peerChangedPatchResponse(
|
||||
changed []*tailcfg.PeerChange,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(patchResponseDebug).
|
||||
WithPeerChangedPatch(changed).
|
||||
Build()
|
||||
}
|
||||
@@ -186,6 +188,7 @@ func (m *mapper) peerChangeResponse(
|
||||
peers := m.state.ListPeers(nodeID, changedNodeID)
|
||||
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(changeResponseDebug).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithSelfNode().
|
||||
WithUserProfiles(peers).
|
||||
@@ -199,6 +202,7 @@ func (m *mapper) peerRemovedResponse(
|
||||
removedNodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(removeResponseDebug).
|
||||
WithPeersRemoved(removedNodeID).
|
||||
Build()
|
||||
}
|
||||
@@ -214,7 +218,7 @@ func writeDebugMapResponse(
|
||||
}
|
||||
|
||||
perms := fs.FileMode(debugMapResponsePerm)
|
||||
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", node.ID))
|
||||
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID))
|
||||
err = os.MkdirAll(mPath, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -224,7 +228,7 @@ func writeDebugMapResponse(
|
||||
|
||||
mapResponsePath := path.Join(
|
||||
mPath,
|
||||
fmt.Sprintf("%s.json", now),
|
||||
fmt.Sprintf("%s-%s.json", now, t),
|
||||
)
|
||||
|
||||
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
||||
@@ -244,7 +248,11 @@ func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, er
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
nodes, err := os.ReadDir(debugDumpMapResponsePath)
|
||||
return ReadMapResponsesFromDirectory(debugDumpMapResponsePath)
|
||||
}
|
||||
|
||||
func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapResponse, error) {
|
||||
nodes, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -263,7 +271,7 @@ func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, er
|
||||
|
||||
nodeID := types.NodeID(nodeIDu)
|
||||
|
||||
files, err := os.ReadDir(path.Join(debugDumpMapResponsePath, node.Name()))
|
||||
files, err := os.ReadDir(path.Join(dir, node.Name()))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Reading dir %s", node.Name())
|
||||
continue
|
||||
@@ -278,7 +286,7 @@ func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, er
|
||||
continue
|
||||
}
|
||||
|
||||
body, err := os.ReadFile(path.Join(debugDumpMapResponsePath, node.Name(), file.Name()))
|
||||
body, err := os.ReadFile(path.Join(dir, node.Name(), file.Name()))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Reading file %s", file.Name())
|
||||
continue
|
||||
|
||||
@@ -158,7 +158,6 @@ func TestTailNode(t *testing.T) {
|
||||
|
||||
Tags: []string{},
|
||||
|
||||
LastSeen: &lastSeen,
|
||||
MachineAuthorized: true,
|
||||
|
||||
CapMap: tailcfg.NodeCapMap{
|
||||
|
||||
@@ -175,8 +175,8 @@ func rejectUnsupported(
|
||||
Int("client_cap_ver", int(version)).
|
||||
Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)).
|
||||
Str("client_version", capver.TailscaleVersion(version)).
|
||||
Str("node_key", nkey.ShortString()).
|
||||
Str("machine_key", mkey.ShortString()).
|
||||
Str("node.key", nkey.ShortString()).
|
||||
Str("machine.key", mkey.ShortString()).
|
||||
Msg("unsupported client connected")
|
||||
http.Error(writer, unsupportedClientError(version).Error(), http.StatusBadRequest)
|
||||
|
||||
@@ -282,7 +282,7 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
if err := json.NewEncoder(writer).Encode(registerResponse); err != nil {
|
||||
log.Error().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse")
|
||||
log.Error().Caller().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
||||
a.registrationCache.Set(state, registrationInfo)
|
||||
|
||||
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
|
||||
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
|
||||
log.Debug().Caller().Msgf("Redirecting to %s for authentication", authURL)
|
||||
|
||||
http.Redirect(writer, req, authURL, http.StatusFound)
|
||||
}
|
||||
@@ -311,7 +311,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(werr).
|
||||
Msg("Failed to write response")
|
||||
Msg("Failed to write HTTP response")
|
||||
}
|
||||
|
||||
return
|
||||
@@ -349,7 +349,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
if _, err := writer.Write(content.Bytes()); err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
util.LogErr(err, "Failed to write HTTP response")
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
@@ -34,7 +34,7 @@ func (pol *Policy) compileFilterRules(
|
||||
|
||||
srcIPs, err := acl.Sources.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving source ips")
|
||||
log.Trace().Caller().Err(err).Msgf("resolving source ips")
|
||||
}
|
||||
|
||||
if srcIPs == nil || len(srcIPs.Prefixes()) == 0 {
|
||||
@@ -52,11 +52,11 @@ func (pol *Policy) compileFilterRules(
|
||||
for _, dest := range acl.Destinations {
|
||||
ips, err := dest.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving destination ips")
|
||||
log.Trace().Caller().Err(err).Msgf("resolving destination ips")
|
||||
}
|
||||
|
||||
if ips == nil {
|
||||
log.Debug().Msgf("destination resolved to nil ips: %v", dest)
|
||||
log.Debug().Caller().Msgf("destination resolved to nil ips: %v", dest)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ func (pol *Policy) compileSSHPolicy(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Trace().Msgf("compiling SSH policy for node %q", node.Hostname())
|
||||
log.Trace().Caller().Msgf("compiling SSH policy for node %q", node.Hostname())
|
||||
|
||||
var rules []*tailcfg.SSHRule
|
||||
|
||||
@@ -115,7 +115,7 @@ func (pol *Policy) compileSSHPolicy(
|
||||
for _, src := range rule.Destinations {
|
||||
ips, err := src.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving destination ips")
|
||||
log.Trace().Caller().Err(err).Msgf("resolving destination ips")
|
||||
}
|
||||
dest.AddSet(ips)
|
||||
}
|
||||
@@ -142,7 +142,7 @@ func (pol *Policy) compileSSHPolicy(
|
||||
var principals []*tailcfg.SSHPrincipal
|
||||
srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule)
|
||||
log.Trace().Caller().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule)
|
||||
continue // Skip this rule if we can't resolve sources
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -79,6 +80,14 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
|
||||
filterHash := deephash.Hash(&filter)
|
||||
filterChanged := filterHash != pm.filterHash
|
||||
if filterChanged {
|
||||
log.Debug().
|
||||
Str("filter.hash.old", pm.filterHash.String()[:8]).
|
||||
Str("filter.hash.new", filterHash.String()[:8]).
|
||||
Int("filter.rules", len(pm.filter)).
|
||||
Int("filter.rules.new", len(filter)).
|
||||
Msg("Policy filter hash changed")
|
||||
}
|
||||
pm.filter = filter
|
||||
pm.filterHash = filterHash
|
||||
if filterChanged {
|
||||
@@ -95,6 +104,14 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
|
||||
tagOwnerMapHash := deephash.Hash(&tagMap)
|
||||
tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash
|
||||
if tagOwnerChanged {
|
||||
log.Debug().
|
||||
Str("tagOwner.hash.old", pm.tagOwnerMapHash.String()[:8]).
|
||||
Str("tagOwner.hash.new", tagOwnerMapHash.String()[:8]).
|
||||
Int("tagOwners.old", len(pm.tagOwnerMap)).
|
||||
Int("tagOwners.new", len(tagMap)).
|
||||
Msg("Tag owner hash changed")
|
||||
}
|
||||
pm.tagOwnerMap = tagMap
|
||||
pm.tagOwnerMapHash = tagOwnerMapHash
|
||||
|
||||
@@ -105,19 +122,42 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
|
||||
autoApproveMapHash := deephash.Hash(&autoMap)
|
||||
autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash
|
||||
if autoApproveChanged {
|
||||
log.Debug().
|
||||
Str("autoApprove.hash.old", pm.autoApproveMapHash.String()[:8]).
|
||||
Str("autoApprove.hash.new", autoApproveMapHash.String()[:8]).
|
||||
Int("autoApprovers.old", len(pm.autoApproveMap)).
|
||||
Int("autoApprovers.new", len(autoMap)).
|
||||
Msg("Auto-approvers hash changed")
|
||||
}
|
||||
pm.autoApproveMap = autoMap
|
||||
pm.autoApproveMapHash = autoApproveMapHash
|
||||
|
||||
exitSetHash := deephash.Hash(&autoMap)
|
||||
exitSetHash := deephash.Hash(&exitSet)
|
||||
exitSetChanged := exitSetHash != pm.exitSetHash
|
||||
if exitSetChanged {
|
||||
log.Debug().
|
||||
Str("exitSet.hash.old", pm.exitSetHash.String()[:8]).
|
||||
Str("exitSet.hash.new", exitSetHash.String()[:8]).
|
||||
Msg("Exit node set hash changed")
|
||||
}
|
||||
pm.exitSet = exitSet
|
||||
pm.exitSetHash = exitSetHash
|
||||
|
||||
// If neither of the calculated values changed, no need to update nodes
|
||||
if !filterChanged && !tagOwnerChanged && !autoApproveChanged && !exitSetChanged {
|
||||
log.Trace().
|
||||
Msg("Policy evaluation detected no changes - all hashes match")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Bool("filter.changed", filterChanged).
|
||||
Bool("tagOwners.changed", tagOwnerChanged).
|
||||
Bool("autoApprovers.changed", autoApproveChanged).
|
||||
Bool("exitNodes.changed", exitSetChanged).
|
||||
Msg("Policy changes require node updates")
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -151,6 +191,16 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
// Log policy metadata for debugging
|
||||
log.Debug().
|
||||
Int("policy.bytes", len(polB)).
|
||||
Int("acls.count", len(pol.ACLs)).
|
||||
Int("groups.count", len(pol.Groups)).
|
||||
Int("hosts.count", len(pol.Hosts)).
|
||||
Int("tagOwners.count", len(pol.TagOwners)).
|
||||
Int("autoApprovers.routes.count", len(pol.AutoApprovers.Routes)).
|
||||
Msg("Policy parsed successfully")
|
||||
|
||||
pm.pol = pol
|
||||
|
||||
return pm.updateLocked()
|
||||
|
||||
@@ -216,6 +216,21 @@ func (m *mapSession) serveLongPoll() {
|
||||
|
||||
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
|
||||
|
||||
// TODO(kradalby): Redo the comments here
|
||||
// Add node to batcher so it can receive updates,
|
||||
// adding this before connecting it to the state ensure that
|
||||
// it does not miss any updates that might be sent in the split
|
||||
// time between the node connecting and the batcher being ready.
|
||||
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil {
|
||||
m.errf(err, "failed to add node to batcher")
|
||||
log.Error().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Err(err).Msg("AddNode failed in poll session")
|
||||
return
|
||||
}
|
||||
log.Debug().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("AddNode succeeded in poll session because node added to batcher")
|
||||
|
||||
m.h.Change(mapReqChange)
|
||||
m.h.Change(connectChanges...)
|
||||
|
||||
// Loop through updates and continuously send them to the
|
||||
// client.
|
||||
for {
|
||||
@@ -227,7 +242,7 @@ func (m *mapSession) serveLongPoll() {
|
||||
return
|
||||
|
||||
case <-ctx.Done():
|
||||
m.tracef("poll context done")
|
||||
m.tracef("poll context done chan:%p", m.ch)
|
||||
mapResponseEnded.WithLabelValues("done").Inc()
|
||||
return
|
||||
|
||||
@@ -295,7 +310,15 @@ func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
|
||||
}
|
||||
}
|
||||
|
||||
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("node.name", m.node.Hostname).
|
||||
Uint64("node.id", m.node.ID.Uint64()).
|
||||
Str("chan", fmt.Sprintf("%p", m.ch)).
|
||||
TimeDiff("timeSpent", time.Now(), startWrite).
|
||||
Str("machine.key", m.node.MachineKey.String()).
|
||||
Bool("keepalive", msg.KeepAlive).
|
||||
Msgf("finished writing mapresp to node chan(%p)", m.ch)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -305,14 +328,14 @@ var keepAlive = tailcfg.MapResponse{
|
||||
}
|
||||
|
||||
func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcfg.PeerChange) {
|
||||
trace := log.Trace().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname)
|
||||
trace := log.Trace().Caller().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname)
|
||||
|
||||
if peerChange.Key != nil {
|
||||
trace = trace.Str("node_key", peerChange.Key.ShortString())
|
||||
trace = trace.Str("node.key", peerChange.Key.ShortString())
|
||||
}
|
||||
|
||||
if peerChange.DiscoKey != nil {
|
||||
trace = trace.Str("disco_key", peerChange.DiscoKey.ShortString())
|
||||
trace = trace.Str("disco.key", peerChange.DiscoKey.ShortString())
|
||||
}
|
||||
|
||||
if peerChange.Online != nil {
|
||||
@@ -349,7 +372,7 @@ func logPollFunc(
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
Str("node", node.Hostname).
|
||||
Str("node.name", node.Hostname).
|
||||
Msgf(msg, a...)
|
||||
},
|
||||
func(msg string, a ...any) {
|
||||
@@ -358,7 +381,7 @@ func logPollFunc(
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
Str("node", node.Hostname).
|
||||
Str("node.name", node.Hostname).
|
||||
Msgf(msg, a...)
|
||||
},
|
||||
func(msg string, a ...any) {
|
||||
@@ -367,7 +390,7 @@ func logPollFunc(
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
Str("node", node.Hostname).
|
||||
Str("node.name", node.Hostname).
|
||||
Msgf(msg, a...)
|
||||
},
|
||||
func(err error, msg string, a ...any) {
|
||||
@@ -376,7 +399,7 @@ func logPollFunc(
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
Str("node", node.Hostname).
|
||||
Str("node.name", node.Hostname).
|
||||
Err(err).
|
||||
Msgf(msg, a...)
|
||||
}
|
||||
|
||||
@@ -1430,7 +1430,7 @@ func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) {
|
||||
return change.EmptySet, fmt.Errorf("listing users for policy update: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().Int("userCount", len(users)).Msg("Updating policy manager with users")
|
||||
log.Debug().Caller().Int("user.count", len(users)).Msg("Policy manager user update initiated because user list modification detected")
|
||||
|
||||
changed, err := s.polMan.SetUsers(users)
|
||||
if err != nil {
|
||||
|
||||
@@ -97,6 +97,35 @@ func (c ChangeSet) IsFull() bool {
|
||||
return c.Change == Full || c.Change == Policy
|
||||
}
|
||||
|
||||
func HasFull(cs []ChangeSet) bool {
|
||||
for _, c := range cs {
|
||||
if c.IsFull() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func SplitAllAndSelf(cs []ChangeSet) (all []ChangeSet, self []ChangeSet) {
|
||||
for _, c := range cs {
|
||||
if c.SelfUpdateOnly {
|
||||
self = append(self, c)
|
||||
} else {
|
||||
all = append(all, c)
|
||||
}
|
||||
}
|
||||
return all, self
|
||||
}
|
||||
|
||||
func RemoveUpdatesForSelf(id types.NodeID, cs []ChangeSet) (ret []ChangeSet) {
|
||||
for _, c := range cs {
|
||||
if c.NodeID != id || c.Change.AlsoSelf() {
|
||||
ret = append(ret, c)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (c ChangeSet) AlsoSelf() bool {
|
||||
// If NodeID is 0, it means this ChangeSet is not related to a specific node,
|
||||
// so we consider it as a change that should be sent to all nodes.
|
||||
|
||||
@@ -489,6 +489,7 @@ func derpConfig() DERPConfig {
|
||||
urlAddr, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("url", urlStr).
|
||||
Err(err).
|
||||
Msg("Failed to parse url, ignoring...")
|
||||
@@ -561,6 +562,7 @@ func logConfig() LogConfig {
|
||||
logFormat = TextLogFormat
|
||||
default:
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "GetLogConfig").
|
||||
Msgf("Could not parse log format: %s. Valid choices are 'json' or 'text'", logFormatOpt)
|
||||
}
|
||||
|
||||
@@ -54,6 +54,20 @@ func (id NodeID) String() string {
|
||||
return strconv.FormatUint(id.Uint64(), util.Base10)
|
||||
}
|
||||
|
||||
func ParseNodeID(s string) (NodeID, error) {
|
||||
id, err := strconv.ParseUint(s, util.Base10, 64)
|
||||
return NodeID(id), err
|
||||
}
|
||||
|
||||
func MustParseNodeID(s string) NodeID {
|
||||
id, err := ParseNodeID(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return id
|
||||
}
|
||||
|
||||
// Node is a Headscale client.
|
||||
type Node struct {
|
||||
ID NodeID `gorm:"primary_key"`
|
||||
|
||||
@@ -61,6 +61,7 @@ func (pak *PreAuthKey) Validate() error {
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Str("key", pak.Key).
|
||||
Bool("hasExpiration", pak.Expiration != nil).
|
||||
Time("expiration", func() time.Time {
|
||||
|
||||
@@ -321,7 +321,7 @@ func (u *User) FromClaim(claims *OIDCClaims) {
|
||||
if err == nil {
|
||||
u.Name = claims.Username
|
||||
} else {
|
||||
log.Debug().Err(err).Msgf("Username %s is not valid", claims.Username)
|
||||
log.Debug().Caller().Err(err).Msgf("Username %s is not valid", claims.Username)
|
||||
}
|
||||
|
||||
if claims.EmailVerified {
|
||||
|
||||
Reference in New Issue
Block a user