refactor: simplify io code and make utils module independent

This commit is contained in:
yusing
2025-05-23 22:19:14 +08:00
parent ff08c40403
commit f1e204f7fd
24 changed files with 124 additions and 73 deletions

21
internal/utils/go.mod Normal file
View File

@@ -0,0 +1,21 @@
module github.com/yusing/go-proxy/internal/utils
go 1.24.3
require (
github.com/goccy/go-yaml v1.17.1
github.com/puzpuzpuz/xsync/v4 v4.1.0
github.com/rs/zerolog v1.34.0
github.com/stretchr/testify v1.10.0
go.uber.org/atomic v1.11.0
golang.org/x/text v0.25.0
)
require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
golang.org/x/sys v0.33.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

36
internal/utils/go.sum Normal file
View File

@@ -0,0 +1,36 @@
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/goccy/go-yaml v1.17.1 h1:LI34wktB2xEE3ONG/2Ar54+/HJVBriAGJ55PHls4YuY=
github.com/goccy/go-yaml v1.17.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/puzpuzpuz/xsync/v4 v4.1.0 h1:x9eHRl4QhZFIPJ17yl4KKW9xLyVWbb3/Yq4SXpjF71U=
github.com/puzpuzpuz/xsync/v4 v4.1.0/go.mod h1:VJDmTCJMBt8igNxnkQd86r+8KUeN1quSfNKu5bLYFQo=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -8,7 +8,6 @@ import (
"sync"
"syscall"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils/synk"
)
@@ -91,20 +90,20 @@ func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.Re
}
}
func (p BidirectionalPipe) Start() gperr.Error {
func (p BidirectionalPipe) Start() error {
var wg sync.WaitGroup
wg.Add(2)
b := gperr.NewBuilder("bidirectional pipe error")
var srcErr, dstErr error
go func() {
b.Add(p.pSrcDst.Start())
srcErr = p.pSrcDst.Start()
wg.Done()
}()
go func() {
b.Add(p.pDstSrc.Start())
dstErr = p.pDstSrc.Start()
wg.Done()
}()
wg.Wait()
return b.Error()
return errors.Join(srcErr, dstErr)
}
type httpFlusher interface {
@@ -143,30 +142,18 @@ func CopyClose(dst *ContextWriter, src *ContextReader) (err error) {
wCloser, wCanClose := dst.Writer.(io.Closer)
rCloser, rCanClose := src.Reader.(io.Closer)
if wCanClose || rCanClose {
if src.ctx == dst.ctx {
go func() {
<-src.ctx.Done()
if wCanClose {
wCloser.Close()
}
if rCanClose {
rCloser.Close()
}
}()
} else {
if wCloser != nil {
go func() {
<-src.ctx.Done()
wCloser.Close()
}()
go func() {
select {
case <-src.ctx.Done():
case <-dst.ctx.Done():
}
if rCloser != nil {
go func() {
<-dst.ctx.Done()
rCloser.Close()
}()
if rCanClose {
defer rCloser.Close()
}
}
if wCanClose {
defer wCloser.Close()
}
}()
}
flusher := getHTTPFlusher(dst.Writer)
canFlush := flusher != nil

View File

@@ -4,7 +4,7 @@ import (
"sort"
"github.com/puzpuzpuz/xsync/v4"
"github.com/yusing/go-proxy/internal/logging"
"github.com/rs/zerolog/log"
)
type (
@@ -29,12 +29,12 @@ func (p Pool[T]) Name() string {
func (p Pool[T]) Add(obj T) {
p.checkExists(obj.Key())
p.m.Store(obj.Key(), obj)
logging.Info().Msgf("%s: added %s", p.name, obj.Name())
log.Info().Msgf("%s: added %s", p.name, obj.Name())
}
func (p Pool[T]) Del(obj T) {
p.m.Delete(obj.Key())
logging.Info().Msgf("%s: removed %s", p.name, obj.Name())
log.Info().Msgf("%s: removed %s", p.name, obj.Name())
}
func (p Pool[T]) Get(key string) (T, bool) {

View File

@@ -5,11 +5,11 @@ package pool
import (
"runtime/debug"
"github.com/yusing/go-proxy/internal/logging"
"github.com/rs/zerolog/log"
)
func (p Pool[T]) checkExists(key string) {
if _, ok := p.m.Load(key); ok {
logging.Warn().Msgf("%s: key %s already exists\nstacktrace: %s", p.name, key, string(debug.Stack()))
log.Warn().Msgf("%s: key %s already exists\nstacktrace: %s", p.name, key, string(debug.Stack()))
}
}

View File

@@ -1,564 +0,0 @@
package utils
import (
"encoding/json"
"errors"
"os"
"reflect"
"runtime/debug"
"strconv"
"strings"
"time"
"github.com/go-playground/validator/v10"
"github.com/goccy/go-yaml"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type SerializedObject = map[string]any
type MapUnmarshaller interface {
UnmarshalMap(m map[string]any) gperr.Error
}
var (
ErrInvalidType = gperr.New("invalid type")
ErrNilValue = gperr.New("nil")
ErrUnsettable = gperr.New("unsettable")
ErrUnsupportedConversion = gperr.New("unsupported conversion")
ErrUnknownField = gperr.New("unknown field")
)
var (
tagDeserialize = "deserialize" // `deserialize:"-"` to exclude from deserialization
tagJSON = "json" // share between Deserialize and json.Marshal
tagValidate = "validate" // uses go-playground/validator
tagAliases = "aliases" // declare aliases for fields
)
var mapUnmarshalerType = reflect.TypeFor[MapUnmarshaller]()
var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
func RegisterDefaultValueFactory[T any](factory func() *T) {
t := reflect.TypeFor[T]()
if t.Kind() == reflect.Ptr {
panic("pointer of pointer")
}
if defaultValues.Has(t) {
panic("default value for " + t.String() + " already registered")
}
defaultValues.Store(t, func() any { return factory() })
}
func New(t reflect.Type) reflect.Value {
if dv, ok := defaultValues.Load(t); ok {
return reflect.ValueOf(dv())
}
return reflect.New(t)
}
func extractFields(t reflect.Type) (all, anonymous []reflect.StructField) {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return nil, nil
}
n := t.NumField()
fields := make([]reflect.StructField, 0, n)
for i := range n {
field := t.Field(i)
if !field.IsExported() {
continue
}
if field.Tag.Get(tagDeserialize) == "-" {
continue
}
if field.Anonymous {
f1, f2 := extractFields(field.Type)
fields = append(fields, f1...)
anonymous = append(anonymous, field)
anonymous = append(anonymous, f2...)
} else {
fields = append(fields, field)
}
}
return fields, anonymous
}
func ValidateWithFieldTags(s any) gperr.Error {
errs := gperr.NewBuilder()
err := validate.Struct(s)
var valErrs validator.ValidationErrors
if errors.As(err, &valErrs) {
for _, e := range valErrs {
detail := e.ActualTag()
if e.Param() != "" {
detail += ":" + e.Param()
}
if detail != "required" {
detail = "require " + strconv.Quote(detail)
}
errs.Add(ErrValidationError.
Subject(e.Namespace()).
Withf(detail))
}
}
return errs.Error()
}
func ValidateWithCustomValidator(v reflect.Value) gperr.Error {
isStruct := false
for {
switch v.Kind() {
case reflect.Pointer, reflect.Interface:
if v.IsNil() {
return gperr.Errorf("validate: v is %w", ErrNilValue)
}
if validate, ok := v.Interface().(CustomValidator); ok {
return validate.Validate()
}
if isStruct {
return nil
}
v = v.Elem()
case reflect.Struct:
if !v.CanAddr() {
return nil
}
v = v.Addr()
isStruct = true
default:
return nil
}
}
}
func dive(dst reflect.Value) (v reflect.Value, t reflect.Type, err gperr.Error) {
dstT := dst.Type()
for {
switch dst.Kind() {
case reflect.Pointer, reflect.Interface:
if dst.IsNil() {
if !dst.CanSet() {
err = gperr.Errorf("dive: dst is %w and is not settable", ErrNilValue)
return
}
dst.Set(New(dstT.Elem()))
}
dst = dst.Elem()
dstT = dst.Type()
case reflect.Struct:
return dst, dstT, nil
default:
if dst.IsNil() {
switch dst.Kind() {
case reflect.Map:
dst.Set(reflect.MakeMap(dstT))
case reflect.Slice:
dst.Set(reflect.MakeSlice(dstT, 0, 0))
default:
err = gperr.Errorf("deserialize: %w for dst %s", ErrInvalidType, dstT.String())
return
}
}
return dst, dstT, nil
}
}
}
// MapUnmarshalValidate takes a SerializedObject and a target value, and assigns the values in the SerializedObject to the target value.
// MapUnmarshalValidate ignores case differences between the field names in the SerializedObject and the target.
//
// The target value must be a struct or a map[string]any.
// If the target value is a struct , and implements the MapUnmarshaller interface,
// the UnmarshalMap method will be called.
//
// If the target value is a struct, but does not implements the MapUnmarshaller interface,
// the SerializedObject will be deserialized into the struct fields and validate if needed.
//
// If the target value is a map[string]any the SerializedObject will be deserialized into the map.
//
// The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization.
func MapUnmarshalValidate(src SerializedObject, dst any) (err gperr.Error) {
return mapUnmarshalValidate(src, dst, true)
}
func mapUnmarshalValidate(src SerializedObject, dst any, checkValidateTag bool) (err gperr.Error) {
dstV := reflect.ValueOf(dst)
dstT := dstV.Type()
if src == nil {
if dstV.CanSet() {
dstV.Set(reflect.Zero(dstT))
return nil
}
return gperr.Errorf("deserialize: src is %w and dst is not settable\n%s", ErrNilValue, debug.Stack())
}
if dstT.Implements(mapUnmarshalerType) {
dstV, _, err = dive(dstV)
if err != nil {
return err
}
return dstV.Addr().Interface().(MapUnmarshaller).UnmarshalMap(src)
}
dstV, dstT, err = dive(dstV)
if err != nil {
return err
}
// convert data fields to lower no-snake
// convert target fields to lower no-snake
// then check if the field of data is in the target
errs := gperr.NewBuilder()
switch dstV.Kind() {
case reflect.Struct, reflect.Interface:
hasValidateTag := false
mapping := make(map[string]reflect.Value)
fields, anonymous := extractFields(dstT)
for _, anon := range anonymous {
if field := dstV.FieldByName(anon.Name); field.Kind() == reflect.Ptr && field.IsNil() {
field.Set(New(anon.Type.Elem()))
}
}
for _, field := range fields {
var key string
if jsonTag, ok := field.Tag.Lookup(tagJSON); ok {
if jsonTag == "-" {
continue
}
key = strutils.CommaSeperatedList(jsonTag)[0]
} else {
key = field.Name
}
key = strutils.ToLowerNoSnake(key)
mapping[key] = dstV.FieldByName(field.Name)
if !hasValidateTag {
_, hasValidateTag = field.Tag.Lookup(tagValidate)
}
aliases, ok := field.Tag.Lookup(tagAliases)
if ok {
for _, alias := range strutils.CommaSeperatedList(aliases) {
mapping[alias] = dstV.FieldByName(field.Name)
}
}
}
for k, v := range src {
if field, ok := mapping[strutils.ToLowerNoSnake(k)]; ok {
err := Convert(reflect.ValueOf(v), field, !hasValidateTag)
if err != nil {
errs.Add(err.Subject(k))
}
} else {
errs.Add(ErrUnknownField.Subject(k).With(gperr.DoYouMean(NearestField(k, mapping))))
}
}
if hasValidateTag && checkValidateTag {
errs.Add(ValidateWithFieldTags(dstV.Interface()))
}
if err := ValidateWithCustomValidator(dstV); err != nil {
errs.Add(err)
}
return errs.Error()
case reflect.Map:
for k, v := range src {
mapVT := dstT.Elem()
tmp := New(mapVT).Elem()
err := Convert(reflect.ValueOf(v), tmp, true)
if err != nil {
errs.Add(err.Subject(k))
continue
}
if err := ValidateWithCustomValidator(tmp.Addr()); err != nil {
errs.Add(err.Subject(k))
} else {
dstV.SetMapIndex(reflect.ValueOf(k), tmp)
}
}
if err := ValidateWithCustomValidator(dstV); err != nil {
errs.Add(err)
}
return errs.Error()
default:
return ErrUnsupportedConversion.Subject("mapping to " + dstT.String() + " ")
}
}
func isIntFloat(t reflect.Kind) bool {
return t >= reflect.Bool && t <= reflect.Float64
}
// Convert attempts to convert the src to dst.
//
// If src is a map, it is deserialized into dst.
// If src is a slice, each of its elements are converted and stored in dst.
// For any other type, it is converted using the reflect.Value.Convert function (if possible).
//
// If dst is not settable, an error is returned.
// If src cannot be converted to dst, an error is returned.
// If any error occurs during conversion (e.g. deserialization), it is returned.
//
// Returns:
// - error: the error occurred during conversion, or nil if no error occurred.
func Convert(src reflect.Value, dst reflect.Value, checkValidateTag bool) gperr.Error {
if !dst.IsValid() {
return gperr.Errorf("convert: dst is %w", ErrNilValue)
}
if (src.Kind() == reflect.Pointer && src.IsNil()) || !src.IsValid() {
if !dst.CanSet() {
return gperr.Errorf("convert: src is %w", ErrNilValue)
}
// manually set nil
dst.Set(reflect.Zero(dst.Type()))
return nil
}
if src.IsZero() {
if !dst.CanSet() {
return gperr.Errorf("convert: src is %w", ErrNilValue)
}
switch dst.Kind() {
case reflect.Pointer, reflect.Interface:
dst.Set(reflect.New(dst.Type().Elem()))
default:
dst.Set(reflect.Zero(dst.Type()))
}
return nil
}
srcT := src.Type()
dstT := dst.Type()
if src.Kind() == reflect.Interface {
src = src.Elem()
srcT = src.Type()
}
if dst.Kind() == reflect.Pointer {
if dst.IsNil() {
dst.Set(New(dstT.Elem()))
}
dst = dst.Elem()
dstT = dst.Type()
}
srcKind := srcT.Kind()
switch {
case srcT.AssignableTo(dstT):
if !dst.CanSet() {
return ErrUnsettable.Subject(dstT.String())
}
dst.Set(src)
return nil
// case srcT.ConvertibleTo(dstT):
// dst.Set(src.Convert(dstT))
// return nil
case srcKind == reflect.String:
if !dst.CanSet() {
return ErrUnsettable.Subject(dstT.String())
}
if convertible, err := ConvertString(src.String(), dst); convertible {
return err
}
case isIntFloat(srcKind):
if !dst.CanSet() {
return ErrUnsettable.Subject(dstT.String())
}
var strV string
switch {
case src.CanInt():
strV = strconv.FormatInt(src.Int(), 10)
case srcKind == reflect.Bool:
strV = strconv.FormatBool(src.Bool())
case src.CanUint():
strV = strconv.FormatUint(src.Uint(), 10)
case src.CanFloat():
strV = strconv.FormatFloat(src.Float(), 'f', -1, 64)
}
if convertible, err := ConvertString(strV, dst); convertible {
return err
}
case srcKind == reflect.Map:
if src.Len() == 0 {
return nil
}
obj, ok := src.Interface().(SerializedObject)
if !ok {
return ErrUnsupportedConversion.Subject(dstT.String() + " to " + srcT.String())
}
return mapUnmarshalValidate(obj, dst.Addr().Interface(), checkValidateTag)
case srcKind == reflect.Slice:
if src.Len() == 0 {
return nil
}
if dstT.Kind() != reflect.Slice {
return ErrUnsupportedConversion.Subject(dstT.String() + " to " + srcT.String())
}
sliceErrs := gperr.NewBuilder()
newSlice := reflect.MakeSlice(dstT, src.Len(), src.Len())
i := 0
for j, v := range src.Seq2() {
tmp := New(dstT.Elem()).Elem()
err := Convert(v, tmp, checkValidateTag)
if err != nil {
sliceErrs.Add(err.Subjectf("[%d]", j))
continue
}
newSlice.Index(i).Set(tmp)
i++
}
if err := sliceErrs.Error(); err != nil {
return err
}
dst.Set(newSlice)
return nil
}
return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT)
}
func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gperr.Error) {
convertible = true
dstT := dst.Type()
if dst.Kind() == reflect.Ptr {
if dst.IsNil() {
dst.Set(New(dstT.Elem()))
}
dst = dst.Elem()
dstT = dst.Type()
}
if dst.Kind() == reflect.String {
dst.SetString(src)
return
}
switch dstT {
case reflect.TypeFor[time.Duration]():
if src == "" {
dst.Set(reflect.Zero(dstT))
return
}
d, err := time.ParseDuration(src)
if err != nil {
return true, gperr.Wrap(err)
}
dst.Set(reflect.ValueOf(d))
return
default:
}
if dstKind := dst.Kind(); isIntFloat(dstKind) {
var i any
var err error
switch {
case dstKind == reflect.Bool:
i, err = strconv.ParseBool(src)
case dst.CanInt():
i, err = strconv.ParseInt(src, 10, dstT.Bits())
case dst.CanUint():
i, err = strconv.ParseUint(src, 10, dstT.Bits())
case dst.CanFloat():
i, err = strconv.ParseFloat(src, dstT.Bits())
}
if err != nil {
return true, gperr.Wrap(err)
}
dst.Set(reflect.ValueOf(i).Convert(dstT))
return
}
// check if (*T).Convertor is implemented
if parser, ok := dst.Addr().Interface().(strutils.Parser); ok {
return true, gperr.Wrap(parser.Parse(src))
}
// yaml like
var tmp any
switch dst.Kind() {
case reflect.Slice:
src = strings.TrimSpace(src)
isMultiline := strings.ContainsRune(src, '\n')
// one liner is comma separated list
if !isMultiline && src[0] != '-' {
values := strutils.CommaSeperatedList(src)
dst.Set(reflect.MakeSlice(dst.Type(), len(values), len(values)))
errs := gperr.NewBuilder()
for i, v := range values {
err := Convert(reflect.ValueOf(v), dst.Index(i), true)
if err != nil {
errs.Add(err.Subjectf("[%d]", i))
}
}
if errs.HasError() {
return true, errs.Error()
}
return
}
sl := make([]any, 0)
err := yaml.Unmarshal([]byte(src), &sl)
if err != nil {
return true, gperr.Wrap(err)
}
tmp = sl
case reflect.Map, reflect.Struct:
rawMap := make(SerializedObject)
err := yaml.Unmarshal([]byte(src), &rawMap)
if err != nil {
return true, gperr.Wrap(err)
}
tmp = rawMap
default:
return false, nil
}
return true, Convert(reflect.ValueOf(tmp), dst, true)
}
func UnmarshalValidateYAML[T any](data []byte, target *T) gperr.Error {
m := make(map[string]any)
if err := yaml.Unmarshal(data, &m); err != nil {
return gperr.Wrap(err)
}
return MapUnmarshalValidate(m, target)
}
func UnmarshalValidateYAMLXSync[V any](data []byte) (_ functional.Map[string, V], err gperr.Error) {
m := make(map[string]any)
if err = gperr.Wrap(yaml.Unmarshal(data, &m)); err != nil {
return
}
m2 := make(map[string]V, len(m))
if err = MapUnmarshalValidate(m, m2); err != nil {
return
}
return functional.NewMapFrom(m2), nil
}
func loadSerialized[T any](path string, dst *T, deserialize func(data []byte, dst any) error) error {
data, err := os.ReadFile(path)
if err != nil {
return err
}
return deserialize(data, dst)
}
func SaveJSON[T any](path string, src *T, perm os.FileMode) error {
data, err := json.Marshal(src)
if err != nil {
return err
}
return os.WriteFile(path, data, perm)
}
func LoadJSONIfExist[T any](path string, dst *T) error {
err := loadSerialized(path, dst, json.Unmarshal)
if os.IsNotExist(err) {
return nil
}
return err
}

View File

@@ -1,335 +0,0 @@
package utils
import (
"reflect"
"strconv"
"testing"
"github.com/goccy/go-yaml"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestDeserialize(t *testing.T) {
type S struct {
I int
S string
IS []int
SS []string
MSI map[string]int
MIS map[int]string
}
var (
testStruct = S{
I: 1,
S: "hello",
IS: []int{1, 2, 3},
SS: []string{"a", "b", "c"},
MSI: map[string]int{"a": 1, "b": 2, "c": 3},
MIS: map[int]string{1: "a", 2: "b", 3: "c"},
}
testStructSerialized = map[string]any{
"I": 1,
"S": "hello",
"IS": []int{1, 2, 3},
"SS": []string{"a", "b", "c"},
"MSI": map[string]int{"a": 1, "b": 2, "c": 3},
"MIS": map[int]string{1: "a", 2: "b", 3: "c"},
}
)
t.Run("deserialize", func(t *testing.T) {
var s2 S
err := MapUnmarshalValidate(testStructSerialized, &s2)
ExpectNoError(t, err)
ExpectEqual(t, s2, testStruct)
})
}
func TestDeserializeAnonymousField(t *testing.T) {
type Anon struct {
A, B int
}
var s struct {
Anon
C int
}
var s2 struct {
*Anon
C int
}
// all, anon := extractFields(reflect.TypeOf(s2))
// t.Fatalf("anon %v, all %v", anon, all)
err := MapUnmarshalValidate(map[string]any{"a": 1, "b": 2, "c": 3}, &s)
ExpectNoError(t, err)
ExpectEqual(t, s.A, 1)
ExpectEqual(t, s.B, 2)
ExpectEqual(t, s.C, 3)
err = MapUnmarshalValidate(map[string]any{"a": 1, "b": 2, "c": 3}, &s2)
ExpectNoError(t, err)
ExpectEqual(t, s2.A, 1)
ExpectEqual(t, s2.B, 2)
ExpectEqual(t, s2.C, 3)
}
func TestPointerPrimitives(t *testing.T) {
type testType struct {
B *bool `json:"b"`
I8 *int8 `json:"i8"`
I16 *int16 `json:"i16"`
I32 *int32 `json:"i32"`
I64 *int64 `json:"i64"`
U8 *uint8 `json:"u8"`
U16 *uint16 `json:"u16"`
U32 *uint32 `json:"u32"`
U64 *uint64 `json:"u64"`
}
var test testType
err := MapUnmarshalValidate(map[string]any{"b": true, "i8": int8(127), "i16": int16(127), "i32": int32(127), "i64": int64(127), "u8": uint8(127), "u16": uint16(127), "u32": uint32(127), "u64": uint64(127)}, &test)
ExpectNoError(t, err)
ExpectEqual(t, *test.B, true)
ExpectEqual(t, *test.I8, int8(127))
ExpectEqual(t, *test.I16, int16(127))
ExpectEqual(t, *test.I32, int32(127))
ExpectEqual(t, *test.I64, int64(127))
ExpectEqual(t, *test.U8, uint8(127))
ExpectEqual(t, *test.U16, uint16(127))
ExpectEqual(t, *test.U32, uint32(127))
ExpectEqual(t, *test.U64, uint64(127))
// zero values
err = MapUnmarshalValidate(map[string]any{"b": false, "i8": int8(0), "i16": int16(0), "i32": int32(0), "i64": int64(0), "u8": uint8(0), "u16": uint16(0), "u32": uint32(0), "u64": uint64(0)}, &test)
ExpectNoError(t, err)
ExpectEqual(t, *test.B, false)
ExpectEqual(t, *test.I8, int8(0))
ExpectEqual(t, *test.I16, int16(0))
ExpectEqual(t, *test.I32, int32(0))
ExpectEqual(t, *test.I64, int64(0))
ExpectEqual(t, *test.U8, uint8(0))
ExpectEqual(t, *test.U16, uint16(0))
ExpectEqual(t, *test.U32, uint32(0))
ExpectEqual(t, *test.U64, uint64(0))
// nil values
err = MapUnmarshalValidate(map[string]any{"b": true, "i8": int8(127), "i16": int16(127), "i32": int32(127), "i64": int64(127), "u8": uint8(127), "u16": uint16(127), "u32": uint32(127), "u64": uint64(127)}, &test)
ExpectNoError(t, err)
err = MapUnmarshalValidate(map[string]any{"b": nil, "i8": nil, "i16": nil, "i32": nil, "i64": nil, "u8": nil, "u16": nil, "u32": nil, "u64": nil}, &test)
ExpectNoError(t, err)
ExpectEqual(t, test.B, nil)
ExpectEqual(t, test.I8, nil)
ExpectEqual(t, test.I16, nil)
ExpectEqual(t, test.I32, nil)
ExpectEqual(t, test.I64, nil)
ExpectEqual(t, test.U8, nil)
ExpectEqual(t, test.U16, nil)
ExpectEqual(t, test.U32, nil)
ExpectEqual(t, test.U64, nil)
}
func TestStringIntConvert(t *testing.T) {
s := "127"
test := struct {
i8 int8
i16 int16
i32 int32
i64 int64
u8 uint8
u16 uint16
u32 uint32
u64 uint64
}{}
ok, err := ConvertString(s, reflect.ValueOf(&test.i8))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqual(t, test.i8, int8(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.i16))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqual(t, test.i16, int16(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.i32))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqual(t, test.i32, int32(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.i64))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqual(t, test.i64, int64(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.u8))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqual(t, test.u8, uint8(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.u16))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqual(t, test.u16, uint16(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.u32))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqual(t, test.u32, uint32(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.u64))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqual(t, test.u64, uint64(127))
}
type testModel struct {
Test testType
Baz string
}
type testType struct {
foo int
bar string
}
func (c *testType) Parse(v string) (err error) {
c.bar = v
c.foo, err = strconv.Atoi(v)
return
}
func TestConvertor(t *testing.T) {
t.Run("valid", func(t *testing.T) {
m := new(testModel)
ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Test": "123"}, m))
ExpectEqual(t, m.Test.foo, 123)
ExpectEqual(t, m.Test.bar, "123")
})
t.Run("int_to_string", func(t *testing.T) {
m := new(testModel)
ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Test": "123"}, m))
ExpectEqual(t, m.Test.foo, 123)
ExpectEqual(t, m.Test.bar, "123")
ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Baz": 123}, m))
ExpectEqual(t, m.Baz, "123")
})
t.Run("invalid", func(t *testing.T) {
m := new(testModel)
err := MapUnmarshalValidate(map[string]any{"Test": struct{ a int }{1}}, m)
ExpectError(t, ErrUnsupportedConversion, err)
})
t.Run("set_empty", func(t *testing.T) {
m := testModel{
Test: testType{1, "2"},
Baz: "3",
}
ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Test": nil, "Baz": nil}, &m))
ExpectEqual(t, m, testModel{})
})
}
func TestStringToSlice(t *testing.T) {
t.Run("comma_separated", func(t *testing.T) {
dst := make([]string, 0)
convertible, err := ConvertString("a,b,c", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
ExpectEqual(t, dst, []string{"a", "b", "c"})
})
t.Run("yaml-like", func(t *testing.T) {
dst := make([]string, 0)
convertible, err := ConvertString("- a\n- b\n- c", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
ExpectEqual(t, dst, []string{"a", "b", "c"})
})
t.Run("single-line-yaml-like", func(t *testing.T) {
dst := make([]string, 0)
convertible, err := ConvertString("- a", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
ExpectEqual(t, dst, []string{"a"})
})
}
func BenchmarkStringToSlice(b *testing.B) {
for range b.N {
dst := make([]int, 0)
_, _ = ConvertString("- 1\n- 2\n- 3", reflect.ValueOf(&dst))
}
}
func BenchmarkStringToSliceYAML(b *testing.B) {
for range b.N {
dst := make([]int, 0)
_ = yaml.Unmarshal([]byte("- 1\n- 2\n- 3"), &dst)
}
}
func TestStringToMap(t *testing.T) {
t.Run("yaml-like", func(t *testing.T) {
dst := make(map[string]string)
convertible, err := ConvertString(" a: b\n c: d", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
ExpectEqual(t, dst, map[string]string{"a": "b", "c": "d"})
})
}
func BenchmarkStringToMap(b *testing.B) {
for range b.N {
dst := make(map[string]string)
_, _ = ConvertString(" a: b\n c: d", reflect.ValueOf(&dst))
}
}
func BenchmarkStringToMapYAML(b *testing.B) {
for range b.N {
dst := make(map[string]string)
_ = yaml.Unmarshal([]byte(" a: b\n c: d"), &dst)
}
}
func TestStringToStruct(t *testing.T) {
t.Run("yaml-like", func(t *testing.T) {
dst := struct {
A string
B int
}{}
convertible, err := ConvertString(" A: a\n B: 123", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
ExpectEqual(t, dst, struct {
A string
B int
}{"a", 123})
})
}
func BenchmarkStringToStruct(b *testing.B) {
for range b.N {
dst := struct {
A string `json:"a"`
B int `json:"b"`
}{}
_, _ = ConvertString(" a: a\n b: 123", reflect.ValueOf(&dst))
}
}
func BenchmarkStringToStructYAML(b *testing.B) {
for range b.N {
dst := struct {
A string `yaml:"a"`
B int `yaml:"b"`
}{}
_ = yaml.Unmarshal([]byte(" a: a\n b: 123"), &dst)
}
}

View File

@@ -6,7 +6,7 @@ import (
"sync/atomic"
"time"
"github.com/yusing/go-proxy/internal/logging"
"github.com/rs/zerolog/log"
)
type BytesPool struct {
@@ -140,6 +140,6 @@ func dropBuffers() {
checks++
}
if count > 0 {
logging.Debug().Int("dropped", count).Int("size", droppedSize).Msg("dropped buffers from pool")
log.Debug().Int("dropped", count).Int("size", droppedSize).Msg("dropped buffers from pool")
}
}

View File

@@ -2,14 +2,16 @@ package expect
import (
"os"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/yusing/go-proxy/internal/common"
)
var isTest = strings.HasSuffix(os.Args[0], ".test")
func init() {
if common.IsTest {
if isTest {
// force verbose output
os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...)
}

View File

@@ -1,19 +1,11 @@
package expect
import (
"os"
"testing"
"github.com/stretchr/testify/require"
"github.com/yusing/go-proxy/internal/common"
)
func init() {
if common.IsTest {
os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...)
}
}
func ExpectNoError(t *testing.T, err error) {
t.Helper()
require.NoError(t, err)

View File

@@ -3,7 +3,6 @@ package utils
import (
"time"
"github.com/yusing/go-proxy/internal/task"
"go.uber.org/atomic"
)
@@ -38,8 +37,6 @@ func init() {
go func() {
for {
select {
case <-task.RootContext().Done():
return
case <-timeNowTicker.C:
shouldCallTimeNow.Store(true)
}

View File

@@ -1,25 +0,0 @@
package utils
import (
"github.com/go-playground/validator/v10"
"github.com/yusing/go-proxy/internal/gperr"
)
var validate = validator.New()
var ErrValidationError = gperr.New("validation error")
type CustomValidator interface {
Validate() gperr.Error
}
func Validator() *validator.Validate {
return validate
}
func MustRegisterValidation(tag string, fn validator.Func) {
err := validate.RegisterValidation(tag, fn)
if err != nil {
panic(err)
}
}