From 1586610a44c6249c98de7ad669bca6b2f1862e9b Mon Sep 17 00:00:00 2001 From: yusing Date: Sun, 26 Jan 2025 14:43:48 +0800 Subject: [PATCH] Cleaned up some validation code, stricter validation --- internal/api/v1/file.go | 46 ++++-- internal/autocert/config.go | 103 +++++++++----- internal/autocert/provider.go | 7 +- internal/autocert/user.go | 2 + internal/config/config.go | 18 +-- internal/config/types/autocert_config.go | 14 -- internal/config/types/config.go | 34 +++-- internal/error/utils.go | 7 + internal/homepage/homepage.go | 22 +-- internal/homepage/override_config.go | 4 - internal/notif/base.go | 47 +++++++ internal/notif/config.go | 54 +++++++ internal/notif/config_test.go | 163 ++++++++++++++++++++++ internal/notif/dispatcher.go | 36 +---- internal/notif/gotify.go | 27 +--- internal/notif/gotify_test.go | 52 ------- internal/notif/providers.go | 48 +++---- internal/notif/webhook.go | 124 ++++++++-------- internal/notif/webhook_test.go | 121 ---------------- internal/route/routes/routequery/query.go | 2 +- internal/utils/serialization.go | 115 +++++++++------ internal/utils/testing/testing.go | 8 ++ internal/utils/validation.go | 4 + 23 files changed, 590 insertions(+), 468 deletions(-) delete mode 100644 internal/config/types/autocert_config.go create mode 100644 internal/notif/base.go create mode 100644 internal/notif/config.go create mode 100644 internal/notif/config_test.go delete mode 100644 internal/notif/gotify_test.go delete mode 100644 internal/notif/webhook_test.go diff --git a/internal/api/v1/file.go b/internal/api/v1/file.go index 63f3aa3e..0a849772 100644 --- a/internal/api/v1/file.go +++ b/internal/api/v1/file.go @@ -75,6 +75,38 @@ func GetFileContent(w http.ResponseWriter, r *http.Request) { U.WriteBody(w, content) } +func validateFile(fileType FileType, content []byte) error { + switch fileType { + case FileTypeConfig: + return config.Validate(content) + case FileTypeMiddleware: + errs := E.NewBuilder("middleware errors") + middleware.BuildMiddlewaresFromYAML("", content, errs) + return errs.Error() + } + return provider.Validate(content) +} + +func ValidateFile(w http.ResponseWriter, r *http.Request) { + fileType := FileType(r.PathValue("type")) + if !fileType.IsValid() { + U.RespondError(w, U.ErrInvalidKey("type"), http.StatusBadRequest) + return + } + content, err := io.ReadAll(r.Body) + if err != nil { + U.HandleErr(w, r, err) + return + } + r.Body.Close() + err = validateFile(fileType, content) + if err != nil { + U.RespondError(w, err, http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusOK) +} + func SetFileContent(w http.ResponseWriter, r *http.Request) { fileType, filename, err := getArgs(r) if err != nil { @@ -87,19 +119,7 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) { return } - var valErr E.Error - switch fileType { - case FileTypeConfig: - valErr = config.Validate(content) - case FileTypeMiddleware: - errs := E.NewBuilder("middleware errors") - middleware.BuildMiddlewaresFromYAML(filename, content, errs) - valErr = errs.Error() - default: - valErr = provider.Validate(content) - } - - if valErr != nil { + if valErr := validateFile(fileType, content); valErr != nil { U.RespondError(w, valErr, http.StatusBadRequest) return } diff --git a/internal/autocert/config.go b/internal/autocert/config.go index d8682b4f..3a32f73c 100644 --- a/internal/autocert/config.go +++ b/internal/autocert/config.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "crypto/x509" "os" + "regexp" "github.com/go-acme/lego/v4/certcrypto" "github.com/go-acme/lego/v4/lego" @@ -13,63 +14,89 @@ import ( "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils/strutils" - - "github.com/yusing/go-proxy/internal/config/types" ) -type Config types.AutoCertConfig +type ( + AutocertConfig struct { + Email string `json:"email,omitempty"` + Domains []string `json:"domains,omitempty"` + CertPath string `json:"cert_path,omitempty"` + KeyPath string `json:"key_path,omitempty"` + ACMEKeyPath string `json:"acme_key_path,omitempty"` + Provider string `json:"provider,omitempty"` + Options ProviderOpt `json:"options,omitempty"` + } + ProviderOpt map[string]any +) var ( ErrMissingDomain = E.New("missing field 'domains'") ErrMissingEmail = E.New("missing field 'email'") ErrMissingProvider = E.New("missing field 'provider'") + ErrInvalidDomain = E.New("invalid domain") ErrUnknownProvider = E.New("unknown provider") ) -func NewConfig(cfg *types.AutoCertConfig) *Config { +var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`) + +// Validate implements the utils.CustomValidator interface. +func (cfg *AutocertConfig) Validate() E.Error { if cfg == nil { - cfg = new(types.AutoCertConfig) + return nil } + + if cfg.Provider == "" { + cfg.Provider = ProviderLocal + return nil + } + + b := E.NewBuilder("autocert errors") + if cfg.Provider != ProviderLocal { + if len(cfg.Domains) == 0 { + b.Add(ErrMissingDomain) + } + if cfg.Email == "" { + b.Add(ErrMissingEmail) + } + for i, d := range cfg.Domains { + if !domainOrWildcardRE.MatchString(d) { + b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i)) + } + } + // check if provider is implemented + providerConstructor, ok := providersGenMap[cfg.Provider] + if !ok { + b.Add(ErrUnknownProvider. + Subject(cfg.Provider). + Withf(strutils.DoYouMean(utils.NearestField(cfg.Provider, providersGenMap)))) + } else { + _, err := providerConstructor(cfg.Options) + if err != nil { + b.Add(err) + } + } + } + return b.Error() +} + +func (cfg *AutocertConfig) GetProvider() (*Provider, E.Error) { + if cfg == nil { + cfg = new(AutocertConfig) + } + + if err := cfg.Validate(); err != nil { + return nil, err + } + if cfg.CertPath == "" { cfg.CertPath = CertFileDefault } if cfg.KeyPath == "" { cfg.KeyPath = KeyFileDefault } - if cfg.Provider == "" { - cfg.Provider = ProviderLocal - } if cfg.ACMEKeyPath == "" { cfg.ACMEKeyPath = ACMEKeyFileDefault } - return (*Config)(cfg) -} - -func (cfg *Config) GetProvider() (*Provider, E.Error) { - b := E.NewBuilder("autocert errors") - - if cfg.Provider != ProviderLocal { - if len(cfg.Domains) == 0 { - b.Add(ErrMissingDomain) - } - if cfg.Provider == "" { - b.Add(ErrMissingProvider) - } - if cfg.Email == "" { - b.Add(ErrMissingEmail) - } - // check if provider is implemented - _, ok := providersGenMap[cfg.Provider] - if !ok { - b.Add(ErrUnknownProvider. - Subject(cfg.Provider). - Withf(strutils.DoYouMean(utils.NearestField(cfg.Provider, providersGenMap)))) - } - } - - if b.HasError() { - return nil, b.Error() - } var privKey *ecdsa.PrivateKey var err error @@ -103,7 +130,7 @@ func (cfg *Config) GetProvider() (*Provider, E.Error) { }, nil } -func (cfg *Config) loadACMEKey() (*ecdsa.PrivateKey, error) { +func (cfg *AutocertConfig) loadACMEKey() (*ecdsa.PrivateKey, error) { data, err := os.ReadFile(cfg.ACMEKeyPath) if err != nil { return nil, err @@ -111,7 +138,7 @@ func (cfg *Config) loadACMEKey() (*ecdsa.PrivateKey, error) { return x509.ParseECPrivateKey(data) } -func (cfg *Config) saveACMEKey(key *ecdsa.PrivateKey) error { +func (cfg *AutocertConfig) saveACMEKey(key *ecdsa.PrivateKey) error { data, err := x509.MarshalECPrivateKey(key) if err != nil { return err diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index e1cd87df..8a2a1c0d 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -14,7 +14,6 @@ import ( "github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/lego" "github.com/go-acme/lego/v4/registration" - "github.com/yusing/go-proxy/internal/config/types" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" @@ -24,7 +23,7 @@ import ( type ( Provider struct { - cfg *Config + cfg *AutocertConfig user *User legoCfg *lego.Config client *lego.Client @@ -33,7 +32,7 @@ type ( tlsCert *tls.Certificate certExpiries CertExpiries } - ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.Error) + ProviderGenerator func(ProviderOpt) (challenge.Provider, E.Error) CertExpiries map[string]time.Time ) @@ -313,7 +312,7 @@ func providerGenerator[CT any, PT challenge.Provider]( defaultCfg func() *CT, newProvider func(*CT) (PT, error), ) ProviderGenerator { - return func(opt types.AutocertProviderOpt) (challenge.Provider, E.Error) { + return func(opt ProviderOpt) (challenge.Provider, E.Error) { cfg := defaultCfg() err := U.Deserialize(opt, cfg) if err != nil { diff --git a/internal/autocert/user.go b/internal/autocert/user.go index 9ced682a..00771e13 100644 --- a/internal/autocert/user.go +++ b/internal/autocert/user.go @@ -15,9 +15,11 @@ type User struct { func (u *User) GetEmail() string { return u.Email } + func (u *User) GetRegistration() *registration.Resource { return u.Registration } + func (u *User) GetPrivateKey() crypto.PrivateKey { return u.key } diff --git a/internal/config/config.go b/internal/config/config.go index 94e5cc82..8e96fe69 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -234,7 +234,7 @@ func (cfg *Config) load() E.Error { errs := E.NewBuilder(errMsg) errs.Add(cfg.entrypoint.SetMiddlewares(model.Entrypoint.Middlewares)) errs.Add(cfg.entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog)) - errs.Add(cfg.initNotification(model.Providers.Notification)) + cfg.initNotification(model.Providers.Notification) errs.Add(cfg.initAutoCert(model.AutoCert)) errs.Add(cfg.loadRouteProviders(&model.Providers)) @@ -249,28 +249,22 @@ func (cfg *Config) load() E.Error { return errs.Error() } -func (cfg *Config) initNotification(notifCfg []types.NotificationConfig) (err E.Error) { +func (cfg *Config) initNotification(notifCfg []notif.NotificationConfig) { if len(notifCfg) == 0 { return } dispatcher := notif.StartNotifDispatcher(cfg.task) - errs := E.NewBuilder("notification providers load errors") - for i, notifier := range notifCfg { - _, err := dispatcher.RegisterProvider(notifier) - if err == nil { - continue - } - errs.Add(err.Subjectf("[%d]", i)) + for _, notifier := range notifCfg { + dispatcher.RegisterProvider(¬ifier) } - return errs.Error() } -func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) { +func (cfg *Config) initAutoCert(autocertCfg *autocert.AutocertConfig) (err E.Error) { if cfg.autocertProvider != nil { return } - cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider() + cfg.autocertProvider, err = autocertCfg.GetProvider() return } diff --git a/internal/config/types/autocert_config.go b/internal/config/types/autocert_config.go deleted file mode 100644 index f054d01b..00000000 --- a/internal/config/types/autocert_config.go +++ /dev/null @@ -1,14 +0,0 @@ -package types - -type ( - AutoCertConfig struct { - Email string `json:"email,omitempty" validate:"email"` - Domains []string `json:"domains,omitempty"` - CertPath string `json:"cert_path,omitempty" validate:"omitempty,filepath"` - KeyPath string `json:"key_path,omitempty" validate:"omitempty,filepath"` - ACMEKeyPath string `json:"acme_key_path,omitempty" validate:"omitempty,filepath"` - Provider string `json:"provider,omitempty"` - Options AutocertProviderOpt `json:"options,omitempty"` - } - AutocertProviderOpt map[string]any -) diff --git a/internal/config/types/config.go b/internal/config/types/config.go index 8640c78f..e7a7e3de 100644 --- a/internal/config/types/config.go +++ b/internal/config/types/config.go @@ -2,8 +2,12 @@ package types import ( "context" + "regexp" + "github.com/go-playground/validator/v10" + "github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/net/http/accesslog" + "github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/utils" E "github.com/yusing/go-proxy/internal/error" @@ -11,23 +15,22 @@ import ( type ( Config struct { - AutoCert *AutoCertConfig `json:"autocert" validate:"omitempty"` - Entrypoint Entrypoint `json:"entrypoint"` - Providers Providers `json:"providers"` - MatchDomains []string `json:"match_domains" validate:"dive,fqdn"` - Homepage HomepageConfig `json:"homepage"` - TimeoutShutdown int `json:"timeout_shutdown" validate:"gte=0"` + AutoCert *autocert.AutocertConfig `json:"autocert"` + Entrypoint Entrypoint `json:"entrypoint"` + Providers Providers `json:"providers"` + MatchDomains []string `json:"match_domains" validate:"domain_name"` + Homepage HomepageConfig `json:"homepage"` + TimeoutShutdown int `json:"timeout_shutdown" validate:"gte=0"` } Providers struct { - Files []string `json:"include" validate:"dive,filepath"` - Docker map[string]string `json:"docker" validate:"dive,unix_addr|url"` - Notification []NotificationConfig `json:"notification"` + Files []string `json:"include" validate:"dive,filepath"` + Docker map[string]string `json:"docker" validate:"dive,unix_addr|url"` + Notification []notif.NotificationConfig `json:"notification"` } Entrypoint struct { Middlewares []map[string]any `json:"middlewares"` AccessLog *accesslog.Config `json:"access_log" validate:"omitempty"` } - NotificationConfig map[string]any ConfigInstance interface { Value() *Config @@ -52,6 +55,17 @@ func Validate(data []byte) E.Error { return utils.DeserializeYAML(data, &model) } +var matchDomainsRegex = regexp.MustCompile(`^[^\.]?([\w\d\-_]\.?)+[^\.]?$`) + func init() { utils.RegisterDefaultValueFactory(DefaultConfig) + utils.MustRegisterValidation("domain_name", func(fl validator.FieldLevel) bool { + domains := fl.Field().Interface().([]string) + for _, domain := range domains { + if !matchDomainsRegex.MatchString(domain) { + return false + } + } + return true + }) } diff --git a/internal/error/utils.go b/internal/error/utils.go index 3bb5e700..5d987fbf 100644 --- a/internal/error/utils.go +++ b/internal/error/utils.go @@ -19,6 +19,13 @@ func Errorf(format string, args ...any) Error { return &baseError{fmt.Errorf(format, args...)} } +func Wrap(err error, message ...string) Error { + if len(message) == 0 || message[0] == "" { + return From(err) + } + return Errorf("%w: %s", err, message[0]) +} + func From(err error) Error { if err == nil { return nil diff --git a/internal/homepage/homepage.go b/internal/homepage/homepage.go index b45ae616..6da2aeda 100644 --- a/internal/homepage/homepage.go +++ b/internal/homepage/homepage.go @@ -1,11 +1,13 @@ package homepage -import "github.com/yusing/go-proxy/internal/utils" +import ( + "github.com/yusing/go-proxy/internal/utils" +) type ( //nolint:recvcheck - Config map[string]Category - Category []*Item + Categories map[string]Category + Category []*Item ItemConfig struct { Show bool `json:"show"` @@ -48,6 +50,10 @@ func NewItem(alias string) *Item { } } +func NewHomePageConfig() Categories { + return Categories(make(map[string]Category)) +} + func (item *Item) IsEmpty() bool { return item == nil || item.IsUnset || item.ItemConfig == nil } @@ -56,15 +62,11 @@ func (item *Item) GetOverride() *Item { return overrideConfigInstance.GetOverride(item) } -func NewHomePageConfig() Config { - return Config(make(map[string]Category)) +func (c *Categories) Clear() { + *c = make(Categories) } -func (c *Config) Clear() { - *c = make(Config) -} - -func (c Config) Add(item *Item) { +func (c Categories) Add(item *Item) { if c[item.Category] == nil { c[item.Category] = make(Category, 0) } diff --git a/internal/homepage/override_config.go b/internal/homepage/override_config.go index 4d10717a..01929e69 100644 --- a/internal/homepage/override_config.go +++ b/internal/homepage/override_config.go @@ -53,10 +53,6 @@ func GetOverrideConfig() *OverrideConfig { return overrideConfigInstance } -func (c *OverrideConfig) UnmarshalJSON(data []byte) error { - return utils.DeserializeJSON(data, c) -} - func (c *OverrideConfig) OverrideItem(alias string, override *ItemConfig) { c.mu.Lock() defer c.mu.Unlock() diff --git a/internal/notif/base.go b/internal/notif/base.go new file mode 100644 index 00000000..272d2d24 --- /dev/null +++ b/internal/notif/base.go @@ -0,0 +1,47 @@ +package notif + +import ( + "net/url" + "strings" + + E "github.com/yusing/go-proxy/internal/error" +) + +type ProviderBase struct { + Name string `json:"name" validate:"required"` + URL string `json:"url" validate:"url"` + Token string `json:"token"` +} + +var ( + ErrMissingToken = E.New("token is required") + ErrURLMissingScheme = E.New("url missing scheme, expect 'http://' or 'https://'") +) + +// Validate implements the utils.CustomValidator interface. +func (base *ProviderBase) Validate() E.Error { + if base.Token == "" { + return ErrMissingToken + } + if !strings.HasPrefix(base.URL, "http://") && !strings.HasPrefix(base.URL, "https://") { + return ErrURLMissingScheme + } + u, err := url.Parse(base.URL) + if err != nil { + return E.Wrap(err) + } + base.URL = u.String() + return nil +} + +func (base *ProviderBase) GetName() string { + return base.Name +} + +func (base *ProviderBase) GetURL() string { + return base.URL +} + +func (base *ProviderBase) GetToken() string { + return base.Token +} diff --git a/internal/notif/config.go b/internal/notif/config.go new file mode 100644 index 00000000..64d99570 --- /dev/null +++ b/internal/notif/config.go @@ -0,0 +1,54 @@ +package notif + +import ( + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils" +) + +type NotificationConfig struct { + ProviderName string `json:"provider"` + Provider Provider `json:"-"` +} + +var ( + ErrMissingNotifProvider = E.New("missing notification provider") + ErrInvalidNotifProviderType = E.New("invalid notification provider type") + ErrUnknownNotifProvider = E.New("unknown notification provider") +) + +// UnmarshalMap implements MapUnmarshaler. +func (cfg *NotificationConfig) UnmarshalMap(m map[string]any) (err E.Error) { + // extract provider name + providerName := m["provider"] + switch providerName := providerName.(type) { + case string: + cfg.ProviderName = providerName + default: + return ErrInvalidNotifProviderType + } + delete(m, "provider") + + if cfg.ProviderName == "" { + return ErrMissingNotifProvider + } + + // validate provider name and initialize provider + switch cfg.ProviderName { + case ProviderWebhook: + cfg.Provider = &Webhook{} + case ProviderGotify: + cfg.Provider = &GotifyClient{} + default: + return ErrUnknownNotifProvider. + Subject(cfg.ProviderName). + Withf("expect %s or %s", ProviderWebhook, ProviderGotify) + } + + // unmarshal provider config + if err := utils.Deserialize(m, cfg.Provider); err != nil { + return err + } + + // validate provider + return cfg.Provider.Validate() +} diff --git a/internal/notif/config_test.go b/internal/notif/config_test.go new file mode 100644 index 00000000..c9f6bffb --- /dev/null +++ b/internal/notif/config_test.go @@ -0,0 +1,163 @@ +package notif + +import ( + "net/http" + "testing" + + "github.com/yusing/go-proxy/internal/utils" + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestNotificationConfig(t *testing.T) { + tests := []struct { + name string + cfg map[string]any + expected Provider + wantErr bool + }{ + { + name: "valid_webhook", + cfg: map[string]any{ + "name": "test", + "provider": "webhook", + "template": "discord", + "url": "https://example.com", + }, + expected: &Webhook{ + ProviderBase: ProviderBase{ + Name: "test", + URL: "https://example.com", + }, + Template: "discord", + Method: http.MethodPost, + MIMEType: "application/json", + ColorMode: "dec", + Payload: discordPayload, + }, + wantErr: false, + }, + { + name: "valid_gotify", + cfg: map[string]any{ + "name": "test", + "provider": "gotify", + "url": "https://example.com", + "token": "token", + }, + expected: &GotifyClient{ + ProviderBase: ProviderBase{ + Name: "test", + URL: "https://example.com", + Token: "token", + }, + }, + wantErr: false, + }, + { + name: "invalid_provider", + cfg: map[string]any{ + "name": "test", + "provider": "invalid", + "url": "https://example.com", + }, + wantErr: true, + }, + { + name: "missing_url", + cfg: map[string]any{ + "name": "test", + "provider": "webhook", + }, + wantErr: true, + }, + { + name: "missing_provider", + cfg: map[string]any{ + "name": "test", + "provider": "webhook", + }, + wantErr: true, + }, + { + name: "gotify_missing_token", + cfg: map[string]any{ + "name": "test", + "provider": "gotify", + "url": "https://example.com", + }, + wantErr: true, + }, + { + name: "webhook_missing_payload", + cfg: map[string]any{ + "name": "test", + "provider": "webhook", + "url": "https://example.com", + }, + wantErr: true, + }, + { + name: "webhook_missing_url", + cfg: map[string]any{ + "name": "test", + "provider": "webhook", + }, + wantErr: true, + }, + { + name: "webhook_invalid_template", + cfg: map[string]any{ + "name": "test", + "provider": "webhook", + "url": "https://example.com", + "template": "invalid", + }, + wantErr: true, + }, + { + name: "webhook_invalid_json_payload", + cfg: map[string]any{ + "name": "test", + "provider": "webhook", + "url": "https://example.com", + "mime_type": "application/json", + "payload": "invalid", + }, + wantErr: true, + }, + { + name: "webhook_empty_text_payload", + cfg: map[string]any{ + "name": "test", + "provider": "webhook", + "url": "https://example.com", + "mime_type": "text/plain", + }, + wantErr: true, + }, + { + name: "webhook_invalid_method", + cfg: map[string]any{ + "name": "test", + "provider": "webhook", + "url": "https://example.com", + "method": "invalid", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var cfg NotificationConfig + provider := tt.cfg["provider"] + err := utils.Deserialize(tt.cfg, &cfg) + if tt.wantErr { + ExpectHasError(t, err) + } else { + ExpectNoError(t, err) + ExpectEqual(t, provider.(string), cfg.ProviderName) + ExpectDeepEqual(t, cfg.Provider, tt.expected) + } + }) + } +} diff --git a/internal/notif/dispatcher.go b/internal/notif/dispatcher.go index d38d261b..af466e17 100644 --- a/internal/notif/dispatcher.go +++ b/internal/notif/dispatcher.go @@ -2,13 +2,10 @@ package notif import ( "github.com/rs/zerolog" - "github.com/yusing/go-proxy/internal/config/types" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" - "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" - "github.com/yusing/go-proxy/internal/utils/strutils" ) type ( @@ -27,12 +24,6 @@ type ( var dispatcher *Dispatcher -var ( - ErrMissingNotifProvider = E.New("missing notification provider") - ErrInvalidNotifProviderType = E.New("invalid notification provider type") - ErrUnknownNotifProvider = E.New("unknown notification provider") -) - const dispatchErr = "notification dispatch error" func StartNotifDispatcher(parent task.Parent) *Dispatcher { @@ -57,29 +48,8 @@ func Notify(msg *LogMessage) { } } -func (disp *Dispatcher) RegisterProvider(cfg types.NotificationConfig) (Provider, E.Error) { - providerName, ok := cfg["provider"] - if !ok { - return nil, ErrMissingNotifProvider - } - switch providerName := providerName.(type) { - case string: - delete(cfg, "provider") - createFunc, ok := Providers[providerName] - if !ok { - return nil, ErrUnknownNotifProvider. - Subject(providerName). - Withf(strutils.DoYouMean(utils.NearestField(providerName, Providers))) - } - - provider, err := createFunc(cfg) - if err == nil { - disp.providers.Add(provider) - } - return provider, err - default: - return nil, ErrInvalidNotifProviderType.Subjectf("%T", providerName) - } +func (disp *Dispatcher) RegisterProvider(cfg *NotificationConfig) { + disp.providers.Add(cfg.Provider) } func (disp *Dispatcher) start() { @@ -110,7 +80,7 @@ func (disp *Dispatcher) dispatch(msg *LogMessage) { errs := E.NewBuilder(dispatchErr) disp.providers.RangeAllParallel(func(p Provider) { if err := notifyProvider(task.Context(), p, msg); err != nil { - errs.Add(E.PrependSubject(p.Name(), err)) + errs.Add(E.PrependSubject(p.GetName(), err)) } }) if errs.HasError() { diff --git a/internal/notif/gotify.go b/internal/notif/gotify.go index fc846896..d0a7b1ba 100644 --- a/internal/notif/gotify.go +++ b/internal/notif/gotify.go @@ -13,37 +13,24 @@ import ( type ( GotifyClient struct { - N string `json:"name" validate:"required"` - U string `json:"url" validate:"url"` - Tok string `json:"token" validate:"required"` + ProviderBase } GotifyMessage model.MessageExternal ) const gotifyMsgEndpoint = "/message" -// Name implements Provider. -func (client *GotifyClient) Name() string { - return client.N +func (client *GotifyClient) GetURL() string { + return client.URL + gotifyMsgEndpoint } -// Method implements Provider. -func (client *GotifyClient) Method() string { +// GetMethod implements Provider. +func (client *GotifyClient) GetMethod() string { return http.MethodPost } -// URL implements Provider. -func (client *GotifyClient) URL() string { - return client.U + gotifyMsgEndpoint -} - -// Token implements Provider. -func (client *GotifyClient) Token() string { - return client.Tok -} - -// MIMEType implements Provider. -func (client *GotifyClient) MIMEType() string { +// GetMIMEType implements Provider. +func (client *GotifyClient) GetMIMEType() string { return "application/json" } diff --git a/internal/notif/gotify_test.go b/internal/notif/gotify_test.go deleted file mode 100644 index fda36edc..00000000 --- a/internal/notif/gotify_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package notif - -import ( - "testing" - - "github.com/yusing/go-proxy/internal/utils" - . "github.com/yusing/go-proxy/internal/utils/testing" -) - -func TestGotifyValidation(t *testing.T) { - t.Parallel() - - newGotify := Providers[ProviderGotify] - - t.Run("valid", func(t *testing.T) { - t.Parallel() - _, err := newGotify(map[string]any{ - "name": "test", - "url": "https://example.com", - "token": "token", - }) - ExpectNoError(t, err) - }) - - t.Run("missing url", func(t *testing.T) { - t.Parallel() - _, err := newGotify(map[string]any{ - "name": "test", - "token": "token", - }) - ExpectError(t, utils.ErrValidationError, err) - }) - - t.Run("missing token", func(t *testing.T) { - t.Parallel() - _, err := newGotify(map[string]any{ - "name": "test", - "url": "https://example.com", - }) - ExpectError(t, utils.ErrValidationError, err) - }) - - t.Run("invalid url", func(t *testing.T) { - t.Parallel() - _, err := newGotify(map[string]any{ - "name": "test", - "url": "example.com", - "token": "token", - }) - ExpectError(t, utils.ErrValidationError, err) - }) -} diff --git a/internal/notif/providers.go b/internal/notif/providers.go index bb6aa5b3..3082b2b5 100644 --- a/internal/notif/providers.go +++ b/internal/notif/providers.go @@ -2,22 +2,24 @@ package notif import ( "context" - "fmt" "io" "net/http" E "github.com/yusing/go-proxy/internal/error" gphttp "github.com/yusing/go-proxy/internal/net/http" - U "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils" ) type ( Provider interface { - Name() string - URL() string - Method() string - Token() string - MIMEType() string + utils.CustomValidator + + GetName() string + GetURL() string + GetToken() string + GetMethod() string + GetMIMEType() string + MakeBody(logMsg *LogMessage) (io.Reader, error) makeRespError(resp *http.Response) error @@ -31,47 +33,29 @@ const ( ProviderWebhook = "webhook" ) -var Providers = map[string]ProviderCreateFunc{ - ProviderGotify: newNotifProvider[*GotifyClient], - ProviderWebhook: newNotifProvider[*Webhook], -} - -func newNotifProvider[T Provider](cfg map[string]any) (Provider, E.Error) { - var client T - err := U.Deserialize(cfg, &client) - if err != nil { - return nil, err.Subject(client.Name()) - } - return client, nil -} - -func formatError(p Provider, err error) error { - return fmt.Errorf("%s error: %w", p.Name(), err) -} - func notifyProvider(ctx context.Context, provider Provider, msg *LogMessage) error { body, err := provider.MakeBody(msg) if err != nil { - return formatError(provider, err) + return E.PrependSubject(provider.GetName(), err) } req, err := http.NewRequestWithContext( ctx, http.MethodPost, - provider.URL(), + provider.GetURL(), body, ) if err != nil { - return formatError(provider, err) + return E.PrependSubject(provider.GetName(), err) } - req.Header.Set("Content-Type", provider.MIMEType()) - if provider.Token() != "" { - req.Header.Set("Authorization", "Bearer "+provider.Token()) + req.Header.Set("Content-Type", provider.GetMIMEType()) + if provider.GetToken() != "" { + req.Header.Set("Authorization", "Bearer "+provider.GetToken()) } resp, err := http.DefaultClient.Do(req) if err != nil { - return formatError(provider, err) + return E.PrependSubject(provider.GetName(), err) } defer resp.Body.Close() diff --git a/internal/notif/webhook.go b/internal/notif/webhook.go index 46d5f239..f3d9d15b 100644 --- a/internal/notif/webhook.go +++ b/internal/notif/webhook.go @@ -8,19 +8,16 @@ import ( "net/http" "strings" - "github.com/go-playground/validator/v10" - "github.com/yusing/go-proxy/internal/utils" + E "github.com/yusing/go-proxy/internal/error" ) type Webhook struct { - N string `json:"name" validate:"required"` - U string `json:"url" validate:"url"` - Template string `json:"template" validate:"omitempty,oneof=discord"` - Payload string `json:"payload" validate:"jsonIfTemplateNotUsed"` - Tok string `json:"token"` - Meth string `json:"method" validate:"oneof=GET POST PUT"` - MIMETyp string `json:"mime_type"` - ColorM string `json:"color_mode" validate:"oneof=hex dec"` + ProviderBase + Template string `json:"template"` + Payload string `json:"payload"` + Method string `json:"method"` + MIMEType string `json:"mime_type"` + ColorMode string `json:"color_mode"` } //go:embed templates/discord.json @@ -30,60 +27,65 @@ var webhookTemplates = map[string]string{ "discord": discordPayload, } -func DefaultValue() *Webhook { - return &Webhook{ - Meth: "POST", - ColorM: "hex", - MIMETyp: "application/json", +func (webhook *Webhook) Validate() E.Error { + if err := webhook.ProviderBase.Validate(); err != nil && !err.Is(ErrMissingToken) { + return err } -} -func jsonIfTemplateNotUsed(fl validator.FieldLevel) bool { - template := fl.Parent().FieldByName("Template").String() - if template != "" { - return true - } - payload := fl.Field().String() - return json.Valid([]byte(payload)) -} - -func init() { - utils.RegisterDefaultValueFactory(DefaultValue) - utils.MustRegisterValidation("jsonIfTemplateNotUsed", jsonIfTemplateNotUsed) -} - -// Name implements Provider. -func (webhook *Webhook) Name() string { - return webhook.N -} - -// Method implements Provider. -func (webhook *Webhook) Method() string { - return webhook.Meth -} - -// URL implements Provider. -func (webhook *Webhook) URL() string { - return webhook.U -} - -// Token implements Provider. -func (webhook *Webhook) Token() string { - return webhook.Tok -} - -// MIMEType implements Provider. -func (webhook *Webhook) MIMEType() string { - return webhook.MIMETyp -} - -func (webhook *Webhook) ColorMode() string { - switch webhook.Template { - case "discord": - return "dec" + switch webhook.MIMEType { + case "": + webhook.MIMEType = "application/json" + case "application/json", "application/x-www-form-urlencoded", "text/plain": default: - return webhook.ColorM + return E.New("invalid mime_type, expect empty, 'application/json', 'application/x-www-form-urlencoded' or 'text/plain'") } + + switch webhook.Template { + case "": + if webhook.MIMEType == "application/json" && !json.Valid([]byte(webhook.Payload)) { + return E.New("invalid payload, expect valid JSON") + } + if webhook.Payload == "" { + return E.New("invalid payload, expect non-empty") + } + case "discord": + webhook.ColorMode = "dec" + webhook.Method = http.MethodPost + webhook.MIMEType = "application/json" + if webhook.Payload == "" { + webhook.Payload = discordPayload + } + default: + return E.New("invalid template, expect empty or 'discord'") + } + + switch webhook.Method { + case "": + webhook.Method = http.MethodPost + case http.MethodGet, http.MethodPost, http.MethodPut: + default: + return E.New("invalid method, expect empty, 'GET', 'POST' or 'PUT'") + } + + switch webhook.ColorMode { + case "": + webhook.ColorMode = "hex" + case "hex", "dec": + default: + return E.New("invalid color_mode, expect empty, 'hex' or 'dec'") + } + + return nil +} + +// GetMethod implements Provider. +func (webhook *Webhook) GetMethod() string { + return webhook.Method +} + +// GetMIMEType implements Provider. +func (webhook *Webhook) GetMIMEType() string { + return webhook.MIMEType } // makeRespError implements Provider. @@ -108,7 +110,7 @@ func (webhook *Webhook) MakeBody(logMsg *LogMessage) (io.Reader, error) { return nil, err } var color string - if webhook.ColorMode() == "hex" { + if webhook.ColorMode == "hex" { color = logMsg.Color.HexString() } else { color = logMsg.Color.DecString() diff --git a/internal/notif/webhook_test.go b/internal/notif/webhook_test.go deleted file mode 100644 index 9fb31ab8..00000000 --- a/internal/notif/webhook_test.go +++ /dev/null @@ -1,121 +0,0 @@ -package notif - -import ( - "encoding/json" - "testing" - - "github.com/yusing/go-proxy/internal/utils" - . "github.com/yusing/go-proxy/internal/utils/testing" -) - -func TestWebhookValidation(t *testing.T) { - t.Parallel() - - newWebhook := Providers[ProviderWebhook] - - t.Run("valid", func(t *testing.T) { - t.Parallel() - _, err := newWebhook(map[string]any{ - "name": "test", - "url": "https://example.com", - "payload": "{}", - }) - ExpectNoError(t, err) - }) - t.Run("valid template", func(t *testing.T) { - t.Parallel() - _, err := newWebhook(map[string]any{ - "name": "test", - "url": "https://example.com", - "template": "discord", - }) - ExpectNoError(t, err) - }) - - t.Run("missing url", func(t *testing.T) { - t.Parallel() - _, err := newWebhook(map[string]any{ - "name": "test", - "payload": "{}", - }) - ExpectError(t, utils.ErrValidationError, err) - }) - - t.Run("missing payload", func(t *testing.T) { - t.Parallel() - _, err := newWebhook(map[string]any{ - "name": "test", - "url": "https://example.com", - }) - ExpectError(t, utils.ErrValidationError, err) - }) - t.Run("invalid url", func(t *testing.T) { - t.Parallel() - _, err := newWebhook(map[string]any{ - "name": "test", - "url": "example.com", - "payload": "{}", - }) - ExpectError(t, utils.ErrValidationError, err) - }) - t.Run("invalid payload", func(t *testing.T) { - t.Parallel() - _, err := newWebhook(map[string]any{ - "name": "test", - "url": "https://example.com", - "payload": "abcd", - }) - ExpectError(t, utils.ErrValidationError, err) - }) - t.Run("invalid method", func(t *testing.T) { - t.Parallel() - _, err := newWebhook(map[string]any{ - "name": "test", - "url": "https://example.com", - "payload": "{}", - "method": "abcd", - }) - ExpectError(t, utils.ErrValidationError, err) - }) - t.Run("invalid template", func(t *testing.T) { - t.Parallel() - _, err := newWebhook(map[string]any{ - "name": "test", - "url": "https://example.com", - "template": "abcd", - }) - ExpectError(t, utils.ErrValidationError, err) - }) -} - -func TestWebhookBody(t *testing.T) { - t.Parallel() - - var webhook Webhook - webhook.Payload = discordPayload - bodyReader, err := webhook.MakeBody(&LogMessage{ - Title: "abc", - Extras: map[string]any{ - "foo": "bar", - }, - }) - ExpectNoError(t, err) - - var body struct { - Embeds []struct { - Title string `json:"title"` - Fields []struct { - Name string `json:"name"` - Value string `json:"value"` - } `json:"fields"` - } `json:"embeds"` - } - - err = json.NewDecoder(bodyReader).Decode(&body) - ExpectNoError(t, err) - - ExpectEqual(t, body.Embeds[0].Title, "abc") - fields := body.Embeds[0].Fields - ExpectEqual(t, fields[0].Name, "foo") - ExpectEqual(t, fields[0].Value, "bar") -} diff --git a/internal/route/routes/routequery/query.go b/internal/route/routes/routequery/query.go index e420e18a..d92d7d36 100644 --- a/internal/route/routes/routequery/query.go +++ b/internal/route/routes/routequery/query.go @@ -57,7 +57,7 @@ func HomepageCategories() []string { return categories } -func HomepageConfig(useDefaultCategories bool, categoryFilter, providerFilter string) homepage.Config { +func HomepageConfig(useDefaultCategories bool, categoryFilter, providerFilter string) homepage.Categories { hpCfg := homepage.NewHomePageConfig() routes.GetHTTPRoutes().RangeAll(func(alias string, r route.HTTPRoute) { diff --git a/internal/utils/serialization.go b/internal/utils/serialization.go index 13da6216..f552c298 100644 --- a/internal/utils/serialization.go +++ b/internal/utils/serialization.go @@ -1,13 +1,10 @@ package utils -// FIXME: some times [%d] is not in correct order - import ( "encoding/json" "errors" "os" "reflect" - "runtime/debug" "strconv" "strings" "time" @@ -21,6 +18,10 @@ import ( type SerializedObject = map[string]any +type MapUnmarshaller interface { + UnmarshalMap(m map[string]any) E.Error +} + var ( ErrInvalidType = E.New("invalid type") ErrNilValue = E.New("nil") @@ -29,6 +30,8 @@ var ( ErrUnknownField = E.New("unknown field") ) +var mapUnmarshalerType = reflect.TypeFor[MapUnmarshaller]() + var defaultValues = functional.NewMapOf[reflect.Type, func() any]() func RegisterDefaultValueFactory[T any](factory func() *T) { @@ -56,8 +59,9 @@ func extractFields(t reflect.Type) (all, anonymous []reflect.StructField) { if t.Kind() != reflect.Struct { return nil, nil } - var fields []reflect.StructField - for i := range t.NumField() { + n := t.NumField() + fields := make([]reflect.StructField, 0, n) + for i := range n { field := t.Field(i) if !field.IsExported() { continue @@ -74,31 +78,74 @@ func extractFields(t reflect.Type) (all, anonymous []reflect.StructField) { return fields, anonymous } +func ValidateWithFieldTags(s any) E.Error { + errs := E.NewBuilder("validate error") + err := validate.Struct(s) + var valErrs validator.ValidationErrors + if errors.As(err, &valErrs) { + for _, e := range valErrs { + detail := e.ActualTag() + if e.Param() != "" { + detail += ":" + e.Param() + } + errs.Add(ErrValidationError. + Subject(e.Namespace()). + Withf("require %q", detail)) + } + } + return errs.Error() +} + // Deserialize takes a SerializedObject and a target value, and assigns the values in the SerializedObject to the target value. // Deserialize ignores case differences between the field names in the SerializedObject and the target. // // The target value must be a struct or a map[string]any. -// If the target value is a struct, the SerializedObject will be deserialized into the struct fields and validate if needed. -// If the target value is a map[string]any, the SerializedObject will be deserialized into the map. +// If the target value is a struct , and implements the MapUnmarshaller interface, +// the UnmarshalMap method will be called. +// +// If the target value is a struct, but does not implements the MapUnmarshaller interface, +// the SerializedObject will be deserialized into the struct fields and validate if needed. +// +// If the target value is a map[string]any the SerializedObject will be deserialized into the map. // // The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization. func Deserialize(src SerializedObject, dst any) E.Error { - if src == nil { - return E.Errorf("deserialize: src is %w", ErrNilValue) - } - if dst == nil { - return E.Errorf("deserialize: dst is %w\n%s", ErrNilValue, debug.Stack()) - } - dstV := reflect.ValueOf(dst) dstT := dstV.Type() + if src == nil { + if dstV.CanSet() { + dstV.Set(reflect.Zero(dstT)) + return nil + } + return E.Errorf("deserialize: src is %w and dst is not settable", ErrNilValue) + } + + if dstT.Implements(mapUnmarshalerType) { + for dstV.IsNil() { + switch dstT.Kind() { + case reflect.Struct: + dstV.Set(New(dstT)) + case reflect.Map: + dstV.Set(reflect.MakeMap(dstT)) + case reflect.Slice: + dstV.Set(reflect.MakeSlice(dstT, 0, 0)) + case reflect.Ptr: + dstV.Set(reflect.New(dstT.Elem())) + default: + return E.Errorf("deserialize: %w for dst %s", ErrInvalidType, dstT.String()) + } + dstV = dstV.Elem() + } + return dstV.Interface().(MapUnmarshaller).UnmarshalMap(src) + } + for dstT.Kind() == reflect.Ptr { if dstV.IsNil() { if dstV.CanSet() { dstV.Set(New(dstT.Elem())) } else { - return E.Errorf("deserialize: dst is %w\n%s", ErrNilValue, debug.Stack()) + return E.Errorf("deserialize: dst is %w and not settable", ErrNilValue) } } dstV = dstV.Elem() @@ -113,9 +160,8 @@ func Deserialize(src SerializedObject, dst any) E.Error { switch dstV.Kind() { case reflect.Struct: - needValidate := false + hasValidateTag := false mapping := make(map[string]reflect.Value) - fieldName := make(map[string]string) fields, anonymous := extractFields(dstT) for _, anon := range anonymous { if field := dstV.FieldByName(anon.Name); field.Kind() == reflect.Ptr && field.IsNil() { @@ -134,17 +180,15 @@ func Deserialize(src SerializedObject, dst any) E.Error { } key = strutils.ToLowerNoSnake(key) mapping[key] = dstV.FieldByName(field.Name) - fieldName[field.Name] = key - if !needValidate { - _, needValidate = field.Tag.Lookup("validate") + if !hasValidateTag { + _, hasValidateTag = field.Tag.Lookup("validate") } aliases, ok := field.Tag.Lookup("aliases") if ok { for _, alias := range strutils.CommaSeperatedList(aliases) { mapping[alias] = dstV.FieldByName(field.Name) - fieldName[field.Name] = alias } } } @@ -158,20 +202,10 @@ func Deserialize(src SerializedObject, dst any) E.Error { errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, mapping)))) } } - if needValidate { - err := validate.Struct(dstV.Interface()) - var valErrs validator.ValidationErrors - if errors.As(err, &valErrs) { - for _, e := range valErrs { - detail := e.ActualTag() - if e.Param() != "" { - detail += ":" + e.Param() - } - errs.Add(ErrValidationError. - Subject(e.StructNamespace()). - Withf("require %q", detail)) - } - } + if hasValidateTag { + errs.Add(ValidateWithFieldTags(dstV.Interface())) + } else if validator, ok := dstV.Addr().Interface().(CustomValidator); ok { + errs.Add(validator.Validate()) } return errs.Error() case reflect.Map: @@ -188,6 +222,9 @@ func Deserialize(src SerializedObject, dst any) E.Error { errs.Add(err.Subject(k)) } } + if validator, ok := dstV.Addr().Interface().(CustomValidator); ok { + errs.Add(validator.Validate()) + } return errs.Error() default: return ErrUnsupportedConversion.Subject("mapping to " + dstT.String()) @@ -421,14 +458,6 @@ func DeserializeYAMLMap[V any](data []byte) (_ functional.Map[string, V], err E. return functional.NewMapFrom(m2), nil } -func DeserializeJSON[T any](data []byte, target T) E.Error { - m := make(map[string]any) - if err := json.Unmarshal(data, &m); err != nil { - return E.From(err) - } - return Deserialize(m, target) -} - func loadSerialized[T any](path string, dst *T, deserialize func(data []byte, dst any) error) error { data, err := os.ReadFile(path) if err != nil { diff --git a/internal/utils/testing/testing.go b/internal/utils/testing/testing.go index c19e5b87..b95ba7d1 100644 --- a/internal/utils/testing/testing.go +++ b/internal/utils/testing/testing.go @@ -35,6 +35,14 @@ func ExpectNoError(t *testing.T, err error) { } } +func ExpectHasError(t *testing.T, err error) { + t.Helper() + if errors.Is(err, nil) { + t.Error("expected err not nil") + t.FailNow() + } +} + func ExpectError(t *testing.T, expected error, err error) { t.Helper() if !errors.Is(err, expected) { diff --git a/internal/utils/validation.go b/internal/utils/validation.go index 226657a3..32db1242 100644 --- a/internal/utils/validation.go +++ b/internal/utils/validation.go @@ -9,6 +9,10 @@ var validate = validator.New() var ErrValidationError = E.New("validation error") +type CustomValidator interface { + Validate() E.Error +} + func Validator() *validator.Validate { return validate }