mirror of
https://github.com/juanfont/headscale.git
synced 2026-04-18 14:59:54 +02:00
auth: generalise auth flow and introduce AuthVerdict
Generalise the registration pipeline to a more general auth pipeline supporting both node registrations and SSH check auth requests. Rename RegistrationID to AuthID, unexport AuthRequest fields, and introduce AuthVerdict to unify the auth finish API. Add the urlParam generic helper for extracting typed URL parameters from chi routes, used by the new auth request handler. Updates #1850
This commit is contained in:
@@ -22,8 +22,8 @@ const (
|
||||
|
||||
// Common errors.
|
||||
var (
|
||||
ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||
ErrInvalidRegistrationIDLength = errors.New("registration ID has invalid length")
|
||||
ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||
ErrInvalidAuthIDLength = errors.New("registration ID has invalid length")
|
||||
)
|
||||
|
||||
type StateUpdateType int
|
||||
@@ -159,21 +159,21 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
|
||||
}
|
||||
}
|
||||
|
||||
const RegistrationIDLength = 24
|
||||
const AuthIDLength = 24
|
||||
|
||||
type RegistrationID string
|
||||
type AuthID string
|
||||
|
||||
func NewRegistrationID() (RegistrationID, error) {
|
||||
rid, err := util.GenerateRandomStringURLSafe(RegistrationIDLength)
|
||||
func NewAuthID() (AuthID, error) {
|
||||
rid, err := util.GenerateRandomStringURLSafe(AuthIDLength)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return RegistrationID(rid), nil
|
||||
return AuthID(rid), nil
|
||||
}
|
||||
|
||||
func MustRegistrationID() RegistrationID {
|
||||
rid, err := NewRegistrationID()
|
||||
func MustAuthID() AuthID {
|
||||
rid, err := NewAuthID()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -181,43 +181,89 @@ func MustRegistrationID() RegistrationID {
|
||||
return rid
|
||||
}
|
||||
|
||||
func RegistrationIDFromString(str string) (RegistrationID, error) {
|
||||
if len(str) != RegistrationIDLength {
|
||||
return "", fmt.Errorf("%w: expected %d, got %d", ErrInvalidRegistrationIDLength, RegistrationIDLength, len(str))
|
||||
func AuthIDFromString(str string) (AuthID, error) {
|
||||
r := AuthID(str)
|
||||
|
||||
err := r.Validate()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return RegistrationID(str), nil
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r RegistrationID) String() string {
|
||||
func (r AuthID) String() string {
|
||||
return string(r)
|
||||
}
|
||||
|
||||
type RegisterNode struct {
|
||||
Node Node
|
||||
Registered chan *Node
|
||||
closed *atomic.Bool
|
||||
func (r AuthID) Validate() error {
|
||||
if len(r) != AuthIDLength {
|
||||
return fmt.Errorf("%w: expected %d, got %d", ErrInvalidAuthIDLength, AuthIDLength, len(r))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewRegisterNode(node Node) RegisterNode {
|
||||
return RegisterNode{
|
||||
Node: node,
|
||||
Registered: make(chan *Node),
|
||||
closed: &atomic.Bool{},
|
||||
// AuthRequest represent a pending authentication request from a user or a node.
|
||||
// If it is a registration request, the node field will be populate with the node that is trying to register.
|
||||
// When the authentication process is finished, the node that has been authenticated will be sent through the Finished channel.
|
||||
// The closed field is used to ensure that the Finished channel is only closed once, and that no more nodes are sent through it after it has been closed.
|
||||
type AuthRequest struct {
|
||||
node *Node
|
||||
finished chan AuthVerdict
|
||||
closed *atomic.Bool
|
||||
}
|
||||
|
||||
func NewRegisterAuthRequest(node Node) AuthRequest {
|
||||
return AuthRequest{
|
||||
node: &node,
|
||||
finished: make(chan AuthVerdict),
|
||||
closed: &atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func (rn *RegisterNode) SendAndClose(node *Node) {
|
||||
// Node returns the node that is trying to register.
|
||||
// It will panic if the AuthRequest is not a registration request.
|
||||
// Can _only_ be used in the registration path.
|
||||
func (rn *AuthRequest) Node() NodeView {
|
||||
if rn.node == nil {
|
||||
panic("Node can only be used in registration requests")
|
||||
}
|
||||
|
||||
return rn.node.View()
|
||||
}
|
||||
|
||||
func (rn *AuthRequest) FinishAuth(verdict AuthVerdict) {
|
||||
if rn.closed.Swap(true) {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case rn.Registered <- node:
|
||||
case rn.finished <- verdict:
|
||||
default:
|
||||
}
|
||||
|
||||
close(rn.Registered)
|
||||
close(rn.finished)
|
||||
}
|
||||
|
||||
func (rn *AuthRequest) WaitForAuth() <-chan AuthVerdict {
|
||||
return rn.finished
|
||||
}
|
||||
|
||||
type AuthVerdict struct {
|
||||
// Err is the error that occurred during the authentication process, if any.
|
||||
// If Err is nil, the authentication process has succeeded.
|
||||
// If Err is not nil, the authentication process has failed and the node should not be authenticated.
|
||||
Err error
|
||||
|
||||
// Node is the node that has been authenticated.
|
||||
// Node is only valid if the auth request was a registration request
|
||||
// and the authentication process has succeeded.
|
||||
Node NodeView
|
||||
}
|
||||
|
||||
func (v AuthVerdict) Accept() bool {
|
||||
return v.Err == nil
|
||||
}
|
||||
|
||||
// DefaultBatcherWorkers returns the default number of batcher workers.
|
||||
|
||||
Reference in New Issue
Block a user