Compare commits

..

3 Commits

Author SHA1 Message Date
Juan Font
0d98493360 Reduced the number of containers in integration tests 2022-04-30 21:14:56 +00:00
Juan Font
03659c4175 Updated changelog 2022-04-30 14:50:55 +00:00
Juan Font
843e2bd9b6 Do not setLastStateChangeToNow every 5 seconds 2022-04-30 14:47:16 +00:00
16 changed files with 132 additions and 956 deletions

View File

@@ -7,6 +7,7 @@
- Headscale fails to serve if the ACL policy file cannot be parsed [#537](https://github.com/juanfont/headscale/pull/537)
- Fix labels cardinality error when registering unknown pre-auth key [#519](https://github.com/juanfont/headscale/pull/519)
- Fix send on closed channel crash in polling [#542](https://github.com/juanfont/headscale/pull/542)
- Fixed spurious calls to setLastStateChangeToNow from ephemeral nodes [#566](https://github.com/juanfont/headscale/pull/566)
## 0.15.0 (2022-03-20)

175
api.go
View File

@@ -9,7 +9,6 @@ import (
"html/template"
"io"
"net/http"
"strconv"
"strings"
"time"
@@ -22,50 +21,18 @@ import (
)
const (
reservedResponseHeaderSize = 4
RegisterMethodAuthKey = "authkey"
RegisterMethodOIDC = "oidc"
RegisterMethodCLI = "cli"
ErrRegisterMethodCLIDoesNotSupportExpire = Error(
"machines registered with CLI does not support expire",
)
)
const (
reservedResponseHeaderSize = 4
RegisterMethodAuthKey = "authkey"
RegisterMethodOIDC = "oidc"
RegisterMethodCLI = "cli"
// The CapabilityVersion is used by Tailscale clients to indicate
// their codebase version. Tailscale clients can communicate over TS2021
// from CapabilityVersion 28.
// See https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go
NoiseCapabilityVersion = 28
)
// KeyHandler provides the Headscale pub key
// Listens in /key.
func (h *Headscale) KeyHandler(ctx *gin.Context) {
// New Tailscale clients send a 'v' parameter to indicate the CurrentCapabilityVersion
v := ctx.Query("v")
if v != "" {
clientCapabilityVersion, err := strconv.Atoi(v)
if err != nil {
ctx.String(http.StatusBadRequest, "Invalid version")
return
}
if clientCapabilityVersion >= NoiseCapabilityVersion {
// Tailscale has a different key for the TS2021 protocol. Not sure why.
resp := tailcfg.OverTLSPublicKeyResponse{
LegacyPublicKey: h.privateKey.Public(),
PublicKey: h.noisePrivateKey.Public(),
}
ctx.JSON(http.StatusOK, resp)
return
}
}
// Old clients don't send a 'v' parameter, so we send the legacy public key
ctx.Data(
http.StatusOK,
"text/plain; charset=utf-8",
@@ -202,7 +169,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
}
h.registrationCache.Set(
NodePublicKeyStripPrefix(req.NodeKey),
machineKeyStr,
newMachine,
registerCacheExpiration,
)
@@ -321,61 +288,33 @@ func (h *Headscale) getMapResponse(
Msgf("Generated map response: %s", tailMapResponseToString(resp))
var respBody []byte
if machineKey.IsZero() {
// The TS2021 protocol does not rely anymore on the machine key to
// encrypt in a NaCl box the map response. We just send it back
// unencrypted via the encrypted Noise channel.
// declare the incoming size on the first 4 bytes
respBody, err := json.Marshal(resp)
if req.Compress == "zstd" {
src, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Str("func", "getMapResponse").
Err(err).
Msg("Cannot marshal map response")
Msg("Failed to marshal response for the client")
return nil, err
}
var srcCompressed []byte
if req.Compress == "zstd" {
encoder, _ := zstd.NewWriter(nil)
srcCompressed = encoder.EncodeAll(respBody, nil)
} else {
srcCompressed = respBody
}
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(srcCompressed)))
data = append(data, srcCompressed...)
return data, nil
encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil)
respBody = h.privateKey.SealTo(machineKey, srcCompressed)
} else {
if req.Compress == "zstd" {
src, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Str("func", "getMapResponse").
Err(err).
Msg("Failed to marshal response for the client")
return nil, err
}
encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil)
respBody = h.privateKey.SealTo(machineKey, srcCompressed)
} else {
respBody, err = encode(resp, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
respBody, err = encode(resp, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
// declare the incoming size on the first 4 bytes
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
}
// declare the incoming size on the first 4 bytes
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
}
func (h *Headscale) getMapKeepAliveResponse(
@@ -387,36 +326,31 @@ func (h *Headscale) getMapKeepAliveResponse(
}
var respBody []byte
var err error
if machineKey.IsZero() {
// The TS2021 protocol does not rely anymore on the machine key.
return json.Marshal(mapResponse)
} else {
if mapRequest.Compress == "zstd" {
src, err := json.Marshal(mapResponse)
if err != nil {
log.Error().
Caller().
Str("func", "getMapKeepAliveResponse").
Err(err).
Msg("Failed to marshal keepalive response for the client")
if mapRequest.Compress == "zstd" {
src, err := json.Marshal(mapResponse)
if err != nil {
log.Error().
Caller().
Str("func", "getMapKeepAliveResponse").
Err(err).
Msg("Failed to marshal keepalive response for the client")
return nil, err
}
encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil)
respBody = h.privateKey.SealTo(machineKey, srcCompressed)
} else {
respBody, err = encode(mapResponse, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
return nil, err
}
encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil)
respBody = h.privateKey.SealTo(machineKey, srcCompressed)
} else {
respBody, err = encode(mapResponse, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
}
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
}
func (h *Headscale) handleMachineLogOut(
@@ -477,7 +411,6 @@ func (h *Headscale) handleMachineValidRegistration(
return
}
machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
@@ -504,10 +437,10 @@ func (h *Headscale) handleMachineExpired(
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), machine.NodeKey)
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.String())
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), machine.NodeKey)
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.String())
}
respBody, err := encode(resp, &machineKey, h.privateKey)
@@ -522,7 +455,6 @@ func (h *Headscale) handleMachineExpired(
return
}
machineRegistrations.WithLabelValues("reauth", "web", "success", machine.Namespace.Name).
Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
@@ -572,21 +504,13 @@ func (h *Headscale) handleMachineRegistrationNew(
resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
NodePublicKeyStripPrefix(registerRequest.NodeKey),
machineKey.String(),
)
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), NodePublicKeyStripPrefix(registerRequest.NodeKey))
strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey))
}
if machineKey.IsZero() {
// TS2021
ctx.JSON(http.StatusOK, resp)
return
}
// The Tailscale legacy protocol requires to encrypt the NaCl box with the MachineKey
respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
@@ -600,6 +524,7 @@ func (h *Headscale) handleMachineRegistrationNew(
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
}
// TODO: check if any locks are needed around IP allocation.
func (h *Headscale) handleAuthKey(
ctx *gin.Context,
machineKey key.MachinePublic,

69
app.go
View File

@@ -81,7 +81,6 @@ type Config struct {
EphemeralNodeInactivityTimeout time.Duration
IPPrefixes []netaddr.IPPrefix
PrivateKeyPath string
NoisePrivateKeyPath string
BaseDomain string
DERP DERPConfig
@@ -144,15 +143,12 @@ type CLIConfig struct {
// Headscale represents the base app of the service.
type Headscale struct {
cfg Config
db *gorm.DB
dbString string
dbType string
dbDebug bool
privateKey *key.MachinePrivate
noisePrivateKey *key.MachinePrivate
noiseRouter *gin.Engine
cfg Config
db *gorm.DB
dbString string
dbType string
dbDebug bool
privateKey *key.MachinePrivate
DERPMap *tailcfg.DERPMap
DERPServer *DERPServer
@@ -192,20 +188,11 @@ func LookupTLSClientAuthMode(mode string) (tls.ClientAuthType, bool) {
}
func NewHeadscale(cfg Config) (*Headscale, error) {
privateKey, err := readOrCreatePrivateKey(cfg.PrivateKeyPath)
privKey, err := readOrCreatePrivateKey(cfg.PrivateKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to read or create private key: %w", err)
}
noisePrivateKey, err := readOrCreatePrivateKey(cfg.NoisePrivateKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to read or create noise private key: %w", err)
}
if privateKey.Equal(*noisePrivateKey) {
return nil, fmt.Errorf("private key and noise private key are the same")
}
var dbString string
switch cfg.DBtype {
case Postgres:
@@ -232,8 +219,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
cfg: cfg,
dbType: cfg.DBtype,
dbString: dbString,
privateKey: privateKey,
noisePrivateKey: noisePrivateKey,
privateKey: privKey,
aclRules: tailcfg.FilterAllowAll, // default allowall
registrationCache: registrationCache,
}
@@ -273,10 +259,9 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
}
// Redirect to our TLS url.
func (h *Headscale) redirect(ctx *gin.Context) {
log.Trace().Msgf("Redirecting to TLS, path %s", ctx.Request.RequestURI)
target := h.cfg.ServerURL + ctx.Request.RequestURI
http.Redirect(ctx.Writer, ctx.Request, target, http.StatusFound)
func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
target := h.cfg.ServerURL + req.URL.RequestURI()
http.Redirect(w, req, target, http.StatusFound)
}
// expireEphemeralNodes deletes ephemeral machine records that have not been
@@ -307,11 +292,13 @@ func (h *Headscale) expireEphemeralNodesWorker() {
return
}
expiredFound := false
for _, machine := range machines {
if machine.AuthKey != nil && machine.LastSeen != nil &&
machine.AuthKey.Ephemeral &&
time.Now().
After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
expiredFound = true
log.Info().
Str("machine", machine.Name).
Msg("Ephemeral client removed from database")
@@ -326,7 +313,9 @@ func (h *Headscale) expireEphemeralNodesWorker() {
}
}
h.setLastStateChangeToNow(namespace.Name)
if expiredFound {
h.setLastStateChangeToNow(namespace.Name)
}
}
}
@@ -479,13 +468,11 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine {
"/health",
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
)
router.POST(ts2021UpgradePath, h.NoiseUpgradeHandler)
router.GET("/key", h.KeyHandler)
router.GET("/register", h.RegisterWebAPI)
router.POST("/machine/:id/map", h.PollNetMapHandler)
router.POST("/machine/:id", h.RegistrationHandler)
router.GET("/oidc/register/:nkey", h.RegisterOIDC)
router.GET("/oidc/register/:mkey", h.RegisterOIDC)
router.GET("/oidc/callback", h.OIDCCallback)
router.GET("/apple", h.AppleConfigMessage)
router.GET("/apple/:platform", h.ApplePlatformConfig)
@@ -511,15 +498,6 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine {
return router
}
func (h *Headscale) createNoiseRouter() *gin.Engine {
router := gin.Default()
router.POST("/machine/register", h.NoiseRegistrationHandler)
router.POST("/machine/map", h.NoisePollNetMapHandler)
return router
}
// Serve launches a GIN server with the Headscale API.
func (h *Headscale) Serve() error {
var err error
@@ -685,14 +663,8 @@ func (h *Headscale) Serve() error {
// HTTP setup
//
// This is the regular router that we expose
// over our main Addr. It also serves the legacy Tailcale API
router := h.createRouter(grpcGatewayMux)
// This router is only served over the Noise connection,
// and exposes only the new API
h.noiseRouter = h.createNoiseRouter()
httpServer := &http.Server{
Addr: h.cfg.Addr,
Handler: router,
@@ -773,14 +745,10 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
// Configuration via autocert with HTTP-01. This requires listening on
// port 80 for the certificate validation in addition to the headscale
// service, which can be configured to run on any other port.
httpRouter := gin.Default()
httpRouter.POST(ts2021UpgradePath, h.NoiseUpgradeHandler)
httpRouter.NoRoute(h.redirect)
go func() {
log.Fatal().
Caller().
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, certManager.HTTPHandler(httpRouter))).
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))).
Msg("failed to set up a HTTP server")
}()
@@ -818,7 +786,6 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
}
func (h *Headscale) setLastStateChangeToNow(namespace string) {
log.Trace().Msgf("setting last state change to now for namespace %s", namespace)
now := time.Now().UTC()
lastStateUpdate.WithLabelValues("", "headscale").Set(float64(now.Unix()))
h.lastStateChange.Store(namespace, now)

View File

@@ -326,10 +326,9 @@ func getHeadscaleConfig() headscale.Config {
GRPCAddr: viper.GetString("grpc_listen_addr"),
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
IPPrefixes: prefixes,
PrivateKeyPath: absPath(viper.GetString("private_key_path")),
NoisePrivateKeyPath: absPath(viper.GetString("noise_private_key_path")),
BaseDomain: baseDomain,
IPPrefixes: prefixes,
PrivateKeyPath: absPath(viper.GetString("private_key_path")),
BaseDomain: baseDomain,
DERP: derpConfig,

View File

@@ -41,13 +41,6 @@ grpc_allow_insecure: false
# autogenerated if it's missing
private_key_path: /var/lib/headscale/private.key
# The Noise private key is used to encrypt the
# traffic between headscale and Tailscale clients when
# using the new Noise-based TS2021 protocol.
# The noise private key file which will be
# autogenerated if it's missing
noise_private_key_path: /var/lib/headscale/noise_private.key
# List of IP prefixes to allocate tailaddresses from.
# Each prefix consists of either an IPv4 or IPv6 address,
# and the associated prefix length, delimited by a slash.

2
go.mod
View File

@@ -27,7 +27,6 @@ require (
github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e
github.com/zsais/go-gin-prometheus v0.1.0
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4
golang.org/x/net v0.0.0-20220412020605-290c469a71a5
golang.org/x/oauth2 v0.0.0-20220411215720-9780585627b5
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
google.golang.org/genproto v0.0.0-20220422154200-b37d22cd5731
@@ -133,6 +132,7 @@ require (
go4.org/intern v0.0.0-20211027215823-ae77deb06f29 // indirect
go4.org/mem v0.0.0-20210711025021-927187094b94 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20211027215541-db492cf91b37 // indirect
golang.org/x/net v0.0.0-20220412020605-290c469a71a5 // indirect
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad // indirect
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect
golang.org/x/text v0.3.7 // indirect

View File

@@ -5,10 +5,9 @@ import (
"context"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
type headscaleV1APIServer struct { // v1.HeadscaleServiceServer
@@ -374,7 +373,6 @@ func (api headscaleV1APIServer) DebugCreateMachine(
MachineKey: request.GetKey(),
Name: request.GetName(),
Namespace: *namespace,
NodeKey: key.NewNode().Public().String(),
Expiry: &time.Time{},
LastSeen: &time.Time{},
@@ -384,7 +382,7 @@ func (api headscaleV1APIServer) DebugCreateMachine(
}
api.h.registrationCache.Set(
newMachine.NodeKey,
request.GetKey(),
newMachine,
registerCacheExpiration,
)

View File

@@ -47,11 +47,11 @@ func TestIntegrationTestSuite(t *testing.T) {
s.namespaces = map[string]TestNamespace{
"thisspace": {
count: 15,
count: 10,
tailscales: make(map[string]dockertest.Resource),
},
"otherspace": {
count: 5,
count: 3,
tailscales: make(map[string]dockertest.Resource),
},
}

View File

@@ -13,7 +13,6 @@ dns_config:
- 1.1.1.1
db_path: /tmp/integration_test_db.sqlite3
private_key_path: private.key
noise_private_key_path: noise_private.key
listen_addr: 0.0.0.0:8080
metrics_listen_addr: 127.0.0.1:9090
server_url: http://headscale:8080

View File

@@ -13,7 +13,6 @@ dns_config:
- 1.1.1.1
db_path: /tmp/integration_test_db.sqlite3
private_key_path: private.key
noise_private_key_path: noise_private.key
listen_addr: 0.0.0.0:8443
server_url: https://headscale:8443
tls_cert_path: "/etc/headscale/tls/server.crt"

View File

@@ -335,7 +335,7 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
return &m, nil
}
// GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct.
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
func (h *Headscale) GetMachineByMachineKey(
machineKey key.MachinePublic,
) (*Machine, error) {
@@ -347,19 +347,6 @@ func (h *Headscale) GetMachineByMachineKey(
return &m, nil
}
// GetMachineByNodeKeys finds a Machine by its current NodeKey or the old one, and returns the Machine struct.
func (h *Headscale) GetMachineByNodeKeys(
nodeKey key.NodePublic, oldNodeKey key.NodePublic,
) (*Machine, error) {
m := Machine{}
if result := h.db.Preload("Namespace").First(&m, "node_key = ? OR node_key = ?",
NodePublicKeyStripPrefix(nodeKey), NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil {
return nil, result.Error
}
return &m, nil
}
// UpdateMachine takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database.
func (h *Headscale) UpdateMachine(machine *Machine) error {
@@ -375,7 +362,6 @@ func (h *Headscale) ExpireMachine(machine *Machine) {
now := time.Now()
machine.Expiry = &now
log.Trace().Msgf("Expiring machine %s", machine.Name)
h.setLastStateChangeToNow(machine.Namespace.Name)
h.db.Save(machine)
@@ -388,7 +374,6 @@ func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) {
machine.LastSuccessfulUpdate = &now
machine.Expiry = &expiry
log.Trace().Msgf("Refreshing machine %s", machine.Name)
h.setLastStateChangeToNow(machine.Namespace.Name)
h.db.Save(machine)
@@ -520,14 +505,11 @@ func (machine Machine) toNode(
}
var machineKey key.MachinePublic
if machine.MachineKey != "" {
// MachineKey is only used in the legacy protocol
err = machineKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
)
if err != nil {
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
}
err = machineKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
)
if err != nil {
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
}
var discoKey key.DiscoPublic
@@ -660,11 +642,11 @@ func (machine *Machine) toProto() *v1.Machine {
}
func (h *Headscale) RegisterMachineFromAuthCallback(
nodeKeyStr string,
machineKeyStr string,
namespaceName string,
registrationMethod string,
) (*Machine, error) {
if machineInterface, ok := h.registrationCache.Get(nodeKeyStr); ok {
if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok {
if registrationMachine, ok := machineInterface.(Machine); ok {
namespace, err := h.GetNamespace(namespaceName)
if err != nil {
@@ -695,7 +677,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
) (*Machine, error) {
log.Trace().
Caller().
Str("node_key", machine.NodeKey).
Str("machine_key", machine.MachineKey).
Msg("Registering machine")
log.Trace().

View File

@@ -10,7 +10,6 @@ import (
"gopkg.in/check.v1"
"inet.af/netaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
func (s *Suite) TestGetMachine(c *check.C) {
@@ -65,35 +64,6 @@ func (s *Suite) TestGetMachineByID(c *check.C) {
c.Assert(err, check.IsNil)
}
func (s *Suite) TestGetMachineByNodeKeys(c *check.C) {
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = app.GetMachineByID(0)
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
oldNodeKey := key.NewNode()
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()),
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: namespace.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
app.db.Save(&machine)
_, err = app.GetMachineByNodeKeys(nodeKey.Public(), oldNodeKey.Public())
c.Assert(err, check.IsNil)
}
func (s *Suite) TestDeleteMachine(c *check.C) {
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)

126
noise.go
View File

@@ -1,126 +0,0 @@
package headscale
import (
"encoding/base64"
"net/http"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"tailscale.com/control/controlbase"
"tailscale.com/net/netutil"
)
const (
errWrongConnectionUpgrade = Error("wrong connection upgrade")
errCannotHijack = Error("cannot hijack connection")
errNoiseHandshakeFailed = Error("noise handshake failed")
)
const (
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade
ts2021UpgradePath = "/ts2021"
// upgradeHeader is the value of the Upgrade HTTP header used to
// indicate the Tailscale control protocol.
upgradeHeaderValue = "tailscale-control-protocol"
// handshakeHeaderName is the HTTP request header that can
// optionally contain base64-encoded initial handshake
// payload, to save an RTT.
handshakeHeaderName = "X-Tailscale-Handshake"
)
// NoiseUpgradeHandler is to upgrade the connection and hijack the net.Conn
// in order to use the Noise-based TS2021 protocol. Listens in /ts2021
func (h *Headscale) NoiseUpgradeHandler(ctx *gin.Context) {
log.Trace().Caller().Msgf("Noise upgrade handler for client %s", ctx.ClientIP())
// Under normal circumpstances, we should be able to use the controlhttp.AcceptHTTP()
// function to do this - kindly left there by the Tailscale authors for us to use.
// (https://github.com/tailscale/tailscale/blob/main/control/controlhttp/server.go)
//
// However, Gin seems to be doing something funny/different with its writer (see AcceptHTTP code).
// This causes problems when the upgrade headers are sent in AcceptHTTP.
// So have getNoiseConnection() that is essentially an AcceptHTTP but using the native Gin methods.
noiseConn, err := h.getNoiseConnection(ctx)
if err != nil {
log.Error().Err(err).Msg("noise upgrade failed")
ctx.AbortWithError(http.StatusInternalServerError, err)
return
}
server := http.Server{}
server.Handler = h2c.NewHandler(h.noiseRouter, &http2.Server{})
server.Serve(netutil.NewOneConnListener(noiseConn, nil))
}
// getNoiseConnection is basically AcceptHTTP from tailscale, but more _alla_ Gin
// TODO(juan): Figure out why we need to do this at all.
func (h *Headscale) getNoiseConnection(ctx *gin.Context) (*controlbase.Conn, error) {
next := ctx.GetHeader("Upgrade")
if next == "" {
ctx.String(http.StatusBadRequest, "missing next protocol")
return nil, errWrongConnectionUpgrade
}
if next != upgradeHeaderValue {
ctx.String(http.StatusBadRequest, "unknown next protocol")
return nil, errWrongConnectionUpgrade
}
initB64 := ctx.GetHeader(handshakeHeaderName)
if initB64 == "" {
ctx.String(http.StatusBadRequest, "missing Tailscale handshake header")
return nil, errWrongConnectionUpgrade
}
init, err := base64.StdEncoding.DecodeString(initB64)
if err != nil {
ctx.String(http.StatusBadRequest, "invalid tailscale handshake header")
return nil, errWrongConnectionUpgrade
}
hijacker, ok := ctx.Writer.(http.Hijacker)
if !ok {
log.Error().Caller().Err(err).Msgf("Hijack failed")
ctx.String(http.StatusInternalServerError, "HTTP does not support general TCP support")
return nil, errCannotHijack
}
// This is what changes from the original AcceptHTTP() function.
ctx.Header("Upgrade", upgradeHeaderValue)
ctx.Header("Connection", "upgrade")
ctx.Status(http.StatusSwitchingProtocols)
ctx.Writer.WriteHeaderNow()
// end
netConn, conn, err := hijacker.Hijack()
if err != nil {
log.Error().Caller().Err(err).Msgf("Hijack failed")
ctx.String(http.StatusInternalServerError, "HTTP does not support general TCP support")
return nil, errCannotHijack
}
if err := conn.Flush(); err != nil {
netConn.Close()
return nil, errCannotHijack
}
netConn = netutil.NewDrainBufConn(netConn, conn.Reader)
nc, err := controlbase.Server(ctx.Request.Context(), netConn, *h.noisePrivateKey, init)
if err != nil {
netConn.Close()
return nil, errNoiseHandshakeFailed
}
return nc, nil
}

View File

@@ -1,551 +0,0 @@
package headscale
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
func (h *Headscale) NoiseRegistrationHandler(ctx *gin.Context) {
log.Trace().Caller().Msgf("Noise registration handler for client %s", ctx.ClientIP())
body, _ := io.ReadAll(ctx.Request.Body)
req := tailcfg.RegisterRequest{}
if err := json.Unmarshal(body, &req); err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse RegisterRequest")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
ctx.String(http.StatusInternalServerError, "Eek!")
return
}
log.Info().Caller().
Str("nodekey", req.NodeKey.ShortString()).
Str("oldnodekey", req.OldNodeKey.ShortString()).Msg("Nodekys!")
now := time.Now().UTC()
machine, err := h.GetMachineByNodeKeys(req.NodeKey, req.OldNodeKey)
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine via Noise")
// If the machine has AuthKey set, handle registration via PreAuthKeys
if req.Auth.AuthKey != "" {
h.handleNoiseAuthKey(ctx, req)
return
}
hname, err := NormalizeToFQDNRules(
req.Hostinfo.Hostname,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil {
log.Error().
Caller().
Str("hostinfo.name", req.Hostinfo.Hostname).
Err(err)
return
}
// The machine did not have a key to authenticate, which means
// that we rely on a method that calls back some how (OpenID or CLI)
// We create the machine and then keep it around until a callback
// happens
newMachine := Machine{
MachineKey: "",
Name: hname,
NodeKey: NodePublicKeyStripPrefix(req.NodeKey),
LastSeen: &now,
Expiry: &time.Time{},
}
if !req.Expiry.IsZero() {
log.Trace().
Caller().
Str("machine", req.Hostinfo.Hostname).
Time("expiry", req.Expiry).
Msg("Non-zero expiry time requested")
newMachine.Expiry = &req.Expiry
}
h.registrationCache.Set(
NodePublicKeyStripPrefix(req.NodeKey),
newMachine,
registerCacheExpiration,
)
h.handleMachineRegistrationNew(ctx, key.MachinePublic{}, req)
return
}
// The machine is already registered, so we need to pass through reauth or key update.
if machine != nil {
// If the NodeKey stored in headscale is the same as the key presented in a registration
// request, then we have a node that is either:
// - Trying to log out (sending a expiry in the past)
// - A valid, registered machine, looking for the node map
// - Expired machine wanting to reauthenticate
if machine.NodeKey == NodePublicKeyStripPrefix(req.NodeKey) {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
h.handleNoiseNodeLogOut(ctx, *machine)
return
}
// If machine is not expired, and is register, we have a already accepted this machine,
// let it proceed with a valid registration
if !machine.isExpired() {
h.handleNoiseNodeValidRegistration(ctx, *machine)
return
}
}
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if machine.NodeKey == NodePublicKeyStripPrefix(req.OldNodeKey) &&
!machine.isExpired() {
h.handleNoiseNodeRefreshKey(ctx, req, *machine)
return
}
// The node has expired
h.handleNoiseNodeExpired(ctx, req, *machine)
return
}
}
// NoisePollNetMapHandler takes care of /machine/:id/map
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) NoisePollNetMapHandler(ctx *gin.Context) {
log.Trace().
Caller().
Str("id", ctx.Param("id")).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(ctx.Request.Body)
req := tailcfg.MapRequest{}
if err := json.Unmarshal(body, &req); err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse MapRequest")
ctx.String(http.StatusInternalServerError, "Eek!")
return
}
machine, err := h.GetMachineByNodeKeys(req.NodeKey, key.NodePublic{})
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().Caller().
Msgf("Ignoring request, cannot find node with node key %s", req.NodeKey.String())
ctx.String(http.StatusUnauthorized, "")
return
}
log.Error().
Caller().
Msgf("Failed to fetch machine from the database with NodeKey: %s", req.NodeKey.String())
ctx.String(http.StatusInternalServerError, "")
return
}
log.Trace().Caller().
Str("NodeKey", req.NodeKey.ShortString()).
Str("machine", machine.Name).
Msg("Found machine in database")
hname, err := NormalizeToFQDNRules(
req.Hostinfo.Hostname,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil {
log.Error().
Caller().
Str("hostinfo.name", req.Hostinfo.Hostname).
Err(err)
}
machine.Name = hname
machine.HostInfo = HostInfo(*req.Hostinfo)
machine.DiscoKey = DiscoPublicKeyStripPrefix(req.DiscoKey)
now := time.Now().UTC()
// update ACLRules with peer informations (to update server tags if necessary)
if h.aclPolicy != nil {
err = h.UpdateACLRules()
if err != nil {
log.Error().
Caller().
Str("func", "handleAuthKey").
Str("machine", machine.Name).
Err(err)
}
}
// From Tailscale client:
//
// ReadOnly is whether the client just wants to fetch the MapResponse,
// without updating their Endpoints. The Endpoints field will be ignored and
// LastSeen will not be updated and peers will not be notified of changes.
//
// The intended use is for clients to discover the DERP map at start-up
// before their first real endpoint update.
if !req.ReadOnly {
machine.Endpoints = req.Endpoints
machine.LastSeen = &now
}
h.db.Updates(machine)
data, err := h.getMapResponse(key.MachinePublic{}, req, machine)
if err != nil {
log.Error().
Caller().
Str("id", ctx.Param("id")).
Str("machine", machine.Name).
Err(err).
Msg("Failed to get Map response")
ctx.String(http.StatusInternalServerError, ":(")
return
}
// We update our peers if the client is not sending ReadOnly in the MapRequest
// so we don't distribute its initial request (it comes with
// empty endpoints to peers)
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug().
Caller().
Str("id", ctx.Param("id")).
Str("machine", machine.Name).
Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream).
Msg("Noise client map request processed")
if req.ReadOnly {
log.Info().
Caller().
Str("machine", machine.Name).
Msg("Noise client is starting up. Probably interested in a DERP map")
// log.Info().Str("machine", machine.Name).Bytes("resp", data).Msg("Sending DERP map to client")
ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
return
}
// There has been an update to _any_ of the nodes that the other nodes would
// need to know about
log.Trace().Msgf("Updating peers for noise machine %s", machine.Name)
h.setLastStateChangeToNow(machine.Namespace.Name)
// The request is not ReadOnly, so we need to set up channels for updating
// peers via longpoll
// Only create update channel if it has not been created
log.Trace().
Caller().
Str("id", ctx.Param("id")).
Str("machine", machine.Name).
Msg("Noise loading or creating update channel")
// TODO: could probably remove all that duplication once generics land.
closeChanWithLog := func(channel interface{}, name string) {
log.Trace().
Caller().
Str("machine", machine.Name).
Str("channel", "Done").
Msg(fmt.Sprintf("Closing %s channel", name))
switch c := channel.(type) {
case (chan struct{}):
close(c)
case (chan []byte):
close(c)
}
}
const chanSize = 8
updateChan := make(chan struct{}, chanSize)
defer closeChanWithLog(updateChan, "updateChan")
pollDataChan := make(chan []byte, chanSize)
defer closeChanWithLog(pollDataChan, "pollDataChan")
keepAliveChan := make(chan []byte)
defer closeChanWithLog(keepAliveChan, "keepAliveChan")
if req.OmitPeers && !req.Stream {
log.Info().
Caller().
Str("machine", machine.Name).
Msg("Noise client sent endpoint update and is ok with a response without peer list")
ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
// It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so.
updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Name, "endpoint-update").
Inc()
updateChan <- struct{}{}
return
} else if req.OmitPeers && req.Stream {
log.Warn().
Caller().
Str("machine", machine.Name).
Msg("Ignoring request, don't know how to handle it")
ctx.String(http.StatusBadRequest, "")
return
}
log.Info().
Caller().
Str("machine", machine.Name).
Msg("Noise client is ready to access the tailnet")
log.Info().
Caller().
Str("machine", machine.Name).
Msg("Sending initial map")
pollDataChan <- data
log.Info().
Caller().
Str("machine", machine.Name).
Msg("Notifying peers")
updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Name, "full-update").
Inc()
updateChan <- struct{}{}
h.PollNetMapStream(
ctx,
machine,
req,
key.MachinePublic{},
pollDataChan,
keepAliveChan,
updateChan,
)
log.Trace().
Caller().
Str("id", ctx.Param("id")).
Str("machine", machine.Name).
Msg("Finished stream, closing PollNetMap session")
}
func (h *Headscale) handleNoiseNodeValidRegistration(
ctx *gin.Context,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
// The machine registration is valid, respond with redirect to /map
log.Debug().
Str("machine", machine.Name).
Msg("Client is registered and we have the current NodeKey. All clear to /map")
resp.AuthURL = ""
resp.MachineAuthorized = true
resp.User = *machine.Namespace.toUser()
resp.Login = *machine.Namespace.toLogin()
machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
Inc()
ctx.JSON(http.StatusOK, resp)
}
func (h *Headscale) handleNoiseNodeLogOut(
ctx *gin.Context,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
log.Info().
Str("machine", machine.Name).
Msg("Client requested logout")
h.ExpireMachine(&machine)
resp.AuthURL = ""
resp.MachineAuthorized = false
resp.User = *machine.Namespace.toUser()
ctx.JSON(http.StatusOK, resp)
}
func (h *Headscale) handleNoiseNodeRefreshKey(
ctx *gin.Context,
registerRequest tailcfg.RegisterRequest,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
log.Debug().
Str("machine", machine.Name).
Msg("We have the OldNodeKey in the database. This is a key refresh")
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
h.db.Save(&machine)
resp.AuthURL = ""
resp.User = *machine.Namespace.toUser()
ctx.JSON(http.StatusOK, resp)
}
func (h *Headscale) handleNoiseNodeExpired(
ctx *gin.Context,
registerRequest tailcfg.RegisterRequest,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
// The client has registered before, but has expired
log.Debug().
Caller().
Str("machine", machine.Name).
Msg("Machine registration has expired. Sending a authurl to register")
if registerRequest.Auth.AuthKey != "" {
h.handleNoiseAuthKey(ctx, registerRequest)
return
}
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), machine.NodeKey)
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), machine.NodeKey)
}
machineRegistrations.WithLabelValues("reauth", "web", "success", machine.Namespace.Name).
Inc()
ctx.JSON(http.StatusOK, resp)
}
func (h *Headscale) handleNoiseAuthKey(
ctx *gin.Context,
registerRequest tailcfg.RegisterRequest,
) {
log.Debug().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Msgf("Processing auth key for %s over Noise", registerRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{}
pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey)
if err != nil {
log.Error().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false
ctx.JSON(http.StatusUnauthorized, resp)
log.Error().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Failed authentication via AuthKey over Noise")
if pak != nil {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
} else {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", "unknown").Inc()
}
return
}
log.Debug().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Authentication key was valid, proceeding to acquire IP addresses")
nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey)
// retrieve machine information if it exist
// The error is not important, because if it does not
// exist, then this is a new machine and we will move
// on to registration.
machine, _ := h.GetMachineByNodeKeys(registerRequest.NodeKey, registerRequest.OldNodeKey)
if machine != nil {
log.Trace().
Caller().
Str("machine", machine.Name).
Msg("machine already registered, refreshing with new auth key")
machine.NodeKey = nodeKey
machine.AuthKeyID = uint(pak.ID)
h.RefreshMachine(machine, registerRequest.Expiry)
} else {
now := time.Now().UTC()
machineToRegister := Machine{
Name: registerRequest.Hostinfo.Hostname,
NamespaceID: pak.Namespace.ID,
MachineKey: "",
RegisterMethod: RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry,
NodeKey: nodeKey,
LastSeen: &now,
AuthKeyID: uint(pak.ID),
}
machine, err = h.RegisterMachine(
machineToRegister,
)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("could not register machine")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
ctx.String(
http.StatusInternalServerError,
"could not register machine",
)
return
}
}
h.UsePreAuthKey(pak)
resp.MachineAuthorized = true
resp.User = *pak.Namespace.toUser()
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.Namespace.Name).
Inc()
ctx.JSON(http.StatusOK, resp)
log.Info().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")).
Msg("Successfully authenticated via AuthKey on Noise")
}

36
oidc.go
View File

@@ -62,10 +62,10 @@ func (h *Headscale) initOIDC() error {
// RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:nKey.
// Listens in /oidc/register/:mKey.
func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
nodeKeyStr := ctx.Param("nkey")
if nodeKeyStr == "" {
machineKeyStr := ctx.Param("mkey")
if machineKeyStr == "" {
ctx.String(http.StatusBadRequest, "Wrong params")
return
@@ -73,7 +73,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
log.Trace().
Caller().
Str("node_key", nodeKeyStr).
Str("machine_key", machineKeyStr).
Msg("Received oidc register call")
randomBlob := make([]byte, randomByteSize)
@@ -89,7 +89,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
stateStr := hex.EncodeToString(randomBlob)[:32]
// place the machine key into the state cache, so it can be retrieved later
h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration)
h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration)
authURL := h.oauth2Config.AuthCodeURL(stateStr)
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
@@ -114,7 +114,7 @@ var oidcCallbackTemplate = template.Must(
)
// OIDCCallback handles the callback from the OIDC endpoint
// Retrieves the nkey from the state cache and adds the machine to the users email namespace
// Retrieves the mkey from the state cache and adds the machine to the users email namespace
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into machine HostInfo
// Listens in /oidc/callback.
@@ -188,32 +188,32 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
}
// retrieve machinekey from state cache
nodeKeyIf, machineKeyFound := h.registrationCache.Get(state)
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
if !machineKeyFound {
log.Error().
Msg("requested node state key expired before authorisation completed")
Msg("requested machine state key expired before authorisation completed")
ctx.String(http.StatusBadRequest, "state has expired")
return
}
nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string)
machineKeyFromCache, machineKeyOK := machineKeyIf.(string)
var nodeKey key.NodePublic
err = nodeKey.UnmarshalText(
[]byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
var machineKey key.MachinePublic
err = machineKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machineKeyFromCache)),
)
if err != nil {
log.Error().
Msg("could not parse node public key")
Msg("could not parse machine public key")
ctx.String(http.StatusBadRequest, "could not parse public key")
return
}
if !nodeKeyOK {
log.Error().Msg("could not get node key from cache")
if !machineKeyOK {
log.Error().Msg("could not get machine key from cache")
ctx.String(
http.StatusInternalServerError,
"could not get machine key from cache",
@@ -226,7 +226,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
// The error is not important, because if it does not
// exist, then this is a new machine and we will move
// on to registration.
machine, _ := h.GetMachineByNodeKeys(nodeKey, key.NodePublic{})
machine, _ := h.GetMachineByMachineKey(machineKey)
if machine != nil {
log.Trace().
@@ -305,10 +305,10 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return
}
nodeKeyStr := NodePublicKeyStripPrefix(nodeKey)
machineKeyStr := MachinePublicKeyStripPrefix(machineKey)
_, err = h.RegisterMachineFromAuthCallback(
nodeKeyStr,
machineKeyStr,
namespace.Name,
RegisterMethodOIDC,
)

36
poll.go
View File

@@ -64,8 +64,8 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Caller().
Msgf("Ignoring request (client %s), cannot find machine with key %s", ctx.ClientIP(), machineKey.String())
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
ctx.String(http.StatusUnauthorized, "")
return
@@ -163,7 +163,6 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
// There has been an update to _any_ of the nodes that the other nodes would
// need to know about
log.Trace().Msgf("Updating peers for machine %s", machine.Name)
h.setLastStateChangeToNow(machine.Namespace.Name)
// The request is not ReadOnly, so we need to set up channels for updating
@@ -242,7 +241,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Msg("Finished stream, closing PollNetMap session")
}
// PollNetMapStream takes care of /map
// PollNetMapStream takes care of /machine/:id/map
// stream logic, ensuring we communicate updates and data
// to the connected clients.
func (h *Headscale) PollNetMapStream(
@@ -255,6 +254,24 @@ func (h *Headscale) PollNetMapStream(
updateChan chan struct{},
) {
{
machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
ctx.String(http.StatusUnauthorized, "")
return
}
log.Error().
Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
ctx.String(http.StatusInternalServerError, "")
return
}
ctx := context.WithValue(ctx.Request.Context(), "machineName", machine.Name)
ctx, cancel := context.WithCancel(ctx)
@@ -372,7 +389,10 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Keep alive sent successfully")
// TODO(kradalbCne(machine)
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachine(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
@@ -434,7 +454,7 @@ func (h *Headscale) PollNetMapStream(
Err(err).
Msg("Could not get the map update")
}
nBytes, err := writer.Write(data)
_, err = writer.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
@@ -451,7 +471,7 @@ func (h *Headscale) PollNetMapStream(
Str("handler", "PollNetMapStream").
Str("machine", machine.Name).
Str("channel", "update").
Msgf("Updated Map has been sent (%d bytes)", nBytes)
Msg("Updated Map has been sent")
updateRequestsSentToNode.WithLabelValues(machine.Namespace.Name, machine.Name, "success").
Inc()
@@ -589,7 +609,7 @@ func (h *Headscale) scheduledPollWorker(
case <-updateCheckerTicker.C:
log.Debug().
Caller().
Str("func", "scheduledPollWorker").
Str("machine", machine.Name).
Msg("Sending update request")
updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Name, "scheduled-update").