mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-19 15:01:22 +02:00
improved deserialization method
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -30,6 +31,28 @@ var (
|
||||
ErrUnknownField = E.New("unknown field")
|
||||
)
|
||||
|
||||
var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
|
||||
|
||||
func RegisterDefaultValueFactory[T any](factory func() *T) {
|
||||
t := reflect.TypeFor[T]()
|
||||
if t.Kind() == reflect.Ptr {
|
||||
panic("pointer of pointer")
|
||||
}
|
||||
if defaultValues.Has(t) {
|
||||
panic("default value for " + t.String() + " already registered")
|
||||
}
|
||||
defaultValues.Store(t, func() any { return factory() })
|
||||
}
|
||||
|
||||
func New(t reflect.Type) reflect.Value {
|
||||
if dv, ok := defaultValues.Load(t); ok {
|
||||
logging.Debug().Str("type", t.String()).Msg("using default value")
|
||||
return reflect.ValueOf(dv())
|
||||
}
|
||||
logging.Debug().Str("type", t.String()).Msg("using zero value")
|
||||
return reflect.New(t)
|
||||
}
|
||||
|
||||
// Serialize converts the given data into a map[string]any representation.
|
||||
//
|
||||
// It uses reflection to inspect the data type and handle different kinds of data.
|
||||
@@ -150,7 +173,7 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
||||
for dstT.Kind() == reflect.Ptr {
|
||||
if dstV.IsNil() {
|
||||
if dstV.CanSet() {
|
||||
dstV.Set(reflect.New(dstT.Elem()))
|
||||
dstV.Set(New(dstT.Elem()))
|
||||
} else {
|
||||
return E.Errorf("deserialize: dst is %w", ErrNilValue)
|
||||
}
|
||||
@@ -214,12 +237,8 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
||||
if e.Param() != "" {
|
||||
detail += ":" + e.Param()
|
||||
}
|
||||
fieldName, ok := fieldName[e.Field()]
|
||||
if !ok {
|
||||
fieldName = e.Field()
|
||||
}
|
||||
errs.Add(ErrValidationError.
|
||||
Subject(fieldName).
|
||||
Subject(e.StructNamespace()).
|
||||
Withf("require %q", detail))
|
||||
}
|
||||
}
|
||||
@@ -230,12 +249,14 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
||||
dstV.Set(reflect.MakeMap(dstT))
|
||||
}
|
||||
for k := range src {
|
||||
tmp := reflect.New(dstT.Elem()).Elem()
|
||||
mapVT := dstT.Elem()
|
||||
tmp := New(mapVT).Elem()
|
||||
err := Convert(reflect.ValueOf(src[k]), tmp)
|
||||
if err != nil {
|
||||
if err == nil {
|
||||
dstV.SetMapIndex(reflect.ValueOf(strutils.ToLowerNoSnake(k)), tmp)
|
||||
} else {
|
||||
errs.Add(err.Subject(k))
|
||||
}
|
||||
dstV.SetMapIndex(reflect.ValueOf(strutils.ToLowerNoSnake(k)), tmp)
|
||||
}
|
||||
return errs.Error()
|
||||
default:
|
||||
@@ -243,6 +264,10 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
||||
}
|
||||
}
|
||||
|
||||
func isIntFloat(t reflect.Kind) bool {
|
||||
return t >= reflect.Bool && t <= reflect.Float64
|
||||
}
|
||||
|
||||
// Convert attempts to convert the src to dst.
|
||||
//
|
||||
// If src is a map, it is deserialized into dst.
|
||||
@@ -270,20 +295,41 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error {
|
||||
|
||||
if dst.Kind() == reflect.Pointer {
|
||||
if dst.IsNil() {
|
||||
dst.Set(reflect.New(dstT.Elem()))
|
||||
dst.Set(New(dstT.Elem()))
|
||||
}
|
||||
dst = dst.Elem()
|
||||
dstT = dst.Type()
|
||||
}
|
||||
|
||||
srcKind := srcT.Kind()
|
||||
|
||||
switch {
|
||||
case srcT.AssignableTo(dstT):
|
||||
dst.Set(src)
|
||||
return nil
|
||||
case srcT.ConvertibleTo(dstT):
|
||||
dst.Set(src.Convert(dstT))
|
||||
return nil
|
||||
case srcT.Kind() == reflect.Map:
|
||||
// case srcT.ConvertibleTo(dstT):
|
||||
// dst.Set(src.Convert(dstT))
|
||||
// return nil
|
||||
case srcKind == reflect.String:
|
||||
if convertible, err := ConvertString(src.String(), dst); convertible {
|
||||
return err
|
||||
}
|
||||
case isIntFloat(srcKind):
|
||||
var strV string
|
||||
switch {
|
||||
case src.CanInt():
|
||||
strV = strconv.FormatInt(src.Int(), 10)
|
||||
case srcKind == reflect.Bool:
|
||||
strV = strconv.FormatBool(src.Bool())
|
||||
case src.CanUint():
|
||||
strV = strconv.FormatUint(src.Uint(), 10)
|
||||
case src.CanFloat():
|
||||
strV = strconv.FormatFloat(src.Float(), 'f', -1, 64)
|
||||
}
|
||||
if convertible, err := ConvertString(strV, dst); convertible {
|
||||
return err
|
||||
}
|
||||
case srcKind == reflect.Map:
|
||||
if src.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -292,7 +338,7 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error {
|
||||
return ErrUnsupportedConversion.Subject(dstT.String() + " to " + srcT.String())
|
||||
}
|
||||
return Deserialize(obj, dst.Addr().Interface())
|
||||
case srcT.Kind() == reflect.Slice:
|
||||
case srcKind == reflect.Slice:
|
||||
if src.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -302,7 +348,7 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error {
|
||||
newSlice := reflect.MakeSlice(dstT, 0, src.Len())
|
||||
i := 0
|
||||
for _, v := range src.Seq2() {
|
||||
tmp := reflect.New(dstT.Elem()).Elem()
|
||||
tmp := New(dstT.Elem()).Elem()
|
||||
err := Convert(v, tmp)
|
||||
if err != nil {
|
||||
return err.Subjectf("[%d]", i)
|
||||
@@ -312,24 +358,16 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error {
|
||||
}
|
||||
dst.Set(newSlice)
|
||||
return nil
|
||||
case src.Kind() == reflect.String:
|
||||
if convertible, err := ConvertString(src.String(), dst); convertible {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// check if (*T).Convertor is implemented
|
||||
if parser, ok := dst.Addr().Interface().(strutils.Parser); ok {
|
||||
return E.From(parser.Parse(src.String()))
|
||||
}
|
||||
return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT)
|
||||
}
|
||||
|
||||
func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.Error) {
|
||||
convertible = true
|
||||
dstT := dst.Type()
|
||||
if dst.Kind() == reflect.Ptr {
|
||||
if dst.IsNil() {
|
||||
dst.Set(reflect.New(dst.Type().Elem()))
|
||||
dst.Set(New(dstT.Elem()))
|
||||
}
|
||||
dst = dst.Elem()
|
||||
}
|
||||
@@ -337,10 +375,10 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E
|
||||
dst.SetString(src)
|
||||
return
|
||||
}
|
||||
switch dst.Type() {
|
||||
switch dstT {
|
||||
case reflect.TypeFor[time.Duration]():
|
||||
if src == "" {
|
||||
dst.Set(reflect.Zero(dst.Type()))
|
||||
dst.Set(reflect.Zero(dstT))
|
||||
return
|
||||
}
|
||||
d, err := time.ParseDuration(src)
|
||||
@@ -357,34 +395,33 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E
|
||||
if err != nil {
|
||||
return true, E.From(err)
|
||||
}
|
||||
dst.Set(reflect.ValueOf(*ipnet))
|
||||
dst.Set(reflect.ValueOf(ipnet).Elem())
|
||||
return
|
||||
default:
|
||||
}
|
||||
// primitive types / simple types
|
||||
switch dst.Kind() {
|
||||
case reflect.Bool:
|
||||
b, err := strconv.ParseBool(src)
|
||||
if dstKind := dst.Kind(); isIntFloat(dstKind) {
|
||||
var i any
|
||||
var err error
|
||||
switch {
|
||||
case dstKind == reflect.Bool:
|
||||
i, err = strconv.ParseBool(src)
|
||||
case dst.CanInt():
|
||||
i, err = strconv.ParseInt(src, 10, dstT.Bits())
|
||||
case dst.CanUint():
|
||||
i, err = strconv.ParseUint(src, 10, dstT.Bits())
|
||||
case dst.CanFloat():
|
||||
i, err = strconv.ParseFloat(src, dstT.Bits())
|
||||
}
|
||||
if err != nil {
|
||||
return true, E.From(err)
|
||||
}
|
||||
dst.Set(reflect.ValueOf(b))
|
||||
return
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
i, err := strconv.ParseInt(src, 10, 64)
|
||||
if err != nil {
|
||||
return true, E.From(err)
|
||||
}
|
||||
dst.Set(reflect.ValueOf(i).Convert(dst.Type()))
|
||||
return
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
i, err := strconv.ParseUint(src, 10, 64)
|
||||
if err != nil {
|
||||
return true, E.From(err)
|
||||
}
|
||||
dst.Set(reflect.ValueOf(i).Convert(dst.Type()))
|
||||
dst.Set(reflect.ValueOf(i).Convert(dstT))
|
||||
return
|
||||
}
|
||||
// check if (*T).Convertor is implemented
|
||||
if parser, ok := dst.Addr().Interface().(strutils.Parser); ok {
|
||||
return true, E.From(parser.Parse(src))
|
||||
}
|
||||
// yaml like
|
||||
lines := []string{}
|
||||
src = strings.TrimSpace(src)
|
||||
@@ -446,10 +483,10 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E
|
||||
}
|
||||
tmp = m
|
||||
}
|
||||
if tmp == nil {
|
||||
return false, nil
|
||||
if tmp != nil {
|
||||
return true, Convert(reflect.ValueOf(tmp), dst)
|
||||
}
|
||||
return true, Convert(reflect.ValueOf(tmp), dst)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func DeserializeYAML[T any](data []byte, target T) E.Error {
|
||||
|
||||
@@ -124,6 +124,7 @@ func TestStringIntConvert(t *testing.T) {
|
||||
|
||||
type testModel struct {
|
||||
Test testType
|
||||
Baz string
|
||||
}
|
||||
|
||||
type testType struct {
|
||||
@@ -146,8 +147,19 @@ func TestConvertor(t *testing.T) {
|
||||
ExpectEqual(t, m.Test.bar, "123")
|
||||
})
|
||||
|
||||
t.Run("int_to_string", func(t *testing.T) {
|
||||
m := new(testModel)
|
||||
ExpectNoError(t, Deserialize(map[string]any{"Test": "123"}, m))
|
||||
|
||||
ExpectEqual(t, m.Test.foo, 123)
|
||||
ExpectEqual(t, m.Test.bar, "123")
|
||||
|
||||
ExpectNoError(t, Deserialize(map[string]any{"Baz": 123}, m))
|
||||
ExpectEqual(t, m.Baz, "123")
|
||||
})
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
m := new(testModel)
|
||||
ExpectError(t, strconv.ErrSyntax, Deserialize(map[string]any{"Test": 123}, m))
|
||||
ExpectError(t, ErrUnsupportedConversion, Deserialize(map[string]any{"Test": struct{}{}}, m))
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user