Cleaned up some validation code, stricter validation

This commit is contained in:
yusing
2025-01-26 14:43:48 +08:00
parent 254224c0e8
commit 1586610a44
23 changed files with 590 additions and 468 deletions

View File

@@ -1,13 +1,10 @@
package utils
// FIXME: some times [%d] is not in correct order
import (
"encoding/json"
"errors"
"os"
"reflect"
"runtime/debug"
"strconv"
"strings"
"time"
@@ -21,6 +18,10 @@ import (
type SerializedObject = map[string]any
type MapUnmarshaller interface {
UnmarshalMap(m map[string]any) E.Error
}
var (
ErrInvalidType = E.New("invalid type")
ErrNilValue = E.New("nil")
@@ -29,6 +30,8 @@ var (
ErrUnknownField = E.New("unknown field")
)
var mapUnmarshalerType = reflect.TypeFor[MapUnmarshaller]()
var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
func RegisterDefaultValueFactory[T any](factory func() *T) {
@@ -56,8 +59,9 @@ func extractFields(t reflect.Type) (all, anonymous []reflect.StructField) {
if t.Kind() != reflect.Struct {
return nil, nil
}
var fields []reflect.StructField
for i := range t.NumField() {
n := t.NumField()
fields := make([]reflect.StructField, 0, n)
for i := range n {
field := t.Field(i)
if !field.IsExported() {
continue
@@ -74,31 +78,74 @@ func extractFields(t reflect.Type) (all, anonymous []reflect.StructField) {
return fields, anonymous
}
func ValidateWithFieldTags(s any) E.Error {
errs := E.NewBuilder("validate error")
err := validate.Struct(s)
var valErrs validator.ValidationErrors
if errors.As(err, &valErrs) {
for _, e := range valErrs {
detail := e.ActualTag()
if e.Param() != "" {
detail += ":" + e.Param()
}
errs.Add(ErrValidationError.
Subject(e.Namespace()).
Withf("require %q", detail))
}
}
return errs.Error()
}
// Deserialize takes a SerializedObject and a target value, and assigns the values in the SerializedObject to the target value.
// Deserialize ignores case differences between the field names in the SerializedObject and the target.
//
// The target value must be a struct or a map[string]any.
// If the target value is a struct, the SerializedObject will be deserialized into the struct fields and validate if needed.
// If the target value is a map[string]any, the SerializedObject will be deserialized into the map.
// If the target value is a struct , and implements the MapUnmarshaller interface,
// the UnmarshalMap method will be called.
//
// If the target value is a struct, but does not implements the MapUnmarshaller interface,
// the SerializedObject will be deserialized into the struct fields and validate if needed.
//
// If the target value is a map[string]any the SerializedObject will be deserialized into the map.
//
// The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization.
func Deserialize(src SerializedObject, dst any) E.Error {
if src == nil {
return E.Errorf("deserialize: src is %w", ErrNilValue)
}
if dst == nil {
return E.Errorf("deserialize: dst is %w\n%s", ErrNilValue, debug.Stack())
}
dstV := reflect.ValueOf(dst)
dstT := dstV.Type()
if src == nil {
if dstV.CanSet() {
dstV.Set(reflect.Zero(dstT))
return nil
}
return E.Errorf("deserialize: src is %w and dst is not settable", ErrNilValue)
}
if dstT.Implements(mapUnmarshalerType) {
for dstV.IsNil() {
switch dstT.Kind() {
case reflect.Struct:
dstV.Set(New(dstT))
case reflect.Map:
dstV.Set(reflect.MakeMap(dstT))
case reflect.Slice:
dstV.Set(reflect.MakeSlice(dstT, 0, 0))
case reflect.Ptr:
dstV.Set(reflect.New(dstT.Elem()))
default:
return E.Errorf("deserialize: %w for dst %s", ErrInvalidType, dstT.String())
}
dstV = dstV.Elem()
}
return dstV.Interface().(MapUnmarshaller).UnmarshalMap(src)
}
for dstT.Kind() == reflect.Ptr {
if dstV.IsNil() {
if dstV.CanSet() {
dstV.Set(New(dstT.Elem()))
} else {
return E.Errorf("deserialize: dst is %w\n%s", ErrNilValue, debug.Stack())
return E.Errorf("deserialize: dst is %w and not settable", ErrNilValue)
}
}
dstV = dstV.Elem()
@@ -113,9 +160,8 @@ func Deserialize(src SerializedObject, dst any) E.Error {
switch dstV.Kind() {
case reflect.Struct:
needValidate := false
hasValidateTag := false
mapping := make(map[string]reflect.Value)
fieldName := make(map[string]string)
fields, anonymous := extractFields(dstT)
for _, anon := range anonymous {
if field := dstV.FieldByName(anon.Name); field.Kind() == reflect.Ptr && field.IsNil() {
@@ -134,17 +180,15 @@ func Deserialize(src SerializedObject, dst any) E.Error {
}
key = strutils.ToLowerNoSnake(key)
mapping[key] = dstV.FieldByName(field.Name)
fieldName[field.Name] = key
if !needValidate {
_, needValidate = field.Tag.Lookup("validate")
if !hasValidateTag {
_, hasValidateTag = field.Tag.Lookup("validate")
}
aliases, ok := field.Tag.Lookup("aliases")
if ok {
for _, alias := range strutils.CommaSeperatedList(aliases) {
mapping[alias] = dstV.FieldByName(field.Name)
fieldName[field.Name] = alias
}
}
}
@@ -158,20 +202,10 @@ func Deserialize(src SerializedObject, dst any) E.Error {
errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, mapping))))
}
}
if needValidate {
err := validate.Struct(dstV.Interface())
var valErrs validator.ValidationErrors
if errors.As(err, &valErrs) {
for _, e := range valErrs {
detail := e.ActualTag()
if e.Param() != "" {
detail += ":" + e.Param()
}
errs.Add(ErrValidationError.
Subject(e.StructNamespace()).
Withf("require %q", detail))
}
}
if hasValidateTag {
errs.Add(ValidateWithFieldTags(dstV.Interface()))
} else if validator, ok := dstV.Addr().Interface().(CustomValidator); ok {
errs.Add(validator.Validate())
}
return errs.Error()
case reflect.Map:
@@ -188,6 +222,9 @@ func Deserialize(src SerializedObject, dst any) E.Error {
errs.Add(err.Subject(k))
}
}
if validator, ok := dstV.Addr().Interface().(CustomValidator); ok {
errs.Add(validator.Validate())
}
return errs.Error()
default:
return ErrUnsupportedConversion.Subject("mapping to " + dstT.String())
@@ -421,14 +458,6 @@ func DeserializeYAMLMap[V any](data []byte) (_ functional.Map[string, V], err E.
return functional.NewMapFrom(m2), nil
}
func DeserializeJSON[T any](data []byte, target T) E.Error {
m := make(map[string]any)
if err := json.Unmarshal(data, &m); err != nil {
return E.From(err)
}
return Deserialize(m, target)
}
func loadSerialized[T any](path string, dst *T, deserialize func(data []byte, dst any) error) error {
data, err := os.ReadFile(path)
if err != nil {