fix: unmarshal and some tests

This commit is contained in:
yusing
2025-04-13 12:24:11 +08:00
parent be87d47ebb
commit 3f2dfe14b5
6 changed files with 100 additions and 101 deletions

View File

@@ -47,11 +47,15 @@ var (
var (
typeDuration = reflect.TypeFor[time.Duration]()
typeTime = reflect.TypeFor[time.Time]()
typeURL = reflect.TypeFor[url.URL]()
typeCIDR = reflect.TypeFor[*net.IPNet]()
typeCIDR = reflect.TypeFor[net.IPNet]()
typeMapMarshaller = reflect.TypeFor[MapMarshaller]()
typeMapUnmarshaler = reflect.TypeFor[MapUnmarshaller]()
typeJSONMarshaller = reflect.TypeFor[json.Marshaler]()
typeAny = reflect.TypeOf((*any)(nil)).Elem()
)
var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
@@ -360,20 +364,26 @@ func Convert(src reflect.Value, dst reflect.Value) gperr.Error {
return err
}
case isIntFloat(srcKind):
var strV string
switch {
case src.CanInt():
strV = strconv.FormatInt(src.Int(), 10)
case srcKind == reflect.Bool:
strV = strconv.FormatBool(src.Bool())
case src.CanUint():
strV = strconv.FormatUint(src.Uint(), 10)
case src.CanFloat():
strV = strconv.FormatFloat(src.Float(), 'f', -1, 64)
if dst.Kind() == reflect.String {
var strV string
switch {
case src.CanInt():
strV = strconv.FormatInt(src.Int(), 10)
case srcKind == reflect.Bool:
strV = strconv.FormatBool(src.Bool())
case src.CanUint():
strV = strconv.FormatUint(src.Uint(), 10)
case src.CanFloat():
strV = strconv.FormatFloat(src.Float(), 'f', -1, 64)
}
dst.Set(reflect.ValueOf(strV))
return nil
}
if convertible, err := ConvertString(strV, dst); convertible {
return err
if !isIntFloat(dstT.Kind()) || !src.CanConvert(dstT) {
return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT)
}
dst.Set(src.Convert(dstT))
return nil
case srcKind == reflect.Map:
if src.Len() == 0 {
return nil
@@ -412,8 +422,17 @@ func Convert(src reflect.Value, dst reflect.Value) gperr.Error {
return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT)
}
func nilPointer[T any]() reflect.Value {
return reflect.ValueOf((*T)(nil))
func isSameOrEmbededType(src, dst reflect.Type) bool {
return src == dst || src.ConvertibleTo(dst)
}
func setSameOrEmbedddType(src, dst reflect.Value) {
dstT := dst.Type()
if src.Type().AssignableTo(dstT) {
dst.Set(src)
} else {
dst.Set(src.Convert(dstT))
}
}
func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gperr.Error) {
@@ -430,12 +449,12 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gpe
dst.SetString(src)
return
}
switch dstT {
case typeDuration:
if src == "" {
dst.Set(reflect.Zero(dstT))
return false, nil
}
if src == "" {
dst.Set(reflect.Zero(dstT))
return
}
switch {
case dstT == typeDuration:
d, err := time.ParseDuration(src)
if err != nil {
return true, gperr.Wrap(err)
@@ -445,30 +464,22 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gpe
}
dst.Set(reflect.ValueOf(d))
return
case typeURL:
if src == "" {
dst.Addr().Set(nilPointer[*url.URL]())
return
}
case isSameOrEmbededType(dstT, typeURL):
u, err := url.Parse(src)
if err != nil {
return true, gperr.Wrap(err)
}
dst.Set(reflect.ValueOf(u).Elem())
setSameOrEmbedddType(reflect.ValueOf(u).Elem(), dst)
return
case typeCIDR:
if src == "" {
dst.Addr().Set(nilPointer[*net.IPNet]())
return
}
if !strings.Contains(src, "/") {
case isSameOrEmbededType(dstT, typeCIDR):
if !strings.ContainsRune(src, '/') {
src += "/32" // single IP
}
_, ipnet, err := net.ParseCIDR(src)
if err != nil {
return true, gperr.Wrap(err)
}
dst.Set(reflect.ValueOf(ipnet).Elem())
setSameOrEmbedddType(reflect.ValueOf(ipnet).Elem(), dst)
return
}
if dstKind := dst.Kind(); isIntFloat(dstKind) {