mirror of
https://github.com/juanfont/headscale.git
synced 2026-04-14 21:09:52 +02:00
auth: generalise auth flow and introduce AuthVerdict
Generalise the registration pipeline to a more general auth pipeline supporting both node registrations and SSH check auth requests. Rename RegistrationID to AuthID, unexport AuthRequest fields, and introduce AuthVerdict to unify the auth finish API. Add the urlParam generic helper for extracting typed URL parameters from chi routes, used by the new auth request handler. Updates #1850
This commit is contained in:
@@ -37,7 +37,7 @@ var createNodeCmd = &cobra.Command{
|
|||||||
name, _ := cmd.Flags().GetString("name")
|
name, _ := cmd.Flags().GetString("name")
|
||||||
registrationID, _ := cmd.Flags().GetString("key")
|
registrationID, _ := cmd.Flags().GetString("key")
|
||||||
|
|
||||||
_, err := types.RegistrationIDFromString(registrationID)
|
_, err := types.AuthIDFromString(registrationID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing machine key: %w", err)
|
return fmt.Errorf("parsing machine key: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -479,7 +479,8 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *chi.Mux {
|
|||||||
r.Get("/health", h.HealthHandler)
|
r.Get("/health", h.HealthHandler)
|
||||||
r.Get("/version", h.VersionHandler)
|
r.Get("/version", h.VersionHandler)
|
||||||
r.Get("/key", h.KeyHandler)
|
r.Get("/key", h.KeyHandler)
|
||||||
r.Get("/register/{registration_id}", h.authProvider.RegisterHandler)
|
r.Get("/register/{auth_id}", h.authProvider.RegisterHandler)
|
||||||
|
r.Get("/auth/{auth_id}", h.authProvider.AuthHandler)
|
||||||
|
|
||||||
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
|
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
|
||||||
r.Get("/oidc/callback", provider.OIDCCallbackHandler)
|
r.Get("/oidc/callback", provider.OIDCCallbackHandler)
|
||||||
|
|||||||
@@ -20,7 +20,9 @@ import (
|
|||||||
|
|
||||||
type AuthProvider interface {
|
type AuthProvider interface {
|
||||||
RegisterHandler(w http.ResponseWriter, r *http.Request)
|
RegisterHandler(w http.ResponseWriter, r *http.Request)
|
||||||
AuthURL(regID types.RegistrationID) string
|
AuthHandler(w http.ResponseWriter, r *http.Request)
|
||||||
|
RegisterURL(authID types.AuthID) string
|
||||||
|
AuthURL(authID types.AuthID) string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) handleRegister(
|
func (h *Headscale) handleRegister(
|
||||||
@@ -263,22 +265,24 @@ func (h *Headscale) waitForFollowup(
|
|||||||
return nil, NewHTTPError(http.StatusUnauthorized, "invalid followup URL", err)
|
return nil, NewHTTPError(http.StatusUnauthorized, "invalid followup URL", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
|
followupReg, err := types.AuthIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewHTTPError(http.StatusUnauthorized, "invalid registration ID", err)
|
return nil, NewHTTPError(http.StatusUnauthorized, "invalid registration ID", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if reg, ok := h.state.GetRegistrationCacheEntry(followupReg); ok {
|
if reg, ok := h.state.GetAuthCacheEntry(followupReg); ok {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
|
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
|
||||||
case node := <-reg.Registered:
|
case verdict := <-reg.WaitForAuth():
|
||||||
if node == nil {
|
if verdict.Accept() {
|
||||||
// registration is expired in the cache, instruct the client to try a new registration
|
if !verdict.Node.Valid() {
|
||||||
return h.reqToNewRegisterResponse(req, machineKey)
|
// registration is expired in the cache, instruct the client to try a new registration
|
||||||
}
|
return h.reqToNewRegisterResponse(req, machineKey)
|
||||||
|
}
|
||||||
|
|
||||||
return nodeToRegisterResponse(node.View()), nil
|
return nodeToRegisterResponse(verdict.Node), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -293,14 +297,14 @@ func (h *Headscale) reqToNewRegisterResponse(
|
|||||||
req tailcfg.RegisterRequest,
|
req tailcfg.RegisterRequest,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
) (*tailcfg.RegisterResponse, error) {
|
) (*tailcfg.RegisterResponse, error) {
|
||||||
newRegID, err := types.NewRegistrationID()
|
newAuthID, err := types.NewAuthID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
|
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure we have a valid hostname
|
// Ensure we have a valid hostname
|
||||||
hostname := util.EnsureHostname(
|
hostname := util.EnsureHostname(
|
||||||
req.Hostinfo,
|
req.Hostinfo.View(),
|
||||||
machineKey.String(),
|
machineKey.String(),
|
||||||
req.NodeKey.String(),
|
req.NodeKey.String(),
|
||||||
)
|
)
|
||||||
@@ -309,25 +313,25 @@ func (h *Headscale) reqToNewRegisterResponse(
|
|||||||
hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{})
|
hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{})
|
||||||
hostinfo.Hostname = hostname
|
hostinfo.Hostname = hostname
|
||||||
|
|
||||||
nodeToRegister := types.NewRegisterNode(
|
nodeToRegister := types.Node{
|
||||||
types.Node{
|
Hostname: hostname,
|
||||||
Hostname: hostname,
|
MachineKey: machineKey,
|
||||||
MachineKey: machineKey,
|
NodeKey: req.NodeKey,
|
||||||
NodeKey: req.NodeKey,
|
Hostinfo: hostinfo,
|
||||||
Hostinfo: hostinfo,
|
LastSeen: new(time.Now()),
|
||||||
LastSeen: new(time.Now()),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if !req.Expiry.IsZero() {
|
|
||||||
nodeToRegister.Node.Expiry = &req.Expiry
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Msgf("new followup node registration using key: %s", newRegID)
|
if !req.Expiry.IsZero() {
|
||||||
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister)
|
nodeToRegister.Expiry = &req.Expiry
|
||||||
|
}
|
||||||
|
|
||||||
|
authRegReq := types.NewRegisterAuthRequest(nodeToRegister)
|
||||||
|
|
||||||
|
log.Info().Msgf("new followup node registration using key: %s", newAuthID)
|
||||||
|
h.state.SetAuthCacheEntry(newAuthID, authRegReq)
|
||||||
|
|
||||||
return &tailcfg.RegisterResponse{
|
return &tailcfg.RegisterResponse{
|
||||||
AuthURL: h.authProvider.AuthURL(newRegID),
|
AuthURL: h.authProvider.RegisterURL(newAuthID),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -378,13 +382,6 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
|||||||
// Send both changes. Empty changes are ignored by Change().
|
// Send both changes. Empty changes are ignored by Change().
|
||||||
h.Change(changed, routesChange)
|
h.Change(changed, routesChange)
|
||||||
|
|
||||||
// TODO(kradalby): I think this is covered above, but we need to validate that.
|
|
||||||
// // If policy changed due to node registration, send a separate policy change
|
|
||||||
// if policyChanged {
|
|
||||||
// policyChange := change.PolicyChange()
|
|
||||||
// h.Change(policyChange)
|
|
||||||
// }
|
|
||||||
|
|
||||||
resp := &tailcfg.RegisterResponse{
|
resp := &tailcfg.RegisterResponse{
|
||||||
MachineAuthorized: true,
|
MachineAuthorized: true,
|
||||||
NodeKeyExpired: node.IsExpired(),
|
NodeKeyExpired: node.IsExpired(),
|
||||||
@@ -406,14 +403,14 @@ func (h *Headscale) handleRegisterInteractive(
|
|||||||
req tailcfg.RegisterRequest,
|
req tailcfg.RegisterRequest,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
) (*tailcfg.RegisterResponse, error) {
|
) (*tailcfg.RegisterResponse, error) {
|
||||||
registrationId, err := types.NewRegistrationID()
|
authID, err := types.NewAuthID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("generating registration ID: %w", err)
|
return nil, fmt.Errorf("generating registration ID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure we have a valid hostname
|
// Ensure we have a valid hostname
|
||||||
hostname := util.EnsureHostname(
|
hostname := util.EnsureHostname(
|
||||||
req.Hostinfo,
|
req.Hostinfo.View(),
|
||||||
machineKey.String(),
|
machineKey.String(),
|
||||||
req.NodeKey.String(),
|
req.NodeKey.String(),
|
||||||
)
|
)
|
||||||
@@ -436,28 +433,28 @@ func (h *Headscale) handleRegisterInteractive(
|
|||||||
|
|
||||||
hostinfo.Hostname = hostname
|
hostinfo.Hostname = hostname
|
||||||
|
|
||||||
nodeToRegister := types.NewRegisterNode(
|
nodeToRegister := types.Node{
|
||||||
types.Node{
|
Hostname: hostname,
|
||||||
Hostname: hostname,
|
MachineKey: machineKey,
|
||||||
MachineKey: machineKey,
|
NodeKey: req.NodeKey,
|
||||||
NodeKey: req.NodeKey,
|
Hostinfo: hostinfo,
|
||||||
Hostinfo: hostinfo,
|
LastSeen: new(time.Now()),
|
||||||
LastSeen: new(time.Now()),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if !req.Expiry.IsZero() {
|
|
||||||
nodeToRegister.Node.Expiry = &req.Expiry
|
|
||||||
}
|
}
|
||||||
|
|
||||||
h.state.SetRegistrationCacheEntry(
|
if !req.Expiry.IsZero() {
|
||||||
registrationId,
|
nodeToRegister.Expiry = &req.Expiry
|
||||||
nodeToRegister,
|
}
|
||||||
|
|
||||||
|
authRegReq := types.NewRegisterAuthRequest(nodeToRegister)
|
||||||
|
|
||||||
|
h.state.SetAuthCacheEntry(
|
||||||
|
authID,
|
||||||
|
authRegReq,
|
||||||
)
|
)
|
||||||
|
|
||||||
log.Info().Msgf("starting node registration using key: %s", registrationId)
|
log.Info().Msgf("starting node registration using key: %s", authID)
|
||||||
|
|
||||||
return &tailcfg.RegisterResponse{
|
return &tailcfg.RegisterResponse{
|
||||||
AuthURL: h.authProvider.AuthURL(registrationId),
|
AuthURL: h.authProvider.RegisterURL(authID),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -651,8 +651,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
|
|||||||
|
|
||||||
// Step 1: Create user-owned node WITH expiry set
|
// Step 1: Create user-owned node WITH expiry set
|
||||||
clientExpiry := time.Now().Add(24 * time.Hour)
|
clientExpiry := time.Now().Add(24 * time.Hour)
|
||||||
registrationID1 := types.MustRegistrationID()
|
registrationID1 := types.MustAuthID()
|
||||||
regEntry1 := types.NewRegisterNode(types.Node{
|
regEntry1 := types.NewRegisterAuthRequest(types.Node{
|
||||||
MachineKey: machineKey.Public(),
|
MachineKey: machineKey.Public(),
|
||||||
NodeKey: nodeKey1.Public(),
|
NodeKey: nodeKey1.Public(),
|
||||||
Hostname: "personal-to-tagged",
|
Hostname: "personal-to-tagged",
|
||||||
@@ -662,7 +662,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
|
|||||||
},
|
},
|
||||||
Expiry: &clientExpiry,
|
Expiry: &clientExpiry,
|
||||||
})
|
})
|
||||||
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1)
|
app.state.SetAuthCacheEntry(registrationID1, regEntry1)
|
||||||
|
|
||||||
node, _, err := app.state.HandleNodeFromAuthPath(
|
node, _, err := app.state.HandleNodeFromAuthPath(
|
||||||
registrationID1, types.UserID(user.ID), nil, "webauth",
|
registrationID1, types.UserID(user.ID), nil, "webauth",
|
||||||
@@ -673,8 +673,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
|
|||||||
|
|
||||||
// Step 2: Re-auth with tags (Personal → Tagged conversion)
|
// Step 2: Re-auth with tags (Personal → Tagged conversion)
|
||||||
nodeKey2 := key.NewNode()
|
nodeKey2 := key.NewNode()
|
||||||
registrationID2 := types.MustRegistrationID()
|
registrationID2 := types.MustAuthID()
|
||||||
regEntry2 := types.NewRegisterNode(types.Node{
|
regEntry2 := types.NewRegisterAuthRequest(types.Node{
|
||||||
MachineKey: machineKey.Public(),
|
MachineKey: machineKey.Public(),
|
||||||
NodeKey: nodeKey2.Public(),
|
NodeKey: nodeKey2.Public(),
|
||||||
Hostname: "personal-to-tagged",
|
Hostname: "personal-to-tagged",
|
||||||
@@ -684,7 +684,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
|
|||||||
},
|
},
|
||||||
Expiry: &clientExpiry, // Client still sends expiry
|
Expiry: &clientExpiry, // Client still sends expiry
|
||||||
})
|
})
|
||||||
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2)
|
app.state.SetAuthCacheEntry(registrationID2, regEntry2)
|
||||||
|
|
||||||
nodeAfter, _, err := app.state.HandleNodeFromAuthPath(
|
nodeAfter, _, err := app.state.HandleNodeFromAuthPath(
|
||||||
registrationID2, types.UserID(user.ID), nil, "webauth",
|
registrationID2, types.UserID(user.ID), nil, "webauth",
|
||||||
@@ -723,8 +723,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
|
|||||||
nodeKey1 := key.NewNode()
|
nodeKey1 := key.NewNode()
|
||||||
|
|
||||||
// Step 1: Create tagged node (expiry should be nil)
|
// Step 1: Create tagged node (expiry should be nil)
|
||||||
registrationID1 := types.MustRegistrationID()
|
registrationID1 := types.MustAuthID()
|
||||||
regEntry1 := types.NewRegisterNode(types.Node{
|
regEntry1 := types.NewRegisterAuthRequest(types.Node{
|
||||||
MachineKey: machineKey.Public(),
|
MachineKey: machineKey.Public(),
|
||||||
NodeKey: nodeKey1.Public(),
|
NodeKey: nodeKey1.Public(),
|
||||||
Hostname: "tagged-to-personal",
|
Hostname: "tagged-to-personal",
|
||||||
@@ -733,7 +733,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
|
|||||||
RequestTags: []string{"tag:server"}, // Tagged node
|
RequestTags: []string{"tag:server"}, // Tagged node
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1)
|
app.state.SetAuthCacheEntry(registrationID1, regEntry1)
|
||||||
|
|
||||||
node, _, err := app.state.HandleNodeFromAuthPath(
|
node, _, err := app.state.HandleNodeFromAuthPath(
|
||||||
registrationID1, types.UserID(user.ID), nil, "webauth",
|
registrationID1, types.UserID(user.ID), nil, "webauth",
|
||||||
@@ -745,8 +745,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
|
|||||||
// Step 2: Re-auth with empty tags (Tagged → Personal conversion)
|
// Step 2: Re-auth with empty tags (Tagged → Personal conversion)
|
||||||
nodeKey2 := key.NewNode()
|
nodeKey2 := key.NewNode()
|
||||||
clientExpiry := time.Now().Add(48 * time.Hour)
|
clientExpiry := time.Now().Add(48 * time.Hour)
|
||||||
registrationID2 := types.MustRegistrationID()
|
registrationID2 := types.MustAuthID()
|
||||||
regEntry2 := types.NewRegisterNode(types.Node{
|
regEntry2 := types.NewRegisterAuthRequest(types.Node{
|
||||||
MachineKey: machineKey.Public(),
|
MachineKey: machineKey.Public(),
|
||||||
NodeKey: nodeKey2.Public(),
|
NodeKey: nodeKey2.Public(),
|
||||||
Hostname: "tagged-to-personal",
|
Hostname: "tagged-to-personal",
|
||||||
@@ -756,7 +756,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
|
|||||||
},
|
},
|
||||||
Expiry: &clientExpiry, // Client requests expiry
|
Expiry: &clientExpiry, // Client requests expiry
|
||||||
})
|
})
|
||||||
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2)
|
app.state.SetAuthCacheEntry(registrationID2, regEntry2)
|
||||||
|
|
||||||
nodeAfter, _, err := app.state.HandleNodeFromAuthPath(
|
nodeAfter, _, err := app.state.HandleNodeFromAuthPath(
|
||||||
registrationID2, types.UserID(user.ID), nil, "webauth",
|
registrationID2, types.UserID(user.ID), nil, "webauth",
|
||||||
|
|||||||
@@ -676,28 +676,23 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "followup_registration_success",
|
name: "followup_registration_success",
|
||||||
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper
|
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper
|
||||||
regID, err := types.NewRegistrationID()
|
regID, err := types.NewAuthID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
registered := make(chan *types.Node, 1)
|
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
|
||||||
nodeToRegister := types.RegisterNode{
|
Hostname: "followup-success-node",
|
||||||
Node: types.Node{
|
})
|
||||||
Hostname: "followup-success-node",
|
app.state.SetAuthCacheEntry(regID, nodeToRegister)
|
||||||
},
|
|
||||||
Registered: registered,
|
|
||||||
}
|
|
||||||
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
|
|
||||||
|
|
||||||
// Simulate successful registration - send to buffered channel
|
// Simulate successful registration
|
||||||
// The channel is buffered (size 1), so this can complete immediately
|
// handleRegister will receive the value when it starts waiting
|
||||||
// and handleRegister will receive the value when it starts waiting
|
|
||||||
go func() {
|
go func() {
|
||||||
user := app.state.CreateUserForTest("followup-user")
|
user := app.state.CreateUserForTest("followup-user")
|
||||||
|
|
||||||
node := app.state.CreateNodeForTest(user, "followup-success-node")
|
node := app.state.CreateNodeForTest(user, "followup-success-node")
|
||||||
registered <- node
|
nodeToRegister.FinishAuth(types.AuthVerdict{Node: node.View()})
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
|
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
|
||||||
@@ -723,20 +718,16 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "followup_registration_timeout",
|
name: "followup_registration_timeout",
|
||||||
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper
|
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper
|
||||||
regID, err := types.NewRegistrationID()
|
regID, err := types.NewAuthID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
registered := make(chan *types.Node, 1)
|
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
|
||||||
nodeToRegister := types.RegisterNode{
|
Hostname: "followup-timeout-node",
|
||||||
Node: types.Node{
|
})
|
||||||
Hostname: "followup-timeout-node",
|
app.state.SetAuthCacheEntry(regID, nodeToRegister)
|
||||||
},
|
// Don't call FinishRegistration - will timeout
|
||||||
Registered: registered,
|
|
||||||
}
|
|
||||||
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
|
|
||||||
// Don't send anything on channel - will timeout
|
|
||||||
|
|
||||||
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
|
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
|
||||||
},
|
},
|
||||||
@@ -1345,24 +1336,19 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "followup_registration_node_nil_response",
|
name: "followup_registration_node_nil_response",
|
||||||
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper
|
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper
|
||||||
regID, err := types.NewRegistrationID()
|
regID, err := types.NewAuthID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
registered := make(chan *types.Node, 1)
|
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
|
||||||
nodeToRegister := types.RegisterNode{
|
Hostname: "nil-response-node",
|
||||||
Node: types.Node{
|
})
|
||||||
Hostname: "nil-response-node",
|
app.state.SetAuthCacheEntry(regID, nodeToRegister)
|
||||||
},
|
|
||||||
Registered: registered,
|
|
||||||
}
|
|
||||||
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
|
|
||||||
|
|
||||||
// Simulate registration that returns nil (cache expired during auth)
|
// Simulate registration that returns empty NodeView (cache expired during auth)
|
||||||
// The channel is buffered (size 1), so this can complete immediately
|
|
||||||
go func() {
|
go func() {
|
||||||
registered <- nil // Nil indicates cache expiry
|
nodeToRegister.FinishAuth(types.AuthVerdict{Node: types.NodeView{}}) // Empty view indicates cache expiry
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
|
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
|
||||||
@@ -1815,7 +1801,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||||||
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper
|
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper
|
||||||
// Generate a registration ID that doesn't exist in cache
|
// Generate a registration ID that doesn't exist in cache
|
||||||
// This simulates an expired/missing cache entry
|
// This simulates an expired/missing cache entry
|
||||||
regID, err := types.NewRegistrationID()
|
regID, err := types.NewAuthID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -1847,11 +1833,11 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||||||
|
|
||||||
// Extract and validate the new registration ID exists in cache
|
// Extract and validate the new registration ID exists in cache
|
||||||
newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/")
|
newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/")
|
||||||
newRegID, err := types.RegistrationIDFromString(newRegIDStr)
|
newRegID, err := types.AuthIDFromString(newRegIDStr)
|
||||||
assert.NoError(t, err, "should be able to parse new registration ID") //nolint:testifylint // inside closure
|
assert.NoError(t, err, "should be able to parse new registration ID") //nolint:testifylint // inside closure
|
||||||
|
|
||||||
// Verify new registration entry exists in cache
|
// Verify new registration entry exists in cache
|
||||||
_, found := app.state.GetRegistrationCacheEntry(newRegID)
|
_, found := app.state.GetAuthCacheEntry(newRegID)
|
||||||
assert.True(t, found, "new registration should exist in cache")
|
assert.True(t, found, "new registration should exist in cache")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -2300,7 +2286,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify cache entry exists
|
// Verify cache entry exists
|
||||||
cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID)
|
cacheEntry, found := app.state.GetAuthCacheEntry(registrationID)
|
||||||
assert.True(t, found, "registration cache entry should exist initially")
|
assert.True(t, found, "registration cache entry should exist initially")
|
||||||
assert.NotNil(t, cacheEntry)
|
assert.NotNil(t, cacheEntry)
|
||||||
|
|
||||||
@@ -2315,7 +2301,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||||||
assert.Error(t, err, "should fail with invalid user ID") //nolint:testifylint // inside closure, uses assert pattern
|
assert.Error(t, err, "should fail with invalid user ID") //nolint:testifylint // inside closure, uses assert pattern
|
||||||
|
|
||||||
// Cache entry should still exist after auth error (for retry scenarios)
|
// Cache entry should still exist after auth error (for retry scenarios)
|
||||||
_, stillFound := app.state.GetRegistrationCacheEntry(registrationID)
|
_, stillFound := app.state.GetAuthCacheEntry(registrationID)
|
||||||
assert.True(t, stillFound, "registration cache entry should still exist after auth error for potential retry")
|
assert.True(t, stillFound, "registration cache entry should still exist after auth error for potential retry")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -2375,8 +2361,8 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||||||
assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs")
|
assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs")
|
||||||
|
|
||||||
// Both cache entries should exist simultaneously
|
// Both cache entries should exist simultaneously
|
||||||
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
|
_, found1 := app.state.GetAuthCacheEntry(regID1)
|
||||||
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
|
_, found2 := app.state.GetAuthCacheEntry(regID2)
|
||||||
|
|
||||||
assert.True(t, found1, "first registration cache entry should exist")
|
assert.True(t, found1, "first registration cache entry should exist")
|
||||||
assert.True(t, found2, "second registration cache entry should exist")
|
assert.True(t, found2, "second registration cache entry should exist")
|
||||||
@@ -2427,8 +2413,8 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify both exist
|
// Verify both exist
|
||||||
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
|
_, found1 := app.state.GetAuthCacheEntry(regID1)
|
||||||
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
|
_, found2 := app.state.GetAuthCacheEntry(regID2)
|
||||||
|
|
||||||
assert.True(t, found1, "first cache entry should exist")
|
assert.True(t, found1, "first cache entry should exist")
|
||||||
assert.True(t, found2, "second cache entry should exist")
|
assert.True(t, found2, "second cache entry should exist")
|
||||||
@@ -2490,7 +2476,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// First registration should still be in cache (not completed)
|
// First registration should still be in cache (not completed)
|
||||||
_, stillFound := app.state.GetRegistrationCacheEntry(regID1)
|
_, stillFound := app.state.GetAuthCacheEntry(regID1)
|
||||||
assert.True(t, stillFound, "first registration should still be pending")
|
assert.True(t, stillFound, "first registration should still be pending")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -2601,7 +2587,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
|||||||
var (
|
var (
|
||||||
initialResp *tailcfg.RegisterResponse
|
initialResp *tailcfg.RegisterResponse
|
||||||
authURL string
|
authURL string
|
||||||
registrationID types.RegistrationID
|
registrationID types.AuthID
|
||||||
finalResp *tailcfg.RegisterResponse
|
finalResp *tailcfg.RegisterResponse
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
@@ -2629,10 +2615,10 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
|||||||
|
|
||||||
if step.expectCacheEntry {
|
if step.expectCacheEntry {
|
||||||
// Verify registration cache entry was created
|
// Verify registration cache entry was created
|
||||||
cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID)
|
cacheEntry, found := app.state.GetAuthCacheEntry(registrationID)
|
||||||
require.True(t, found, "registration cache entry should exist")
|
require.True(t, found, "registration cache entry should exist")
|
||||||
require.NotNil(t, cacheEntry, "cache entry should not be nil")
|
require.NotNil(t, cacheEntry, "cache entry should not be nil")
|
||||||
require.Equal(t, req.NodeKey, cacheEntry.Node.NodeKey, "cache entry should have correct node key")
|
require.Equal(t, req.NodeKey, cacheEntry.Node().NodeKey(), "cache entry should have correct node key")
|
||||||
}
|
}
|
||||||
|
|
||||||
case stepTypeAuthCompletion:
|
case stepTypeAuthCompletion:
|
||||||
@@ -2692,7 +2678,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
|||||||
// Check cache cleanup expectation for this step
|
// Check cache cleanup expectation for this step
|
||||||
if step.expectCacheEntry == false && registrationID != "" {
|
if step.expectCacheEntry == false && registrationID != "" {
|
||||||
// Verify cache entry was cleaned up
|
// Verify cache entry was cleaned up
|
||||||
_, found := app.state.GetRegistrationCacheEntry(registrationID)
|
_, found := app.state.GetAuthCacheEntry(registrationID)
|
||||||
require.False(t, found, "registration cache entry should be cleaned up after step: %s", step.stepType)
|
require.False(t, found, "registration cache entry should be cleaned up after step: %s", step.stepType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2714,7 +2700,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL.
|
// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL.
|
||||||
func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) {
|
func extractRegistrationIDFromAuthURL(authURL string) (types.AuthID, error) {
|
||||||
// AuthURL format: "http://localhost/register/abc123"
|
// AuthURL format: "http://localhost/register/abc123"
|
||||||
const registerPrefix = "/register/"
|
const registerPrefix = "/register/"
|
||||||
|
|
||||||
@@ -2725,7 +2711,7 @@ func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, err
|
|||||||
|
|
||||||
idStr := authURL[idx+len(registerPrefix):]
|
idStr := authURL[idx+len(registerPrefix):]
|
||||||
|
|
||||||
return types.RegistrationIDFromString(idStr)
|
return types.AuthIDFromString(idStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateCompleteRegistrationResponse performs comprehensive validation of a registration response.
|
// validateCompleteRegistrationResponse performs comprehensive validation of a registration response.
|
||||||
@@ -3583,8 +3569,8 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) {
|
|||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
|
||||||
// Simulate a registration cache entry (as would be created during web auth)
|
// Simulate a registration cache entry (as would be created during web auth)
|
||||||
registrationID := types.MustRegistrationID()
|
registrationID := types.MustAuthID()
|
||||||
regEntry := types.NewRegisterNode(types.Node{
|
regEntry := types.NewRegisterAuthRequest(types.Node{
|
||||||
MachineKey: machineKey.Public(),
|
MachineKey: machineKey.Public(),
|
||||||
NodeKey: nodeKey.Public(),
|
NodeKey: nodeKey.Public(),
|
||||||
Hostname: "webauth-tags-node",
|
Hostname: "webauth-tags-node",
|
||||||
@@ -3593,7 +3579,7 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) {
|
|||||||
RequestTags: []string{"tag:unauthorized"}, // This tag is not in policy
|
RequestTags: []string{"tag:unauthorized"}, // This tag is not in policy
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
|
app.state.SetAuthCacheEntry(registrationID, regEntry)
|
||||||
|
|
||||||
// Complete the web auth - should fail because tag is unauthorized
|
// Complete the web auth - should fail because tag is unauthorized
|
||||||
_, _, err := app.state.HandleNodeFromAuthPath(
|
_, _, err := app.state.HandleNodeFromAuthPath(
|
||||||
@@ -3646,8 +3632,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
|
|||||||
nodeKey1 := key.NewNode()
|
nodeKey1 := key.NewNode()
|
||||||
|
|
||||||
// Step 1: Initial registration with tags
|
// Step 1: Initial registration with tags
|
||||||
registrationID1 := types.MustRegistrationID()
|
registrationID1 := types.MustAuthID()
|
||||||
regEntry1 := types.NewRegisterNode(types.Node{
|
regEntry1 := types.NewRegisterAuthRequest(types.Node{
|
||||||
MachineKey: machineKey.Public(),
|
MachineKey: machineKey.Public(),
|
||||||
NodeKey: nodeKey1.Public(),
|
NodeKey: nodeKey1.Public(),
|
||||||
Hostname: "reauth-untag-node",
|
Hostname: "reauth-untag-node",
|
||||||
@@ -3656,7 +3642,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
|
|||||||
RequestTags: []string{"tag:valid-owned", "tag:second"},
|
RequestTags: []string{"tag:valid-owned", "tag:second"},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1)
|
app.state.SetAuthCacheEntry(registrationID1, regEntry1)
|
||||||
|
|
||||||
// Complete initial registration with tags
|
// Complete initial registration with tags
|
||||||
node, _, err := app.state.HandleNodeFromAuthPath(
|
node, _, err := app.state.HandleNodeFromAuthPath(
|
||||||
@@ -3673,8 +3659,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
|
|||||||
|
|
||||||
// Step 2: Reauth with EMPTY tags to untag
|
// Step 2: Reauth with EMPTY tags to untag
|
||||||
nodeKey2 := key.NewNode() // New node key for reauth
|
nodeKey2 := key.NewNode() // New node key for reauth
|
||||||
registrationID2 := types.MustRegistrationID()
|
registrationID2 := types.MustAuthID()
|
||||||
regEntry2 := types.NewRegisterNode(types.Node{
|
regEntry2 := types.NewRegisterAuthRequest(types.Node{
|
||||||
MachineKey: machineKey.Public(), // Same machine key
|
MachineKey: machineKey.Public(), // Same machine key
|
||||||
NodeKey: nodeKey2.Public(), // Different node key (rotation)
|
NodeKey: nodeKey2.Public(), // Different node key (rotation)
|
||||||
Hostname: "reauth-untag-node",
|
Hostname: "reauth-untag-node",
|
||||||
@@ -3683,7 +3669,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
|
|||||||
RequestTags: []string{}, // EMPTY - should untag
|
RequestTags: []string{}, // EMPTY - should untag
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2)
|
app.state.SetAuthCacheEntry(registrationID2, regEntry2)
|
||||||
|
|
||||||
// Complete reauth with empty tags
|
// Complete reauth with empty tags
|
||||||
nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath(
|
nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath(
|
||||||
@@ -3759,8 +3745,8 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) {
|
|||||||
|
|
||||||
// Step 2: Reauth via web auth with EMPTY tags to transition to user-owned
|
// Step 2: Reauth via web auth with EMPTY tags to transition to user-owned
|
||||||
nodeKey2 := key.NewNode() // New node key for reauth
|
nodeKey2 := key.NewNode() // New node key for reauth
|
||||||
registrationID := types.MustRegistrationID()
|
registrationID := types.MustAuthID()
|
||||||
regEntry := types.NewRegisterNode(types.Node{
|
regEntry := types.NewRegisterAuthRequest(types.Node{
|
||||||
MachineKey: machineKey.Public(), // Same machine key
|
MachineKey: machineKey.Public(), // Same machine key
|
||||||
NodeKey: nodeKey2.Public(), // Different node key (rotation)
|
NodeKey: nodeKey2.Public(), // Different node key (rotation)
|
||||||
Hostname: "authkey-tagged-node",
|
Hostname: "authkey-tagged-node",
|
||||||
@@ -3769,7 +3755,7 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) {
|
|||||||
RequestTags: []string{}, // EMPTY - should untag
|
RequestTags: []string{}, // EMPTY - should untag
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
|
app.state.SetAuthCacheEntry(registrationID, regEntry)
|
||||||
|
|
||||||
// Complete reauth with empty tags
|
// Complete reauth with empty tags
|
||||||
nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath(
|
nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath(
|
||||||
@@ -3958,8 +3944,8 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) {
|
|||||||
// Step 4: Re-register the node to alice via HandleNodeFromAuthPath
|
// Step 4: Re-register the node to alice via HandleNodeFromAuthPath
|
||||||
// This is what happens when running: headscale nodes register --user alice --key ...
|
// This is what happens when running: headscale nodes register --user alice --key ...
|
||||||
nodeKey2 := key.NewNode()
|
nodeKey2 := key.NewNode()
|
||||||
registrationID := types.MustRegistrationID()
|
registrationID := types.MustAuthID()
|
||||||
regEntry := types.NewRegisterNode(types.Node{
|
regEntry := types.NewRegisterAuthRequest(types.Node{
|
||||||
MachineKey: machineKey.Public(), // Same machine key as the tagged node
|
MachineKey: machineKey.Public(), // Same machine key as the tagged node
|
||||||
NodeKey: nodeKey2.Public(),
|
NodeKey: nodeKey2.Public(),
|
||||||
Hostname: "tagged-orphan-node",
|
Hostname: "tagged-orphan-node",
|
||||||
@@ -3968,7 +3954,7 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) {
|
|||||||
RequestTags: []string{}, // Empty - transition to user-owned
|
RequestTags: []string{}, // Empty - transition to user-owned
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
|
app.state.SetAuthCacheEntry(registrationID, regEntry)
|
||||||
|
|
||||||
// This should NOT panic - before the fix, this would panic with:
|
// This should NOT panic - before the fix, this would panic with:
|
||||||
// panic: runtime error: invalid memory address or nil pointer dereference
|
// panic: runtime error: invalid memory address or nil pointer dereference
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ const (
|
|||||||
type HSDatabase struct {
|
type HSDatabase struct {
|
||||||
DB *gorm.DB
|
DB *gorm.DB
|
||||||
cfg *types.Config
|
cfg *types.Config
|
||||||
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
|
regCache *zcache.Cache[types.AuthID, types.AuthRequest]
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHeadscaleDatabase creates a new database connection and runs migrations.
|
// NewHeadscaleDatabase creates a new database connection and runs migrations.
|
||||||
@@ -56,7 +56,7 @@ type HSDatabase struct {
|
|||||||
//nolint:gocyclo // complex database initialization with many migrations
|
//nolint:gocyclo // complex database initialization with many migrations
|
||||||
func NewHeadscaleDatabase(
|
func NewHeadscaleDatabase(
|
||||||
cfg *types.Config,
|
cfg *types.Config,
|
||||||
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode],
|
regCache *zcache.Cache[types.AuthID, types.AuthRequest],
|
||||||
) (*HSDatabase, error) {
|
) (*HSDatabase, error) {
|
||||||
dbConn, err := openDB(cfg.Database)
|
dbConn, err := openDB(cfg.Database)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -162,8 +162,8 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] {
|
func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] {
|
||||||
return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour)
|
return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
||||||
|
|||||||
@@ -247,7 +247,7 @@ func (api headscaleV1APIServer) RegisterNode(
|
|||||||
Str(zf.RegistrationKey, registrationKey).
|
Str(zf.RegistrationKey, registrationKey).
|
||||||
Msg("registering node")
|
Msg("registering node")
|
||||||
|
|
||||||
registrationId, err := types.RegistrationIDFromString(request.GetKey())
|
registrationId, err := types.AuthIDFromString(request.GetKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -808,33 +808,32 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
|||||||
Hostname: request.GetName(),
|
Hostname: request.GetName(),
|
||||||
}
|
}
|
||||||
|
|
||||||
registrationId, err := types.RegistrationIDFromString(request.GetKey())
|
registrationId, err := types.AuthIDFromString(request.GetKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
newNode := types.NewRegisterNode(
|
newNode := types.Node{
|
||||||
types.Node{
|
NodeKey: key.NewNode().Public(),
|
||||||
NodeKey: key.NewNode().Public(),
|
MachineKey: key.NewMachine().Public(),
|
||||||
MachineKey: key.NewMachine().Public(),
|
Hostname: request.GetName(),
|
||||||
Hostname: request.GetName(),
|
User: user,
|
||||||
User: user,
|
|
||||||
|
|
||||||
Expiry: &time.Time{},
|
Expiry: &time.Time{},
|
||||||
LastSeen: &time.Time{},
|
LastSeen: &time.Time{},
|
||||||
|
|
||||||
Hostinfo: &hostinfo,
|
Hostinfo: &hostinfo,
|
||||||
},
|
}
|
||||||
)
|
|
||||||
|
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Caller().
|
Caller().
|
||||||
Str("registration_id", registrationId.String()).
|
Str("registration_id", registrationId.String()).
|
||||||
Msg("adding debug machine via CLI, appending to registration cache")
|
Msg("adding debug machine via CLI, appending to registration cache")
|
||||||
|
|
||||||
api.h.state.SetRegistrationCacheEntry(registrationId, newNode)
|
authRegReq := types.NewRegisterAuthRequest(newNode)
|
||||||
|
api.h.state.SetAuthCacheEntry(registrationId, authRegReq)
|
||||||
|
|
||||||
return &v1.DebugCreateNodeResponse{Node: newNode.Node.Proto()}, nil
|
return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api headscaleV1APIServer) Health(
|
func (api headscaleV1APIServer) Health(
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/assets"
|
"github.com/juanfont/headscale/hscontrol/assets"
|
||||||
"github.com/juanfont/headscale/hscontrol/templates"
|
"github.com/juanfont/headscale/hscontrol/templates"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
@@ -245,11 +244,41 @@ func NewAuthProviderWeb(serverURL string) *AuthProviderWeb {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string {
|
func (a *AuthProviderWeb) RegisterURL(authID types.AuthID) string {
|
||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
"%s/register/%s",
|
"%s/register/%s",
|
||||||
strings.TrimSuffix(a.serverURL, "/"),
|
strings.TrimSuffix(a.serverURL, "/"),
|
||||||
registrationId.String())
|
authID.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthProviderWeb) AuthURL(authID types.AuthID) string {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"%s/auth/%s",
|
||||||
|
strings.TrimSuffix(a.serverURL, "/"),
|
||||||
|
authID.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthProviderWeb) AuthHandler(
|
||||||
|
writer http.ResponseWriter,
|
||||||
|
req *http.Request,
|
||||||
|
) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func authIDFromRequest(req *http.Request) (types.AuthID, error) {
|
||||||
|
registrationId, err := urlParam[types.AuthID](req, "auth_id")
|
||||||
|
if err != nil {
|
||||||
|
return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// We need to make sure we dont open for XSS style injections, if the parameter that
|
||||||
|
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
|
||||||
|
// the template and log an error.
|
||||||
|
err = registrationId.Validate()
|
||||||
|
if err != nil {
|
||||||
|
return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return registrationId, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterHandler shows a simple message in the browser to point to the CLI
|
// RegisterHandler shows a simple message in the browser to point to the CLI
|
||||||
@@ -261,15 +290,9 @@ func (a *AuthProviderWeb) RegisterHandler(
|
|||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
) {
|
) {
|
||||||
vars := mux.Vars(req)
|
registrationId, err := authIDFromRequest(req)
|
||||||
registrationIdStr := vars["registration_id"]
|
|
||||||
|
|
||||||
// We need to make sure we dont open for XSS style injections, if the parameter that
|
|
||||||
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
|
|
||||||
// the template and log an error.
|
|
||||||
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
|
httpError(writer, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -95,8 +95,8 @@ var allBatcherFunctions = []batcherTestCase{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// emptyCache creates an empty registration cache for testing.
|
// emptyCache creates an empty registration cache for testing.
|
||||||
func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] {
|
func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] {
|
||||||
return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour)
|
return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test configuration constants.
|
// Test configuration constants.
|
||||||
|
|||||||
@@ -24,6 +24,12 @@ import (
|
|||||||
// ErrUnsupportedClientVersion is returned when a client connects with an unsupported protocol version.
|
// ErrUnsupportedClientVersion is returned when a client connects with an unsupported protocol version.
|
||||||
var ErrUnsupportedClientVersion = errors.New("unsupported client version")
|
var ErrUnsupportedClientVersion = errors.New("unsupported client version")
|
||||||
|
|
||||||
|
// ErrMissingURLParameter is returned when a required URL parameter is not provided.
|
||||||
|
var ErrMissingURLParameter = errors.New("missing URL parameter")
|
||||||
|
|
||||||
|
// ErrUnsupportedURLParameterType is returned when a URL parameter has an unsupported type.
|
||||||
|
var ErrUnsupportedURLParameterType = errors.New("unsupported URL parameter type")
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade.
|
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade.
|
||||||
ts2021UpgradePath = "/ts2021"
|
ts2021UpgradePath = "/ts2021"
|
||||||
@@ -374,3 +380,28 @@ func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.
|
|||||||
|
|
||||||
return nv, nil
|
return nv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// urlParam extracts a typed URL parameter from a chi router request.
|
||||||
|
func urlParam[T any](req *http.Request, key string) (T, error) {
|
||||||
|
var zero T
|
||||||
|
|
||||||
|
param := chi.URLParam(req, key)
|
||||||
|
if param == "" {
|
||||||
|
return zero, fmt.Errorf("%w: %s", ErrMissingURLParameter, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
var value T
|
||||||
|
switch any(value).(type) {
|
||||||
|
case string:
|
||||||
|
v, ok := any(param).(T)
|
||||||
|
if !ok {
|
||||||
|
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
value = v
|
||||||
|
default:
|
||||||
|
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/gorilla/mux"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/db"
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
"github.com/juanfont/headscale/hscontrol/templates"
|
"github.com/juanfont/headscale/hscontrol/templates"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
@@ -26,8 +25,8 @@ import (
|
|||||||
const (
|
const (
|
||||||
randomByteSize = 16
|
randomByteSize = 16
|
||||||
defaultOAuthOptionsCount = 3
|
defaultOAuthOptionsCount = 3
|
||||||
registerCacheExpiration = time.Minute * 15
|
authCacheExpiration = time.Minute * 15
|
||||||
registerCacheCleanup = time.Minute * 20
|
authCacheCleanup = time.Minute * 20
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -44,17 +43,21 @@ var (
|
|||||||
errOIDCUnverifiedEmail = errors.New("authenticated principal has an unverified email")
|
errOIDCUnverifiedEmail = errors.New("authenticated principal has an unverified email")
|
||||||
)
|
)
|
||||||
|
|
||||||
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
|
// AuthInfo contains both auth ID and verifier information for OIDC validation.
|
||||||
type RegistrationInfo struct {
|
type AuthInfo struct {
|
||||||
RegistrationID types.RegistrationID
|
AuthID types.AuthID
|
||||||
Verifier *string
|
Verifier *string
|
||||||
|
Registration bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthProviderOIDC struct {
|
type AuthProviderOIDC struct {
|
||||||
h *Headscale
|
h *Headscale
|
||||||
serverURL string
|
serverURL string
|
||||||
cfg *types.OIDCConfig
|
cfg *types.OIDCConfig
|
||||||
registrationCache *zcache.Cache[string, RegistrationInfo]
|
|
||||||
|
// authCache holds auth information between
|
||||||
|
// the auth and the callback steps.
|
||||||
|
authCache *zcache.Cache[string, AuthInfo]
|
||||||
|
|
||||||
oidcProvider *oidc.Provider
|
oidcProvider *oidc.Provider
|
||||||
oauth2Config *oauth2.Config
|
oauth2Config *oauth2.Config
|
||||||
@@ -81,45 +84,63 @@ func NewAuthProviderOIDC(
|
|||||||
Scopes: cfg.Scope,
|
Scopes: cfg.Scope,
|
||||||
}
|
}
|
||||||
|
|
||||||
registrationCache := zcache.New[string, RegistrationInfo](
|
authCache := zcache.New[string, AuthInfo](
|
||||||
registerCacheExpiration,
|
authCacheExpiration,
|
||||||
registerCacheCleanup,
|
authCacheCleanup,
|
||||||
)
|
)
|
||||||
|
|
||||||
return &AuthProviderOIDC{
|
return &AuthProviderOIDC{
|
||||||
h: h,
|
h: h,
|
||||||
serverURL: serverURL,
|
serverURL: serverURL,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
registrationCache: registrationCache,
|
authCache: authCache,
|
||||||
|
|
||||||
oidcProvider: oidcProvider,
|
oidcProvider: oidcProvider,
|
||||||
oauth2Config: oauth2Config,
|
oauth2Config: oauth2Config,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string {
|
func (a *AuthProviderOIDC) AuthURL(authID types.AuthID) string {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"%s/auth/%s",
|
||||||
|
strings.TrimSuffix(a.serverURL, "/"),
|
||||||
|
authID.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthProviderOIDC) AuthHandler(
|
||||||
|
writer http.ResponseWriter,
|
||||||
|
req *http.Request,
|
||||||
|
) {
|
||||||
|
a.authHandler(writer, req, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthProviderOIDC) RegisterURL(authID types.AuthID) string {
|
||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
"%s/register/%s",
|
"%s/register/%s",
|
||||||
strings.TrimSuffix(a.serverURL, "/"),
|
strings.TrimSuffix(a.serverURL, "/"),
|
||||||
registrationID.String())
|
authID.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterHandler registers the OIDC callback handler with the given router.
|
// RegisterHandler registers the OIDC callback handler with the given router.
|
||||||
// It puts NodeKey in cache so the callback can retrieve it using the oidc state param.
|
// It puts NodeKey in cache so the callback can retrieve it using the oidc state param.
|
||||||
// Listens in /register/:registration_id.
|
// Listens in /register/:auth_id.
|
||||||
func (a *AuthProviderOIDC) RegisterHandler(
|
func (a *AuthProviderOIDC) RegisterHandler(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
) {
|
) {
|
||||||
vars := mux.Vars(req)
|
a.authHandler(writer, req, true)
|
||||||
registrationIdStr := vars["registration_id"]
|
}
|
||||||
|
|
||||||
// We need to make sure we dont open for XSS style injections, if the parameter that
|
// authHandler takes an incoming request that needs to be authenticated and
|
||||||
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
|
// validates and prepares it for the OIDC flow.
|
||||||
// the template and log an error.
|
func (a *AuthProviderOIDC) authHandler(
|
||||||
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
|
writer http.ResponseWriter,
|
||||||
|
req *http.Request,
|
||||||
|
registration bool,
|
||||||
|
) {
|
||||||
|
authID, err := authIDFromRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
|
httpError(writer, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,9 +158,9 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize registration info with machine key
|
registrationInfo := AuthInfo{
|
||||||
registrationInfo := RegistrationInfo{
|
AuthID: authID,
|
||||||
RegistrationID: registrationId,
|
Registration: registration,
|
||||||
}
|
}
|
||||||
|
|
||||||
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount)
|
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount)
|
||||||
@@ -167,7 +188,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
|||||||
extras = append(extras, oidc.Nonce(nonce))
|
extras = append(extras, oidc.Nonce(nonce))
|
||||||
|
|
||||||
// Cache the registration info
|
// Cache the registration info
|
||||||
a.registrationCache.Set(state, registrationInfo)
|
a.authCache.Set(state, registrationInfo)
|
||||||
|
|
||||||
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
|
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
|
||||||
log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL)
|
log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL)
|
||||||
@@ -302,16 +323,22 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
// If the node exists, then the node should be reauthenticated,
|
// If the node exists, then the node should be reauthenticated,
|
||||||
// if the node does not exist, and the machine key exists, then
|
// if the node does not exist, and the machine key exists, then
|
||||||
// this is a new node that should be registered.
|
// this is a new node that should be registered.
|
||||||
registrationId := a.getRegistrationIDFromState(state)
|
authInfo := a.getAuthInfoFromState(state)
|
||||||
|
if authInfo == nil {
|
||||||
|
log.Debug().Caller().Str("state", state).Msg("state not found in cache, login session may have expired")
|
||||||
|
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
|
||||||
|
|
||||||
// Register the node if it does not exist.
|
return
|
||||||
if registrationId != nil {
|
}
|
||||||
|
|
||||||
|
// If this is a registration flow, then we need to register the node.
|
||||||
|
if authInfo.Registration {
|
||||||
verb := "Reauthenticated"
|
verb := "Reauthenticated"
|
||||||
|
|
||||||
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
|
newNode, err := a.handleRegistration(user, authInfo.AuthID, nodeExpiry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
|
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
|
||||||
log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed")
|
log.Debug().Caller().Str("registration_id", authInfo.AuthID.String()).Msg("registration session expired before authorization completed")
|
||||||
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err))
|
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err))
|
||||||
|
|
||||||
return
|
return
|
||||||
@@ -339,9 +366,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Neither node nor machine key was found in the state cache meaning
|
// TODO(kradalby): handle login flow (without registration) if needed.
|
||||||
// that we could not reauth nor register the node.
|
// We need to send an update here to whatever might be waiting for this auth flow.
|
||||||
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time {
|
func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time {
|
||||||
@@ -374,7 +400,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
|
|||||||
var exchangeOpts []oauth2.AuthCodeOption
|
var exchangeOpts []oauth2.AuthCodeOption
|
||||||
|
|
||||||
if a.cfg.PKCE.Enabled {
|
if a.cfg.PKCE.Enabled {
|
||||||
regInfo, ok := a.registrationCache.Get(state)
|
regInfo, ok := a.authCache.Get(state)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
|
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
|
||||||
}
|
}
|
||||||
@@ -507,14 +533,14 @@ func doOIDCAuthorization(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRegistrationIDFromState retrieves the registration ID from the state.
|
// getAuthInfoFromState retrieves the registration ID from the state.
|
||||||
func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID {
|
func (a *AuthProviderOIDC) getAuthInfoFromState(state string) *AuthInfo {
|
||||||
regInfo, ok := a.registrationCache.Get(state)
|
authInfo, ok := a.authCache.Get(state)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return ®Info.RegistrationID
|
return &authInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||||
@@ -562,7 +588,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
|||||||
|
|
||||||
func (a *AuthProviderOIDC) handleRegistration(
|
func (a *AuthProviderOIDC) handleRegistration(
|
||||||
user *types.User,
|
user *types.User,
|
||||||
registrationID types.RegistrationID,
|
registrationID types.AuthID,
|
||||||
expiry time.Time,
|
expiry time.Time,
|
||||||
) (bool, error) {
|
) (bool, error) {
|
||||||
node, nodeChange, err := a.h.state.HandleNodeFromAuthPath(
|
node, nodeChange, err := a.h.state.HandleNodeFromAuthPath(
|
||||||
|
|||||||
@@ -64,6 +64,9 @@ var ErrNodeNotInNodeStore = errors.New("node no longer exists in NodeStore")
|
|||||||
// ErrNodeNameNotUnique is returned when a node name is not unique.
|
// ErrNodeNameNotUnique is returned when a node name is not unique.
|
||||||
var ErrNodeNameNotUnique = errors.New("node name is not unique")
|
var ErrNodeNameNotUnique = errors.New("node name is not unique")
|
||||||
|
|
||||||
|
// ErrRegistrationExpired is returned when a registration has expired.
|
||||||
|
var ErrRegistrationExpired = errors.New("registration expired")
|
||||||
|
|
||||||
// State manages Headscale's core state, coordinating between database, policy management,
|
// State manages Headscale's core state, coordinating between database, policy management,
|
||||||
// IP allocation, and DERP routing. All methods are thread-safe.
|
// IP allocation, and DERP routing. All methods are thread-safe.
|
||||||
type State struct {
|
type State struct {
|
||||||
@@ -82,8 +85,10 @@ type State struct {
|
|||||||
derpMap atomic.Pointer[tailcfg.DERPMap]
|
derpMap atomic.Pointer[tailcfg.DERPMap]
|
||||||
// polMan handles policy evaluation and management
|
// polMan handles policy evaluation and management
|
||||||
polMan policy.PolicyManager
|
polMan policy.PolicyManager
|
||||||
// registrationCache caches node registration data to reduce database load
|
|
||||||
registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
|
// authCache caches any pending authentication requests, from either auth type (Web and OIDC).
|
||||||
|
authCache *zcache.Cache[types.AuthID, types.AuthRequest]
|
||||||
|
|
||||||
// primaryRoutes tracks primary route assignments for nodes
|
// primaryRoutes tracks primary route assignments for nodes
|
||||||
primaryRoutes *routes.PrimaryRoutes
|
primaryRoutes *routes.PrimaryRoutes
|
||||||
}
|
}
|
||||||
@@ -101,20 +106,20 @@ func NewState(cfg *types.Config) (*State, error) {
|
|||||||
cacheCleanup = cfg.Tuning.RegisterCacheCleanup
|
cacheCleanup = cfg.Tuning.RegisterCacheCleanup
|
||||||
}
|
}
|
||||||
|
|
||||||
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
|
authCache := zcache.New[types.AuthID, types.AuthRequest](
|
||||||
cacheExpiration,
|
cacheExpiration,
|
||||||
cacheCleanup,
|
cacheCleanup,
|
||||||
)
|
)
|
||||||
|
|
||||||
registrationCache.OnEvicted(
|
authCache.OnEvicted(
|
||||||
func(id types.RegistrationID, rn types.RegisterNode) {
|
func(id types.AuthID, rn types.AuthRequest) {
|
||||||
rn.SendAndClose(nil)
|
rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired})
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
db, err := hsdb.NewHeadscaleDatabase(
|
db, err := hsdb.NewHeadscaleDatabase(
|
||||||
cfg,
|
cfg,
|
||||||
registrationCache,
|
authCache,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("initializing database: %w", err)
|
return nil, fmt.Errorf("initializing database: %w", err)
|
||||||
@@ -178,12 +183,12 @@ func NewState(cfg *types.Config) (*State, error) {
|
|||||||
return &State{
|
return &State{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
|
||||||
db: db,
|
db: db,
|
||||||
ipAlloc: ipAlloc,
|
ipAlloc: ipAlloc,
|
||||||
polMan: polMan,
|
polMan: polMan,
|
||||||
registrationCache: registrationCache,
|
authCache: authCache,
|
||||||
primaryRoutes: routes.New(),
|
primaryRoutes: routes.New(),
|
||||||
nodeStore: nodeStore,
|
nodeStore: nodeStore,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1057,9 +1062,9 @@ func (s *State) DeletePreAuthKey(id uint64) error {
|
|||||||
return s.db.DeletePreAuthKey(id)
|
return s.db.DeletePreAuthKey(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRegistrationCacheEntry retrieves a node registration from cache.
|
// GetAuthCacheEntry retrieves a node registration from cache.
|
||||||
func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.RegisterNode, bool) {
|
func (s *State) GetAuthCacheEntry(id types.AuthID) (*types.AuthRequest, bool) {
|
||||||
entry, found := s.registrationCache.Get(id)
|
entry, found := s.authCache.Get(id)
|
||||||
if !found {
|
if !found {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -1067,26 +1072,24 @@ func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.Regis
|
|||||||
return &entry, true
|
return &entry, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRegistrationCacheEntry stores a node registration in cache.
|
// SetAuthCacheEntry stores a node registration in cache.
|
||||||
func (s *State) SetRegistrationCacheEntry(id types.RegistrationID, entry types.RegisterNode) {
|
func (s *State) SetAuthCacheEntry(id types.AuthID, entry types.AuthRequest) {
|
||||||
s.registrationCache.Set(id, entry)
|
s.authCache.Set(id, entry)
|
||||||
}
|
}
|
||||||
|
|
||||||
// logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname.
|
// logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname.
|
||||||
func logHostinfoValidation(machineKey, nodeKey, username, hostname string, hostinfo *tailcfg.Hostinfo) {
|
func logHostinfoValidation(nv types.NodeView, username, hostname string) {
|
||||||
if hostinfo == nil {
|
if !nv.Hostinfo().Valid() {
|
||||||
log.Warn().
|
log.Warn().
|
||||||
Caller().
|
Caller().
|
||||||
Str(zf.MachineKey, machineKey).
|
EmbedObject(nv).
|
||||||
Str(zf.NodeKey, nodeKey).
|
|
||||||
Str(zf.UserName, username).
|
Str(zf.UserName, username).
|
||||||
Str(zf.GeneratedHostname, hostname).
|
Str(zf.GeneratedHostname, hostname).
|
||||||
Msg("Registration had nil hostinfo, generated default hostname")
|
Msg("Registration had nil hostinfo, generated default hostname")
|
||||||
} else if hostinfo.Hostname == "" {
|
} else if nv.Hostinfo().Hostname() == "" {
|
||||||
log.Warn().
|
log.Warn().
|
||||||
Caller().
|
Caller().
|
||||||
Str(zf.MachineKey, machineKey).
|
EmbedObject(nv).
|
||||||
Str(zf.NodeKey, nodeKey).
|
|
||||||
Str(zf.UserName, username).
|
Str(zf.UserName, username).
|
||||||
Str(zf.GeneratedHostname, hostname).
|
Str(zf.GeneratedHostname, hostname).
|
||||||
Msg("Registration had empty hostname, generated default")
|
Msg("Registration had empty hostname, generated default")
|
||||||
@@ -1128,7 +1131,7 @@ type authNodeUpdateParams struct {
|
|||||||
// Node to update; must be valid and in NodeStore.
|
// Node to update; must be valid and in NodeStore.
|
||||||
ExistingNode types.NodeView
|
ExistingNode types.NodeView
|
||||||
// Client data: keys, hostinfo, endpoints.
|
// Client data: keys, hostinfo, endpoints.
|
||||||
RegEntry *types.RegisterNode
|
RegEntry *types.AuthRequest
|
||||||
// Pre-validated hostinfo; NetInfo preserved from ExistingNode.
|
// Pre-validated hostinfo; NetInfo preserved from ExistingNode.
|
||||||
ValidHostinfo *tailcfg.Hostinfo
|
ValidHostinfo *tailcfg.Hostinfo
|
||||||
// Hostname from hostinfo, or generated from keys if client omits it.
|
// Hostname from hostinfo, or generated from keys if client omits it.
|
||||||
@@ -1147,6 +1150,7 @@ type authNodeUpdateParams struct {
|
|||||||
// an existing node. It updates the node in NodeStore, processes RequestTags, and
|
// an existing node. It updates the node in NodeStore, processes RequestTags, and
|
||||||
// persists changes to the database.
|
// persists changes to the database.
|
||||||
func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) {
|
func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) {
|
||||||
|
regNv := params.RegEntry.Node()
|
||||||
// Log the operation type
|
// Log the operation type
|
||||||
if params.IsConvertFromTag {
|
if params.IsConvertFromTag {
|
||||||
log.Info().
|
log.Info().
|
||||||
@@ -1155,16 +1159,16 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
|||||||
Msg("Converting tagged node to user-owned node")
|
Msg("Converting tagged node to user-owned node")
|
||||||
} else {
|
} else {
|
||||||
log.Info().
|
log.Info().
|
||||||
EmbedObject(params.ExistingNode).
|
Object("existing", params.ExistingNode).
|
||||||
Interface("hostinfo", params.RegEntry.Node.Hostinfo).
|
Object("incoming", regNv).
|
||||||
Msg("Updating existing node registration via reauth")
|
Msg("Updating existing node registration via reauth")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process RequestTags during reauth (#2979)
|
// Process RequestTags during reauth (#2979)
|
||||||
// Due to json:",omitempty", we treat empty/nil as "clear tags"
|
// Due to json:",omitempty", we treat empty/nil as "clear tags"
|
||||||
var requestTags []string
|
var requestTags []string
|
||||||
if params.RegEntry.Node.Hostinfo != nil {
|
if regNv.Hostinfo().Valid() {
|
||||||
requestTags = params.RegEntry.Node.Hostinfo.RequestTags
|
requestTags = regNv.Hostinfo().RequestTags().AsSlice()
|
||||||
}
|
}
|
||||||
|
|
||||||
oldTags := params.ExistingNode.Tags().AsSlice()
|
oldTags := params.ExistingNode.Tags().AsSlice()
|
||||||
@@ -1182,8 +1186,8 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
|||||||
|
|
||||||
// Update existing node in NodeStore - validation passed, safe to mutate
|
// Update existing node in NodeStore - validation passed, safe to mutate
|
||||||
updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) {
|
updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) {
|
||||||
node.NodeKey = params.RegEntry.Node.NodeKey
|
node.NodeKey = regNv.NodeKey()
|
||||||
node.DiscoKey = params.RegEntry.Node.DiscoKey
|
node.DiscoKey = regNv.DiscoKey()
|
||||||
node.Hostname = params.Hostname
|
node.Hostname = params.Hostname
|
||||||
|
|
||||||
// Preserve NetInfo from existing node when re-registering
|
// Preserve NetInfo from existing node when re-registering
|
||||||
@@ -1194,7 +1198,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
|||||||
params.ValidHostinfo,
|
params.ValidHostinfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
node.Endpoints = params.RegEntry.Node.Endpoints
|
node.Endpoints = regNv.Endpoints().AsSlice()
|
||||||
node.IsOnline = new(false)
|
node.IsOnline = new(false)
|
||||||
node.LastSeen = new(time.Now())
|
node.LastSeen = new(time.Now())
|
||||||
|
|
||||||
@@ -1203,7 +1207,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
|||||||
if params.IsConvertFromTag {
|
if params.IsConvertFromTag {
|
||||||
node.RegisterMethod = params.RegisterMethod
|
node.RegisterMethod = params.RegisterMethod
|
||||||
} else {
|
} else {
|
||||||
node.RegisterMethod = params.RegEntry.Node.RegisterMethod
|
node.RegisterMethod = regNv.RegisterMethod()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track tagged status BEFORE processing tags
|
// Track tagged status BEFORE processing tags
|
||||||
@@ -1223,7 +1227,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
|||||||
if params.Expiry != nil {
|
if params.Expiry != nil {
|
||||||
node.Expiry = params.Expiry
|
node.Expiry = params.Expiry
|
||||||
} else {
|
} else {
|
||||||
node.Expiry = params.RegEntry.Node.Expiry
|
node.Expiry = regNv.Expiry().Clone()
|
||||||
}
|
}
|
||||||
case !wasTagged && isTagged:
|
case !wasTagged && isTagged:
|
||||||
// Personal → Tagged: clear expiry (tagged nodes don't expire)
|
// Personal → Tagged: clear expiry (tagged nodes don't expire)
|
||||||
@@ -1233,14 +1237,14 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
|||||||
if params.Expiry != nil {
|
if params.Expiry != nil {
|
||||||
node.Expiry = params.Expiry
|
node.Expiry = params.Expiry
|
||||||
} else {
|
} else {
|
||||||
node.Expiry = params.RegEntry.Node.Expiry
|
node.Expiry = regNv.Expiry().Clone()
|
||||||
}
|
}
|
||||||
case !isTagged:
|
case !isTagged:
|
||||||
// Personal → Personal: update expiry from client
|
// Personal → Personal: update expiry from client
|
||||||
if params.Expiry != nil {
|
if params.Expiry != nil {
|
||||||
node.Expiry = params.Expiry
|
node.Expiry = params.Expiry
|
||||||
} else {
|
} else {
|
||||||
node.Expiry = params.RegEntry.Node.Expiry
|
node.Expiry = regNv.Expiry().Clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Tagged → Tagged: keep existing expiry (nil) - no action needed
|
// Tagged → Tagged: keep existing expiry (nil) - no action needed
|
||||||
@@ -1527,13 +1531,13 @@ func (s *State) processReauthTags(
|
|||||||
|
|
||||||
// HandleNodeFromAuthPath handles node registration through authentication flow (like OIDC).
|
// HandleNodeFromAuthPath handles node registration through authentication flow (like OIDC).
|
||||||
func (s *State) HandleNodeFromAuthPath(
|
func (s *State) HandleNodeFromAuthPath(
|
||||||
registrationID types.RegistrationID,
|
authID types.AuthID,
|
||||||
userID types.UserID,
|
userID types.UserID,
|
||||||
expiry *time.Time,
|
expiry *time.Time,
|
||||||
registrationMethod string,
|
registrationMethod string,
|
||||||
) (types.NodeView, change.Change, error) {
|
) (types.NodeView, change.Change, error) {
|
||||||
// Get the registration entry from cache
|
// Get the registration entry from cache
|
||||||
regEntry, ok := s.GetRegistrationCacheEntry(registrationID)
|
regEntry, ok := s.GetAuthCacheEntry(authID)
|
||||||
if !ok {
|
if !ok {
|
||||||
return types.NodeView{}, change.Change{}, hsdb.ErrNodeNotFoundRegistrationCache
|
return types.NodeView{}, change.Change{}, hsdb.ErrNodeNotFoundRegistrationCache
|
||||||
}
|
}
|
||||||
@@ -1546,25 +1550,27 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
|
|
||||||
// Ensure we have a valid hostname from the registration cache entry
|
// Ensure we have a valid hostname from the registration cache entry
|
||||||
hostname := util.EnsureHostname(
|
hostname := util.EnsureHostname(
|
||||||
regEntry.Node.Hostinfo,
|
regEntry.Node().Hostinfo(),
|
||||||
regEntry.Node.MachineKey.String(),
|
regEntry.Node().MachineKey().String(),
|
||||||
regEntry.Node.NodeKey.String(),
|
regEntry.Node().NodeKey().String(),
|
||||||
)
|
)
|
||||||
|
|
||||||
// Ensure we have valid hostinfo
|
// Ensure we have valid hostinfo
|
||||||
validHostinfo := cmp.Or(regEntry.Node.Hostinfo, &tailcfg.Hostinfo{})
|
hostinfo := &tailcfg.Hostinfo{}
|
||||||
validHostinfo.Hostname = hostname
|
if regEntry.Node().Hostinfo().Valid() {
|
||||||
|
hostinfo = regEntry.Node().Hostinfo().AsStruct()
|
||||||
|
}
|
||||||
|
|
||||||
|
hostinfo.Hostname = hostname
|
||||||
|
|
||||||
logHostinfoValidation(
|
logHostinfoValidation(
|
||||||
regEntry.Node.MachineKey.ShortString(),
|
regEntry.Node(),
|
||||||
regEntry.Node.NodeKey.String(),
|
|
||||||
user.Name,
|
user.Name,
|
||||||
hostname,
|
hostname,
|
||||||
regEntry.Node.Hostinfo,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Lookup existing nodes
|
// Lookup existing nodes
|
||||||
machineKey := regEntry.Node.MachineKey
|
machineKey := regEntry.Node().MachineKey()
|
||||||
existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID))
|
existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID))
|
||||||
existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey)
|
existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey)
|
||||||
|
|
||||||
@@ -1578,7 +1584,7 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
|
|
||||||
// Create logger with common fields for all auth operations
|
// Create logger with common fields for all auth operations
|
||||||
logger := log.With().
|
logger := log.With().
|
||||||
Str(zf.RegistrationID, registrationID.String()).
|
Str(zf.RegistrationID, authID.String()).
|
||||||
Str(zf.UserName, user.Name).
|
Str(zf.UserName, user.Name).
|
||||||
Str(zf.MachineKey, machineKey.ShortString()).
|
Str(zf.MachineKey, machineKey.ShortString()).
|
||||||
Str(zf.Method, registrationMethod).
|
Str(zf.Method, registrationMethod).
|
||||||
@@ -1587,7 +1593,7 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
// Common params for update operations
|
// Common params for update operations
|
||||||
updateParams := authNodeUpdateParams{
|
updateParams := authNodeUpdateParams{
|
||||||
RegEntry: regEntry,
|
RegEntry: regEntry,
|
||||||
ValidHostinfo: validHostinfo,
|
ValidHostinfo: hostinfo,
|
||||||
Hostname: hostname,
|
Hostname: hostname,
|
||||||
User: user,
|
User: user,
|
||||||
Expiry: expiry,
|
Expiry: expiry,
|
||||||
@@ -1621,7 +1627,7 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
Msg("Creating new node for different user (same machine key exists for another user)")
|
Msg("Creating new node for different user (same machine key exists for another user)")
|
||||||
|
|
||||||
finalNode, err = s.createNewNodeFromAuth(
|
finalNode, err = s.createNewNodeFromAuth(
|
||||||
logger, user, regEntry, hostname, validHostinfo,
|
logger, user, regEntry, hostname, hostinfo,
|
||||||
expiry, registrationMethod, existingNodeAnyUser,
|
expiry, registrationMethod, existingNodeAnyUser,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1629,7 +1635,7 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
finalNode, err = s.createNewNodeFromAuth(
|
finalNode, err = s.createNewNodeFromAuth(
|
||||||
logger, user, regEntry, hostname, validHostinfo,
|
logger, user, regEntry, hostname, hostinfo,
|
||||||
expiry, registrationMethod, types.NodeView{},
|
expiry, registrationMethod, types.NodeView{},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1638,10 +1644,10 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Signal to waiting clients
|
// Signal to waiting clients
|
||||||
regEntry.SendAndClose(finalNode.AsStruct())
|
regEntry.FinishAuth(types.AuthVerdict{Node: finalNode})
|
||||||
|
|
||||||
// Delete from registration cache
|
// Delete from registration cache
|
||||||
s.registrationCache.Delete(registrationID)
|
s.authCache.Delete(authID)
|
||||||
|
|
||||||
// Update policy managers
|
// Update policy managers
|
||||||
usersChange, err := s.updatePolicyManagerUsers()
|
usersChange, err := s.updatePolicyManagerUsers()
|
||||||
@@ -1670,7 +1676,7 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
func (s *State) createNewNodeFromAuth(
|
func (s *State) createNewNodeFromAuth(
|
||||||
logger zerolog.Logger,
|
logger zerolog.Logger,
|
||||||
user *types.User,
|
user *types.User,
|
||||||
regEntry *types.RegisterNode,
|
regEntry *types.AuthRequest,
|
||||||
hostname string,
|
hostname string,
|
||||||
validHostinfo *tailcfg.Hostinfo,
|
validHostinfo *tailcfg.Hostinfo,
|
||||||
expiry *time.Time,
|
expiry *time.Time,
|
||||||
@@ -1683,13 +1689,13 @@ func (s *State) createNewNodeFromAuth(
|
|||||||
|
|
||||||
return s.createAndSaveNewNode(newNodeParams{
|
return s.createAndSaveNewNode(newNodeParams{
|
||||||
User: *user,
|
User: *user,
|
||||||
MachineKey: regEntry.Node.MachineKey,
|
MachineKey: regEntry.Node().MachineKey(),
|
||||||
NodeKey: regEntry.Node.NodeKey,
|
NodeKey: regEntry.Node().NodeKey(),
|
||||||
DiscoKey: regEntry.Node.DiscoKey,
|
DiscoKey: regEntry.Node().DiscoKey(),
|
||||||
Hostname: hostname,
|
Hostname: hostname,
|
||||||
Hostinfo: validHostinfo,
|
Hostinfo: validHostinfo,
|
||||||
Endpoints: regEntry.Node.Endpoints,
|
Endpoints: regEntry.Node().Endpoints().AsSlice(),
|
||||||
Expiry: cmp.Or(expiry, regEntry.Node.Expiry),
|
Expiry: cmp.Or(expiry, regEntry.Node().Expiry().Clone()),
|
||||||
RegisterMethod: registrationMethod,
|
RegisterMethod: registrationMethod,
|
||||||
ExistingNodeForNetinfo: existingNodeForNetinfo,
|
ExistingNodeForNetinfo: existingNodeForNetinfo,
|
||||||
})
|
})
|
||||||
@@ -1784,7 +1790,7 @@ func (s *State) HandleNodeFromPreAuthKey(
|
|||||||
|
|
||||||
// Ensure we have a valid hostname - handle nil/empty cases
|
// Ensure we have a valid hostname - handle nil/empty cases
|
||||||
hostname := util.EnsureHostname(
|
hostname := util.EnsureHostname(
|
||||||
regReq.Hostinfo,
|
regReq.Hostinfo.View(),
|
||||||
machineKey.String(),
|
machineKey.String(),
|
||||||
regReq.NodeKey.String(),
|
regReq.NodeKey.String(),
|
||||||
)
|
)
|
||||||
@@ -1793,14 +1799,6 @@ func (s *State) HandleNodeFromPreAuthKey(
|
|||||||
validHostinfo := cmp.Or(regReq.Hostinfo, &tailcfg.Hostinfo{})
|
validHostinfo := cmp.Or(regReq.Hostinfo, &tailcfg.Hostinfo{})
|
||||||
validHostinfo.Hostname = hostname
|
validHostinfo.Hostname = hostname
|
||||||
|
|
||||||
logHostinfoValidation(
|
|
||||||
machineKey.ShortString(),
|
|
||||||
regReq.NodeKey.ShortString(),
|
|
||||||
pakUsername(),
|
|
||||||
hostname,
|
|
||||||
regReq.Hostinfo,
|
|
||||||
)
|
|
||||||
|
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Caller().
|
Caller().
|
||||||
Str(zf.NodeName, hostname).
|
Str(zf.NodeName, hostname).
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RegisterWeb(registrationID types.RegistrationID) *elem.Element {
|
func RegisterWeb(registrationID types.AuthID) *elem.Element {
|
||||||
return HtmlStructure(
|
return HtmlStructure(
|
||||||
elem.Title(nil, elem.Text("Registration - Headscale")),
|
elem.Title(nil, elem.Text("Registration - Headscale")),
|
||||||
mdTypesetBody(
|
mdTypesetBody(
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ func TestTemplateHTMLConsistency(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Register Web",
|
name: "Register Web",
|
||||||
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
|
html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Windows Config",
|
name: "Windows Config",
|
||||||
@@ -77,7 +77,7 @@ func TestTemplateModernHTMLFeatures(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Register Web",
|
name: "Register Web",
|
||||||
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
|
html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Windows Config",
|
name: "Windows Config",
|
||||||
@@ -125,7 +125,7 @@ func TestTemplateExternalLinkSecurity(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Register Web",
|
name: "Register Web",
|
||||||
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
|
html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(),
|
||||||
externalURLs: []string{}, // No external links
|
externalURLs: []string{}, // No external links
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -190,7 +190,7 @@ func TestTemplateAccessibilityAttributes(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Register Web",
|
name: "Register Web",
|
||||||
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
|
html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Windows Config",
|
name: "Windows Config",
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ const (
|
|||||||
|
|
||||||
// Common errors.
|
// Common errors.
|
||||||
var (
|
var (
|
||||||
ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||||
ErrInvalidRegistrationIDLength = errors.New("registration ID has invalid length")
|
ErrInvalidAuthIDLength = errors.New("registration ID has invalid length")
|
||||||
)
|
)
|
||||||
|
|
||||||
type StateUpdateType int
|
type StateUpdateType int
|
||||||
@@ -159,21 +159,21 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const RegistrationIDLength = 24
|
const AuthIDLength = 24
|
||||||
|
|
||||||
type RegistrationID string
|
type AuthID string
|
||||||
|
|
||||||
func NewRegistrationID() (RegistrationID, error) {
|
func NewAuthID() (AuthID, error) {
|
||||||
rid, err := util.GenerateRandomStringURLSafe(RegistrationIDLength)
|
rid, err := util.GenerateRandomStringURLSafe(AuthIDLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return RegistrationID(rid), nil
|
return AuthID(rid), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func MustRegistrationID() RegistrationID {
|
func MustAuthID() AuthID {
|
||||||
rid, err := NewRegistrationID()
|
rid, err := NewAuthID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -181,43 +181,89 @@ func MustRegistrationID() RegistrationID {
|
|||||||
return rid
|
return rid
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegistrationIDFromString(str string) (RegistrationID, error) {
|
func AuthIDFromString(str string) (AuthID, error) {
|
||||||
if len(str) != RegistrationIDLength {
|
r := AuthID(str)
|
||||||
return "", fmt.Errorf("%w: expected %d, got %d", ErrInvalidRegistrationIDLength, RegistrationIDLength, len(str))
|
|
||||||
|
err := r.Validate()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return RegistrationID(str), nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r RegistrationID) String() string {
|
func (r AuthID) String() string {
|
||||||
return string(r)
|
return string(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
type RegisterNode struct {
|
func (r AuthID) Validate() error {
|
||||||
Node Node
|
if len(r) != AuthIDLength {
|
||||||
Registered chan *Node
|
return fmt.Errorf("%w: expected %d, got %d", ErrInvalidAuthIDLength, AuthIDLength, len(r))
|
||||||
closed *atomic.Bool
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRegisterNode(node Node) RegisterNode {
|
// AuthRequest represent a pending authentication request from a user or a node.
|
||||||
return RegisterNode{
|
// If it is a registration request, the node field will be populate with the node that is trying to register.
|
||||||
Node: node,
|
// When the authentication process is finished, the node that has been authenticated will be sent through the Finished channel.
|
||||||
Registered: make(chan *Node),
|
// The closed field is used to ensure that the Finished channel is only closed once, and that no more nodes are sent through it after it has been closed.
|
||||||
closed: &atomic.Bool{},
|
type AuthRequest struct {
|
||||||
|
node *Node
|
||||||
|
finished chan AuthVerdict
|
||||||
|
closed *atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRegisterAuthRequest(node Node) AuthRequest {
|
||||||
|
return AuthRequest{
|
||||||
|
node: &node,
|
||||||
|
finished: make(chan AuthVerdict),
|
||||||
|
closed: &atomic.Bool{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rn *RegisterNode) SendAndClose(node *Node) {
|
// Node returns the node that is trying to register.
|
||||||
|
// It will panic if the AuthRequest is not a registration request.
|
||||||
|
// Can _only_ be used in the registration path.
|
||||||
|
func (rn *AuthRequest) Node() NodeView {
|
||||||
|
if rn.node == nil {
|
||||||
|
panic("Node can only be used in registration requests")
|
||||||
|
}
|
||||||
|
|
||||||
|
return rn.node.View()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rn *AuthRequest) FinishAuth(verdict AuthVerdict) {
|
||||||
if rn.closed.Swap(true) {
|
if rn.closed.Swap(true) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case rn.Registered <- node:
|
case rn.finished <- verdict:
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
close(rn.Registered)
|
close(rn.finished)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rn *AuthRequest) WaitForAuth() <-chan AuthVerdict {
|
||||||
|
return rn.finished
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthVerdict struct {
|
||||||
|
// Err is the error that occurred during the authentication process, if any.
|
||||||
|
// If Err is nil, the authentication process has succeeded.
|
||||||
|
// If Err is not nil, the authentication process has failed and the node should not be authenticated.
|
||||||
|
Err error
|
||||||
|
|
||||||
|
// Node is the node that has been authenticated.
|
||||||
|
// Node is only valid if the auth request was a registration request
|
||||||
|
// and the authentication process has succeeded.
|
||||||
|
Node NodeView
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v AuthVerdict) Accept() bool {
|
||||||
|
return v.Err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultBatcherWorkers returns the default number of batcher workers.
|
// DefaultBatcherWorkers returns the default number of batcher workers.
|
||||||
|
|||||||
@@ -295,8 +295,8 @@ func IsCI() bool {
|
|||||||
// 3. If normalisation fails → generate invalid-<random> replacement
|
// 3. If normalisation fails → generate invalid-<random> replacement
|
||||||
//
|
//
|
||||||
// Returns the guaranteed-valid hostname to use.
|
// Returns the guaranteed-valid hostname to use.
|
||||||
func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) string {
|
func EnsureHostname(hostinfo tailcfg.HostinfoView, machineKey, nodeKey string) string {
|
||||||
if hostinfo == nil || hostinfo.Hostname == "" {
|
if !hostinfo.Valid() || hostinfo.Hostname() == "" {
|
||||||
key := cmp.Or(machineKey, nodeKey)
|
key := cmp.Or(machineKey, nodeKey)
|
||||||
if key == "" {
|
if key == "" {
|
||||||
return "unknown-node"
|
return "unknown-node"
|
||||||
@@ -310,7 +310,7 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri
|
|||||||
return "node-" + keyPrefix
|
return "node-" + keyPrefix
|
||||||
}
|
}
|
||||||
|
|
||||||
lowercased := strings.ToLower(hostinfo.Hostname)
|
lowercased := strings.ToLower(hostinfo.Hostname())
|
||||||
|
|
||||||
err := ValidateHostname(lowercased)
|
err := ValidateHostname(lowercased)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|||||||
@@ -1070,7 +1070,7 @@ func TestEnsureHostname(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
got := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
|
got := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey)
|
||||||
// For invalid hostnames, we just check the prefix since the random part varies
|
// For invalid hostnames, we just check the prefix since the random part varies
|
||||||
if strings.HasPrefix(tt.want, "invalid-") {
|
if strings.HasPrefix(tt.want, "invalid-") {
|
||||||
if !strings.HasPrefix(got, "invalid-") {
|
if !strings.HasPrefix(got, "invalid-") {
|
||||||
@@ -1255,7 +1255,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
gotHostname := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
|
gotHostname := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey)
|
||||||
// For invalid hostnames, we just check the prefix since the random part varies
|
// For invalid hostnames, we just check the prefix since the random part varies
|
||||||
if strings.HasPrefix(tt.wantHostname, "invalid-") {
|
if strings.HasPrefix(tt.wantHostname, "invalid-") {
|
||||||
if !strings.HasPrefix(gotHostname, "invalid-") {
|
if !strings.HasPrefix(gotHostname, "invalid-") {
|
||||||
@@ -1284,7 +1284,7 @@ func TestEnsureHostname_DNSLabelLimit(t *testing.T) {
|
|||||||
|
|
||||||
hostinfo := &tailcfg.Hostinfo{Hostname: hostname}
|
hostinfo := &tailcfg.Hostinfo{Hostname: hostname}
|
||||||
|
|
||||||
result := EnsureHostname(hostinfo, "mkey", "nkey")
|
result := EnsureHostname(hostinfo.View(), "mkey", "nkey")
|
||||||
if len(result) > 63 {
|
if len(result) > 63 {
|
||||||
t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result))
|
t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result))
|
||||||
}
|
}
|
||||||
@@ -1300,8 +1300,8 @@ func TestEnsureHostname_Idempotent(t *testing.T) {
|
|||||||
OS: "linux",
|
OS: "linux",
|
||||||
}
|
}
|
||||||
|
|
||||||
hostname1 := EnsureHostname(originalHostinfo, "mkey", "nkey")
|
hostname1 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey")
|
||||||
hostname2 := EnsureHostname(originalHostinfo, "mkey", "nkey")
|
hostname2 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey")
|
||||||
|
|
||||||
if hostname1 != hostname2 {
|
if hostname1 != hostname2 {
|
||||||
t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2)
|
t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2)
|
||||||
|
|||||||
@@ -1065,11 +1065,11 @@ func TestNodeCommand(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
regIDs := []string{
|
regIDs := []string{
|
||||||
types.MustRegistrationID().String(),
|
types.MustAuthID().String(),
|
||||||
types.MustRegistrationID().String(),
|
types.MustAuthID().String(),
|
||||||
types.MustRegistrationID().String(),
|
types.MustAuthID().String(),
|
||||||
types.MustRegistrationID().String(),
|
types.MustAuthID().String(),
|
||||||
types.MustRegistrationID().String(),
|
types.MustAuthID().String(),
|
||||||
}
|
}
|
||||||
nodes := make([]*v1.Node, len(regIDs))
|
nodes := make([]*v1.Node, len(regIDs))
|
||||||
|
|
||||||
@@ -1153,8 +1153,8 @@ func TestNodeCommand(t *testing.T) {
|
|||||||
assert.Equal(t, "node-5", listAll[4].GetName())
|
assert.Equal(t, "node-5", listAll[4].GetName())
|
||||||
|
|
||||||
otherUserRegIDs := []string{
|
otherUserRegIDs := []string{
|
||||||
types.MustRegistrationID().String(),
|
types.MustAuthID().String(),
|
||||||
types.MustRegistrationID().String(),
|
types.MustAuthID().String(),
|
||||||
}
|
}
|
||||||
otherUserMachines := make([]*v1.Node, len(otherUserRegIDs))
|
otherUserMachines := make([]*v1.Node, len(otherUserRegIDs))
|
||||||
|
|
||||||
@@ -1326,11 +1326,11 @@ func TestNodeExpireCommand(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
regIDs := []string{
|
regIDs := []string{
|
||||||
types.MustRegistrationID().String(),
|
types.MustAuthID().String(),
|
||||||
types.MustRegistrationID().String(),
|
types.MustAuthID().String(),
|
||||||
types.MustRegistrationID().String(),
|
types.MustAuthID().String(),
|
||||||
types.MustRegistrationID().String(),
|
types.MustAuthID().String(),
|
||||||
types.MustRegistrationID().String(),
|
types.MustAuthID().String(),
|
||||||
}
|
}
|
||||||
nodes := make([]*v1.Node, len(regIDs))
|
nodes := make([]*v1.Node, len(regIDs))
|
||||||
|
|
||||||
@@ -1461,11 +1461,11 @@ func TestNodeRenameCommand(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
regIDs := []string{
|
regIDs := []string{
|
||||||
types.MustRegistrationID().String(),
|
types.MustAuthID().String(),
|
||||||
types.MustRegistrationID().String(),
|
types.MustAuthID().String(),
|
||||||
types.MustRegistrationID().String(),
|
types.MustAuthID().String(),
|
||||||
types.MustRegistrationID().String(),
|
types.MustAuthID().String(),
|
||||||
types.MustRegistrationID().String(),
|
types.MustAuthID().String(),
|
||||||
}
|
}
|
||||||
nodes := make([]*v1.Node, len(regIDs))
|
nodes := make([]*v1.Node, len(regIDs))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user