diff --git a/internal/serialization/serialization.go b/internal/serialization/serialization.go index 57eb4771..1a4f153e 100644 --- a/internal/serialization/serialization.go +++ b/internal/serialization/serialization.go @@ -89,31 +89,35 @@ func ValidateWithFieldTags(s any) gperr.Error { return errs.Error() } +var validatorType = reflect.TypeFor[CustomValidator]() + func ValidateWithCustomValidator(v reflect.Value) gperr.Error { - isStruct := false - for { - switch v.Kind() { - case reflect.Pointer, reflect.Interface: - if v.IsNil() { - return gperr.Errorf("validate: v is %w", ErrNilValue) - } - if validate, ok := v.Interface().(CustomValidator); ok { - return validate.Validate() - } - if isStruct { - return nil - } - v = v.Elem() - case reflect.Struct: - if !v.CanAddr() { - return nil - } - v = v.Addr() - isStruct = true - default: + if v.Kind() == reflect.Struct { + if v.Type().Implements(validatorType) { + return v.Interface().(CustomValidator).Validate() + } + if v.CanAddr() { + return validateWithValidator(v.Addr()) + } + return nil + } + if v.Kind() == reflect.Pointer { + if v.IsNil() { return nil } + if v.Type().Implements(validatorType) { + return v.Interface().(CustomValidator).Validate() + } + return validateWithValidator(v.Elem()) } + return nil +} + +func validateWithValidator(v reflect.Value) gperr.Error { + if v.Type().Implements(validatorType) { + return v.Interface().(CustomValidator).Validate() + } + return nil } func dive(dst reflect.Value) (v reflect.Value, t reflect.Type, err gperr.Error) {