mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-21 16:01:22 +02:00
Cleaned up some validation code, stricter validation
This commit is contained in:
@@ -1,13 +1,10 @@
|
||||
package utils
|
||||
|
||||
// FIXME: some times [%d] is not in correct order
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -21,6 +18,10 @@ import (
|
||||
|
||||
type SerializedObject = map[string]any
|
||||
|
||||
type MapUnmarshaller interface {
|
||||
UnmarshalMap(m map[string]any) E.Error
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidType = E.New("invalid type")
|
||||
ErrNilValue = E.New("nil")
|
||||
@@ -29,6 +30,8 @@ var (
|
||||
ErrUnknownField = E.New("unknown field")
|
||||
)
|
||||
|
||||
var mapUnmarshalerType = reflect.TypeFor[MapUnmarshaller]()
|
||||
|
||||
var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
|
||||
|
||||
func RegisterDefaultValueFactory[T any](factory func() *T) {
|
||||
@@ -56,8 +59,9 @@ func extractFields(t reflect.Type) (all, anonymous []reflect.StructField) {
|
||||
if t.Kind() != reflect.Struct {
|
||||
return nil, nil
|
||||
}
|
||||
var fields []reflect.StructField
|
||||
for i := range t.NumField() {
|
||||
n := t.NumField()
|
||||
fields := make([]reflect.StructField, 0, n)
|
||||
for i := range n {
|
||||
field := t.Field(i)
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
@@ -74,31 +78,74 @@ func extractFields(t reflect.Type) (all, anonymous []reflect.StructField) {
|
||||
return fields, anonymous
|
||||
}
|
||||
|
||||
func ValidateWithFieldTags(s any) E.Error {
|
||||
errs := E.NewBuilder("validate error")
|
||||
err := validate.Struct(s)
|
||||
var valErrs validator.ValidationErrors
|
||||
if errors.As(err, &valErrs) {
|
||||
for _, e := range valErrs {
|
||||
detail := e.ActualTag()
|
||||
if e.Param() != "" {
|
||||
detail += ":" + e.Param()
|
||||
}
|
||||
errs.Add(ErrValidationError.
|
||||
Subject(e.Namespace()).
|
||||
Withf("require %q", detail))
|
||||
}
|
||||
}
|
||||
return errs.Error()
|
||||
}
|
||||
|
||||
// Deserialize takes a SerializedObject and a target value, and assigns the values in the SerializedObject to the target value.
|
||||
// Deserialize ignores case differences between the field names in the SerializedObject and the target.
|
||||
//
|
||||
// The target value must be a struct or a map[string]any.
|
||||
// If the target value is a struct, the SerializedObject will be deserialized into the struct fields and validate if needed.
|
||||
// If the target value is a map[string]any, the SerializedObject will be deserialized into the map.
|
||||
// If the target value is a struct , and implements the MapUnmarshaller interface,
|
||||
// the UnmarshalMap method will be called.
|
||||
//
|
||||
// If the target value is a struct, but does not implements the MapUnmarshaller interface,
|
||||
// the SerializedObject will be deserialized into the struct fields and validate if needed.
|
||||
//
|
||||
// If the target value is a map[string]any the SerializedObject will be deserialized into the map.
|
||||
//
|
||||
// The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization.
|
||||
func Deserialize(src SerializedObject, dst any) E.Error {
|
||||
if src == nil {
|
||||
return E.Errorf("deserialize: src is %w", ErrNilValue)
|
||||
}
|
||||
if dst == nil {
|
||||
return E.Errorf("deserialize: dst is %w\n%s", ErrNilValue, debug.Stack())
|
||||
}
|
||||
|
||||
dstV := reflect.ValueOf(dst)
|
||||
dstT := dstV.Type()
|
||||
|
||||
if src == nil {
|
||||
if dstV.CanSet() {
|
||||
dstV.Set(reflect.Zero(dstT))
|
||||
return nil
|
||||
}
|
||||
return E.Errorf("deserialize: src is %w and dst is not settable", ErrNilValue)
|
||||
}
|
||||
|
||||
if dstT.Implements(mapUnmarshalerType) {
|
||||
for dstV.IsNil() {
|
||||
switch dstT.Kind() {
|
||||
case reflect.Struct:
|
||||
dstV.Set(New(dstT))
|
||||
case reflect.Map:
|
||||
dstV.Set(reflect.MakeMap(dstT))
|
||||
case reflect.Slice:
|
||||
dstV.Set(reflect.MakeSlice(dstT, 0, 0))
|
||||
case reflect.Ptr:
|
||||
dstV.Set(reflect.New(dstT.Elem()))
|
||||
default:
|
||||
return E.Errorf("deserialize: %w for dst %s", ErrInvalidType, dstT.String())
|
||||
}
|
||||
dstV = dstV.Elem()
|
||||
}
|
||||
return dstV.Interface().(MapUnmarshaller).UnmarshalMap(src)
|
||||
}
|
||||
|
||||
for dstT.Kind() == reflect.Ptr {
|
||||
if dstV.IsNil() {
|
||||
if dstV.CanSet() {
|
||||
dstV.Set(New(dstT.Elem()))
|
||||
} else {
|
||||
return E.Errorf("deserialize: dst is %w\n%s", ErrNilValue, debug.Stack())
|
||||
return E.Errorf("deserialize: dst is %w and not settable", ErrNilValue)
|
||||
}
|
||||
}
|
||||
dstV = dstV.Elem()
|
||||
@@ -113,9 +160,8 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
||||
|
||||
switch dstV.Kind() {
|
||||
case reflect.Struct:
|
||||
needValidate := false
|
||||
hasValidateTag := false
|
||||
mapping := make(map[string]reflect.Value)
|
||||
fieldName := make(map[string]string)
|
||||
fields, anonymous := extractFields(dstT)
|
||||
for _, anon := range anonymous {
|
||||
if field := dstV.FieldByName(anon.Name); field.Kind() == reflect.Ptr && field.IsNil() {
|
||||
@@ -134,17 +180,15 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
||||
}
|
||||
key = strutils.ToLowerNoSnake(key)
|
||||
mapping[key] = dstV.FieldByName(field.Name)
|
||||
fieldName[field.Name] = key
|
||||
|
||||
if !needValidate {
|
||||
_, needValidate = field.Tag.Lookup("validate")
|
||||
if !hasValidateTag {
|
||||
_, hasValidateTag = field.Tag.Lookup("validate")
|
||||
}
|
||||
|
||||
aliases, ok := field.Tag.Lookup("aliases")
|
||||
if ok {
|
||||
for _, alias := range strutils.CommaSeperatedList(aliases) {
|
||||
mapping[alias] = dstV.FieldByName(field.Name)
|
||||
fieldName[field.Name] = alias
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -158,20 +202,10 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
||||
errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, mapping))))
|
||||
}
|
||||
}
|
||||
if needValidate {
|
||||
err := validate.Struct(dstV.Interface())
|
||||
var valErrs validator.ValidationErrors
|
||||
if errors.As(err, &valErrs) {
|
||||
for _, e := range valErrs {
|
||||
detail := e.ActualTag()
|
||||
if e.Param() != "" {
|
||||
detail += ":" + e.Param()
|
||||
}
|
||||
errs.Add(ErrValidationError.
|
||||
Subject(e.StructNamespace()).
|
||||
Withf("require %q", detail))
|
||||
}
|
||||
}
|
||||
if hasValidateTag {
|
||||
errs.Add(ValidateWithFieldTags(dstV.Interface()))
|
||||
} else if validator, ok := dstV.Addr().Interface().(CustomValidator); ok {
|
||||
errs.Add(validator.Validate())
|
||||
}
|
||||
return errs.Error()
|
||||
case reflect.Map:
|
||||
@@ -188,6 +222,9 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
||||
errs.Add(err.Subject(k))
|
||||
}
|
||||
}
|
||||
if validator, ok := dstV.Addr().Interface().(CustomValidator); ok {
|
||||
errs.Add(validator.Validate())
|
||||
}
|
||||
return errs.Error()
|
||||
default:
|
||||
return ErrUnsupportedConversion.Subject("mapping to " + dstT.String())
|
||||
@@ -421,14 +458,6 @@ func DeserializeYAMLMap[V any](data []byte) (_ functional.Map[string, V], err E.
|
||||
return functional.NewMapFrom(m2), nil
|
||||
}
|
||||
|
||||
func DeserializeJSON[T any](data []byte, target T) E.Error {
|
||||
m := make(map[string]any)
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
return Deserialize(m, target)
|
||||
}
|
||||
|
||||
func loadSerialized[T any](path string, dst *T, deserialize func(data []byte, dst any) error) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user