improved deserialization method

This commit is contained in:
yusing
2024-12-18 07:18:18 +08:00
parent 6aefe4d5d9
commit f2a9ddd1a6
8 changed files with 310 additions and 98 deletions

View File

@@ -1,10 +1,13 @@
package types
import "github.com/yusing/go-proxy/internal/net/http/accesslog"
import (
"github.com/yusing/go-proxy/internal/net/http/accesslog"
"github.com/yusing/go-proxy/internal/utils"
)
type (
Config struct {
AutoCert *AutoCertConfig `json:"autocert"`
AutoCert *AutoCertConfig `json:"autocert" validate:"omitempty"`
Entrypoint Entrypoint `json:"entrypoint"`
Providers Providers `json:"providers"`
MatchDomains []string `json:"match_domains" validate:"dive,fqdn"`
@@ -18,7 +21,7 @@ type (
}
Entrypoint struct {
Middlewares []map[string]any `json:"middlewares"`
AccessLog *accesslog.Config `json:"access_log"`
AccessLog *accesslog.Config `json:"access_log" validate:"omitempty"`
}
NotificationConfig map[string]any
)
@@ -31,3 +34,7 @@ func DefaultConfig() *Config {
},
}
}
func init() {
utils.RegisterDefaultValueFactory(DefaultConfig)
}

View File

@@ -59,9 +59,9 @@ func fmtLog(cfg *Config) string {
}
func TestAccessLoggerCommon(t *testing.T) {
config := DefaultConfig
config := DefaultConfig()
config.Format = FormatCommon
ExpectEqual(t, fmtLog(&config),
ExpectEqual(t, fmtLog(config),
fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d",
host, remote, TestTimeNow, method, uri, proto, status, contentLength,
),
@@ -69,9 +69,9 @@ func TestAccessLoggerCommon(t *testing.T) {
}
func TestAccessLoggerCombined(t *testing.T) {
config := DefaultConfig
config := DefaultConfig()
config.Format = FormatCombined
ExpectEqual(t, fmtLog(&config),
ExpectEqual(t, fmtLog(config),
fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d \"%s\" \"%s\"",
host, remote, TestTimeNow, method, uri, proto, status, contentLength, referer, ua,
),
@@ -79,10 +79,10 @@ func TestAccessLoggerCombined(t *testing.T) {
}
func TestAccessLoggerRedactQuery(t *testing.T) {
config := DefaultConfig
config := DefaultConfig()
config.Format = FormatCommon
config.Fields.Query.DefaultMode = FieldModeRedact
ExpectEqual(t, fmtLog(&config),
ExpectEqual(t, fmtLog(config),
fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d",
host, remote, TestTimeNow, method, uriRedacted, proto, status, contentLength,
),
@@ -99,8 +99,8 @@ func getJSONEntry(t *testing.T, config *Config) JSONLogEntry {
}
func TestAccessLoggerJSON(t *testing.T) {
config := DefaultConfig
entry := getJSONEntry(t, &config)
config := DefaultConfig()
entry := getJSONEntry(t, config)
ExpectEqual(t, entry.IP, remote)
ExpectEqual(t, entry.Method, method)
ExpectEqual(t, entry.Scheme, "http")

View File

@@ -1,5 +1,7 @@
package accesslog
import "github.com/yusing/go-proxy/internal/utils"
type (
Format string
Filters struct {
@@ -30,18 +32,24 @@ var (
const DefaultBufferSize = 100
var DefaultConfig = Config{
BufferSize: DefaultBufferSize,
Format: FormatCombined,
Fields: Fields{
Headers: FieldConfig{
DefaultMode: FieldModeDrop,
func DefaultConfig() *Config {
return &Config{
BufferSize: DefaultBufferSize,
Format: FormatCombined,
Fields: Fields{
Headers: FieldConfig{
DefaultMode: FieldModeDrop,
},
Query: FieldConfig{
DefaultMode: FieldModeKeep,
},
Cookies: FieldConfig{
DefaultMode: FieldModeDrop,
},
},
Query: FieldConfig{
DefaultMode: FieldModeKeep,
},
Cookies: FieldConfig{
DefaultMode: FieldModeDrop,
},
},
}
}
func init() {
utils.RegisterDefaultValueFactory(DefaultConfig)
}

View File

@@ -10,9 +10,9 @@ import (
// Cookie header should be removed,
// stored in JSONLogEntry.Cookies instead.
func TestAccessLoggerJSONKeepHeaders(t *testing.T) {
config := DefaultConfig
config := DefaultConfig()
config.Fields.Headers.DefaultMode = FieldModeKeep
entry := getJSONEntry(t, &config)
entry := getJSONEntry(t, config)
ExpectDeepEqual(t, len(entry.Headers["Cookie"]), 0)
for k, v := range req.Header {
if k != "Cookie" {
@@ -22,9 +22,9 @@ func TestAccessLoggerJSONKeepHeaders(t *testing.T) {
}
func TestAccessLoggerJSONRedactHeaders(t *testing.T) {
config := DefaultConfig
config := DefaultConfig()
config.Fields.Headers.DefaultMode = FieldModeRedact
entry := getJSONEntry(t, &config)
entry := getJSONEntry(t, config)
ExpectDeepEqual(t, len(entry.Headers["Cookie"]), 0)
for k := range req.Header {
if k != "Cookie" {
@@ -34,10 +34,10 @@ func TestAccessLoggerJSONRedactHeaders(t *testing.T) {
}
func TestAccessLoggerJSONKeepCookies(t *testing.T) {
config := DefaultConfig
config := DefaultConfig()
config.Fields.Headers.DefaultMode = FieldModeKeep
config.Fields.Cookies.DefaultMode = FieldModeKeep
entry := getJSONEntry(t, &config)
entry := getJSONEntry(t, config)
ExpectDeepEqual(t, len(entry.Headers["Cookie"]), 0)
for _, cookie := range req.Cookies() {
ExpectEqual(t, entry.Cookies[cookie.Name], cookie.Value)
@@ -45,10 +45,10 @@ func TestAccessLoggerJSONKeepCookies(t *testing.T) {
}
func TestAccessLoggerJSONRedactCookies(t *testing.T) {
config := DefaultConfig
config := DefaultConfig()
config.Fields.Headers.DefaultMode = FieldModeKeep
config.Fields.Cookies.DefaultMode = FieldModeRedact
entry := getJSONEntry(t, &config)
entry := getJSONEntry(t, config)
ExpectDeepEqual(t, len(entry.Headers["Cookie"]), 0)
for _, cookie := range req.Cookies() {
ExpectEqual(t, entry.Cookies[cookie.Name], RedactedValue)
@@ -56,17 +56,17 @@ func TestAccessLoggerJSONRedactCookies(t *testing.T) {
}
func TestAccessLoggerJSONDropQuery(t *testing.T) {
config := DefaultConfig
config := DefaultConfig()
config.Fields.Query.DefaultMode = FieldModeDrop
entry := getJSONEntry(t, &config)
entry := getJSONEntry(t, config)
ExpectDeepEqual(t, entry.Query["foo"], nil)
ExpectDeepEqual(t, entry.Query["bar"], nil)
}
func TestAccessLoggerJSONRedactQuery(t *testing.T) {
config := DefaultConfig
config := DefaultConfig()
config.Fields.Query.DefaultMode = FieldModeRedact
entry := getJSONEntry(t, &config)
entry := getJSONEntry(t, config)
ExpectDeepEqual(t, entry.Query["foo"], []string{RedactedValue})
ExpectDeepEqual(t, entry.Query["bar"], []string{RedactedValue})
}

View File

@@ -18,9 +18,9 @@ type Webhook struct {
Template string `json:"template" validate:"omitempty,oneof=discord"`
Payload string `json:"payload" validate:"jsonIfTemplateNotUsed"`
Tok string `json:"token"`
Meth string `json:"method" validate:"omitempty,oneof=GET POST PUT"`
Meth string `json:"method" validate:"oneof=GET POST PUT"`
MIMETyp string `json:"mime_type"`
ColorM string `json:"color_mode" validate:"omitempty,oneof=hex dec"`
ColorM string `json:"color_mode" validate:"oneof=hex dec"`
}
//go:embed templates/discord.json
@@ -30,6 +30,14 @@ var webhookTemplates = map[string]string{
"discord": discordPayload,
}
func DefaultValue() *Webhook {
return &Webhook{
Meth: "POST",
ColorM: "hex",
MIMETyp: "application/json",
}
}
func jsonIfTemplateNotUsed(fl validator.FieldLevel) bool {
template := fl.Parent().FieldByName("Template").String()
if template != "" {
@@ -40,6 +48,7 @@ func jsonIfTemplateNotUsed(fl validator.FieldLevel) bool {
}
func init() {
utils.RegisterDefaultValueFactory(DefaultValue)
err := utils.Validator().RegisterValidation("jsonIfTemplateNotUsed", jsonIfTemplateNotUsed)
if err != nil {
panic(err)
@@ -53,10 +62,7 @@ func (webhook *Webhook) Name() string {
// Method implements Provider.
func (webhook *Webhook) Method() string {
if webhook.Meth != "" {
return webhook.Meth
}
return http.MethodPost
return webhook.Meth
}
// URL implements Provider.
@@ -71,10 +77,7 @@ func (webhook *Webhook) Token() string {
// MIMEType implements Provider.
func (webhook *Webhook) MIMEType() string {
if webhook.MIMETyp != "" {
return webhook.MIMETyp
}
return "application/json"
return webhook.MIMETyp
}
func (webhook *Webhook) ColorMode() string {

View File

@@ -13,6 +13,7 @@ import (
"github.com/go-playground/validator/v10"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils"
"gopkg.in/yaml.v3"
@@ -30,6 +31,28 @@ var (
ErrUnknownField = E.New("unknown field")
)
var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
func RegisterDefaultValueFactory[T any](factory func() *T) {
t := reflect.TypeFor[T]()
if t.Kind() == reflect.Ptr {
panic("pointer of pointer")
}
if defaultValues.Has(t) {
panic("default value for " + t.String() + " already registered")
}
defaultValues.Store(t, func() any { return factory() })
}
func New(t reflect.Type) reflect.Value {
if dv, ok := defaultValues.Load(t); ok {
logging.Debug().Str("type", t.String()).Msg("using default value")
return reflect.ValueOf(dv())
}
logging.Debug().Str("type", t.String()).Msg("using zero value")
return reflect.New(t)
}
// Serialize converts the given data into a map[string]any representation.
//
// It uses reflection to inspect the data type and handle different kinds of data.
@@ -150,7 +173,7 @@ func Deserialize(src SerializedObject, dst any) E.Error {
for dstT.Kind() == reflect.Ptr {
if dstV.IsNil() {
if dstV.CanSet() {
dstV.Set(reflect.New(dstT.Elem()))
dstV.Set(New(dstT.Elem()))
} else {
return E.Errorf("deserialize: dst is %w", ErrNilValue)
}
@@ -214,12 +237,8 @@ func Deserialize(src SerializedObject, dst any) E.Error {
if e.Param() != "" {
detail += ":" + e.Param()
}
fieldName, ok := fieldName[e.Field()]
if !ok {
fieldName = e.Field()
}
errs.Add(ErrValidationError.
Subject(fieldName).
Subject(e.StructNamespace()).
Withf("require %q", detail))
}
}
@@ -230,12 +249,14 @@ func Deserialize(src SerializedObject, dst any) E.Error {
dstV.Set(reflect.MakeMap(dstT))
}
for k := range src {
tmp := reflect.New(dstT.Elem()).Elem()
mapVT := dstT.Elem()
tmp := New(mapVT).Elem()
err := Convert(reflect.ValueOf(src[k]), tmp)
if err != nil {
if err == nil {
dstV.SetMapIndex(reflect.ValueOf(strutils.ToLowerNoSnake(k)), tmp)
} else {
errs.Add(err.Subject(k))
}
dstV.SetMapIndex(reflect.ValueOf(strutils.ToLowerNoSnake(k)), tmp)
}
return errs.Error()
default:
@@ -243,6 +264,10 @@ func Deserialize(src SerializedObject, dst any) E.Error {
}
}
func isIntFloat(t reflect.Kind) bool {
return t >= reflect.Bool && t <= reflect.Float64
}
// Convert attempts to convert the src to dst.
//
// If src is a map, it is deserialized into dst.
@@ -270,20 +295,41 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error {
if dst.Kind() == reflect.Pointer {
if dst.IsNil() {
dst.Set(reflect.New(dstT.Elem()))
dst.Set(New(dstT.Elem()))
}
dst = dst.Elem()
dstT = dst.Type()
}
srcKind := srcT.Kind()
switch {
case srcT.AssignableTo(dstT):
dst.Set(src)
return nil
case srcT.ConvertibleTo(dstT):
dst.Set(src.Convert(dstT))
return nil
case srcT.Kind() == reflect.Map:
// case srcT.ConvertibleTo(dstT):
// dst.Set(src.Convert(dstT))
// return nil
case srcKind == reflect.String:
if convertible, err := ConvertString(src.String(), dst); convertible {
return err
}
case isIntFloat(srcKind):
var strV string
switch {
case src.CanInt():
strV = strconv.FormatInt(src.Int(), 10)
case srcKind == reflect.Bool:
strV = strconv.FormatBool(src.Bool())
case src.CanUint():
strV = strconv.FormatUint(src.Uint(), 10)
case src.CanFloat():
strV = strconv.FormatFloat(src.Float(), 'f', -1, 64)
}
if convertible, err := ConvertString(strV, dst); convertible {
return err
}
case srcKind == reflect.Map:
if src.Len() == 0 {
return nil
}
@@ -292,7 +338,7 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error {
return ErrUnsupportedConversion.Subject(dstT.String() + " to " + srcT.String())
}
return Deserialize(obj, dst.Addr().Interface())
case srcT.Kind() == reflect.Slice:
case srcKind == reflect.Slice:
if src.Len() == 0 {
return nil
}
@@ -302,7 +348,7 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error {
newSlice := reflect.MakeSlice(dstT, 0, src.Len())
i := 0
for _, v := range src.Seq2() {
tmp := reflect.New(dstT.Elem()).Elem()
tmp := New(dstT.Elem()).Elem()
err := Convert(v, tmp)
if err != nil {
return err.Subjectf("[%d]", i)
@@ -312,24 +358,16 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error {
}
dst.Set(newSlice)
return nil
case src.Kind() == reflect.String:
if convertible, err := ConvertString(src.String(), dst); convertible {
return err
}
}
// check if (*T).Convertor is implemented
if parser, ok := dst.Addr().Interface().(strutils.Parser); ok {
return E.From(parser.Parse(src.String()))
}
return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT)
}
func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.Error) {
convertible = true
dstT := dst.Type()
if dst.Kind() == reflect.Ptr {
if dst.IsNil() {
dst.Set(reflect.New(dst.Type().Elem()))
dst.Set(New(dstT.Elem()))
}
dst = dst.Elem()
}
@@ -337,10 +375,10 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E
dst.SetString(src)
return
}
switch dst.Type() {
switch dstT {
case reflect.TypeFor[time.Duration]():
if src == "" {
dst.Set(reflect.Zero(dst.Type()))
dst.Set(reflect.Zero(dstT))
return
}
d, err := time.ParseDuration(src)
@@ -357,34 +395,33 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E
if err != nil {
return true, E.From(err)
}
dst.Set(reflect.ValueOf(*ipnet))
dst.Set(reflect.ValueOf(ipnet).Elem())
return
default:
}
// primitive types / simple types
switch dst.Kind() {
case reflect.Bool:
b, err := strconv.ParseBool(src)
if dstKind := dst.Kind(); isIntFloat(dstKind) {
var i any
var err error
switch {
case dstKind == reflect.Bool:
i, err = strconv.ParseBool(src)
case dst.CanInt():
i, err = strconv.ParseInt(src, 10, dstT.Bits())
case dst.CanUint():
i, err = strconv.ParseUint(src, 10, dstT.Bits())
case dst.CanFloat():
i, err = strconv.ParseFloat(src, dstT.Bits())
}
if err != nil {
return true, E.From(err)
}
dst.Set(reflect.ValueOf(b))
return
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
i, err := strconv.ParseInt(src, 10, 64)
if err != nil {
return true, E.From(err)
}
dst.Set(reflect.ValueOf(i).Convert(dst.Type()))
return
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
i, err := strconv.ParseUint(src, 10, 64)
if err != nil {
return true, E.From(err)
}
dst.Set(reflect.ValueOf(i).Convert(dst.Type()))
dst.Set(reflect.ValueOf(i).Convert(dstT))
return
}
// check if (*T).Convertor is implemented
if parser, ok := dst.Addr().Interface().(strutils.Parser); ok {
return true, E.From(parser.Parse(src))
}
// yaml like
lines := []string{}
src = strings.TrimSpace(src)
@@ -446,10 +483,10 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E
}
tmp = m
}
if tmp == nil {
return false, nil
if tmp != nil {
return true, Convert(reflect.ValueOf(tmp), dst)
}
return true, Convert(reflect.ValueOf(tmp), dst)
return false, nil
}
func DeserializeYAML[T any](data []byte, target T) E.Error {

View File

@@ -124,6 +124,7 @@ func TestStringIntConvert(t *testing.T) {
type testModel struct {
Test testType
Baz string
}
type testType struct {
@@ -146,8 +147,19 @@ func TestConvertor(t *testing.T) {
ExpectEqual(t, m.Test.bar, "123")
})
t.Run("int_to_string", func(t *testing.T) {
m := new(testModel)
ExpectNoError(t, Deserialize(map[string]any{"Test": "123"}, m))
ExpectEqual(t, m.Test.foo, 123)
ExpectEqual(t, m.Test.bar, "123")
ExpectNoError(t, Deserialize(map[string]any{"Baz": 123}, m))
ExpectEqual(t, m.Baz, "123")
})
t.Run("invalid", func(t *testing.T) {
m := new(testModel)
ExpectError(t, strconv.ErrSyntax, Deserialize(map[string]any{"Test": 123}, m))
ExpectError(t, ErrUnsupportedConversion, Deserialize(map[string]any{"Test": struct{}{}}, m))
})
}