mirror of
https://github.com/yusing/godoxy.git
synced 2026-03-18 15:34:38 +01:00
refactor: simplify io code and make utils module independent
This commit is contained in:
566
internal/serialization/serialization.go
Normal file
566
internal/serialization/serialization.go
Normal file
@@ -0,0 +1,566 @@
|
||||
package serialization
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type SerializedObject = map[string]any
|
||||
|
||||
type MapUnmarshaller interface {
|
||||
UnmarshalMap(m map[string]any) gperr.Error
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidType = gperr.New("invalid type")
|
||||
ErrNilValue = gperr.New("nil")
|
||||
ErrUnsettable = gperr.New("unsettable")
|
||||
ErrUnsupportedConversion = gperr.New("unsupported conversion")
|
||||
ErrUnknownField = gperr.New("unknown field")
|
||||
)
|
||||
|
||||
var (
|
||||
tagDeserialize = "deserialize" // `deserialize:"-"` to exclude from deserialization
|
||||
tagJSON = "json" // share between Deserialize and json.Marshal
|
||||
tagValidate = "validate" // uses go-playground/validator
|
||||
tagAliases = "aliases" // declare aliases for fields
|
||||
)
|
||||
|
||||
var mapUnmarshalerType = reflect.TypeFor[MapUnmarshaller]()
|
||||
|
||||
var defaultValues = xsync.NewMapOf[reflect.Type, func() any]()
|
||||
|
||||
func RegisterDefaultValueFactory[T any](factory func() *T) {
|
||||
t := reflect.TypeFor[T]()
|
||||
if t.Kind() == reflect.Ptr {
|
||||
panic("pointer of pointer")
|
||||
}
|
||||
if _, ok := defaultValues.Load(t); ok {
|
||||
panic("default value for " + t.String() + " already registered")
|
||||
}
|
||||
defaultValues.Store(t, func() any { return factory() })
|
||||
}
|
||||
|
||||
func New(t reflect.Type) reflect.Value {
|
||||
if dv, ok := defaultValues.Load(t); ok {
|
||||
return reflect.ValueOf(dv())
|
||||
}
|
||||
return reflect.New(t)
|
||||
}
|
||||
|
||||
func extractFields(t reflect.Type) (all, anonymous []reflect.StructField) {
|
||||
for t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
if t.Kind() != reflect.Struct {
|
||||
return nil, nil
|
||||
}
|
||||
n := t.NumField()
|
||||
fields := make([]reflect.StructField, 0, n)
|
||||
for i := range n {
|
||||
field := t.Field(i)
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
if field.Tag.Get(tagDeserialize) == "-" {
|
||||
continue
|
||||
}
|
||||
if field.Anonymous {
|
||||
f1, f2 := extractFields(field.Type)
|
||||
fields = append(fields, f1...)
|
||||
anonymous = append(anonymous, field)
|
||||
anonymous = append(anonymous, f2...)
|
||||
} else {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
}
|
||||
return fields, anonymous
|
||||
}
|
||||
|
||||
func ValidateWithFieldTags(s any) gperr.Error {
|
||||
errs := gperr.NewBuilder()
|
||||
err := validate.Struct(s)
|
||||
var valErrs validator.ValidationErrors
|
||||
if errors.As(err, &valErrs) {
|
||||
for _, e := range valErrs {
|
||||
detail := e.ActualTag()
|
||||
if e.Param() != "" {
|
||||
detail += ":" + e.Param()
|
||||
}
|
||||
if detail != "required" {
|
||||
detail = "require " + strconv.Quote(detail)
|
||||
}
|
||||
errs.Add(ErrValidationError.
|
||||
Subject(e.Namespace()).
|
||||
Withf(detail))
|
||||
}
|
||||
}
|
||||
return errs.Error()
|
||||
}
|
||||
|
||||
func ValidateWithCustomValidator(v reflect.Value) gperr.Error {
|
||||
isStruct := false
|
||||
for {
|
||||
switch v.Kind() {
|
||||
case reflect.Pointer, reflect.Interface:
|
||||
if v.IsNil() {
|
||||
return gperr.Errorf("validate: v is %w", ErrNilValue)
|
||||
}
|
||||
if validate, ok := v.Interface().(CustomValidator); ok {
|
||||
return validate.Validate()
|
||||
}
|
||||
if isStruct {
|
||||
return nil
|
||||
}
|
||||
v = v.Elem()
|
||||
case reflect.Struct:
|
||||
if !v.CanAddr() {
|
||||
return nil
|
||||
}
|
||||
v = v.Addr()
|
||||
isStruct = true
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func dive(dst reflect.Value) (v reflect.Value, t reflect.Type, err gperr.Error) {
|
||||
dstT := dst.Type()
|
||||
for {
|
||||
switch dst.Kind() {
|
||||
case reflect.Pointer, reflect.Interface:
|
||||
if dst.IsNil() {
|
||||
if !dst.CanSet() {
|
||||
err = gperr.Errorf("dive: dst is %w and is not settable", ErrNilValue)
|
||||
return
|
||||
}
|
||||
dst.Set(New(dstT.Elem()))
|
||||
}
|
||||
dst = dst.Elem()
|
||||
dstT = dst.Type()
|
||||
case reflect.Struct:
|
||||
return dst, dstT, nil
|
||||
default:
|
||||
if dst.IsNil() {
|
||||
switch dst.Kind() {
|
||||
case reflect.Map:
|
||||
dst.Set(reflect.MakeMap(dstT))
|
||||
case reflect.Slice:
|
||||
dst.Set(reflect.MakeSlice(dstT, 0, 0))
|
||||
default:
|
||||
err = gperr.Errorf("deserialize: %w for dst %s", ErrInvalidType, dstT.String())
|
||||
return
|
||||
}
|
||||
}
|
||||
return dst, dstT, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MapUnmarshalValidate takes a SerializedObject and a target value, and assigns the values in the SerializedObject to the target value.
|
||||
// MapUnmarshalValidate ignores case differences between the field names in the SerializedObject and the target.
|
||||
//
|
||||
// The target value must be a struct or a map[string]any.
|
||||
// If the target value is a struct , and implements the MapUnmarshaller interface,
|
||||
// the UnmarshalMap method will be called.
|
||||
//
|
||||
// If the target value is a struct, but does not implements the MapUnmarshaller interface,
|
||||
// the SerializedObject will be deserialized into the struct fields and validate if needed.
|
||||
//
|
||||
// If the target value is a map[string]any the SerializedObject will be deserialized into the map.
|
||||
//
|
||||
// The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization.
|
||||
func MapUnmarshalValidate(src SerializedObject, dst any) (err gperr.Error) {
|
||||
return mapUnmarshalValidate(src, dst, true)
|
||||
}
|
||||
|
||||
func mapUnmarshalValidate(src SerializedObject, dst any, checkValidateTag bool) (err gperr.Error) {
|
||||
dstV := reflect.ValueOf(dst)
|
||||
dstT := dstV.Type()
|
||||
|
||||
if src == nil {
|
||||
if dstV.CanSet() {
|
||||
dstV.Set(reflect.Zero(dstT))
|
||||
return nil
|
||||
}
|
||||
return gperr.Errorf("deserialize: src is %w and dst is not settable\n%s", ErrNilValue, debug.Stack())
|
||||
}
|
||||
|
||||
if dstT.Implements(mapUnmarshalerType) {
|
||||
dstV, _, err = dive(dstV)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return dstV.Addr().Interface().(MapUnmarshaller).UnmarshalMap(src)
|
||||
}
|
||||
|
||||
dstV, dstT, err = dive(dstV)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// convert data fields to lower no-snake
|
||||
// convert target fields to lower no-snake
|
||||
// then check if the field of data is in the target
|
||||
|
||||
errs := gperr.NewBuilder()
|
||||
|
||||
switch dstV.Kind() {
|
||||
case reflect.Struct, reflect.Interface:
|
||||
hasValidateTag := false
|
||||
mapping := make(map[string]reflect.Value)
|
||||
fields, anonymous := extractFields(dstT)
|
||||
for _, anon := range anonymous {
|
||||
if field := dstV.FieldByName(anon.Name); field.Kind() == reflect.Ptr && field.IsNil() {
|
||||
field.Set(New(anon.Type.Elem()))
|
||||
}
|
||||
}
|
||||
for _, field := range fields {
|
||||
var key string
|
||||
if jsonTag, ok := field.Tag.Lookup(tagJSON); ok {
|
||||
if jsonTag == "-" {
|
||||
continue
|
||||
}
|
||||
key = strutils.CommaSeperatedList(jsonTag)[0]
|
||||
} else {
|
||||
key = field.Name
|
||||
}
|
||||
key = strutils.ToLowerNoSnake(key)
|
||||
mapping[key] = dstV.FieldByName(field.Name)
|
||||
|
||||
if !hasValidateTag {
|
||||
_, hasValidateTag = field.Tag.Lookup(tagValidate)
|
||||
}
|
||||
|
||||
aliases, ok := field.Tag.Lookup(tagAliases)
|
||||
if ok {
|
||||
for _, alias := range strutils.CommaSeperatedList(aliases) {
|
||||
mapping[alias] = dstV.FieldByName(field.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
for k, v := range src {
|
||||
if field, ok := mapping[strutils.ToLowerNoSnake(k)]; ok {
|
||||
err := Convert(reflect.ValueOf(v), field, !hasValidateTag)
|
||||
if err != nil {
|
||||
errs.Add(err.Subject(k))
|
||||
}
|
||||
} else {
|
||||
errs.Add(ErrUnknownField.Subject(k).With(gperr.DoYouMean(utils.NearestField(k, mapping))))
|
||||
}
|
||||
}
|
||||
if hasValidateTag && checkValidateTag {
|
||||
errs.Add(ValidateWithFieldTags(dstV.Interface()))
|
||||
}
|
||||
if err := ValidateWithCustomValidator(dstV); err != nil {
|
||||
errs.Add(err)
|
||||
}
|
||||
return errs.Error()
|
||||
case reflect.Map:
|
||||
for k, v := range src {
|
||||
mapVT := dstT.Elem()
|
||||
tmp := New(mapVT).Elem()
|
||||
err := Convert(reflect.ValueOf(v), tmp, true)
|
||||
if err != nil {
|
||||
errs.Add(err.Subject(k))
|
||||
continue
|
||||
}
|
||||
if err := ValidateWithCustomValidator(tmp.Addr()); err != nil {
|
||||
errs.Add(err.Subject(k))
|
||||
} else {
|
||||
dstV.SetMapIndex(reflect.ValueOf(k), tmp)
|
||||
}
|
||||
}
|
||||
if err := ValidateWithCustomValidator(dstV); err != nil {
|
||||
errs.Add(err)
|
||||
}
|
||||
return errs.Error()
|
||||
default:
|
||||
return ErrUnsupportedConversion.Subject("mapping to " + dstT.String() + " ")
|
||||
}
|
||||
}
|
||||
|
||||
func isIntFloat(t reflect.Kind) bool {
|
||||
return t >= reflect.Bool && t <= reflect.Float64
|
||||
}
|
||||
|
||||
// Convert attempts to convert the src to dst.
|
||||
//
|
||||
// If src is a map, it is deserialized into dst.
|
||||
// If src is a slice, each of its elements are converted and stored in dst.
|
||||
// For any other type, it is converted using the reflect.Value.Convert function (if possible).
|
||||
//
|
||||
// If dst is not settable, an error is returned.
|
||||
// If src cannot be converted to dst, an error is returned.
|
||||
// If any error occurs during conversion (e.g. deserialization), it is returned.
|
||||
//
|
||||
// Returns:
|
||||
// - error: the error occurred during conversion, or nil if no error occurred.
|
||||
func Convert(src reflect.Value, dst reflect.Value, checkValidateTag bool) gperr.Error {
|
||||
if !dst.IsValid() {
|
||||
return gperr.Errorf("convert: dst is %w", ErrNilValue)
|
||||
}
|
||||
|
||||
if (src.Kind() == reflect.Pointer && src.IsNil()) || !src.IsValid() {
|
||||
if !dst.CanSet() {
|
||||
return gperr.Errorf("convert: src is %w", ErrNilValue)
|
||||
}
|
||||
// manually set nil
|
||||
dst.Set(reflect.Zero(dst.Type()))
|
||||
return nil
|
||||
}
|
||||
|
||||
if src.IsZero() {
|
||||
if !dst.CanSet() {
|
||||
return gperr.Errorf("convert: src is %w", ErrNilValue)
|
||||
}
|
||||
switch dst.Kind() {
|
||||
case reflect.Pointer, reflect.Interface:
|
||||
dst.Set(reflect.New(dst.Type().Elem()))
|
||||
default:
|
||||
dst.Set(reflect.Zero(dst.Type()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
srcT := src.Type()
|
||||
dstT := dst.Type()
|
||||
|
||||
if src.Kind() == reflect.Interface {
|
||||
src = src.Elem()
|
||||
srcT = src.Type()
|
||||
}
|
||||
|
||||
if dst.Kind() == reflect.Pointer {
|
||||
if dst.IsNil() {
|
||||
dst.Set(New(dstT.Elem()))
|
||||
}
|
||||
dst = dst.Elem()
|
||||
dstT = dst.Type()
|
||||
}
|
||||
|
||||
srcKind := srcT.Kind()
|
||||
|
||||
switch {
|
||||
case srcT.AssignableTo(dstT):
|
||||
if !dst.CanSet() {
|
||||
return ErrUnsettable.Subject(dstT.String())
|
||||
}
|
||||
dst.Set(src)
|
||||
return nil
|
||||
// case srcT.ConvertibleTo(dstT):
|
||||
// dst.Set(src.Convert(dstT))
|
||||
// return nil
|
||||
case srcKind == reflect.String:
|
||||
if !dst.CanSet() {
|
||||
return ErrUnsettable.Subject(dstT.String())
|
||||
}
|
||||
if convertible, err := ConvertString(src.String(), dst); convertible {
|
||||
return err
|
||||
}
|
||||
case isIntFloat(srcKind):
|
||||
if !dst.CanSet() {
|
||||
return ErrUnsettable.Subject(dstT.String())
|
||||
}
|
||||
var strV string
|
||||
switch {
|
||||
case src.CanInt():
|
||||
strV = strconv.FormatInt(src.Int(), 10)
|
||||
case srcKind == reflect.Bool:
|
||||
strV = strconv.FormatBool(src.Bool())
|
||||
case src.CanUint():
|
||||
strV = strconv.FormatUint(src.Uint(), 10)
|
||||
case src.CanFloat():
|
||||
strV = strconv.FormatFloat(src.Float(), 'f', -1, 64)
|
||||
}
|
||||
if convertible, err := ConvertString(strV, dst); convertible {
|
||||
return err
|
||||
}
|
||||
case srcKind == reflect.Map:
|
||||
if src.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
obj, ok := src.Interface().(SerializedObject)
|
||||
if !ok {
|
||||
return ErrUnsupportedConversion.Subject(dstT.String() + " to " + srcT.String())
|
||||
}
|
||||
return mapUnmarshalValidate(obj, dst.Addr().Interface(), checkValidateTag)
|
||||
case srcKind == reflect.Slice:
|
||||
if src.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
if dstT.Kind() != reflect.Slice {
|
||||
return ErrUnsupportedConversion.Subject(dstT.String() + " to " + srcT.String())
|
||||
}
|
||||
sliceErrs := gperr.NewBuilder()
|
||||
newSlice := reflect.MakeSlice(dstT, src.Len(), src.Len())
|
||||
i := 0
|
||||
for j, v := range src.Seq2() {
|
||||
tmp := New(dstT.Elem()).Elem()
|
||||
err := Convert(v, tmp, checkValidateTag)
|
||||
if err != nil {
|
||||
sliceErrs.Add(err.Subjectf("[%d]", j))
|
||||
continue
|
||||
}
|
||||
newSlice.Index(i).Set(tmp)
|
||||
i++
|
||||
}
|
||||
if err := sliceErrs.Error(); err != nil {
|
||||
return err
|
||||
}
|
||||
dst.Set(newSlice)
|
||||
return nil
|
||||
}
|
||||
return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT)
|
||||
}
|
||||
|
||||
func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gperr.Error) {
|
||||
convertible = true
|
||||
dstT := dst.Type()
|
||||
if dst.Kind() == reflect.Ptr {
|
||||
if dst.IsNil() {
|
||||
dst.Set(New(dstT.Elem()))
|
||||
}
|
||||
dst = dst.Elem()
|
||||
dstT = dst.Type()
|
||||
}
|
||||
if dst.Kind() == reflect.String {
|
||||
dst.SetString(src)
|
||||
return
|
||||
}
|
||||
switch dstT {
|
||||
case reflect.TypeFor[time.Duration]():
|
||||
if src == "" {
|
||||
dst.Set(reflect.Zero(dstT))
|
||||
return
|
||||
}
|
||||
d, err := time.ParseDuration(src)
|
||||
if err != nil {
|
||||
return true, gperr.Wrap(err)
|
||||
}
|
||||
dst.Set(reflect.ValueOf(d))
|
||||
return
|
||||
default:
|
||||
}
|
||||
if dstKind := dst.Kind(); isIntFloat(dstKind) {
|
||||
var i any
|
||||
var err error
|
||||
switch {
|
||||
case dstKind == reflect.Bool:
|
||||
i, err = strconv.ParseBool(src)
|
||||
case dst.CanInt():
|
||||
i, err = strconv.ParseInt(src, 10, dstT.Bits())
|
||||
case dst.CanUint():
|
||||
i, err = strconv.ParseUint(src, 10, dstT.Bits())
|
||||
case dst.CanFloat():
|
||||
i, err = strconv.ParseFloat(src, dstT.Bits())
|
||||
}
|
||||
if err != nil {
|
||||
return true, gperr.Wrap(err)
|
||||
}
|
||||
dst.Set(reflect.ValueOf(i).Convert(dstT))
|
||||
return
|
||||
}
|
||||
// check if (*T).Convertor is implemented
|
||||
if parser, ok := dst.Addr().Interface().(strutils.Parser); ok {
|
||||
return true, gperr.Wrap(parser.Parse(src))
|
||||
}
|
||||
// yaml like
|
||||
var tmp any
|
||||
switch dst.Kind() {
|
||||
case reflect.Slice:
|
||||
src = strings.TrimSpace(src)
|
||||
isMultiline := strings.ContainsRune(src, '\n')
|
||||
// one liner is comma separated list
|
||||
if !isMultiline && src[0] != '-' {
|
||||
values := strutils.CommaSeperatedList(src)
|
||||
dst.Set(reflect.MakeSlice(dst.Type(), len(values), len(values)))
|
||||
errs := gperr.NewBuilder()
|
||||
for i, v := range values {
|
||||
err := Convert(reflect.ValueOf(v), dst.Index(i), true)
|
||||
if err != nil {
|
||||
errs.Add(err.Subjectf("[%d]", i))
|
||||
}
|
||||
}
|
||||
if errs.HasError() {
|
||||
return true, errs.Error()
|
||||
}
|
||||
return
|
||||
}
|
||||
sl := make([]any, 0)
|
||||
err := yaml.Unmarshal([]byte(src), &sl)
|
||||
if err != nil {
|
||||
return true, gperr.Wrap(err)
|
||||
}
|
||||
tmp = sl
|
||||
case reflect.Map, reflect.Struct:
|
||||
rawMap := make(SerializedObject)
|
||||
err := yaml.Unmarshal([]byte(src), &rawMap)
|
||||
if err != nil {
|
||||
return true, gperr.Wrap(err)
|
||||
}
|
||||
tmp = rawMap
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
return true, Convert(reflect.ValueOf(tmp), dst, true)
|
||||
}
|
||||
|
||||
func UnmarshalValidateYAML[T any](data []byte, target *T) gperr.Error {
|
||||
m := make(map[string]any)
|
||||
if err := yaml.Unmarshal(data, &m); err != nil {
|
||||
return gperr.Wrap(err)
|
||||
}
|
||||
return MapUnmarshalValidate(m, target)
|
||||
}
|
||||
|
||||
func UnmarshalValidateYAMLXSync[V any](data []byte) (_ functional.Map[string, V], err gperr.Error) {
|
||||
m := make(map[string]any)
|
||||
if err = gperr.Wrap(yaml.Unmarshal(data, &m)); err != nil {
|
||||
return
|
||||
}
|
||||
m2 := make(map[string]V, len(m))
|
||||
if err = MapUnmarshalValidate(m, m2); err != nil {
|
||||
return
|
||||
}
|
||||
return functional.NewMapFrom(m2), nil
|
||||
}
|
||||
|
||||
func loadSerialized[T any](path string, dst *T, deserialize func(data []byte, dst any) error) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return deserialize(data, dst)
|
||||
}
|
||||
|
||||
func SaveJSON[T any](path string, src *T, perm os.FileMode) error {
|
||||
data, err := json.Marshal(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, perm)
|
||||
}
|
||||
|
||||
func LoadJSONIfExist[T any](path string, dst *T) error {
|
||||
err := loadSerialized(path, dst, json.Unmarshal)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
335
internal/serialization/serialization_test.go
Normal file
335
internal/serialization/serialization_test.go
Normal file
@@ -0,0 +1,335 @@
|
||||
package serialization
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestDeserialize(t *testing.T) {
|
||||
type S struct {
|
||||
I int
|
||||
S string
|
||||
IS []int
|
||||
SS []string
|
||||
MSI map[string]int
|
||||
MIS map[int]string
|
||||
}
|
||||
|
||||
var (
|
||||
testStruct = S{
|
||||
I: 1,
|
||||
S: "hello",
|
||||
IS: []int{1, 2, 3},
|
||||
SS: []string{"a", "b", "c"},
|
||||
MSI: map[string]int{"a": 1, "b": 2, "c": 3},
|
||||
MIS: map[int]string{1: "a", 2: "b", 3: "c"},
|
||||
}
|
||||
testStructSerialized = map[string]any{
|
||||
"I": 1,
|
||||
"S": "hello",
|
||||
"IS": []int{1, 2, 3},
|
||||
"SS": []string{"a", "b", "c"},
|
||||
"MSI": map[string]int{"a": 1, "b": 2, "c": 3},
|
||||
"MIS": map[int]string{1: "a", 2: "b", 3: "c"},
|
||||
}
|
||||
)
|
||||
|
||||
t.Run("deserialize", func(t *testing.T) {
|
||||
var s2 S
|
||||
err := MapUnmarshalValidate(testStructSerialized, &s2)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, s2, testStruct)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeserializeAnonymousField(t *testing.T) {
|
||||
type Anon struct {
|
||||
A, B int
|
||||
}
|
||||
var s struct {
|
||||
Anon
|
||||
C int
|
||||
}
|
||||
var s2 struct {
|
||||
*Anon
|
||||
C int
|
||||
}
|
||||
// all, anon := extractFields(reflect.TypeOf(s2))
|
||||
// t.Fatalf("anon %v, all %v", anon, all)
|
||||
err := MapUnmarshalValidate(map[string]any{"a": 1, "b": 2, "c": 3}, &s)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, s.A, 1)
|
||||
ExpectEqual(t, s.B, 2)
|
||||
ExpectEqual(t, s.C, 3)
|
||||
|
||||
err = MapUnmarshalValidate(map[string]any{"a": 1, "b": 2, "c": 3}, &s2)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, s2.A, 1)
|
||||
ExpectEqual(t, s2.B, 2)
|
||||
ExpectEqual(t, s2.C, 3)
|
||||
}
|
||||
|
||||
func TestPointerPrimitives(t *testing.T) {
|
||||
type testType struct {
|
||||
B *bool `json:"b"`
|
||||
I8 *int8 `json:"i8"`
|
||||
I16 *int16 `json:"i16"`
|
||||
I32 *int32 `json:"i32"`
|
||||
I64 *int64 `json:"i64"`
|
||||
U8 *uint8 `json:"u8"`
|
||||
U16 *uint16 `json:"u16"`
|
||||
U32 *uint32 `json:"u32"`
|
||||
U64 *uint64 `json:"u64"`
|
||||
}
|
||||
var test testType
|
||||
|
||||
err := MapUnmarshalValidate(map[string]any{"b": true, "i8": int8(127), "i16": int16(127), "i32": int32(127), "i64": int64(127), "u8": uint8(127), "u16": uint16(127), "u32": uint32(127), "u64": uint64(127)}, &test)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, *test.B, true)
|
||||
ExpectEqual(t, *test.I8, int8(127))
|
||||
ExpectEqual(t, *test.I16, int16(127))
|
||||
ExpectEqual(t, *test.I32, int32(127))
|
||||
ExpectEqual(t, *test.I64, int64(127))
|
||||
ExpectEqual(t, *test.U8, uint8(127))
|
||||
ExpectEqual(t, *test.U16, uint16(127))
|
||||
ExpectEqual(t, *test.U32, uint32(127))
|
||||
ExpectEqual(t, *test.U64, uint64(127))
|
||||
|
||||
// zero values
|
||||
err = MapUnmarshalValidate(map[string]any{"b": false, "i8": int8(0), "i16": int16(0), "i32": int32(0), "i64": int64(0), "u8": uint8(0), "u16": uint16(0), "u32": uint32(0), "u64": uint64(0)}, &test)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, *test.B, false)
|
||||
ExpectEqual(t, *test.I8, int8(0))
|
||||
ExpectEqual(t, *test.I16, int16(0))
|
||||
ExpectEqual(t, *test.I32, int32(0))
|
||||
ExpectEqual(t, *test.I64, int64(0))
|
||||
ExpectEqual(t, *test.U8, uint8(0))
|
||||
ExpectEqual(t, *test.U16, uint16(0))
|
||||
ExpectEqual(t, *test.U32, uint32(0))
|
||||
ExpectEqual(t, *test.U64, uint64(0))
|
||||
|
||||
// nil values
|
||||
err = MapUnmarshalValidate(map[string]any{"b": true, "i8": int8(127), "i16": int16(127), "i32": int32(127), "i64": int64(127), "u8": uint8(127), "u16": uint16(127), "u32": uint32(127), "u64": uint64(127)}, &test)
|
||||
ExpectNoError(t, err)
|
||||
err = MapUnmarshalValidate(map[string]any{"b": nil, "i8": nil, "i16": nil, "i32": nil, "i64": nil, "u8": nil, "u16": nil, "u32": nil, "u64": nil}, &test)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, test.B, nil)
|
||||
ExpectEqual(t, test.I8, nil)
|
||||
ExpectEqual(t, test.I16, nil)
|
||||
ExpectEqual(t, test.I32, nil)
|
||||
ExpectEqual(t, test.I64, nil)
|
||||
ExpectEqual(t, test.U8, nil)
|
||||
ExpectEqual(t, test.U16, nil)
|
||||
ExpectEqual(t, test.U32, nil)
|
||||
ExpectEqual(t, test.U64, nil)
|
||||
}
|
||||
|
||||
func TestStringIntConvert(t *testing.T) {
|
||||
s := "127"
|
||||
|
||||
test := struct {
|
||||
i8 int8
|
||||
i16 int16
|
||||
i32 int32
|
||||
i64 int64
|
||||
u8 uint8
|
||||
u16 uint16
|
||||
u32 uint32
|
||||
u64 uint64
|
||||
}{}
|
||||
|
||||
ok, err := ConvertString(s, reflect.ValueOf(&test.i8))
|
||||
|
||||
ExpectTrue(t, ok)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, test.i8, int8(127))
|
||||
|
||||
ok, err = ConvertString(s, reflect.ValueOf(&test.i16))
|
||||
ExpectTrue(t, ok)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, test.i16, int16(127))
|
||||
|
||||
ok, err = ConvertString(s, reflect.ValueOf(&test.i32))
|
||||
ExpectTrue(t, ok)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, test.i32, int32(127))
|
||||
|
||||
ok, err = ConvertString(s, reflect.ValueOf(&test.i64))
|
||||
ExpectTrue(t, ok)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, test.i64, int64(127))
|
||||
|
||||
ok, err = ConvertString(s, reflect.ValueOf(&test.u8))
|
||||
ExpectTrue(t, ok)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, test.u8, uint8(127))
|
||||
|
||||
ok, err = ConvertString(s, reflect.ValueOf(&test.u16))
|
||||
ExpectTrue(t, ok)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, test.u16, uint16(127))
|
||||
|
||||
ok, err = ConvertString(s, reflect.ValueOf(&test.u32))
|
||||
ExpectTrue(t, ok)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, test.u32, uint32(127))
|
||||
|
||||
ok, err = ConvertString(s, reflect.ValueOf(&test.u64))
|
||||
ExpectTrue(t, ok)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, test.u64, uint64(127))
|
||||
}
|
||||
|
||||
type testModel struct {
|
||||
Test testType
|
||||
Baz string
|
||||
}
|
||||
|
||||
type testType struct {
|
||||
foo int
|
||||
bar string
|
||||
}
|
||||
|
||||
func (c *testType) Parse(v string) (err error) {
|
||||
c.bar = v
|
||||
c.foo, err = strconv.Atoi(v)
|
||||
return
|
||||
}
|
||||
|
||||
func TestConvertor(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
m := new(testModel)
|
||||
ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Test": "123"}, m))
|
||||
|
||||
ExpectEqual(t, m.Test.foo, 123)
|
||||
ExpectEqual(t, m.Test.bar, "123")
|
||||
})
|
||||
|
||||
t.Run("int_to_string", func(t *testing.T) {
|
||||
m := new(testModel)
|
||||
ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Test": "123"}, m))
|
||||
|
||||
ExpectEqual(t, m.Test.foo, 123)
|
||||
ExpectEqual(t, m.Test.bar, "123")
|
||||
|
||||
ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Baz": 123}, m))
|
||||
ExpectEqual(t, m.Baz, "123")
|
||||
})
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
m := new(testModel)
|
||||
err := MapUnmarshalValidate(map[string]any{"Test": struct{ a int }{1}}, m)
|
||||
ExpectError(t, ErrUnsupportedConversion, err)
|
||||
})
|
||||
|
||||
t.Run("set_empty", func(t *testing.T) {
|
||||
m := testModel{
|
||||
Test: testType{1, "2"},
|
||||
Baz: "3",
|
||||
}
|
||||
ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Test": nil, "Baz": nil}, &m))
|
||||
ExpectEqual(t, m, testModel{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestStringToSlice(t *testing.T) {
|
||||
t.Run("comma_separated", func(t *testing.T) {
|
||||
dst := make([]string, 0)
|
||||
convertible, err := ConvertString("a,b,c", reflect.ValueOf(&dst))
|
||||
ExpectTrue(t, convertible)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, dst, []string{"a", "b", "c"})
|
||||
})
|
||||
t.Run("yaml-like", func(t *testing.T) {
|
||||
dst := make([]string, 0)
|
||||
convertible, err := ConvertString("- a\n- b\n- c", reflect.ValueOf(&dst))
|
||||
ExpectTrue(t, convertible)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, dst, []string{"a", "b", "c"})
|
||||
})
|
||||
t.Run("single-line-yaml-like", func(t *testing.T) {
|
||||
dst := make([]string, 0)
|
||||
convertible, err := ConvertString("- a", reflect.ValueOf(&dst))
|
||||
ExpectTrue(t, convertible)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, dst, []string{"a"})
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkStringToSlice(b *testing.B) {
|
||||
for range b.N {
|
||||
dst := make([]int, 0)
|
||||
_, _ = ConvertString("- 1\n- 2\n- 3", reflect.ValueOf(&dst))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStringToSliceYAML(b *testing.B) {
|
||||
for range b.N {
|
||||
dst := make([]int, 0)
|
||||
_ = yaml.Unmarshal([]byte("- 1\n- 2\n- 3"), &dst)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringToMap(t *testing.T) {
|
||||
t.Run("yaml-like", func(t *testing.T) {
|
||||
dst := make(map[string]string)
|
||||
convertible, err := ConvertString(" a: b\n c: d", reflect.ValueOf(&dst))
|
||||
ExpectTrue(t, convertible)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, dst, map[string]string{"a": "b", "c": "d"})
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkStringToMap(b *testing.B) {
|
||||
for range b.N {
|
||||
dst := make(map[string]string)
|
||||
_, _ = ConvertString(" a: b\n c: d", reflect.ValueOf(&dst))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStringToMapYAML(b *testing.B) {
|
||||
for range b.N {
|
||||
dst := make(map[string]string)
|
||||
_ = yaml.Unmarshal([]byte(" a: b\n c: d"), &dst)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringToStruct(t *testing.T) {
|
||||
t.Run("yaml-like", func(t *testing.T) {
|
||||
dst := struct {
|
||||
A string
|
||||
B int
|
||||
}{}
|
||||
convertible, err := ConvertString(" A: a\n B: 123", reflect.ValueOf(&dst))
|
||||
ExpectTrue(t, convertible)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, dst, struct {
|
||||
A string
|
||||
B int
|
||||
}{"a", 123})
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkStringToStruct(b *testing.B) {
|
||||
for range b.N {
|
||||
dst := struct {
|
||||
A string `json:"a"`
|
||||
B int `json:"b"`
|
||||
}{}
|
||||
_, _ = ConvertString(" a: a\n b: 123", reflect.ValueOf(&dst))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStringToStructYAML(b *testing.B) {
|
||||
for range b.N {
|
||||
dst := struct {
|
||||
A string `yaml:"a"`
|
||||
B int `yaml:"b"`
|
||||
}{}
|
||||
_ = yaml.Unmarshal([]byte(" a: a\n b: 123"), &dst)
|
||||
}
|
||||
}
|
||||
25
internal/serialization/validation.go
Normal file
25
internal/serialization/validation.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package serialization
|
||||
|
||||
import (
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
)
|
||||
|
||||
var validate = validator.New()
|
||||
|
||||
var ErrValidationError = gperr.New("validation error")
|
||||
|
||||
type CustomValidator interface {
|
||||
Validate() gperr.Error
|
||||
}
|
||||
|
||||
func Validator() *validator.Validate {
|
||||
return validate
|
||||
}
|
||||
|
||||
func MustRegisterValidation(tag string, fn validator.Func) {
|
||||
err := validate.RegisterValidation(tag, fn)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user