mirror of
https://github.com/juanfont/headscale.git
synced 2026-03-18 15:34:06 +01:00
Generalise the registration pipeline to a more general auth pipeline supporting both node registrations and SSH check auth requests. Rename RegistrationID to AuthID, unexport AuthRequest fields, and introduce AuthVerdict to unify the auth finish API. Add the urlParam generic helper for extracting typed URL parameters from chi routes, used by the new auth request handler. Updates #1850
1145 lines
29 KiB
Go
1145 lines
29 KiB
Go
package hscontrol
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
_ "net/http/pprof" // nolint
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/cenkalti/backoff/v5"
|
|
"github.com/davecgh/go-spew/spew"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
"github.com/go-chi/metrics"
|
|
grpcRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
|
"github.com/juanfont/headscale"
|
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
|
"github.com/juanfont/headscale/hscontrol/capver"
|
|
"github.com/juanfont/headscale/hscontrol/db"
|
|
"github.com/juanfont/headscale/hscontrol/derp"
|
|
derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
|
|
"github.com/juanfont/headscale/hscontrol/dns"
|
|
"github.com/juanfont/headscale/hscontrol/mapper"
|
|
"github.com/juanfont/headscale/hscontrol/state"
|
|
"github.com/juanfont/headscale/hscontrol/types"
|
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
|
"github.com/juanfont/headscale/hscontrol/util"
|
|
zerolog "github.com/philip-bui/grpc-zerolog"
|
|
"github.com/pkg/profile"
|
|
zl "github.com/rs/zerolog"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/sasha-s/go-deadlock"
|
|
"golang.org/x/crypto/acme"
|
|
"golang.org/x/crypto/acme/autocert"
|
|
"golang.org/x/sync/errgroup"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/peer"
|
|
"google.golang.org/grpc/reflection"
|
|
"google.golang.org/grpc/status"
|
|
"tailscale.com/envknob"
|
|
"tailscale.com/tailcfg"
|
|
"tailscale.com/types/dnstype"
|
|
"tailscale.com/types/key"
|
|
"tailscale.com/util/dnsname"
|
|
)
|
|
|
|
var (
|
|
errSTUNAddressNotSet = errors.New("STUN address not set")
|
|
errUnsupportedLetsEncryptChallengeType = errors.New(
|
|
"unknown value for Lets Encrypt challenge type",
|
|
)
|
|
errEmptyInitialDERPMap = errors.New(
|
|
"initial DERPMap is empty, Headscale requires at least one entry",
|
|
)
|
|
)
|
|
|
|
var (
|
|
debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK")
|
|
debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT")
|
|
)
|
|
|
|
func init() {
|
|
deadlock.Opts.Disable = !debugDeadlock
|
|
if debugDeadlock {
|
|
deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout()
|
|
deadlock.Opts.PrintAllCurrentGoroutines = true
|
|
}
|
|
}
|
|
|
|
const (
|
|
AuthPrefix = "Bearer "
|
|
updateInterval = 5 * time.Second
|
|
privateKeyFileMode = 0o600
|
|
headscaleDirPerm = 0o700
|
|
)
|
|
|
|
// Headscale represents the base app of the service.
|
|
type Headscale struct {
|
|
cfg *types.Config
|
|
state *state.State
|
|
noisePrivateKey *key.MachinePrivate
|
|
ephemeralGC *db.EphemeralGarbageCollector
|
|
|
|
DERPServer *derpServer.DERPServer
|
|
|
|
// Things that generate changes
|
|
extraRecordMan *dns.ExtraRecordsMan
|
|
authProvider AuthProvider
|
|
mapBatcher mapper.Batcher
|
|
|
|
clientStreamsOpen sync.WaitGroup
|
|
}
|
|
|
|
var (
|
|
profilingEnabled = envknob.Bool("HEADSCALE_DEBUG_PROFILING_ENABLED")
|
|
profilingPath = envknob.String("HEADSCALE_DEBUG_PROFILING_PATH")
|
|
tailsqlEnabled = envknob.Bool("HEADSCALE_DEBUG_TAILSQL_ENABLED")
|
|
tailsqlStateDir = envknob.String("HEADSCALE_DEBUG_TAILSQL_STATE_DIR")
|
|
tailsqlTSKey = envknob.String("TS_AUTHKEY")
|
|
dumpConfig = envknob.Bool("HEADSCALE_DEBUG_DUMP_CONFIG")
|
|
)
|
|
|
|
func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|
var err error
|
|
|
|
if profilingEnabled {
|
|
runtime.SetBlockProfileRate(1)
|
|
}
|
|
|
|
noisePrivateKey, err := readOrCreatePrivateKey(cfg.NoisePrivateKeyPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading or creating Noise protocol private key: %w", err)
|
|
}
|
|
|
|
s, err := state.NewState(cfg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("init state: %w", err)
|
|
}
|
|
|
|
app := Headscale{
|
|
cfg: cfg,
|
|
noisePrivateKey: noisePrivateKey,
|
|
clientStreamsOpen: sync.WaitGroup{},
|
|
state: s,
|
|
}
|
|
|
|
// Initialize ephemeral garbage collector
|
|
ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) {
|
|
node, ok := app.state.GetNodeByID(ni)
|
|
if !ok {
|
|
log.Error().Uint64("node.id", ni.Uint64()).Msg("ephemeral node deletion failed")
|
|
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("ephemeral node deletion failed because node not found in NodeStore")
|
|
|
|
return
|
|
}
|
|
|
|
policyChanged, err := app.state.DeleteNode(node)
|
|
if err != nil {
|
|
log.Error().Err(err).EmbedObject(node).Msg("ephemeral node deletion failed")
|
|
return
|
|
}
|
|
|
|
app.Change(policyChanged)
|
|
log.Debug().Caller().EmbedObject(node).Msg("ephemeral node deleted because garbage collection timeout reached")
|
|
})
|
|
app.ephemeralGC = ephemeralGC
|
|
|
|
var authProvider AuthProvider
|
|
|
|
authProvider = NewAuthProviderWeb(cfg.ServerURL)
|
|
if cfg.OIDC.Issuer != "" {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
oidcProvider, err := NewAuthProviderOIDC(
|
|
ctx,
|
|
&app,
|
|
cfg.ServerURL,
|
|
&cfg.OIDC,
|
|
)
|
|
if err != nil {
|
|
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
|
|
return nil, err
|
|
} else {
|
|
log.Warn().Err(err).Msg("failed to set up OIDC provider, falling back to CLI based authentication")
|
|
}
|
|
} else {
|
|
authProvider = oidcProvider
|
|
}
|
|
}
|
|
|
|
app.authProvider = authProvider
|
|
|
|
if app.cfg.TailcfgDNSConfig != nil && app.cfg.TailcfgDNSConfig.Proxied { // if MagicDNS
|
|
// TODO(kradalby): revisit why this takes a list.
|
|
var magicDNSDomains []dnsname.FQDN
|
|
if cfg.PrefixV4 != nil {
|
|
magicDNSDomains = append(
|
|
magicDNSDomains,
|
|
util.GenerateIPv4DNSRootDomain(*cfg.PrefixV4)...)
|
|
}
|
|
|
|
if cfg.PrefixV6 != nil {
|
|
magicDNSDomains = append(
|
|
magicDNSDomains,
|
|
util.GenerateIPv6DNSRootDomain(*cfg.PrefixV6)...)
|
|
}
|
|
|
|
// we might have routes already from Split DNS
|
|
if app.cfg.TailcfgDNSConfig.Routes == nil {
|
|
app.cfg.TailcfgDNSConfig.Routes = make(map[string][]*dnstype.Resolver)
|
|
}
|
|
|
|
for _, d := range magicDNSDomains {
|
|
app.cfg.TailcfgDNSConfig.Routes[d.WithoutTrailingDot()] = nil
|
|
}
|
|
}
|
|
|
|
if cfg.DERP.ServerEnabled {
|
|
derpServerKey, err := readOrCreatePrivateKey(cfg.DERP.ServerPrivateKeyPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading or creating DERP server private key: %w", err)
|
|
}
|
|
|
|
if derpServerKey.Equal(*noisePrivateKey) {
|
|
return nil, fmt.Errorf(
|
|
"DERP server private key and noise private key are the same: %w",
|
|
err,
|
|
)
|
|
}
|
|
|
|
if cfg.DERP.ServerVerifyClients {
|
|
t := http.DefaultTransport.(*http.Transport) //nolint:forcetypeassert
|
|
t.RegisterProtocol(
|
|
derpServer.DerpVerifyScheme,
|
|
derpServer.NewDERPVerifyTransport(app.handleVerifyRequest),
|
|
)
|
|
}
|
|
|
|
embeddedDERPServer, err := derpServer.NewDERPServer(
|
|
cfg.ServerURL,
|
|
key.NodePrivate(*derpServerKey),
|
|
&cfg.DERP,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
app.DERPServer = embeddedDERPServer
|
|
}
|
|
|
|
return &app, nil
|
|
}
|
|
|
|
// Redirect to our TLS url.
|
|
func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
|
|
target := h.cfg.ServerURL + req.URL.RequestURI()
|
|
http.Redirect(w, req, target, http.StatusFound)
|
|
}
|
|
|
|
func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|
expireTicker := time.NewTicker(updateInterval)
|
|
defer expireTicker.Stop()
|
|
|
|
lastExpiryCheck := time.Unix(0, 0)
|
|
|
|
derpTickerChan := make(<-chan time.Time)
|
|
|
|
if h.cfg.DERP.AutoUpdate && h.cfg.DERP.UpdateFrequency != 0 {
|
|
derpTicker := time.NewTicker(h.cfg.DERP.UpdateFrequency)
|
|
defer derpTicker.Stop()
|
|
|
|
derpTickerChan = derpTicker.C
|
|
}
|
|
|
|
var extraRecordsUpdate <-chan []tailcfg.DNSRecord
|
|
if h.extraRecordMan != nil {
|
|
extraRecordsUpdate = h.extraRecordMan.UpdateCh()
|
|
} else {
|
|
extraRecordsUpdate = make(chan []tailcfg.DNSRecord)
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Info().Caller().Msg("scheduled task worker is shutting down.")
|
|
return
|
|
|
|
case <-expireTicker.C:
|
|
var (
|
|
expiredNodeChanges []change.Change
|
|
changed bool
|
|
)
|
|
|
|
lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
|
|
|
|
if changed {
|
|
log.Trace().Interface("changes", expiredNodeChanges).Msgf("expiring nodes")
|
|
|
|
// Send the changes directly since they're already in the new format
|
|
for _, nodeChange := range expiredNodeChanges {
|
|
h.Change(nodeChange)
|
|
}
|
|
}
|
|
|
|
case <-derpTickerChan:
|
|
log.Info().Msg("fetching DERPMap updates")
|
|
|
|
derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) { //nolint:contextcheck
|
|
derpMap, err := derp.GetDERPMap(h.cfg.DERP)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
|
|
region, _ := h.DERPServer.GenerateRegion()
|
|
derpMap.Regions[region.RegionID] = ®ion
|
|
}
|
|
|
|
return derpMap, nil
|
|
}, backoff.WithBackOff(backoff.NewExponentialBackOff()))
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("failed to build new DERPMap, retrying later")
|
|
continue
|
|
}
|
|
|
|
h.state.SetDERPMap(derpMap)
|
|
|
|
h.Change(change.DERPMap())
|
|
|
|
case records, ok := <-extraRecordsUpdate:
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
h.cfg.TailcfgDNSConfig.ExtraRecords = records
|
|
|
|
h.Change(change.ExtraRecords())
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
|
req any,
|
|
info *grpc.UnaryServerInfo,
|
|
handler grpc.UnaryHandler,
|
|
) (any, error) {
|
|
// Check if the request is coming from the on-server client.
|
|
// This is not secure, but it is to maintain maintainability
|
|
// with the "legacy" database-based client
|
|
// It is also needed for grpc-gateway to be able to connect to
|
|
// the server
|
|
client, _ := peer.FromContext(ctx)
|
|
|
|
log.Trace().
|
|
Caller().
|
|
Str("client_address", client.Addr.String()).
|
|
Msg("Client is trying to authenticate")
|
|
|
|
meta, ok := metadata.FromIncomingContext(ctx)
|
|
if !ok {
|
|
return ctx, status.Errorf(
|
|
codes.InvalidArgument,
|
|
"retrieving metadata",
|
|
)
|
|
}
|
|
|
|
authHeader, ok := meta["authorization"]
|
|
if !ok {
|
|
return ctx, status.Errorf(
|
|
codes.Unauthenticated,
|
|
"authorization token not supplied",
|
|
)
|
|
}
|
|
|
|
token := authHeader[0]
|
|
|
|
if !strings.HasPrefix(token, AuthPrefix) {
|
|
return ctx, status.Error(
|
|
codes.Unauthenticated,
|
|
`missing "Bearer " prefix in "Authorization" header`,
|
|
)
|
|
}
|
|
|
|
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix))
|
|
if err != nil {
|
|
return ctx, status.Error(codes.Internal, "validating token")
|
|
}
|
|
|
|
if !valid {
|
|
log.Info().
|
|
Str("client_address", client.Addr.String()).
|
|
Msg("invalid token")
|
|
|
|
return ctx, status.Error(codes.Unauthenticated, "invalid token")
|
|
}
|
|
|
|
return handler(ctx, req)
|
|
}
|
|
|
|
func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(
|
|
writer http.ResponseWriter,
|
|
req *http.Request,
|
|
) {
|
|
log.Trace().
|
|
Caller().
|
|
Str("client_address", req.RemoteAddr).
|
|
Msg("HTTP authentication invoked")
|
|
|
|
authHeader := req.Header.Get("Authorization")
|
|
|
|
writeUnauthorized := func(statusCode int) {
|
|
writer.WriteHeader(statusCode)
|
|
|
|
if _, err := writer.Write([]byte("Unauthorized")); err != nil { //nolint:noinlineerr
|
|
log.Error().Err(err).Msg("writing HTTP response failed")
|
|
}
|
|
}
|
|
|
|
if !strings.HasPrefix(authHeader, AuthPrefix) {
|
|
log.Error().
|
|
Caller().
|
|
Str("client_address", req.RemoteAddr).
|
|
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
|
writeUnauthorized(http.StatusUnauthorized)
|
|
|
|
return
|
|
}
|
|
|
|
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
|
|
if err != nil {
|
|
log.Info().
|
|
Caller().
|
|
Err(err).
|
|
Str("client_address", req.RemoteAddr).
|
|
Msg("failed to validate token")
|
|
writeUnauthorized(http.StatusUnauthorized)
|
|
|
|
return
|
|
}
|
|
|
|
if !valid {
|
|
log.Info().
|
|
Str("client_address", req.RemoteAddr).
|
|
Msg("invalid token")
|
|
writeUnauthorized(http.StatusUnauthorized)
|
|
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(writer, req)
|
|
})
|
|
}
|
|
|
|
// ensureUnixSocketIsAbsent will check if the given path for headscales unix socket is clear
|
|
// and will remove it if it is not.
|
|
func (h *Headscale) ensureUnixSocketIsAbsent() error {
|
|
// File does not exist, all fine
|
|
if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) { //nolint:noinlineerr
|
|
return nil
|
|
}
|
|
|
|
return os.Remove(h.cfg.UnixSocket)
|
|
}
|
|
|
|
func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *chi.Mux {
|
|
r := chi.NewRouter()
|
|
r.Use(metrics.Collector(metrics.CollectorOpts{
|
|
Host: false,
|
|
Proto: true,
|
|
Skip: func(r *http.Request) bool {
|
|
return r.Method != http.MethodOptions
|
|
},
|
|
}))
|
|
r.Use(middleware.RequestID)
|
|
r.Use(middleware.RealIP)
|
|
r.Use(middleware.RequestLogger(&zerologRequestLogger{}))
|
|
r.Use(middleware.Recoverer)
|
|
|
|
r.Post(ts2021UpgradePath, h.NoiseUpgradeHandler)
|
|
|
|
r.Get("/robots.txt", h.RobotsHandler)
|
|
r.Get("/health", h.HealthHandler)
|
|
r.Get("/version", h.VersionHandler)
|
|
r.Get("/key", h.KeyHandler)
|
|
r.Get("/register/{auth_id}", h.authProvider.RegisterHandler)
|
|
r.Get("/auth/{auth_id}", h.authProvider.AuthHandler)
|
|
|
|
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
|
|
r.Get("/oidc/callback", provider.OIDCCallbackHandler)
|
|
}
|
|
|
|
r.Get("/apple", h.AppleConfigMessage)
|
|
r.Get("/apple/{platform}", h.ApplePlatformConfig)
|
|
r.Get("/windows", h.WindowsConfigMessage)
|
|
|
|
// TODO(kristoffer): move swagger into a package
|
|
r.Get("/swagger", headscale.SwaggerUI)
|
|
r.Get("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1)
|
|
|
|
r.Post("/verify", h.VerifyHandler)
|
|
|
|
if h.cfg.DERP.ServerEnabled {
|
|
r.HandleFunc("/derp", h.DERPServer.DERPHandler)
|
|
r.HandleFunc("/derp/probe", derpServer.DERPProbeHandler)
|
|
r.HandleFunc("/derp/latency-check", derpServer.DERPProbeHandler)
|
|
r.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.state.DERPMap()))
|
|
}
|
|
|
|
r.Route("/api", func(r chi.Router) {
|
|
r.Use(h.httpAuthenticationMiddleware)
|
|
r.HandleFunc("/v1/*", grpcMux.ServeHTTP)
|
|
})
|
|
r.Get("/favicon.ico", FaviconHandler)
|
|
r.Get("/", BlankHandler)
|
|
|
|
return r
|
|
}
|
|
|
|
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
|
//
|
|
//nolint:gocyclo // complex server startup function
|
|
func (h *Headscale) Serve() error {
|
|
var err error
|
|
|
|
capver.CanOldCodeBeCleanedUp()
|
|
|
|
if profilingEnabled {
|
|
if profilingPath != "" {
|
|
err = os.MkdirAll(profilingPath, os.ModePerm)
|
|
if err != nil {
|
|
log.Fatal().Err(err).Msg("failed to create profiling directory")
|
|
}
|
|
|
|
defer profile.Start(profile.ProfilePath(profilingPath)).Stop()
|
|
} else {
|
|
defer profile.Start().Stop()
|
|
}
|
|
}
|
|
|
|
if dumpConfig {
|
|
spew.Dump(h.cfg)
|
|
}
|
|
|
|
versionInfo := types.GetVersionInfo()
|
|
log.Info().Str("version", versionInfo.Version).Str("commit", versionInfo.Commit).Msg("starting headscale")
|
|
log.Info().
|
|
Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)).
|
|
Msg("Clients with a lower minimum version will be rejected")
|
|
|
|
h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state)
|
|
|
|
h.mapBatcher.Start()
|
|
defer h.mapBatcher.Close()
|
|
|
|
if h.cfg.DERP.ServerEnabled {
|
|
// When embedded DERP is enabled we always need a STUN server
|
|
if h.cfg.DERP.STUNAddr == "" {
|
|
return errSTUNAddressNotSet
|
|
}
|
|
|
|
go h.DERPServer.ServeSTUN()
|
|
}
|
|
|
|
derpMap, err := derp.GetDERPMap(h.cfg.DERP)
|
|
if err != nil {
|
|
return fmt.Errorf("getting DERPMap: %w", err)
|
|
}
|
|
|
|
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
|
|
region, _ := h.DERPServer.GenerateRegion()
|
|
derpMap.Regions[region.RegionID] = ®ion
|
|
}
|
|
|
|
if len(derpMap.Regions) == 0 {
|
|
return errEmptyInitialDERPMap
|
|
}
|
|
|
|
h.state.SetDERPMap(derpMap)
|
|
|
|
// Start ephemeral node garbage collector and schedule all nodes
|
|
// that are already in the database and ephemeral. If they are still
|
|
// around between restarts, they will reconnect and the GC will
|
|
// be cancelled.
|
|
go h.ephemeralGC.Start()
|
|
|
|
ephmNodes := h.state.ListEphemeralNodes()
|
|
for _, node := range ephmNodes.All() {
|
|
h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
|
|
}
|
|
|
|
if h.cfg.DNSConfig.ExtraRecordsPath != "" {
|
|
h.extraRecordMan, err = dns.NewExtraRecordsManager(h.cfg.DNSConfig.ExtraRecordsPath)
|
|
if err != nil {
|
|
return fmt.Errorf("setting up extrarecord manager: %w", err)
|
|
}
|
|
|
|
h.cfg.TailcfgDNSConfig.ExtraRecords = h.extraRecordMan.Records()
|
|
|
|
go h.extraRecordMan.Run()
|
|
defer h.extraRecordMan.Close()
|
|
}
|
|
|
|
// Start all scheduled tasks, e.g. expiring nodes, derp updates and
|
|
// records updates
|
|
scheduleCtx, scheduleCancel := context.WithCancel(context.Background())
|
|
defer scheduleCancel()
|
|
|
|
go h.scheduledTasks(scheduleCtx)
|
|
|
|
if zl.GlobalLevel() == zl.TraceLevel {
|
|
zerolog.RespLog = true
|
|
} else {
|
|
zerolog.RespLog = false
|
|
}
|
|
|
|
// Prepare group for running listeners
|
|
errorGroup := new(errgroup.Group)
|
|
|
|
ctx := context.Background()
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
//
|
|
//
|
|
// Set up LOCAL listeners
|
|
//
|
|
|
|
err = h.ensureUnixSocketIsAbsent()
|
|
if err != nil {
|
|
return fmt.Errorf("removing old socket file: %w", err)
|
|
}
|
|
|
|
socketDir := filepath.Dir(h.cfg.UnixSocket)
|
|
|
|
err = util.EnsureDir(socketDir)
|
|
if err != nil {
|
|
return fmt.Errorf("setting up unix socket: %w", err)
|
|
}
|
|
|
|
socketListener, err := new(net.ListenConfig).Listen(context.Background(), "unix", h.cfg.UnixSocket)
|
|
if err != nil {
|
|
return fmt.Errorf("setting up gRPC socket: %w", err)
|
|
}
|
|
|
|
// Change socket permissions
|
|
if err := os.Chmod(h.cfg.UnixSocket, h.cfg.UnixSocketPermission); err != nil { //nolint:noinlineerr
|
|
return fmt.Errorf("changing gRPC socket permission: %w", err)
|
|
}
|
|
|
|
grpcGatewayMux := grpcRuntime.NewServeMux()
|
|
|
|
// Make the grpc-gateway connect to grpc over socket
|
|
grpcGatewayConn, err := grpc.Dial( //nolint:staticcheck // SA1019: deprecated but supported in 1.x
|
|
h.cfg.UnixSocket,
|
|
[]grpc.DialOption{
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
grpc.WithContextDialer(util.GrpcSocketDialer),
|
|
}...,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("setting up gRPC gateway via socket: %w", err)
|
|
}
|
|
|
|
// Connect to the gRPC server over localhost to skip
|
|
// the authentication.
|
|
err = v1.RegisterHeadscaleServiceHandler(ctx, grpcGatewayMux, grpcGatewayConn)
|
|
if err != nil {
|
|
return fmt.Errorf("registering Headscale API service to gRPC: %w", err)
|
|
}
|
|
|
|
// Start the local gRPC server without TLS and without authentication
|
|
grpcSocket := grpc.NewServer(
|
|
// Uncomment to debug grpc communication.
|
|
// zerolog.UnaryInterceptor(),
|
|
)
|
|
|
|
v1.RegisterHeadscaleServiceServer(grpcSocket, newHeadscaleV1APIServer(h))
|
|
reflection.Register(grpcSocket)
|
|
|
|
errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) })
|
|
|
|
//
|
|
//
|
|
// Set up REMOTE listeners
|
|
//
|
|
|
|
tlsConfig, err := h.getTLSSettings()
|
|
if err != nil {
|
|
return fmt.Errorf("configuring TLS settings: %w", err)
|
|
}
|
|
|
|
//
|
|
//
|
|
// gRPC setup
|
|
//
|
|
|
|
// We are sadly not able to run gRPC and HTTPS (2.0) on the same
|
|
// port because the connection mux does not support matching them
|
|
// since they are so similar. There is multiple issues open and we
|
|
// can revisit this if changes:
|
|
// https://github.com/soheilhy/cmux/issues/68
|
|
// https://github.com/soheilhy/cmux/issues/91
|
|
|
|
var (
|
|
grpcServer *grpc.Server
|
|
grpcListener net.Listener
|
|
)
|
|
|
|
if tlsConfig != nil || h.cfg.GRPCAllowInsecure {
|
|
log.Info().Msgf("enabling remote gRPC at %s", h.cfg.GRPCAddr)
|
|
|
|
grpcOptions := []grpc.ServerOption{
|
|
grpc.ChainUnaryInterceptor(
|
|
h.grpcAuthenticationInterceptor,
|
|
// Uncomment to debug grpc communication.
|
|
// zerolog.NewUnaryServerInterceptor(),
|
|
),
|
|
}
|
|
|
|
if tlsConfig != nil {
|
|
grpcOptions = append(grpcOptions,
|
|
grpc.Creds(credentials.NewTLS(tlsConfig)),
|
|
)
|
|
} else {
|
|
log.Warn().Msg("gRPC is running without security")
|
|
}
|
|
|
|
grpcServer = grpc.NewServer(grpcOptions...)
|
|
|
|
v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h))
|
|
reflection.Register(grpcServer)
|
|
|
|
grpcListener, err = new(net.ListenConfig).Listen(context.Background(), "tcp", h.cfg.GRPCAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("binding to TCP address: %w", err)
|
|
}
|
|
|
|
errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
|
|
|
|
log.Info().
|
|
Msgf("listening and serving gRPC on: %s", h.cfg.GRPCAddr)
|
|
}
|
|
|
|
//
|
|
//
|
|
// HTTP setup
|
|
//
|
|
// This is the regular router that we expose
|
|
// over our main Addr
|
|
router := h.createRouter(grpcGatewayMux)
|
|
|
|
httpServer := &http.Server{
|
|
Addr: h.cfg.Addr,
|
|
Handler: router,
|
|
ReadTimeout: types.HTTPTimeout,
|
|
|
|
// Long polling should not have any timeout, this is overridden
|
|
// further down the chain
|
|
WriteTimeout: types.HTTPTimeout,
|
|
}
|
|
|
|
var httpListener net.Listener
|
|
|
|
if tlsConfig != nil {
|
|
httpServer.TLSConfig = tlsConfig
|
|
httpListener, err = tls.Listen("tcp", h.cfg.Addr, tlsConfig)
|
|
} else {
|
|
httpListener, err = new(net.ListenConfig).Listen(context.Background(), "tcp", h.cfg.Addr)
|
|
}
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("binding to TCP address: %w", err)
|
|
}
|
|
|
|
errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
|
|
|
|
log.Info().
|
|
Msgf("listening and serving HTTP on: %s", h.cfg.Addr)
|
|
|
|
// Only start debug/metrics server if address is configured
|
|
var debugHTTPServer *http.Server
|
|
|
|
var debugHTTPListener net.Listener
|
|
|
|
if h.cfg.MetricsAddr != "" {
|
|
debugHTTPListener, err = (&net.ListenConfig{}).Listen(ctx, "tcp", h.cfg.MetricsAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("binding to TCP address: %w", err)
|
|
}
|
|
|
|
debugHTTPServer = h.debugHTTPServer()
|
|
|
|
errorGroup.Go(func() error { return debugHTTPServer.Serve(debugHTTPListener) })
|
|
|
|
log.Info().
|
|
Msgf("listening and serving debug and metrics on: %s", h.cfg.MetricsAddr)
|
|
} else {
|
|
log.Info().Msg("metrics server disabled (metrics_listen_addr is empty)")
|
|
}
|
|
|
|
var tailsqlContext context.Context
|
|
|
|
if tailsqlEnabled {
|
|
if h.cfg.Database.Type != types.DatabaseSqlite {
|
|
//nolint:gocritic // exitAfterDefer: Fatal exits during initialization before servers start
|
|
log.Fatal().
|
|
Str("type", h.cfg.Database.Type).
|
|
Msgf("tailsql only support %q", types.DatabaseSqlite)
|
|
}
|
|
|
|
if tailsqlTSKey == "" {
|
|
//nolint:gocritic // exitAfterDefer: Fatal exits during initialization before servers start
|
|
log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set")
|
|
}
|
|
|
|
tailsqlContext = context.Background()
|
|
|
|
go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path) //nolint:errcheck
|
|
}
|
|
|
|
// Handle common process-killing signals so we can gracefully shut down:
|
|
sigc := make(chan os.Signal, 1)
|
|
signal.Notify(sigc,
|
|
syscall.SIGHUP,
|
|
syscall.SIGINT,
|
|
syscall.SIGTERM,
|
|
syscall.SIGQUIT,
|
|
syscall.SIGHUP)
|
|
|
|
sigFunc := func(c chan os.Signal) {
|
|
// Wait for a SIGINT or SIGKILL:
|
|
for {
|
|
sig := <-c
|
|
switch sig {
|
|
case syscall.SIGHUP:
|
|
log.Info().
|
|
Str("signal", sig.String()).
|
|
Msg("Received SIGHUP, reloading ACL policy")
|
|
|
|
if h.cfg.Policy.IsEmpty() {
|
|
continue
|
|
}
|
|
|
|
changes, err := h.state.ReloadPolicy()
|
|
if err != nil {
|
|
log.Error().Err(err).Msgf("reloading policy")
|
|
continue
|
|
}
|
|
|
|
h.Change(changes...)
|
|
|
|
default:
|
|
info := func(msg string) { log.Info().Msg(msg) }
|
|
|
|
log.Info().
|
|
Str("signal", sig.String()).
|
|
Msg("Received signal to stop, shutting down gracefully")
|
|
|
|
scheduleCancel()
|
|
h.ephemeralGC.Close()
|
|
|
|
// Gracefully shut down servers
|
|
shutdownCtx, cancel := context.WithTimeout(
|
|
context.WithoutCancel(ctx),
|
|
types.HTTPShutdownTimeout,
|
|
)
|
|
defer cancel()
|
|
|
|
if debugHTTPServer != nil {
|
|
info("shutting down debug http server")
|
|
|
|
err := debugHTTPServer.Shutdown(shutdownCtx)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("failed to shutdown prometheus http")
|
|
}
|
|
}
|
|
|
|
info("shutting down main http server")
|
|
|
|
err := httpServer.Shutdown(shutdownCtx)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("failed to shutdown http")
|
|
}
|
|
|
|
info("closing batcher")
|
|
h.mapBatcher.Close()
|
|
|
|
info("waiting for netmap stream to close")
|
|
h.clientStreamsOpen.Wait()
|
|
|
|
info("shutting down grpc server (socket)")
|
|
grpcSocket.GracefulStop()
|
|
|
|
if grpcServer != nil {
|
|
info("shutting down grpc server (external)")
|
|
grpcServer.GracefulStop()
|
|
grpcListener.Close()
|
|
}
|
|
|
|
if tailsqlContext != nil {
|
|
info("shutting down tailsql")
|
|
tailsqlContext.Done()
|
|
}
|
|
|
|
// Close network listeners
|
|
info("closing network listeners")
|
|
|
|
if debugHTTPListener != nil {
|
|
debugHTTPListener.Close()
|
|
}
|
|
|
|
httpListener.Close()
|
|
grpcGatewayConn.Close()
|
|
|
|
// Stop listening (and unlink the socket if unix type):
|
|
info("closing socket listener")
|
|
socketListener.Close()
|
|
|
|
// Close state connections
|
|
info("closing state and database")
|
|
|
|
err = h.state.Close()
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("failed to close state")
|
|
}
|
|
|
|
log.Info().
|
|
Msg("Headscale stopped")
|
|
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
errorGroup.Go(func() error {
|
|
sigFunc(sigc)
|
|
|
|
return nil
|
|
})
|
|
|
|
return errorGroup.Wait()
|
|
}
|
|
|
|
func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
|
var err error
|
|
|
|
if h.cfg.TLS.LetsEncrypt.Hostname != "" {
|
|
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
|
log.Warn().
|
|
Msg("Listening with TLS but ServerURL does not start with https://")
|
|
}
|
|
|
|
certManager := autocert.Manager{
|
|
Prompt: autocert.AcceptTOS,
|
|
HostPolicy: autocert.HostWhitelist(h.cfg.TLS.LetsEncrypt.Hostname),
|
|
Cache: autocert.DirCache(h.cfg.TLS.LetsEncrypt.CacheDir),
|
|
Client: &acme.Client{
|
|
DirectoryURL: h.cfg.ACMEURL,
|
|
HTTPClient: &http.Client{
|
|
Transport: &acmeLogger{
|
|
rt: http.DefaultTransport,
|
|
},
|
|
},
|
|
},
|
|
Email: h.cfg.ACMEEmail,
|
|
}
|
|
|
|
switch h.cfg.TLS.LetsEncrypt.ChallengeType {
|
|
case types.TLSALPN01ChallengeType:
|
|
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
|
|
// The RFC requires that the validation is done on port 443; in other words, headscale
|
|
// must be reachable on port 443.
|
|
return certManager.TLSConfig(), nil
|
|
|
|
case types.HTTP01ChallengeType:
|
|
// Configuration via autocert with HTTP-01. This requires listening on
|
|
// port 80 for the certificate validation in addition to the headscale
|
|
// service, which can be configured to run on any other port.
|
|
server := &http.Server{
|
|
Addr: h.cfg.TLS.LetsEncrypt.Listen,
|
|
Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)),
|
|
ReadTimeout: types.HTTPTimeout,
|
|
}
|
|
|
|
go func() {
|
|
err := server.ListenAndServe()
|
|
log.Fatal().
|
|
Caller().
|
|
Err(err).
|
|
Msg("failed to set up a HTTP server")
|
|
}()
|
|
|
|
return certManager.TLSConfig(), nil
|
|
|
|
default:
|
|
return nil, errUnsupportedLetsEncryptChallengeType
|
|
}
|
|
} else if h.cfg.TLS.CertPath == "" {
|
|
if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
|
|
log.Warn().Msg("listening without TLS but ServerURL does not start with http://")
|
|
}
|
|
|
|
return nil, err
|
|
} else {
|
|
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
|
log.Warn().Msg("listening with TLS but ServerURL does not start with https://")
|
|
}
|
|
|
|
tlsConfig := &tls.Config{
|
|
NextProtos: []string{"http/1.1"},
|
|
Certificates: make([]tls.Certificate, 1),
|
|
MinVersion: tls.VersionTLS12,
|
|
}
|
|
|
|
tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(h.cfg.TLS.CertPath, h.cfg.TLS.KeyPath)
|
|
|
|
return tlsConfig, err
|
|
}
|
|
}
|
|
|
|
func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
|
dir := filepath.Dir(path)
|
|
|
|
err := util.EnsureDir(dir)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ensuring private key directory: %w", err)
|
|
}
|
|
|
|
privateKey, err := os.ReadFile(path)
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
log.Info().Str("path", path).Msg("no private key file at path, creating...")
|
|
|
|
machineKey := key.NewMachine()
|
|
|
|
machineKeyStr, err := machineKey.MarshalText()
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"converting private key to string for saving: %w",
|
|
err,
|
|
)
|
|
}
|
|
|
|
err = os.WriteFile(path, machineKeyStr, privateKeyFileMode)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"saving private key to disk at path %q: %w",
|
|
path,
|
|
err,
|
|
)
|
|
}
|
|
|
|
return &machineKey, nil
|
|
} else if err != nil {
|
|
return nil, fmt.Errorf("reading private key file: %w", err)
|
|
}
|
|
|
|
trimmedPrivateKey := strings.TrimSpace(string(privateKey))
|
|
|
|
var machineKey key.MachinePrivate
|
|
if err = machineKey.UnmarshalText([]byte(trimmedPrivateKey)); err != nil { //nolint:noinlineerr
|
|
return nil, fmt.Errorf("parsing private key: %w", err)
|
|
}
|
|
|
|
return &machineKey, nil
|
|
}
|
|
|
|
// Change is used to send changes to nodes.
|
|
// All change should be enqueued here and empty will be automatically
|
|
// ignored.
|
|
func (h *Headscale) Change(cs ...change.Change) {
|
|
h.mapBatcher.AddWork(cs...)
|
|
}
|
|
|
|
// Provide some middleware that can inspect the ACME/autocert https calls
|
|
// and log when things are failing.
|
|
type acmeLogger struct {
|
|
rt http.RoundTripper
|
|
}
|
|
|
|
// RoundTrip will log when ACME/autocert failures happen either when err != nil OR
|
|
// when http status codes indicate a failure has occurred.
|
|
func (l *acmeLogger) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
resp, err := l.rt.RoundTrip(req)
|
|
if err != nil {
|
|
log.Error().Err(err).Str("url", req.URL.String()).Msg("acme request failed")
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode >= http.StatusBadRequest {
|
|
defer resp.Body.Close()
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
log.Error().Int("status_code", resp.StatusCode).Str("url", req.URL.String()).Bytes("body", body).Msg("acme request returned error")
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
// zerologRequestLogger implements chi's middleware.LogFormatter
|
|
// to route HTTP request logs through zerolog.
|
|
type zerologRequestLogger struct{}
|
|
|
|
func (z *zerologRequestLogger) NewLogEntry(
|
|
r *http.Request,
|
|
) middleware.LogEntry {
|
|
return &zerologLogEntry{
|
|
method: r.Method,
|
|
path: r.URL.Path,
|
|
proto: r.Proto,
|
|
remote: r.RemoteAddr,
|
|
}
|
|
}
|
|
|
|
type zerologLogEntry struct {
|
|
method string
|
|
path string
|
|
proto string
|
|
remote string
|
|
}
|
|
|
|
func (e *zerologLogEntry) Write(
|
|
status, bytes int,
|
|
header http.Header,
|
|
elapsed time.Duration,
|
|
extra any,
|
|
) {
|
|
log.Info().
|
|
Str("method", e.method).
|
|
Str("path", e.path).
|
|
Str("proto", e.proto).
|
|
Str("remote", e.remote).
|
|
Int("status", status).
|
|
Int("bytes", bytes).
|
|
Dur("elapsed", elapsed).
|
|
Msg("http request")
|
|
}
|
|
|
|
func (e *zerologLogEntry) Panic(
|
|
v any,
|
|
stack []byte,
|
|
) {
|
|
log.Error().
|
|
Interface("panic", v).
|
|
Bytes("stack", stack).
|
|
Msg("http handler panic")
|
|
}
|