From feafdf05f2c08b3549ea87c16914fba1f168d5fe Mon Sep 17 00:00:00 2001 From: yusing Date: Wed, 15 Oct 2025 14:20:47 +0800 Subject: [PATCH] fix(validation): correct CustomValidator and strutils.Parser handling, add tests --- internal/serialization/serialization.go | 43 +----- internal/serialization/validation.go | 43 +++++- .../serialization/validation_common_test.go | 34 +++++ .../serialization/validation_mismatch_test.go | 126 ++++++++++++++++++ .../validation_string_ptr_test.go | 93 +++++++++++++ .../serialization/validation_string_test.go | 85 ++++++++++++ .../validation_struct_ptr_test.go | 69 ++++++++++ .../serialization/validation_struct_test.go | 66 +++++++++ internal/serialization/validation_test.go | 1 + 9 files changed, 519 insertions(+), 41 deletions(-) create mode 100644 internal/serialization/validation_common_test.go create mode 100644 internal/serialization/validation_mismatch_test.go create mode 100644 internal/serialization/validation_string_ptr_test.go create mode 100644 internal/serialization/validation_string_test.go create mode 100644 internal/serialization/validation_struct_ptr_test.go create mode 100644 internal/serialization/validation_struct_test.go create mode 100644 internal/serialization/validation_test.go diff --git a/internal/serialization/serialization.go b/internal/serialization/serialization.go index 5ea44243..cdcbe82a 100644 --- a/internal/serialization/serialization.go +++ b/internal/serialization/serialization.go @@ -90,37 +90,6 @@ func ValidateWithFieldTags(s any) gperr.Error { return errs.Error() } -var validatorType = reflect.TypeFor[CustomValidator]() - -func ValidateWithCustomValidator(v reflect.Value) gperr.Error { - if v.Kind() == reflect.Struct { - if v.Type().Implements(validatorType) { - return v.Interface().(CustomValidator).Validate() - } - if v.CanAddr() { - return validateWithValidator(v.Addr()) - } - return nil - } - if v.Kind() == reflect.Pointer { - if v.IsNil() { - return nil - } - if v.Type().Implements(validatorType) { - return v.Interface().(CustomValidator).Validate() - } - return validateWithValidator(v.Elem()) - } - return nil -} - -func validateWithValidator(v reflect.Value) gperr.Error { - if v.Type().Implements(validatorType) { - return v.Interface().(CustomValidator).Validate() - } - return nil -} - func dive(dst reflect.Value) (v reflect.Value, t reflect.Type, err gperr.Error) { dstT := dst.Type() for { @@ -529,6 +498,12 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gpe default: } + // check if (*T).Convertor is implemented + if dst.Addr().Type().Implements(parserType) { + parser := dst.Addr().Interface().(strutils.Parser) + return true, gperr.Wrap(parser.Parse(src)) + } + if gi.ReflectIsNumeric(dst) || dst.Kind() == reflect.Bool { err := gi.ReflectStrToNumBool(dst, src) if err != nil { @@ -537,12 +512,6 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gpe return true, nil } - // check if (*T).Convertor is implemented - if dst.Addr().Type().Implements(parserType) { - parser := dst.Addr().Interface().(strutils.Parser) - return true, gperr.Wrap(parser.Parse(src)) - } - // yaml like var tmp any switch dst.Kind() { diff --git a/internal/serialization/validation.go b/internal/serialization/validation.go index 058fa9a0..650ca740 100644 --- a/internal/serialization/validation.go +++ b/internal/serialization/validation.go @@ -1,6 +1,8 @@ package serialization import ( + "reflect" + "github.com/go-playground/validator/v10" gperr "github.com/yusing/goutils/errs" ) @@ -9,10 +11,6 @@ var validate = validator.New() var ErrValidationError = gperr.New("validation error") -type CustomValidator interface { - Validate() gperr.Error -} - func Validator() *validator.Validate { return validate } @@ -23,3 +21,40 @@ func MustRegisterValidation(tag string, fn validator.Func) { panic(err) } } + +type CustomValidator interface { + Validate() gperr.Error +} + +var validatorType = reflect.TypeFor[CustomValidator]() + +func ValidateWithCustomValidator(v reflect.Value) gperr.Error { + if v.Kind() == reflect.Pointer { + if v.IsNil() { + // return nil + return validateWithValidator(reflect.New(v.Type().Elem())) + } + if v.Type().Implements(validatorType) { + return v.Interface().(CustomValidator).Validate() + } + return validateWithValidator(v.Elem()) + } else { + vt := v.Type() + if vt.PkgPath() != "" { // not a builtin type + if vt.Implements(validatorType) { + return v.Interface().(CustomValidator).Validate() + } + if v.CanAddr() { + return validateWithValidator(v.Addr()) + } + } + } + return nil +} + +func validateWithValidator(v reflect.Value) gperr.Error { + if v.Type().Implements(validatorType) { + return v.Interface().(CustomValidator).Validate() + } + return nil +} diff --git a/internal/serialization/validation_common_test.go b/internal/serialization/validation_common_test.go new file mode 100644 index 00000000..739a3f4c --- /dev/null +++ b/internal/serialization/validation_common_test.go @@ -0,0 +1,34 @@ +package serialization + +import ( + "testing" + + "github.com/go-playground/validator/v10" +) + +// Common helper functions +func ptr[T any](s T) *T { + return &s +} + +// Common test function for MustRegisterValidation +func TestMustRegisterValidation(t *testing.T) { + // Test registering a custom validation + fn := func(fl validator.FieldLevel) bool { + return fl.Field().String() != "invalid" + } + + // This should not panic + MustRegisterValidation("test_tag", fn) + + // Verify the validation was registered + err := validate.VarWithValue("valid", "test", "test_tag") + if err != nil { + t.Errorf("Expected validation to pass, got error: %v", err) + } + + err = validate.VarWithValue("invalid", "test", "test_tag") + if err == nil { + t.Error("Expected validation to fail") + } +} diff --git a/internal/serialization/validation_mismatch_test.go b/internal/serialization/validation_mismatch_test.go new file mode 100644 index 00000000..15aad4ca --- /dev/null +++ b/internal/serialization/validation_mismatch_test.go @@ -0,0 +1,126 @@ +package serialization + +import ( + "reflect" + "testing" + + gperr "github.com/yusing/goutils/errs" +) + +// Test cases for when *T implements CustomValidator but T is passed in +type CustomValidatingInt int + +func (c *CustomValidatingInt) Validate() gperr.Error { + if c == nil { + return gperr.New("pointer int cannot be nil") + } + if *c <= 0 { + return gperr.New("int must be positive") + } + if *c > 100 { + return gperr.New("int must be <= 100") + } + return nil +} + +// Test cases for when T implements CustomValidator but *T is passed in +type CustomValidatingFloat float64 + +func (c CustomValidatingFloat) Validate() gperr.Error { + if c < 0 { + return gperr.New("float must be non-negative") + } + if c > 1000 { + return gperr.New("float must be <= 1000") + } + return nil +} + +func TestValidateWithCustomValidator_PointerMethodButValuePassed(t *testing.T) { + tests := []struct { + name string + input CustomValidatingInt + wantErr bool + }{ + {"custom validating int as value - valid", CustomValidatingInt(50), false}, + {"custom validating int as value - zero", CustomValidatingInt(0), false}, + {"custom validating int as value - negative", CustomValidatingInt(-5), false}, + {"custom validating int as value - large", CustomValidatingInt(200), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateWithCustomValidator(reflect.ValueOf(tt.input)) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateWithCustomValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateWithCustomValidator_PointerMethodWithPointerPassed(t *testing.T) { + tests := []struct { + name string + input *CustomValidatingInt + wantErr bool + }{ + {"valid custom validating int pointer", ptr(CustomValidatingInt(50)), false}, + {"nil custom validating int pointer", nil, true}, // Should fail because Validate() checks for nil + {"invalid custom validating int pointer - zero", ptr(CustomValidatingInt(0)), true}, + {"invalid custom validating int pointer - negative", ptr(CustomValidatingInt(-5)), true}, + {"invalid custom validating int pointer - too large", ptr(CustomValidatingInt(200)), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateWithCustomValidator(reflect.ValueOf(tt.input)) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateWithCustomValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateWithCustomValidator_ValueMethodButPointerPassed(t *testing.T) { + tests := []struct { + name string + input *CustomValidatingFloat + wantErr bool + }{ + {"valid custom validating float pointer", ptr(CustomValidatingFloat(50.5)), false}, + {"nil custom validating float pointer", nil, false}, + {"invalid custom validating float pointer - negative", ptr(CustomValidatingFloat(-5.5)), true}, + {"invalid custom validating float pointer - too large", ptr(CustomValidatingFloat(2000.5)), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateWithCustomValidator(reflect.ValueOf(tt.input)) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateWithCustomValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateWithCustomValidator_ValueMethodWithValuePassed(t *testing.T) { + tests := []struct { + name string + input CustomValidatingFloat + wantErr bool + }{ + {"valid custom validating float", CustomValidatingFloat(50.5), false}, + {"invalid custom validating float - negative", CustomValidatingFloat(-5.5), true}, + {"invalid custom validating float - too large", CustomValidatingFloat(2000.5), true}, + {"valid custom validating float - boundary", CustomValidatingFloat(1000), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateWithCustomValidator(reflect.ValueOf(tt.input)) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateWithCustomValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/internal/serialization/validation_string_ptr_test.go b/internal/serialization/validation_string_ptr_test.go new file mode 100644 index 00000000..1de98ecc --- /dev/null +++ b/internal/serialization/validation_string_ptr_test.go @@ -0,0 +1,93 @@ +package serialization + +import ( + "reflect" + "testing" + + gperr "github.com/yusing/goutils/errs" +) + +type CustomValidatingPointerString string + +func (c *CustomValidatingPointerString) Validate() gperr.Error { + if c == nil { + return gperr.New("pointer string cannot be nil") + } + if *c == "" { + return gperr.New("string cannot be empty") + } + if len(*c) < 2 { + return gperr.New("string must be at least 2 characters") + } + return nil +} + +func TestValidateWithCustomValidator_StringPointer(t *testing.T) { + tests := []struct { + name string + input *string + wantErr bool + }{ + {"valid string pointer", ptr("hello"), false}, + {"nil string pointer", nil, false}, + {"empty string pointer", ptr(""), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateWithCustomValidator(reflect.ValueOf(tt.input)) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateWithCustomValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateWithCustomValidator_CustomValidatingPointerStringValue(t *testing.T) { + tests := []struct { + name string + input CustomValidatingPointerString + wantErr bool + }{ + {"custom validating pointer string as value - valid", CustomValidatingPointerString("hello"), false}, + {"custom validating pointer string as value - empty", CustomValidatingPointerString(""), false}, + {"custom validating pointer string as value - short", CustomValidatingPointerString("a"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateWithCustomValidator(reflect.ValueOf(tt.input)) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateWithCustomValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateWithCustomValidator_CustomValidatingPointerStringPointer(t *testing.T) { + tests := []struct { + name string + input *CustomValidatingPointerString + wantErr bool + }{ + {"valid custom validating pointer string", customStringPointerPtr(CustomValidatingPointerString("hello")), false}, + {"nil custom validating pointer string", nil, true}, // Should fail because Validate() checks for nil + {"invalid custom validating pointer string - empty", customStringPointerPtr(CustomValidatingPointerString("")), true}, + {"invalid custom validating pointer string - too short", customStringPointerPtr(CustomValidatingPointerString("a")), true}, + {"valid custom validating pointer string - minimum length", customStringPointerPtr(CustomValidatingPointerString("ab")), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateWithCustomValidator(reflect.ValueOf(tt.input)) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateWithCustomValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// Helper function to create CustomValidatingPointerString pointer +func customStringPointerPtr(s CustomValidatingPointerString) *CustomValidatingPointerString { + return &s +} diff --git a/internal/serialization/validation_string_test.go b/internal/serialization/validation_string_test.go new file mode 100644 index 00000000..b432492b --- /dev/null +++ b/internal/serialization/validation_string_test.go @@ -0,0 +1,85 @@ +package serialization + +import ( + "reflect" + "testing" + + gperr "github.com/yusing/goutils/errs" +) + +type CustomValidatingString string + +func (c CustomValidatingString) Validate() gperr.Error { + if c == "" { + return gperr.New("string cannot be empty") + } + if len(c) < 2 { + return gperr.New("string must be at least 2 characters") + } + return nil +} + +func TestValidateWithCustomValidator_String(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + {"regular string - no custom validation", "hello", false}, + {"empty regular string - no custom validation", "", false}, + {"short regular string - no custom validation", "a", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateWithCustomValidator(reflect.ValueOf(tt.input)) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateWithCustomValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateWithCustomValidator_CustomValidatingString(t *testing.T) { + tests := []struct { + name string + input CustomValidatingString + wantErr bool + }{ + {"valid custom validating string", CustomValidatingString("hello"), false}, + {"invalid custom validating string - empty", CustomValidatingString(""), true}, + {"invalid custom validating string - too short", CustomValidatingString("a"), true}, + {"valid custom validating string - minimum length", CustomValidatingString("ab"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateWithCustomValidator(reflect.ValueOf(tt.input)) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateWithCustomValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateWithCustomValidator_CustomValidatingStringPointer(t *testing.T) { + tests := []struct { + name string + input *CustomValidatingString + wantErr bool + }{ + {"valid custom validating string pointer", ptr(CustomValidatingString("hello")), false}, + {"nil custom validating string pointer", nil, true}, + {"invalid custom validating string pointer - empty", ptr(CustomValidatingString("")), true}, + {"invalid custom validating string pointer - too short", ptr(CustomValidatingString("a")), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateWithCustomValidator(reflect.ValueOf(tt.input)) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateWithCustomValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/internal/serialization/validation_struct_ptr_test.go b/internal/serialization/validation_struct_ptr_test.go new file mode 100644 index 00000000..6d09646b --- /dev/null +++ b/internal/serialization/validation_struct_ptr_test.go @@ -0,0 +1,69 @@ +package serialization + +import ( + "reflect" + "testing" + + gperr "github.com/yusing/goutils/errs" +) + +type CustomValidatingPointerStruct struct { + Value string +} + +func (c *CustomValidatingPointerStruct) Validate() gperr.Error { + if c == nil { + return gperr.New("pointer struct cannot be nil") + } + if c.Value == "" { + return gperr.New("value cannot be empty") + } + if len(c.Value) < 3 { + return gperr.New("value must be at least 3 characters") + } + return nil +} + +func TestValidateWithCustomValidator_CustomValidatingPointerStructValue(t *testing.T) { + tests := []struct { + name string + input CustomValidatingPointerStruct + wantErr bool + }{ + {"custom validating pointer struct as value - valid", CustomValidatingPointerStruct{Value: "hello"}, false}, + {"custom validating pointer struct as value - empty", CustomValidatingPointerStruct{Value: ""}, false}, + {"custom validating pointer struct as value - short", CustomValidatingPointerStruct{Value: "hi"}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateWithCustomValidator(reflect.ValueOf(tt.input)) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateWithCustomValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateWithCustomValidator_CustomValidatingPointerStructPointer(t *testing.T) { + tests := []struct { + name string + input *CustomValidatingPointerStruct + wantErr bool + }{ + {"valid custom validating pointer struct", &CustomValidatingPointerStruct{Value: "hello"}, false}, + {"nil custom validating pointer struct", nil, true}, // Should fail because Validate() checks for nil + {"invalid custom validating pointer struct - empty", &CustomValidatingPointerStruct{Value: ""}, true}, + {"invalid custom validating pointer struct - too short", &CustomValidatingPointerStruct{Value: "hi"}, true}, + {"valid custom validating pointer struct - minimum length", &CustomValidatingPointerStruct{Value: "abc"}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateWithCustomValidator(reflect.ValueOf(tt.input)) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateWithCustomValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/internal/serialization/validation_struct_test.go b/internal/serialization/validation_struct_test.go new file mode 100644 index 00000000..37fef26e --- /dev/null +++ b/internal/serialization/validation_struct_test.go @@ -0,0 +1,66 @@ +package serialization + +import ( + "reflect" + "testing" + + gperr "github.com/yusing/goutils/errs" +) + +type CustomValidatingStruct struct { + Value string +} + +func (c CustomValidatingStruct) Validate() gperr.Error { + if c.Value == "" { + return gperr.New("value cannot be empty") + } + if len(c.Value) < 3 { + return gperr.New("value must be at least 3 characters") + } + return nil +} + +func TestValidateWithCustomValidator_Struct(t *testing.T) { + tests := []struct { + name string + input CustomValidatingStruct + wantErr bool + }{ + {"valid custom validating struct", CustomValidatingStruct{Value: "hello"}, false}, + {"invalid custom validating struct - empty", CustomValidatingStruct{Value: ""}, true}, + {"invalid custom validating struct - too short", CustomValidatingStruct{Value: "hi"}, true}, + {"valid custom validating struct - minimum length", CustomValidatingStruct{Value: "abc"}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateWithCustomValidator(reflect.ValueOf(tt.input)) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateWithCustomValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateWithCustomValidator_CustomValidatingStructPointer(t *testing.T) { + tests := []struct { + name string + input *CustomValidatingStruct + wantErr bool + }{ + {"valid custom validating struct pointer", &CustomValidatingStruct{Value: "hello"}, false}, + {"nil custom validating struct pointer", nil, true}, + {"invalid custom validating struct pointer - empty", &CustomValidatingStruct{Value: ""}, true}, + {"invalid custom validating struct pointer - too short", &CustomValidatingStruct{Value: "hi"}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateWithCustomValidator(reflect.ValueOf(tt.input)) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateWithCustomValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/internal/serialization/validation_test.go b/internal/serialization/validation_test.go new file mode 100644 index 00000000..957a87f2 --- /dev/null +++ b/internal/serialization/validation_test.go @@ -0,0 +1 @@ +package serialization