Files
headscale/hscontrol/db/text_serialiser.go
Kristoffer Dalby 3acce2da87 errors: rewrite errors to follow go best practices
Errors should not start capitalised and they should not contain the word error
or state that they "failed" as we already know it is an error

Signed-off-by: Kristoffer Dalby <kristoffer@dalby.cc>
2026-02-06 07:40:29 +01:00

102 lines
2.7 KiB
Go

package db
import (
"context"
"encoding"
"fmt"
"reflect"
"gorm.io/gorm/schema"
)
// Got from https://github.com/xdg-go/strum/blob/main/types.go
var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()
func isTextUnmarshaler(rv reflect.Value) bool {
return rv.Type().Implements(textUnmarshalerType)
}
func maybeInstantiatePtr(rv reflect.Value) {
if rv.Kind() == reflect.Ptr && rv.IsNil() {
np := reflect.New(rv.Type().Elem())
rv.Set(np)
}
}
func decodingError(name string, err error) error {
return fmt.Errorf("decoding to %s: %w", name, err)
}
// TextSerialiser implements the Serialiser interface for fields that
// have a type that implements encoding.TextUnmarshaler.
type TextSerialiser struct{}
func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue any) error {
fieldValue := reflect.New(field.FieldType)
// If the field is a pointer, we need to dereference it to get the actual type
// so we do not end with a second pointer.
if fieldValue.Elem().Kind() == reflect.Ptr {
fieldValue = fieldValue.Elem()
}
if dbValue != nil {
var bytes []byte
switch v := dbValue.(type) {
case []byte:
bytes = v
case string:
bytes = []byte(v)
default:
return fmt.Errorf("unmarshalling text value: %#v", dbValue)
}
if isTextUnmarshaler(fieldValue) {
maybeInstantiatePtr(fieldValue)
f := fieldValue.MethodByName("UnmarshalText")
args := []reflect.Value{reflect.ValueOf(bytes)}
ret := f.Call(args)
if !ret[0].IsNil() {
return decodingError(field.Name, ret[0].Interface().(error))
}
// If the underlying field is to a pointer type, we need to
// assign the value as a pointer to it.
// If it is not a pointer, we need to assign the value to the
// field.
dstField := field.ReflectValueOf(ctx, dst)
if dstField.Kind() == reflect.Ptr {
dstField.Set(fieldValue)
} else {
dstField.Set(fieldValue.Elem())
}
return nil
} else {
return fmt.Errorf("unsupported type: %T", fieldValue.Interface())
}
}
return nil
}
func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue any) (any, error) {
switch v := fieldValue.(type) {
case encoding.TextMarshaler:
// If the value is nil, we return nil, however, go nil values are not
// always comparable, particularly when reflection is involved:
// https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) {
return nil, nil
}
b, err := v.MarshalText()
if err != nil {
return nil, err
}
return string(b), nil
default:
return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v)
}
}