Files
headscale/hscontrol/db/ip.go
Kristoffer Dalby 3acce2da87 errors: rewrite errors to follow go best practices
Errors should not start capitalised and they should not contain the word error
or state that they "failed" as we already know it is an error

Signed-off-by: Kristoffer Dalby <kristoffer@dalby.cc>
2026-02-06 07:40:29 +01:00

353 lines
8.9 KiB
Go

package db
import (
"crypto/rand"
"database/sql"
"errors"
"fmt"
"math/big"
"net/netip"
"sync"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"go4.org/netipx"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
)
var errGeneratedIPBytesInvalid = errors.New("generated ip bytes are invalid ip")
// IPAllocator is a singleton responsible for allocating
// IP addresses for nodes and making sure the same
// address is not handed out twice. There can only be one
// and it needs to be created before any other database
// writes occur.
type IPAllocator struct {
mu sync.Mutex
prefix4 *netip.Prefix
prefix6 *netip.Prefix
// Previous IPs handed out
prev4 netip.Addr
prev6 netip.Addr
// strategy used for handing out IP addresses.
strategy types.IPAllocationStrategy
// Set of all IPs handed out.
// This might not be in sync with the database,
// but it is more conservative. If saves to the
// database fails, the IP will be allocated here
// until the next restart of Headscale.
usedIPs netipx.IPSetBuilder
}
// NewIPAllocator returns a new IPAllocator singleton which
// can be used to hand out unique IP addresses within the
// provided IPv4 and IPv6 prefix. It needs to be created
// when headscale starts and needs to finish its read
// transaction before any writes to the database occur.
func NewIPAllocator(
db *HSDatabase,
prefix4, prefix6 *netip.Prefix,
strategy types.IPAllocationStrategy,
) (*IPAllocator, error) {
ret := IPAllocator{
prefix4: prefix4,
prefix6: prefix6,
strategy: strategy,
}
var v4s []sql.NullString
var v6s []sql.NullString
if db != nil {
err := db.Read(func(rx *gorm.DB) error {
return rx.Model(&types.Node{}).Pluck("ipv4", &v4s).Error
})
if err != nil {
return nil, fmt.Errorf("reading IPv4 addresses from database: %w", err)
}
err = db.Read(func(rx *gorm.DB) error {
return rx.Model(&types.Node{}).Pluck("ipv6", &v6s).Error
})
if err != nil {
return nil, fmt.Errorf("reading IPv6 addresses from database: %w", err)
}
}
var ips netipx.IPSetBuilder
// Add network and broadcast addrs to used pool so they
// are not handed out to nodes.
if prefix4 != nil {
network4, broadcast4 := util.GetIPPrefixEndpoints(*prefix4)
ips.Add(network4)
ips.Add(broadcast4)
// Use network as starting point, it will be used to call .Next()
// TODO(kradalby): Could potentially take all the IPs loaded from
// the database into account to start at a more "educated" location.
ret.prev4 = network4
}
if prefix6 != nil {
network6, broadcast6 := util.GetIPPrefixEndpoints(*prefix6)
ips.Add(network6)
ips.Add(broadcast6)
ret.prev6 = network6
}
// Fetch all the IP Addresses currently handed out from the Database
// and add them to the used IP set.
for _, addrStr := range append(v4s, v6s...) {
if addrStr.Valid {
addr, err := netip.ParseAddr(addrStr.String)
if err != nil {
return nil, fmt.Errorf("parsing IP address from database: %w", err)
}
ips.Add(addr)
}
}
// Build the initial IPSet to validate that we can use it.
_, err := ips.IPSet()
if err != nil {
return nil, fmt.Errorf(
"building initial IP Set: %w",
err,
)
}
ret.usedIPs = ips
return &ret, nil
}
func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) {
i.mu.Lock()
defer i.mu.Unlock()
var err error
var ret4 *netip.Addr
var ret6 *netip.Addr
if i.prefix4 != nil {
ret4, err = i.next(i.prev4, i.prefix4)
if err != nil {
return nil, nil, fmt.Errorf("allocating IPv4 address: %w", err)
}
i.prev4 = *ret4
}
if i.prefix6 != nil {
ret6, err = i.next(i.prev6, i.prefix6)
if err != nil {
return nil, nil, fmt.Errorf("allocating IPv6 address: %w", err)
}
i.prev6 = *ret6
}
return ret4, ret6, nil
}
var ErrCouldNotAllocateIP = errors.New("failed to allocate IP")
func (i *IPAllocator) nextLocked(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) {
i.mu.Lock()
defer i.mu.Unlock()
return i.next(prev, prefix)
}
func (i *IPAllocator) next(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) {
var err error
var ip netip.Addr
switch i.strategy {
case types.IPAllocationStrategySequential:
// Get the first IP in our prefix
ip = prev.Next()
case types.IPAllocationStrategyRandom:
ip, err = randomNext(*prefix)
if err != nil {
return nil, fmt.Errorf("getting random IP: %w", err)
}
}
// TODO(kradalby): maybe this can be done less often.
set, err := i.usedIPs.IPSet()
if err != nil {
return nil, err
}
for {
if !prefix.Contains(ip) {
return nil, ErrCouldNotAllocateIP
}
// Check if the IP has already been allocated
// or if it is a IP reserved by Tailscale.
if set.Contains(ip) || isTailscaleReservedIP(ip) {
switch i.strategy {
case types.IPAllocationStrategySequential:
ip = ip.Next()
case types.IPAllocationStrategyRandom:
ip, err = randomNext(*prefix)
if err != nil {
return nil, fmt.Errorf("getting random IP: %w", err)
}
}
continue
}
i.usedIPs.Add(ip)
return &ip, nil
}
}
func randomNext(pfx netip.Prefix) (netip.Addr, error) {
rang := netipx.RangeOfPrefix(pfx)
fromIP, toIP := rang.From(), rang.To()
var from, to big.Int
from.SetBytes(fromIP.AsSlice())
to.SetBytes(toIP.AsSlice())
// Find the max, this is how we can do "random range",
// get the "max" as 0 -> to - from and then add back from
// after.
tempMax := big.NewInt(0).Sub(&to, &from)
out, err := rand.Int(rand.Reader, tempMax)
if err != nil {
return netip.Addr{}, fmt.Errorf("generating random IP: %w", err)
}
valInRange := big.NewInt(0).Add(&from, out)
ip, ok := netip.AddrFromSlice(valInRange.Bytes())
if !ok {
return netip.Addr{}, errGeneratedIPBytesInvalid
}
if !pfx.Contains(ip) {
return netip.Addr{}, fmt.Errorf(
"generated ip(%s) not in prefix(%s)",
ip.String(),
pfx.String(),
)
}
return ip, nil
}
func isTailscaleReservedIP(ip netip.Addr) bool {
return tsaddr.ChromeOSVMRange().Contains(ip) ||
tsaddr.TailscaleServiceIP() == ip ||
tsaddr.TailscaleServiceIPv6() == ip
}
// BackfillNodeIPs will take a database transaction, and
// iterate through all of the current nodes in headscale
// and ensure it has IP addresses according to the current
// configuration.
// This means that if both IPv4 and IPv6 is set in the
// config, and some nodes are missing that type of IP,
// it will be added.
// If a prefix type has been removed (IPv4 or IPv6), it
// will remove the IPs in that family from the node.
func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
var err error
var ret []string
err = db.Write(func(tx *gorm.DB) error {
if i == nil {
return errors.New("backfilling IPs: ip allocator was nil")
}
log.Trace().Caller().Msgf("starting to backfill IPs")
nodes, err := ListNodes(tx)
if err != nil {
return fmt.Errorf("listing nodes to backfill IPs: %w", err)
}
for _, node := range nodes {
log.Trace().Caller().EmbedObject(node).Msg("ip backfill check started because node found in database")
changed := false
// IPv4 prefix is set, but node ip is missing, alloc
if i.prefix4 != nil && node.IPv4 == nil {
ret4, err := i.nextLocked(i.prev4, i.prefix4)
if err != nil {
return fmt.Errorf("allocating IPv4 for node(%d): %w", node.ID, err)
}
node.IPv4 = ret4
changed = true
ret = append(ret, fmt.Sprintf("assigned IPv4 %q to Node(%d) %q", ret4.String(), node.ID, node.Hostname))
}
// IPv6 prefix is set, but node ip is missing, alloc
if i.prefix6 != nil && node.IPv6 == nil {
ret6, err := i.nextLocked(i.prev6, i.prefix6)
if err != nil {
return fmt.Errorf("allocating IPv6 for node(%d): %w", node.ID, err)
}
node.IPv6 = ret6
changed = true
ret = append(ret, fmt.Sprintf("assigned IPv6 %q to Node(%d) %q", ret6.String(), node.ID, node.Hostname))
}
// IPv4 prefix is not set, but node has IP, remove
if i.prefix4 == nil && node.IPv4 != nil {
ret = append(ret, fmt.Sprintf("removing IPv4 %q from Node(%d) %q", node.IPv4.String(), node.ID, node.Hostname))
node.IPv4 = nil
changed = true
}
// IPv6 prefix is not set, but node has IP, remove
if i.prefix6 == nil && node.IPv6 != nil {
ret = append(ret, fmt.Sprintf("removing IPv6 %q from Node(%d) %q", node.IPv6.String(), node.ID, node.Hostname))
node.IPv6 = nil
changed = true
}
if changed {
// Use Updates() with Select() to only update IP fields, avoiding overwriting
// other fields like Expiry. We need Select() because Updates() alone skips
// zero values, but we DO want to update IPv4/IPv6 to nil when removing them.
// See issue #2862.
err := tx.Model(node).Select("ipv4", "ipv6").Updates(node).Error
if err != nil {
return fmt.Errorf("saving node(%d) after adding IPs: %w", node.ID, err)
}
}
}
return nil
})
return ret, err
}
func (i *IPAllocator) FreeIPs(ips []netip.Addr) {
i.mu.Lock()
defer i.mu.Unlock()
for _, ip := range ips {
i.usedIPs.Remove(ip)
}
}