Compare commits

..

3 Commits

Author SHA1 Message Date
Juan Font
0d98493360 Reduced the number of containers in integration tests 2022-04-30 21:14:56 +00:00
Juan Font
03659c4175 Updated changelog 2022-04-30 14:50:55 +00:00
Juan Font
843e2bd9b6 Do not setLastStateChangeToNow every 5 seconds 2022-04-30 14:47:16 +00:00
16 changed files with 132 additions and 956 deletions

View File

@@ -7,6 +7,7 @@
- Headscale fails to serve if the ACL policy file cannot be parsed [#537](https://github.com/juanfont/headscale/pull/537) - Headscale fails to serve if the ACL policy file cannot be parsed [#537](https://github.com/juanfont/headscale/pull/537)
- Fix labels cardinality error when registering unknown pre-auth key [#519](https://github.com/juanfont/headscale/pull/519) - Fix labels cardinality error when registering unknown pre-auth key [#519](https://github.com/juanfont/headscale/pull/519)
- Fix send on closed channel crash in polling [#542](https://github.com/juanfont/headscale/pull/542) - Fix send on closed channel crash in polling [#542](https://github.com/juanfont/headscale/pull/542)
- Fixed spurious calls to setLastStateChangeToNow from ephemeral nodes [#566](https://github.com/juanfont/headscale/pull/566)
## 0.15.0 (2022-03-20) ## 0.15.0 (2022-03-20)

175
api.go
View File

@@ -9,7 +9,6 @@ import (
"html/template" "html/template"
"io" "io"
"net/http" "net/http"
"strconv"
"strings" "strings"
"time" "time"
@@ -22,50 +21,18 @@ import (
) )
const ( const (
reservedResponseHeaderSize = 4
RegisterMethodAuthKey = "authkey"
RegisterMethodOIDC = "oidc"
RegisterMethodCLI = "cli"
ErrRegisterMethodCLIDoesNotSupportExpire = Error( ErrRegisterMethodCLIDoesNotSupportExpire = Error(
"machines registered with CLI does not support expire", "machines registered with CLI does not support expire",
) )
) )
const (
reservedResponseHeaderSize = 4
RegisterMethodAuthKey = "authkey"
RegisterMethodOIDC = "oidc"
RegisterMethodCLI = "cli"
// The CapabilityVersion is used by Tailscale clients to indicate
// their codebase version. Tailscale clients can communicate over TS2021
// from CapabilityVersion 28.
// See https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go
NoiseCapabilityVersion = 28
)
// KeyHandler provides the Headscale pub key // KeyHandler provides the Headscale pub key
// Listens in /key. // Listens in /key.
func (h *Headscale) KeyHandler(ctx *gin.Context) { func (h *Headscale) KeyHandler(ctx *gin.Context) {
// New Tailscale clients send a 'v' parameter to indicate the CurrentCapabilityVersion
v := ctx.Query("v")
if v != "" {
clientCapabilityVersion, err := strconv.Atoi(v)
if err != nil {
ctx.String(http.StatusBadRequest, "Invalid version")
return
}
if clientCapabilityVersion >= NoiseCapabilityVersion {
// Tailscale has a different key for the TS2021 protocol. Not sure why.
resp := tailcfg.OverTLSPublicKeyResponse{
LegacyPublicKey: h.privateKey.Public(),
PublicKey: h.noisePrivateKey.Public(),
}
ctx.JSON(http.StatusOK, resp)
return
}
}
// Old clients don't send a 'v' parameter, so we send the legacy public key
ctx.Data( ctx.Data(
http.StatusOK, http.StatusOK,
"text/plain; charset=utf-8", "text/plain; charset=utf-8",
@@ -202,7 +169,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
} }
h.registrationCache.Set( h.registrationCache.Set(
NodePublicKeyStripPrefix(req.NodeKey), machineKeyStr,
newMachine, newMachine,
registerCacheExpiration, registerCacheExpiration,
) )
@@ -321,61 +288,33 @@ func (h *Headscale) getMapResponse(
Msgf("Generated map response: %s", tailMapResponseToString(resp)) Msgf("Generated map response: %s", tailMapResponseToString(resp))
var respBody []byte var respBody []byte
if machineKey.IsZero() { if req.Compress == "zstd" {
// The TS2021 protocol does not rely anymore on the machine key to src, err := json.Marshal(resp)
// encrypt in a NaCl box the map response. We just send it back
// unencrypted via the encrypted Noise channel.
// declare the incoming size on the first 4 bytes
respBody, err := json.Marshal(resp)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Str("func", "getMapResponse").
Err(err). Err(err).
Msg("Cannot marshal map response") Msg("Failed to marshal response for the client")
return nil, err
} }
var srcCompressed []byte encoder, _ := zstd.NewWriter(nil)
if req.Compress == "zstd" { srcCompressed := encoder.EncodeAll(src, nil)
encoder, _ := zstd.NewWriter(nil) respBody = h.privateKey.SealTo(machineKey, srcCompressed)
srcCompressed = encoder.EncodeAll(respBody, nil)
} else {
srcCompressed = respBody
}
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(srcCompressed)))
data = append(data, srcCompressed...)
return data, nil
} else { } else {
if req.Compress == "zstd" { respBody, err = encode(resp, &machineKey, h.privateKey)
src, err := json.Marshal(resp) if err != nil {
if err != nil { return nil, err
log.Error().
Caller().
Str("func", "getMapResponse").
Err(err).
Msg("Failed to marshal response for the client")
return nil, err
}
encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil)
respBody = h.privateKey.SealTo(machineKey, srcCompressed)
} else {
respBody, err = encode(resp, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
} }
// declare the incoming size on the first 4 bytes
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
} }
// declare the incoming size on the first 4 bytes
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
} }
func (h *Headscale) getMapKeepAliveResponse( func (h *Headscale) getMapKeepAliveResponse(
@@ -387,36 +326,31 @@ func (h *Headscale) getMapKeepAliveResponse(
} }
var respBody []byte var respBody []byte
var err error var err error
if machineKey.IsZero() { if mapRequest.Compress == "zstd" {
// The TS2021 protocol does not rely anymore on the machine key. src, err := json.Marshal(mapResponse)
return json.Marshal(mapResponse) if err != nil {
} else { log.Error().
if mapRequest.Compress == "zstd" { Caller().
src, err := json.Marshal(mapResponse) Str("func", "getMapKeepAliveResponse").
if err != nil { Err(err).
log.Error(). Msg("Failed to marshal keepalive response for the client")
Caller().
Str("func", "getMapKeepAliveResponse").
Err(err).
Msg("Failed to marshal keepalive response for the client")
return nil, err return nil, err
} }
encoder, _ := zstd.NewWriter(nil) encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil) srcCompressed := encoder.EncodeAll(src, nil)
respBody = h.privateKey.SealTo(machineKey, srcCompressed) respBody = h.privateKey.SealTo(machineKey, srcCompressed)
} else { } else {
respBody, err = encode(mapResponse, &machineKey, h.privateKey) respBody, err = encode(mapResponse, &machineKey, h.privateKey)
if err != nil { if err != nil {
return nil, err return nil, err
}
} }
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
} }
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
} }
func (h *Headscale) handleMachineLogOut( func (h *Headscale) handleMachineLogOut(
@@ -477,7 +411,6 @@ func (h *Headscale) handleMachineValidRegistration(
return return
} }
machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name). machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
Inc() Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
@@ -504,10 +437,10 @@ func (h *Headscale) handleMachineExpired(
if h.cfg.OIDC.Issuer != "" { if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), machine.NodeKey) strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.String())
} else { } else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s", resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), machine.NodeKey) strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.String())
} }
respBody, err := encode(resp, &machineKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)
@@ -522,7 +455,6 @@ func (h *Headscale) handleMachineExpired(
return return
} }
machineRegistrations.WithLabelValues("reauth", "web", "success", machine.Namespace.Name). machineRegistrations.WithLabelValues("reauth", "web", "success", machine.Namespace.Name).
Inc() Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
@@ -572,21 +504,13 @@ func (h *Headscale) handleMachineRegistrationNew(
resp.AuthURL = fmt.Sprintf( resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s", "%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), strings.TrimSuffix(h.cfg.ServerURL, "/"),
NodePublicKeyStripPrefix(registerRequest.NodeKey), machineKey.String(),
) )
} else { } else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s", resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), NodePublicKeyStripPrefix(registerRequest.NodeKey)) strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey))
} }
if machineKey.IsZero() {
// TS2021
ctx.JSON(http.StatusOK, resp)
return
}
// The Tailscale legacy protocol requires to encrypt the NaCl box with the MachineKey
respBody, err := encode(resp, &machineKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
@@ -600,6 +524,7 @@ func (h *Headscale) handleMachineRegistrationNew(
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
} }
// TODO: check if any locks are needed around IP allocation.
func (h *Headscale) handleAuthKey( func (h *Headscale) handleAuthKey(
ctx *gin.Context, ctx *gin.Context,
machineKey key.MachinePublic, machineKey key.MachinePublic,

69
app.go
View File

@@ -81,7 +81,6 @@ type Config struct {
EphemeralNodeInactivityTimeout time.Duration EphemeralNodeInactivityTimeout time.Duration
IPPrefixes []netaddr.IPPrefix IPPrefixes []netaddr.IPPrefix
PrivateKeyPath string PrivateKeyPath string
NoisePrivateKeyPath string
BaseDomain string BaseDomain string
DERP DERPConfig DERP DERPConfig
@@ -144,15 +143,12 @@ type CLIConfig struct {
// Headscale represents the base app of the service. // Headscale represents the base app of the service.
type Headscale struct { type Headscale struct {
cfg Config cfg Config
db *gorm.DB db *gorm.DB
dbString string dbString string
dbType string dbType string
dbDebug bool dbDebug bool
privateKey *key.MachinePrivate privateKey *key.MachinePrivate
noisePrivateKey *key.MachinePrivate
noiseRouter *gin.Engine
DERPMap *tailcfg.DERPMap DERPMap *tailcfg.DERPMap
DERPServer *DERPServer DERPServer *DERPServer
@@ -192,20 +188,11 @@ func LookupTLSClientAuthMode(mode string) (tls.ClientAuthType, bool) {
} }
func NewHeadscale(cfg Config) (*Headscale, error) { func NewHeadscale(cfg Config) (*Headscale, error) {
privateKey, err := readOrCreatePrivateKey(cfg.PrivateKeyPath) privKey, err := readOrCreatePrivateKey(cfg.PrivateKeyPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read or create private key: %w", err) return nil, fmt.Errorf("failed to read or create private key: %w", err)
} }
noisePrivateKey, err := readOrCreatePrivateKey(cfg.NoisePrivateKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to read or create noise private key: %w", err)
}
if privateKey.Equal(*noisePrivateKey) {
return nil, fmt.Errorf("private key and noise private key are the same")
}
var dbString string var dbString string
switch cfg.DBtype { switch cfg.DBtype {
case Postgres: case Postgres:
@@ -232,8 +219,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
cfg: cfg, cfg: cfg,
dbType: cfg.DBtype, dbType: cfg.DBtype,
dbString: dbString, dbString: dbString,
privateKey: privateKey, privateKey: privKey,
noisePrivateKey: noisePrivateKey,
aclRules: tailcfg.FilterAllowAll, // default allowall aclRules: tailcfg.FilterAllowAll, // default allowall
registrationCache: registrationCache, registrationCache: registrationCache,
} }
@@ -273,10 +259,9 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
} }
// Redirect to our TLS url. // Redirect to our TLS url.
func (h *Headscale) redirect(ctx *gin.Context) { func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
log.Trace().Msgf("Redirecting to TLS, path %s", ctx.Request.RequestURI) target := h.cfg.ServerURL + req.URL.RequestURI()
target := h.cfg.ServerURL + ctx.Request.RequestURI http.Redirect(w, req, target, http.StatusFound)
http.Redirect(ctx.Writer, ctx.Request, target, http.StatusFound)
} }
// expireEphemeralNodes deletes ephemeral machine records that have not been // expireEphemeralNodes deletes ephemeral machine records that have not been
@@ -307,11 +292,13 @@ func (h *Headscale) expireEphemeralNodesWorker() {
return return
} }
expiredFound := false
for _, machine := range machines { for _, machine := range machines {
if machine.AuthKey != nil && machine.LastSeen != nil && if machine.AuthKey != nil && machine.LastSeen != nil &&
machine.AuthKey.Ephemeral && machine.AuthKey.Ephemeral &&
time.Now(). time.Now().
After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) { After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
expiredFound = true
log.Info(). log.Info().
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("Ephemeral client removed from database") Msg("Ephemeral client removed from database")
@@ -326,7 +313,9 @@ func (h *Headscale) expireEphemeralNodesWorker() {
} }
} }
h.setLastStateChangeToNow(namespace.Name) if expiredFound {
h.setLastStateChangeToNow(namespace.Name)
}
} }
} }
@@ -479,13 +468,11 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine {
"/health", "/health",
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) }, func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
) )
router.POST(ts2021UpgradePath, h.NoiseUpgradeHandler)
router.GET("/key", h.KeyHandler) router.GET("/key", h.KeyHandler)
router.GET("/register", h.RegisterWebAPI) router.GET("/register", h.RegisterWebAPI)
router.POST("/machine/:id/map", h.PollNetMapHandler) router.POST("/machine/:id/map", h.PollNetMapHandler)
router.POST("/machine/:id", h.RegistrationHandler) router.POST("/machine/:id", h.RegistrationHandler)
router.GET("/oidc/register/:nkey", h.RegisterOIDC) router.GET("/oidc/register/:mkey", h.RegisterOIDC)
router.GET("/oidc/callback", h.OIDCCallback) router.GET("/oidc/callback", h.OIDCCallback)
router.GET("/apple", h.AppleConfigMessage) router.GET("/apple", h.AppleConfigMessage)
router.GET("/apple/:platform", h.ApplePlatformConfig) router.GET("/apple/:platform", h.ApplePlatformConfig)
@@ -511,15 +498,6 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine {
return router return router
} }
func (h *Headscale) createNoiseRouter() *gin.Engine {
router := gin.Default()
router.POST("/machine/register", h.NoiseRegistrationHandler)
router.POST("/machine/map", h.NoisePollNetMapHandler)
return router
}
// Serve launches a GIN server with the Headscale API. // Serve launches a GIN server with the Headscale API.
func (h *Headscale) Serve() error { func (h *Headscale) Serve() error {
var err error var err error
@@ -685,14 +663,8 @@ func (h *Headscale) Serve() error {
// HTTP setup // HTTP setup
// //
// This is the regular router that we expose
// over our main Addr. It also serves the legacy Tailcale API
router := h.createRouter(grpcGatewayMux) router := h.createRouter(grpcGatewayMux)
// This router is only served over the Noise connection,
// and exposes only the new API
h.noiseRouter = h.createNoiseRouter()
httpServer := &http.Server{ httpServer := &http.Server{
Addr: h.cfg.Addr, Addr: h.cfg.Addr,
Handler: router, Handler: router,
@@ -773,14 +745,10 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
// Configuration via autocert with HTTP-01. This requires listening on // Configuration via autocert with HTTP-01. This requires listening on
// port 80 for the certificate validation in addition to the headscale // port 80 for the certificate validation in addition to the headscale
// service, which can be configured to run on any other port. // service, which can be configured to run on any other port.
httpRouter := gin.Default()
httpRouter.POST(ts2021UpgradePath, h.NoiseUpgradeHandler)
httpRouter.NoRoute(h.redirect)
go func() { go func() {
log.Fatal(). log.Fatal().
Caller(). Caller().
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, certManager.HTTPHandler(httpRouter))). Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))).
Msg("failed to set up a HTTP server") Msg("failed to set up a HTTP server")
}() }()
@@ -818,7 +786,6 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
} }
func (h *Headscale) setLastStateChangeToNow(namespace string) { func (h *Headscale) setLastStateChangeToNow(namespace string) {
log.Trace().Msgf("setting last state change to now for namespace %s", namespace)
now := time.Now().UTC() now := time.Now().UTC()
lastStateUpdate.WithLabelValues("", "headscale").Set(float64(now.Unix())) lastStateUpdate.WithLabelValues("", "headscale").Set(float64(now.Unix()))
h.lastStateChange.Store(namespace, now) h.lastStateChange.Store(namespace, now)

View File

@@ -326,10 +326,9 @@ func getHeadscaleConfig() headscale.Config {
GRPCAddr: viper.GetString("grpc_listen_addr"), GRPCAddr: viper.GetString("grpc_listen_addr"),
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"), GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
IPPrefixes: prefixes, IPPrefixes: prefixes,
PrivateKeyPath: absPath(viper.GetString("private_key_path")), PrivateKeyPath: absPath(viper.GetString("private_key_path")),
NoisePrivateKeyPath: absPath(viper.GetString("noise_private_key_path")), BaseDomain: baseDomain,
BaseDomain: baseDomain,
DERP: derpConfig, DERP: derpConfig,

View File

@@ -41,13 +41,6 @@ grpc_allow_insecure: false
# autogenerated if it's missing # autogenerated if it's missing
private_key_path: /var/lib/headscale/private.key private_key_path: /var/lib/headscale/private.key
# The Noise private key is used to encrypt the
# traffic between headscale and Tailscale clients when
# using the new Noise-based TS2021 protocol.
# The noise private key file which will be
# autogenerated if it's missing
noise_private_key_path: /var/lib/headscale/noise_private.key
# List of IP prefixes to allocate tailaddresses from. # List of IP prefixes to allocate tailaddresses from.
# Each prefix consists of either an IPv4 or IPv6 address, # Each prefix consists of either an IPv4 or IPv6 address,
# and the associated prefix length, delimited by a slash. # and the associated prefix length, delimited by a slash.

2
go.mod
View File

@@ -27,7 +27,6 @@ require (
github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e
github.com/zsais/go-gin-prometheus v0.1.0 github.com/zsais/go-gin-prometheus v0.1.0
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4
golang.org/x/net v0.0.0-20220412020605-290c469a71a5
golang.org/x/oauth2 v0.0.0-20220411215720-9780585627b5 golang.org/x/oauth2 v0.0.0-20220411215720-9780585627b5
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
google.golang.org/genproto v0.0.0-20220422154200-b37d22cd5731 google.golang.org/genproto v0.0.0-20220422154200-b37d22cd5731
@@ -133,6 +132,7 @@ require (
go4.org/intern v0.0.0-20211027215823-ae77deb06f29 // indirect go4.org/intern v0.0.0-20211027215823-ae77deb06f29 // indirect
go4.org/mem v0.0.0-20210711025021-927187094b94 // indirect go4.org/mem v0.0.0-20210711025021-927187094b94 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20211027215541-db492cf91b37 // indirect go4.org/unsafe/assume-no-moving-gc v0.0.0-20211027215541-db492cf91b37 // indirect
golang.org/x/net v0.0.0-20220412020605-290c469a71a5 // indirect
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad // indirect golang.org/x/sys v0.0.0-20220412211240-33da011f77ad // indirect
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect
golang.org/x/text v0.3.7 // indirect golang.org/x/text v0.3.7 // indirect

View File

@@ -5,10 +5,9 @@ import (
"context" "context"
"time" "time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key"
) )
type headscaleV1APIServer struct { // v1.HeadscaleServiceServer type headscaleV1APIServer struct { // v1.HeadscaleServiceServer
@@ -374,7 +373,6 @@ func (api headscaleV1APIServer) DebugCreateMachine(
MachineKey: request.GetKey(), MachineKey: request.GetKey(),
Name: request.GetName(), Name: request.GetName(),
Namespace: *namespace, Namespace: *namespace,
NodeKey: key.NewNode().Public().String(),
Expiry: &time.Time{}, Expiry: &time.Time{},
LastSeen: &time.Time{}, LastSeen: &time.Time{},
@@ -384,7 +382,7 @@ func (api headscaleV1APIServer) DebugCreateMachine(
} }
api.h.registrationCache.Set( api.h.registrationCache.Set(
newMachine.NodeKey, request.GetKey(),
newMachine, newMachine,
registerCacheExpiration, registerCacheExpiration,
) )

View File

@@ -47,11 +47,11 @@ func TestIntegrationTestSuite(t *testing.T) {
s.namespaces = map[string]TestNamespace{ s.namespaces = map[string]TestNamespace{
"thisspace": { "thisspace": {
count: 15, count: 10,
tailscales: make(map[string]dockertest.Resource), tailscales: make(map[string]dockertest.Resource),
}, },
"otherspace": { "otherspace": {
count: 5, count: 3,
tailscales: make(map[string]dockertest.Resource), tailscales: make(map[string]dockertest.Resource),
}, },
} }

View File

@@ -13,7 +13,6 @@ dns_config:
- 1.1.1.1 - 1.1.1.1
db_path: /tmp/integration_test_db.sqlite3 db_path: /tmp/integration_test_db.sqlite3
private_key_path: private.key private_key_path: private.key
noise_private_key_path: noise_private.key
listen_addr: 0.0.0.0:8080 listen_addr: 0.0.0.0:8080
metrics_listen_addr: 127.0.0.1:9090 metrics_listen_addr: 127.0.0.1:9090
server_url: http://headscale:8080 server_url: http://headscale:8080

View File

@@ -13,7 +13,6 @@ dns_config:
- 1.1.1.1 - 1.1.1.1
db_path: /tmp/integration_test_db.sqlite3 db_path: /tmp/integration_test_db.sqlite3
private_key_path: private.key private_key_path: private.key
noise_private_key_path: noise_private.key
listen_addr: 0.0.0.0:8443 listen_addr: 0.0.0.0:8443
server_url: https://headscale:8443 server_url: https://headscale:8443
tls_cert_path: "/etc/headscale/tls/server.crt" tls_cert_path: "/etc/headscale/tls/server.crt"

View File

@@ -335,7 +335,7 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
return &m, nil return &m, nil
} }
// GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct. // GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
func (h *Headscale) GetMachineByMachineKey( func (h *Headscale) GetMachineByMachineKey(
machineKey key.MachinePublic, machineKey key.MachinePublic,
) (*Machine, error) { ) (*Machine, error) {
@@ -347,19 +347,6 @@ func (h *Headscale) GetMachineByMachineKey(
return &m, nil return &m, nil
} }
// GetMachineByNodeKeys finds a Machine by its current NodeKey or the old one, and returns the Machine struct.
func (h *Headscale) GetMachineByNodeKeys(
nodeKey key.NodePublic, oldNodeKey key.NodePublic,
) (*Machine, error) {
m := Machine{}
if result := h.db.Preload("Namespace").First(&m, "node_key = ? OR node_key = ?",
NodePublicKeyStripPrefix(nodeKey), NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil {
return nil, result.Error
}
return &m, nil
}
// UpdateMachine takes a Machine struct pointer (typically already loaded from database // UpdateMachine takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database. // and updates it with the latest data from the database.
func (h *Headscale) UpdateMachine(machine *Machine) error { func (h *Headscale) UpdateMachine(machine *Machine) error {
@@ -375,7 +362,6 @@ func (h *Headscale) ExpireMachine(machine *Machine) {
now := time.Now() now := time.Now()
machine.Expiry = &now machine.Expiry = &now
log.Trace().Msgf("Expiring machine %s", machine.Name)
h.setLastStateChangeToNow(machine.Namespace.Name) h.setLastStateChangeToNow(machine.Namespace.Name)
h.db.Save(machine) h.db.Save(machine)
@@ -388,7 +374,6 @@ func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) {
machine.LastSuccessfulUpdate = &now machine.LastSuccessfulUpdate = &now
machine.Expiry = &expiry machine.Expiry = &expiry
log.Trace().Msgf("Refreshing machine %s", machine.Name)
h.setLastStateChangeToNow(machine.Namespace.Name) h.setLastStateChangeToNow(machine.Namespace.Name)
h.db.Save(machine) h.db.Save(machine)
@@ -520,14 +505,11 @@ func (machine Machine) toNode(
} }
var machineKey key.MachinePublic var machineKey key.MachinePublic
if machine.MachineKey != "" { err = machineKey.UnmarshalText(
// MachineKey is only used in the legacy protocol []byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
err = machineKey.UnmarshalText( )
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)), if err != nil {
) return nil, fmt.Errorf("failed to parse machine public key: %w", err)
if err != nil {
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
}
} }
var discoKey key.DiscoPublic var discoKey key.DiscoPublic
@@ -660,11 +642,11 @@ func (machine *Machine) toProto() *v1.Machine {
} }
func (h *Headscale) RegisterMachineFromAuthCallback( func (h *Headscale) RegisterMachineFromAuthCallback(
nodeKeyStr string, machineKeyStr string,
namespaceName string, namespaceName string,
registrationMethod string, registrationMethod string,
) (*Machine, error) { ) (*Machine, error) {
if machineInterface, ok := h.registrationCache.Get(nodeKeyStr); ok { if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok {
if registrationMachine, ok := machineInterface.(Machine); ok { if registrationMachine, ok := machineInterface.(Machine); ok {
namespace, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
@@ -695,7 +677,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
) (*Machine, error) { ) (*Machine, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("node_key", machine.NodeKey). Str("machine_key", machine.MachineKey).
Msg("Registering machine") Msg("Registering machine")
log.Trace(). log.Trace().

View File

@@ -10,7 +10,6 @@ import (
"gopkg.in/check.v1" "gopkg.in/check.v1"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key"
) )
func (s *Suite) TestGetMachine(c *check.C) { func (s *Suite) TestGetMachine(c *check.C) {
@@ -65,35 +64,6 @@ func (s *Suite) TestGetMachineByID(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
} }
func (s *Suite) TestGetMachineByNodeKeys(c *check.C) {
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = app.GetMachineByID(0)
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
oldNodeKey := key.NewNode()
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()),
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: namespace.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
app.db.Save(&machine)
_, err = app.GetMachineByNodeKeys(nodeKey.Public(), oldNodeKey.Public())
c.Assert(err, check.IsNil)
}
func (s *Suite) TestDeleteMachine(c *check.C) { func (s *Suite) TestDeleteMachine(c *check.C) {
namespace, err := app.CreateNamespace("test") namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)

126
noise.go
View File

@@ -1,126 +0,0 @@
package headscale
import (
"encoding/base64"
"net/http"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"tailscale.com/control/controlbase"
"tailscale.com/net/netutil"
)
const (
errWrongConnectionUpgrade = Error("wrong connection upgrade")
errCannotHijack = Error("cannot hijack connection")
errNoiseHandshakeFailed = Error("noise handshake failed")
)
const (
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade
ts2021UpgradePath = "/ts2021"
// upgradeHeader is the value of the Upgrade HTTP header used to
// indicate the Tailscale control protocol.
upgradeHeaderValue = "tailscale-control-protocol"
// handshakeHeaderName is the HTTP request header that can
// optionally contain base64-encoded initial handshake
// payload, to save an RTT.
handshakeHeaderName = "X-Tailscale-Handshake"
)
// NoiseUpgradeHandler is to upgrade the connection and hijack the net.Conn
// in order to use the Noise-based TS2021 protocol. Listens in /ts2021
func (h *Headscale) NoiseUpgradeHandler(ctx *gin.Context) {
log.Trace().Caller().Msgf("Noise upgrade handler for client %s", ctx.ClientIP())
// Under normal circumpstances, we should be able to use the controlhttp.AcceptHTTP()
// function to do this - kindly left there by the Tailscale authors for us to use.
// (https://github.com/tailscale/tailscale/blob/main/control/controlhttp/server.go)
//
// However, Gin seems to be doing something funny/different with its writer (see AcceptHTTP code).
// This causes problems when the upgrade headers are sent in AcceptHTTP.
// So have getNoiseConnection() that is essentially an AcceptHTTP but using the native Gin methods.
noiseConn, err := h.getNoiseConnection(ctx)
if err != nil {
log.Error().Err(err).Msg("noise upgrade failed")
ctx.AbortWithError(http.StatusInternalServerError, err)
return
}
server := http.Server{}
server.Handler = h2c.NewHandler(h.noiseRouter, &http2.Server{})
server.Serve(netutil.NewOneConnListener(noiseConn, nil))
}
// getNoiseConnection is basically AcceptHTTP from tailscale, but more _alla_ Gin
// TODO(juan): Figure out why we need to do this at all.
func (h *Headscale) getNoiseConnection(ctx *gin.Context) (*controlbase.Conn, error) {
next := ctx.GetHeader("Upgrade")
if next == "" {
ctx.String(http.StatusBadRequest, "missing next protocol")
return nil, errWrongConnectionUpgrade
}
if next != upgradeHeaderValue {
ctx.String(http.StatusBadRequest, "unknown next protocol")
return nil, errWrongConnectionUpgrade
}
initB64 := ctx.GetHeader(handshakeHeaderName)
if initB64 == "" {
ctx.String(http.StatusBadRequest, "missing Tailscale handshake header")
return nil, errWrongConnectionUpgrade
}
init, err := base64.StdEncoding.DecodeString(initB64)
if err != nil {
ctx.String(http.StatusBadRequest, "invalid tailscale handshake header")
return nil, errWrongConnectionUpgrade
}
hijacker, ok := ctx.Writer.(http.Hijacker)
if !ok {
log.Error().Caller().Err(err).Msgf("Hijack failed")
ctx.String(http.StatusInternalServerError, "HTTP does not support general TCP support")
return nil, errCannotHijack
}
// This is what changes from the original AcceptHTTP() function.
ctx.Header("Upgrade", upgradeHeaderValue)
ctx.Header("Connection", "upgrade")
ctx.Status(http.StatusSwitchingProtocols)
ctx.Writer.WriteHeaderNow()
// end
netConn, conn, err := hijacker.Hijack()
if err != nil {
log.Error().Caller().Err(err).Msgf("Hijack failed")
ctx.String(http.StatusInternalServerError, "HTTP does not support general TCP support")
return nil, errCannotHijack
}
if err := conn.Flush(); err != nil {
netConn.Close()
return nil, errCannotHijack
}
netConn = netutil.NewDrainBufConn(netConn, conn.Reader)
nc, err := controlbase.Server(ctx.Request.Context(), netConn, *h.noisePrivateKey, init)
if err != nil {
netConn.Close()
return nil, errNoiseHandshakeFailed
}
return nc, nil
}

View File

@@ -1,551 +0,0 @@
package headscale
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
func (h *Headscale) NoiseRegistrationHandler(ctx *gin.Context) {
log.Trace().Caller().Msgf("Noise registration handler for client %s", ctx.ClientIP())
body, _ := io.ReadAll(ctx.Request.Body)
req := tailcfg.RegisterRequest{}
if err := json.Unmarshal(body, &req); err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse RegisterRequest")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
ctx.String(http.StatusInternalServerError, "Eek!")
return
}
log.Info().Caller().
Str("nodekey", req.NodeKey.ShortString()).
Str("oldnodekey", req.OldNodeKey.ShortString()).Msg("Nodekys!")
now := time.Now().UTC()
machine, err := h.GetMachineByNodeKeys(req.NodeKey, req.OldNodeKey)
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine via Noise")
// If the machine has AuthKey set, handle registration via PreAuthKeys
if req.Auth.AuthKey != "" {
h.handleNoiseAuthKey(ctx, req)
return
}
hname, err := NormalizeToFQDNRules(
req.Hostinfo.Hostname,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil {
log.Error().
Caller().
Str("hostinfo.name", req.Hostinfo.Hostname).
Err(err)
return
}
// The machine did not have a key to authenticate, which means
// that we rely on a method that calls back some how (OpenID or CLI)
// We create the machine and then keep it around until a callback
// happens
newMachine := Machine{
MachineKey: "",
Name: hname,
NodeKey: NodePublicKeyStripPrefix(req.NodeKey),
LastSeen: &now,
Expiry: &time.Time{},
}
if !req.Expiry.IsZero() {
log.Trace().
Caller().
Str("machine", req.Hostinfo.Hostname).
Time("expiry", req.Expiry).
Msg("Non-zero expiry time requested")
newMachine.Expiry = &req.Expiry
}
h.registrationCache.Set(
NodePublicKeyStripPrefix(req.NodeKey),
newMachine,
registerCacheExpiration,
)
h.handleMachineRegistrationNew(ctx, key.MachinePublic{}, req)
return
}
// The machine is already registered, so we need to pass through reauth or key update.
if machine != nil {
// If the NodeKey stored in headscale is the same as the key presented in a registration
// request, then we have a node that is either:
// - Trying to log out (sending a expiry in the past)
// - A valid, registered machine, looking for the node map
// - Expired machine wanting to reauthenticate
if machine.NodeKey == NodePublicKeyStripPrefix(req.NodeKey) {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
h.handleNoiseNodeLogOut(ctx, *machine)
return
}
// If machine is not expired, and is register, we have a already accepted this machine,
// let it proceed with a valid registration
if !machine.isExpired() {
h.handleNoiseNodeValidRegistration(ctx, *machine)
return
}
}
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if machine.NodeKey == NodePublicKeyStripPrefix(req.OldNodeKey) &&
!machine.isExpired() {
h.handleNoiseNodeRefreshKey(ctx, req, *machine)
return
}
// The node has expired
h.handleNoiseNodeExpired(ctx, req, *machine)
return
}
}
// NoisePollNetMapHandler takes care of /machine/:id/map
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) NoisePollNetMapHandler(ctx *gin.Context) {
log.Trace().
Caller().
Str("id", ctx.Param("id")).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(ctx.Request.Body)
req := tailcfg.MapRequest{}
if err := json.Unmarshal(body, &req); err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse MapRequest")
ctx.String(http.StatusInternalServerError, "Eek!")
return
}
machine, err := h.GetMachineByNodeKeys(req.NodeKey, key.NodePublic{})
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().Caller().
Msgf("Ignoring request, cannot find node with node key %s", req.NodeKey.String())
ctx.String(http.StatusUnauthorized, "")
return
}
log.Error().
Caller().
Msgf("Failed to fetch machine from the database with NodeKey: %s", req.NodeKey.String())
ctx.String(http.StatusInternalServerError, "")
return
}
log.Trace().Caller().
Str("NodeKey", req.NodeKey.ShortString()).
Str("machine", machine.Name).
Msg("Found machine in database")
hname, err := NormalizeToFQDNRules(
req.Hostinfo.Hostname,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil {
log.Error().
Caller().
Str("hostinfo.name", req.Hostinfo.Hostname).
Err(err)
}
machine.Name = hname
machine.HostInfo = HostInfo(*req.Hostinfo)
machine.DiscoKey = DiscoPublicKeyStripPrefix(req.DiscoKey)
now := time.Now().UTC()
// update ACLRules with peer informations (to update server tags if necessary)
if h.aclPolicy != nil {
err = h.UpdateACLRules()
if err != nil {
log.Error().
Caller().
Str("func", "handleAuthKey").
Str("machine", machine.Name).
Err(err)
}
}
// From Tailscale client:
//
// ReadOnly is whether the client just wants to fetch the MapResponse,
// without updating their Endpoints. The Endpoints field will be ignored and
// LastSeen will not be updated and peers will not be notified of changes.
//
// The intended use is for clients to discover the DERP map at start-up
// before their first real endpoint update.
if !req.ReadOnly {
machine.Endpoints = req.Endpoints
machine.LastSeen = &now
}
h.db.Updates(machine)
data, err := h.getMapResponse(key.MachinePublic{}, req, machine)
if err != nil {
log.Error().
Caller().
Str("id", ctx.Param("id")).
Str("machine", machine.Name).
Err(err).
Msg("Failed to get Map response")
ctx.String(http.StatusInternalServerError, ":(")
return
}
// We update our peers if the client is not sending ReadOnly in the MapRequest
// so we don't distribute its initial request (it comes with
// empty endpoints to peers)
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug().
Caller().
Str("id", ctx.Param("id")).
Str("machine", machine.Name).
Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream).
Msg("Noise client map request processed")
if req.ReadOnly {
log.Info().
Caller().
Str("machine", machine.Name).
Msg("Noise client is starting up. Probably interested in a DERP map")
// log.Info().Str("machine", machine.Name).Bytes("resp", data).Msg("Sending DERP map to client")
ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
return
}
// There has been an update to _any_ of the nodes that the other nodes would
// need to know about
log.Trace().Msgf("Updating peers for noise machine %s", machine.Name)
h.setLastStateChangeToNow(machine.Namespace.Name)
// The request is not ReadOnly, so we need to set up channels for updating
// peers via longpoll
// Only create update channel if it has not been created
log.Trace().
Caller().
Str("id", ctx.Param("id")).
Str("machine", machine.Name).
Msg("Noise loading or creating update channel")
// TODO: could probably remove all that duplication once generics land.
closeChanWithLog := func(channel interface{}, name string) {
log.Trace().
Caller().
Str("machine", machine.Name).
Str("channel", "Done").
Msg(fmt.Sprintf("Closing %s channel", name))
switch c := channel.(type) {
case (chan struct{}):
close(c)
case (chan []byte):
close(c)
}
}
const chanSize = 8
updateChan := make(chan struct{}, chanSize)
defer closeChanWithLog(updateChan, "updateChan")
pollDataChan := make(chan []byte, chanSize)
defer closeChanWithLog(pollDataChan, "pollDataChan")
keepAliveChan := make(chan []byte)
defer closeChanWithLog(keepAliveChan, "keepAliveChan")
if req.OmitPeers && !req.Stream {
log.Info().
Caller().
Str("machine", machine.Name).
Msg("Noise client sent endpoint update and is ok with a response without peer list")
ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
// It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so.
updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Name, "endpoint-update").
Inc()
updateChan <- struct{}{}
return
} else if req.OmitPeers && req.Stream {
log.Warn().
Caller().
Str("machine", machine.Name).
Msg("Ignoring request, don't know how to handle it")
ctx.String(http.StatusBadRequest, "")
return
}
log.Info().
Caller().
Str("machine", machine.Name).
Msg("Noise client is ready to access the tailnet")
log.Info().
Caller().
Str("machine", machine.Name).
Msg("Sending initial map")
pollDataChan <- data
log.Info().
Caller().
Str("machine", machine.Name).
Msg("Notifying peers")
updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Name, "full-update").
Inc()
updateChan <- struct{}{}
h.PollNetMapStream(
ctx,
machine,
req,
key.MachinePublic{},
pollDataChan,
keepAliveChan,
updateChan,
)
log.Trace().
Caller().
Str("id", ctx.Param("id")).
Str("machine", machine.Name).
Msg("Finished stream, closing PollNetMap session")
}
func (h *Headscale) handleNoiseNodeValidRegistration(
ctx *gin.Context,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
// The machine registration is valid, respond with redirect to /map
log.Debug().
Str("machine", machine.Name).
Msg("Client is registered and we have the current NodeKey. All clear to /map")
resp.AuthURL = ""
resp.MachineAuthorized = true
resp.User = *machine.Namespace.toUser()
resp.Login = *machine.Namespace.toLogin()
machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
Inc()
ctx.JSON(http.StatusOK, resp)
}
func (h *Headscale) handleNoiseNodeLogOut(
ctx *gin.Context,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
log.Info().
Str("machine", machine.Name).
Msg("Client requested logout")
h.ExpireMachine(&machine)
resp.AuthURL = ""
resp.MachineAuthorized = false
resp.User = *machine.Namespace.toUser()
ctx.JSON(http.StatusOK, resp)
}
func (h *Headscale) handleNoiseNodeRefreshKey(
ctx *gin.Context,
registerRequest tailcfg.RegisterRequest,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
log.Debug().
Str("machine", machine.Name).
Msg("We have the OldNodeKey in the database. This is a key refresh")
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
h.db.Save(&machine)
resp.AuthURL = ""
resp.User = *machine.Namespace.toUser()
ctx.JSON(http.StatusOK, resp)
}
func (h *Headscale) handleNoiseNodeExpired(
ctx *gin.Context,
registerRequest tailcfg.RegisterRequest,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
// The client has registered before, but has expired
log.Debug().
Caller().
Str("machine", machine.Name).
Msg("Machine registration has expired. Sending a authurl to register")
if registerRequest.Auth.AuthKey != "" {
h.handleNoiseAuthKey(ctx, registerRequest)
return
}
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), machine.NodeKey)
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), machine.NodeKey)
}
machineRegistrations.WithLabelValues("reauth", "web", "success", machine.Namespace.Name).
Inc()
ctx.JSON(http.StatusOK, resp)
}
func (h *Headscale) handleNoiseAuthKey(
ctx *gin.Context,
registerRequest tailcfg.RegisterRequest,
) {
log.Debug().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Msgf("Processing auth key for %s over Noise", registerRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{}
pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey)
if err != nil {
log.Error().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false
ctx.JSON(http.StatusUnauthorized, resp)
log.Error().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Failed authentication via AuthKey over Noise")
if pak != nil {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
} else {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", "unknown").Inc()
}
return
}
log.Debug().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Authentication key was valid, proceeding to acquire IP addresses")
nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey)
// retrieve machine information if it exist
// The error is not important, because if it does not
// exist, then this is a new machine and we will move
// on to registration.
machine, _ := h.GetMachineByNodeKeys(registerRequest.NodeKey, registerRequest.OldNodeKey)
if machine != nil {
log.Trace().
Caller().
Str("machine", machine.Name).
Msg("machine already registered, refreshing with new auth key")
machine.NodeKey = nodeKey
machine.AuthKeyID = uint(pak.ID)
h.RefreshMachine(machine, registerRequest.Expiry)
} else {
now := time.Now().UTC()
machineToRegister := Machine{
Name: registerRequest.Hostinfo.Hostname,
NamespaceID: pak.Namespace.ID,
MachineKey: "",
RegisterMethod: RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry,
NodeKey: nodeKey,
LastSeen: &now,
AuthKeyID: uint(pak.ID),
}
machine, err = h.RegisterMachine(
machineToRegister,
)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("could not register machine")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
ctx.String(
http.StatusInternalServerError,
"could not register machine",
)
return
}
}
h.UsePreAuthKey(pak)
resp.MachineAuthorized = true
resp.User = *pak.Namespace.toUser()
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.Namespace.Name).
Inc()
ctx.JSON(http.StatusOK, resp)
log.Info().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")).
Msg("Successfully authenticated via AuthKey on Noise")
}

36
oidc.go
View File

@@ -62,10 +62,10 @@ func (h *Headscale) initOIDC() error {
// RegisterOIDC redirects to the OIDC provider for authentication // RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param // Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:nKey. // Listens in /oidc/register/:mKey.
func (h *Headscale) RegisterOIDC(ctx *gin.Context) { func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
nodeKeyStr := ctx.Param("nkey") machineKeyStr := ctx.Param("mkey")
if nodeKeyStr == "" { if machineKeyStr == "" {
ctx.String(http.StatusBadRequest, "Wrong params") ctx.String(http.StatusBadRequest, "Wrong params")
return return
@@ -73,7 +73,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("node_key", nodeKeyStr). Str("machine_key", machineKeyStr).
Msg("Received oidc register call") Msg("Received oidc register call")
randomBlob := make([]byte, randomByteSize) randomBlob := make([]byte, randomByteSize)
@@ -89,7 +89,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
stateStr := hex.EncodeToString(randomBlob)[:32] stateStr := hex.EncodeToString(randomBlob)[:32]
// place the machine key into the state cache, so it can be retrieved later // place the machine key into the state cache, so it can be retrieved later
h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration) h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration)
authURL := h.oauth2Config.AuthCodeURL(stateStr) authURL := h.oauth2Config.AuthCodeURL(stateStr)
log.Debug().Msgf("Redirecting to %s for authentication", authURL) log.Debug().Msgf("Redirecting to %s for authentication", authURL)
@@ -114,7 +114,7 @@ var oidcCallbackTemplate = template.Must(
) )
// OIDCCallback handles the callback from the OIDC endpoint // OIDCCallback handles the callback from the OIDC endpoint
// Retrieves the nkey from the state cache and adds the machine to the users email namespace // Retrieves the mkey from the state cache and adds the machine to the users email namespace
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into machine HostInfo // TODO: Add groups information from OIDC tokens into machine HostInfo
// Listens in /oidc/callback. // Listens in /oidc/callback.
@@ -188,32 +188,32 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
} }
// retrieve machinekey from state cache // retrieve machinekey from state cache
nodeKeyIf, machineKeyFound := h.registrationCache.Get(state) machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
if !machineKeyFound { if !machineKeyFound {
log.Error(). log.Error().
Msg("requested node state key expired before authorisation completed") Msg("requested machine state key expired before authorisation completed")
ctx.String(http.StatusBadRequest, "state has expired") ctx.String(http.StatusBadRequest, "state has expired")
return return
} }
nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string) machineKeyFromCache, machineKeyOK := machineKeyIf.(string)
var nodeKey key.NodePublic var machineKey key.MachinePublic
err = nodeKey.UnmarshalText( err = machineKey.UnmarshalText(
[]byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)), []byte(MachinePublicKeyEnsurePrefix(machineKeyFromCache)),
) )
if err != nil { if err != nil {
log.Error(). log.Error().
Msg("could not parse node public key") Msg("could not parse machine public key")
ctx.String(http.StatusBadRequest, "could not parse public key") ctx.String(http.StatusBadRequest, "could not parse public key")
return return
} }
if !nodeKeyOK { if !machineKeyOK {
log.Error().Msg("could not get node key from cache") log.Error().Msg("could not get machine key from cache")
ctx.String( ctx.String(
http.StatusInternalServerError, http.StatusInternalServerError,
"could not get machine key from cache", "could not get machine key from cache",
@@ -226,7 +226,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
// The error is not important, because if it does not // The error is not important, because if it does not
// exist, then this is a new machine and we will move // exist, then this is a new machine and we will move
// on to registration. // on to registration.
machine, _ := h.GetMachineByNodeKeys(nodeKey, key.NodePublic{}) machine, _ := h.GetMachineByMachineKey(machineKey)
if machine != nil { if machine != nil {
log.Trace(). log.Trace().
@@ -305,10 +305,10 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
nodeKeyStr := NodePublicKeyStripPrefix(nodeKey) machineKeyStr := MachinePublicKeyStripPrefix(machineKey)
_, err = h.RegisterMachineFromAuthCallback( _, err = h.RegisterMachineFromAuthCallback(
nodeKeyStr, machineKeyStr,
namespace.Name, namespace.Name,
RegisterMethodOIDC, RegisterMethodOIDC,
) )

36
poll.go
View File

@@ -64,8 +64,8 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn(). log.Warn().
Caller(). Str("handler", "PollNetMap").
Msgf("Ignoring request (client %s), cannot find machine with key %s", ctx.ClientIP(), machineKey.String()) Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
ctx.String(http.StatusUnauthorized, "") ctx.String(http.StatusUnauthorized, "")
return return
@@ -163,7 +163,6 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
// There has been an update to _any_ of the nodes that the other nodes would // There has been an update to _any_ of the nodes that the other nodes would
// need to know about // need to know about
log.Trace().Msgf("Updating peers for machine %s", machine.Name)
h.setLastStateChangeToNow(machine.Namespace.Name) h.setLastStateChangeToNow(machine.Namespace.Name)
// The request is not ReadOnly, so we need to set up channels for updating // The request is not ReadOnly, so we need to set up channels for updating
@@ -242,7 +241,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Msg("Finished stream, closing PollNetMap session") Msg("Finished stream, closing PollNetMap session")
} }
// PollNetMapStream takes care of /map // PollNetMapStream takes care of /machine/:id/map
// stream logic, ensuring we communicate updates and data // stream logic, ensuring we communicate updates and data
// to the connected clients. // to the connected clients.
func (h *Headscale) PollNetMapStream( func (h *Headscale) PollNetMapStream(
@@ -255,6 +254,24 @@ func (h *Headscale) PollNetMapStream(
updateChan chan struct{}, updateChan chan struct{},
) { ) {
{ {
machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
ctx.String(http.StatusUnauthorized, "")
return
}
log.Error().
Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
ctx.String(http.StatusInternalServerError, "")
return
}
ctx := context.WithValue(ctx.Request.Context(), "machineName", machine.Name) ctx := context.WithValue(ctx.Request.Context(), "machineName", machine.Name)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
@@ -372,7 +389,10 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Keep alive sent successfully") Msg("Keep alive sent successfully")
// TODO(kradalbCne(machine) // TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachine(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
@@ -434,7 +454,7 @@ func (h *Headscale) PollNetMapStream(
Err(err). Err(err).
Msg("Could not get the map update") Msg("Could not get the map update")
} }
nBytes, err := writer.Write(data) _, err = writer.Write(data)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
@@ -451,7 +471,7 @@ func (h *Headscale) PollNetMapStream(
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", machine.Name). Str("machine", machine.Name).
Str("channel", "update"). Str("channel", "update").
Msgf("Updated Map has been sent (%d bytes)", nBytes) Msg("Updated Map has been sent")
updateRequestsSentToNode.WithLabelValues(machine.Namespace.Name, machine.Name, "success"). updateRequestsSentToNode.WithLabelValues(machine.Namespace.Name, machine.Name, "success").
Inc() Inc()
@@ -589,7 +609,7 @@ func (h *Headscale) scheduledPollWorker(
case <-updateCheckerTicker.C: case <-updateCheckerTicker.C:
log.Debug(). log.Debug().
Caller(). Str("func", "scheduledPollWorker").
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("Sending update request") Msg("Sending update request")
updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Name, "scheduled-update"). updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Name, "scheduled-update").