mirror of
https://github.com/juanfont/headscale.git
synced 2026-04-21 08:11:43 +02:00
users: harden, test, and add cleaner of identifier (#2593)
* users: harden, test, and add cleaner of identifier Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * db: migrate badly joined provider identifiers Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
@@ -194,13 +194,110 @@ type OIDCClaims struct {
|
||||
Username string `json:"preferred_username,omitempty"`
|
||||
}
|
||||
|
||||
// Identifier returns a unique identifier string combining the Iss and Sub claims.
|
||||
// The format depends on whether Iss is a URL or not:
|
||||
// - For URLs: Joins the URL and sub path (e.g., "https://example.com/sub")
|
||||
// - For non-URLs: Joins with a slash (e.g., "oidc/sub")
|
||||
// - For empty Iss: Returns just "sub"
|
||||
// - For empty Sub: Returns just the Issuer
|
||||
// - For both empty: Returns empty string
|
||||
//
|
||||
// The result is cleaned using CleanIdentifier() to ensure consistent formatting.
|
||||
func (c *OIDCClaims) Identifier() string {
|
||||
if strings.HasPrefix(c.Iss, "http") {
|
||||
if i, err := url.JoinPath(c.Iss, c.Sub); err == nil {
|
||||
return i
|
||||
// Handle empty components special cases
|
||||
if c.Iss == "" && c.Sub == "" {
|
||||
return ""
|
||||
}
|
||||
if c.Iss == "" {
|
||||
return CleanIdentifier(c.Sub)
|
||||
}
|
||||
if c.Sub == "" {
|
||||
return CleanIdentifier(c.Iss)
|
||||
}
|
||||
|
||||
// We'll use the raw values and let CleanIdentifier handle all the whitespace
|
||||
issuer := c.Iss
|
||||
subject := c.Sub
|
||||
|
||||
var result string
|
||||
// Try to parse as URL to handle URL joining correctly
|
||||
if u, err := url.Parse(issuer); err == nil && u.Scheme != "" {
|
||||
// For URLs, use proper URL path joining
|
||||
if joined, err := url.JoinPath(issuer, subject); err == nil {
|
||||
result = joined
|
||||
}
|
||||
}
|
||||
return c.Iss + "/" + c.Sub
|
||||
|
||||
// If URL joining failed or issuer wasn't a URL, do simple string join
|
||||
if result == "" {
|
||||
// Default case: simple string joining with slash
|
||||
issuer = strings.TrimSuffix(issuer, "/")
|
||||
subject = strings.TrimPrefix(subject, "/")
|
||||
result = issuer + "/" + subject
|
||||
}
|
||||
|
||||
// Clean the result and return it
|
||||
return CleanIdentifier(result)
|
||||
}
|
||||
|
||||
// CleanIdentifier cleans a potentially malformed identifier by removing double slashes
|
||||
// while preserving protocol specifications like http://. This function will:
|
||||
// - Trim all whitespace from the beginning and end of the identifier
|
||||
// - Remove whitespace within path segments
|
||||
// - Preserve the scheme (http://, https://, etc.) for URLs
|
||||
// - Remove any duplicate slashes in the path
|
||||
// - Remove empty path segments
|
||||
// - For non-URL identifiers, it joins non-empty segments with a single slash
|
||||
// - Returns empty string for identifiers with only slashes
|
||||
// - Normalize URL schemes to lowercase
|
||||
func CleanIdentifier(identifier string) string {
|
||||
if identifier == "" {
|
||||
return identifier
|
||||
}
|
||||
|
||||
// Trim leading/trailing whitespace
|
||||
identifier = strings.TrimSpace(identifier)
|
||||
|
||||
// Handle URLs with schemes
|
||||
u, err := url.Parse(identifier)
|
||||
if err == nil && u.Scheme != "" {
|
||||
// Clean path by removing empty segments and whitespace within segments
|
||||
parts := strings.FieldsFunc(u.Path, func(c rune) bool { return c == '/' })
|
||||
for i, part := range parts {
|
||||
parts[i] = strings.TrimSpace(part)
|
||||
}
|
||||
// Remove empty parts after trimming
|
||||
cleanParts := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
if part != "" {
|
||||
cleanParts = append(cleanParts, part)
|
||||
}
|
||||
}
|
||||
|
||||
if len(cleanParts) == 0 {
|
||||
u.Path = ""
|
||||
} else {
|
||||
u.Path = "/" + strings.Join(cleanParts, "/")
|
||||
}
|
||||
// Ensure scheme is lowercase
|
||||
u.Scheme = strings.ToLower(u.Scheme)
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// Handle non-URL identifiers
|
||||
parts := strings.FieldsFunc(identifier, func(c rune) bool { return c == '/' })
|
||||
// Clean whitespace from each part
|
||||
cleanParts := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
trimmed := strings.TrimSpace(part)
|
||||
if trimmed != "" {
|
||||
cleanParts = append(cleanParts, trimmed)
|
||||
}
|
||||
}
|
||||
if len(cleanParts) == 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(cleanParts, "/")
|
||||
}
|
||||
|
||||
type OIDCUserInfo struct {
|
||||
@@ -231,7 +328,13 @@ func (u *User) FromClaim(claims *OIDCClaims) {
|
||||
}
|
||||
}
|
||||
|
||||
u.ProviderIdentifier = sql.NullString{String: claims.Identifier(), Valid: true}
|
||||
// Get provider identifier
|
||||
identifier := claims.Identifier()
|
||||
// Ensure provider identifier always has a leading slash for backward compatibility
|
||||
if claims.Iss == "" && !strings.HasPrefix(identifier, "/") {
|
||||
identifier = "/" + identifier
|
||||
}
|
||||
u.ProviderIdentifier = sql.NullString{String: identifier, Valid: true}
|
||||
u.DisplayName = claims.Name
|
||||
u.ProfilePicURL = claims.ProfilePictureURL
|
||||
u.Provider = util.RegisterMethodOIDC
|
||||
|
||||
Reference in New Issue
Block a user