Cleaned up some validation code, stricter validation

This commit is contained in:
yusing
2025-01-26 14:43:48 +08:00
parent 254224c0e8
commit 1586610a44
23 changed files with 590 additions and 468 deletions

47
internal/notif/base.go Normal file
View File

@@ -0,0 +1,47 @@
package notif
import (
"net/url"
"strings"
E "github.com/yusing/go-proxy/internal/error"
)
type ProviderBase struct {
Name string `json:"name" validate:"required"`
URL string `json:"url" validate:"url"`
Token string `json:"token"`
}
var (
ErrMissingToken = E.New("token is required")
ErrURLMissingScheme = E.New("url missing scheme, expect 'http://' or 'https://'")
)
// Validate implements the utils.CustomValidator interface.
func (base *ProviderBase) Validate() E.Error {
if base.Token == "" {
return ErrMissingToken
}
if !strings.HasPrefix(base.URL, "http://") && !strings.HasPrefix(base.URL, "https://") {
return ErrURLMissingScheme
}
u, err := url.Parse(base.URL)
if err != nil {
return E.Wrap(err)
}
base.URL = u.String()
return nil
}
func (base *ProviderBase) GetName() string {
return base.Name
}
func (base *ProviderBase) GetURL() string {
return base.URL
}
func (base *ProviderBase) GetToken() string {
return base.Token
}

54
internal/notif/config.go Normal file
View File

@@ -0,0 +1,54 @@
package notif
import (
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils"
)
type NotificationConfig struct {
ProviderName string `json:"provider"`
Provider Provider `json:"-"`
}
var (
ErrMissingNotifProvider = E.New("missing notification provider")
ErrInvalidNotifProviderType = E.New("invalid notification provider type")
ErrUnknownNotifProvider = E.New("unknown notification provider")
)
// UnmarshalMap implements MapUnmarshaler.
func (cfg *NotificationConfig) UnmarshalMap(m map[string]any) (err E.Error) {
// extract provider name
providerName := m["provider"]
switch providerName := providerName.(type) {
case string:
cfg.ProviderName = providerName
default:
return ErrInvalidNotifProviderType
}
delete(m, "provider")
if cfg.ProviderName == "" {
return ErrMissingNotifProvider
}
// validate provider name and initialize provider
switch cfg.ProviderName {
case ProviderWebhook:
cfg.Provider = &Webhook{}
case ProviderGotify:
cfg.Provider = &GotifyClient{}
default:
return ErrUnknownNotifProvider.
Subject(cfg.ProviderName).
Withf("expect %s or %s", ProviderWebhook, ProviderGotify)
}
// unmarshal provider config
if err := utils.Deserialize(m, cfg.Provider); err != nil {
return err
}
// validate provider
return cfg.Provider.Validate()
}

View File

@@ -0,0 +1,163 @@
package notif
import (
"net/http"
"testing"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestNotificationConfig(t *testing.T) {
tests := []struct {
name string
cfg map[string]any
expected Provider
wantErr bool
}{
{
name: "valid_webhook",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"template": "discord",
"url": "https://example.com",
},
expected: &Webhook{
ProviderBase: ProviderBase{
Name: "test",
URL: "https://example.com",
},
Template: "discord",
Method: http.MethodPost,
MIMEType: "application/json",
ColorMode: "dec",
Payload: discordPayload,
},
wantErr: false,
},
{
name: "valid_gotify",
cfg: map[string]any{
"name": "test",
"provider": "gotify",
"url": "https://example.com",
"token": "token",
},
expected: &GotifyClient{
ProviderBase: ProviderBase{
Name: "test",
URL: "https://example.com",
Token: "token",
},
},
wantErr: false,
},
{
name: "invalid_provider",
cfg: map[string]any{
"name": "test",
"provider": "invalid",
"url": "https://example.com",
},
wantErr: true,
},
{
name: "missing_url",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
},
wantErr: true,
},
{
name: "missing_provider",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
},
wantErr: true,
},
{
name: "gotify_missing_token",
cfg: map[string]any{
"name": "test",
"provider": "gotify",
"url": "https://example.com",
},
wantErr: true,
},
{
name: "webhook_missing_payload",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"url": "https://example.com",
},
wantErr: true,
},
{
name: "webhook_missing_url",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
},
wantErr: true,
},
{
name: "webhook_invalid_template",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"url": "https://example.com",
"template": "invalid",
},
wantErr: true,
},
{
name: "webhook_invalid_json_payload",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"url": "https://example.com",
"mime_type": "application/json",
"payload": "invalid",
},
wantErr: true,
},
{
name: "webhook_empty_text_payload",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"url": "https://example.com",
"mime_type": "text/plain",
},
wantErr: true,
},
{
name: "webhook_invalid_method",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"url": "https://example.com",
"method": "invalid",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var cfg NotificationConfig
provider := tt.cfg["provider"]
err := utils.Deserialize(tt.cfg, &cfg)
if tt.wantErr {
ExpectHasError(t, err)
} else {
ExpectNoError(t, err)
ExpectEqual(t, provider.(string), cfg.ProviderName)
ExpectDeepEqual(t, cfg.Provider, tt.expected)
}
})
}
}

View File

@@ -2,13 +2,10 @@ package notif
import (
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/config/types"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type (
@@ -27,12 +24,6 @@ type (
var dispatcher *Dispatcher
var (
ErrMissingNotifProvider = E.New("missing notification provider")
ErrInvalidNotifProviderType = E.New("invalid notification provider type")
ErrUnknownNotifProvider = E.New("unknown notification provider")
)
const dispatchErr = "notification dispatch error"
func StartNotifDispatcher(parent task.Parent) *Dispatcher {
@@ -57,29 +48,8 @@ func Notify(msg *LogMessage) {
}
}
func (disp *Dispatcher) RegisterProvider(cfg types.NotificationConfig) (Provider, E.Error) {
providerName, ok := cfg["provider"]
if !ok {
return nil, ErrMissingNotifProvider
}
switch providerName := providerName.(type) {
case string:
delete(cfg, "provider")
createFunc, ok := Providers[providerName]
if !ok {
return nil, ErrUnknownNotifProvider.
Subject(providerName).
Withf(strutils.DoYouMean(utils.NearestField(providerName, Providers)))
}
provider, err := createFunc(cfg)
if err == nil {
disp.providers.Add(provider)
}
return provider, err
default:
return nil, ErrInvalidNotifProviderType.Subjectf("%T", providerName)
}
func (disp *Dispatcher) RegisterProvider(cfg *NotificationConfig) {
disp.providers.Add(cfg.Provider)
}
func (disp *Dispatcher) start() {
@@ -110,7 +80,7 @@ func (disp *Dispatcher) dispatch(msg *LogMessage) {
errs := E.NewBuilder(dispatchErr)
disp.providers.RangeAllParallel(func(p Provider) {
if err := notifyProvider(task.Context(), p, msg); err != nil {
errs.Add(E.PrependSubject(p.Name(), err))
errs.Add(E.PrependSubject(p.GetName(), err))
}
})
if errs.HasError() {

View File

@@ -13,37 +13,24 @@ import (
type (
GotifyClient struct {
N string `json:"name" validate:"required"`
U string `json:"url" validate:"url"`
Tok string `json:"token" validate:"required"`
ProviderBase
}
GotifyMessage model.MessageExternal
)
const gotifyMsgEndpoint = "/message"
// Name implements Provider.
func (client *GotifyClient) Name() string {
return client.N
func (client *GotifyClient) GetURL() string {
return client.URL + gotifyMsgEndpoint
}
// Method implements Provider.
func (client *GotifyClient) Method() string {
// GetMethod implements Provider.
func (client *GotifyClient) GetMethod() string {
return http.MethodPost
}
// URL implements Provider.
func (client *GotifyClient) URL() string {
return client.U + gotifyMsgEndpoint
}
// Token implements Provider.
func (client *GotifyClient) Token() string {
return client.Tok
}
// MIMEType implements Provider.
func (client *GotifyClient) MIMEType() string {
// GetMIMEType implements Provider.
func (client *GotifyClient) GetMIMEType() string {
return "application/json"
}

View File

@@ -1,52 +0,0 @@
package notif
import (
"testing"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestGotifyValidation(t *testing.T) {
t.Parallel()
newGotify := Providers[ProviderGotify]
t.Run("valid", func(t *testing.T) {
t.Parallel()
_, err := newGotify(map[string]any{
"name": "test",
"url": "https://example.com",
"token": "token",
})
ExpectNoError(t, err)
})
t.Run("missing url", func(t *testing.T) {
t.Parallel()
_, err := newGotify(map[string]any{
"name": "test",
"token": "token",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("missing token", func(t *testing.T) {
t.Parallel()
_, err := newGotify(map[string]any{
"name": "test",
"url": "https://example.com",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid url", func(t *testing.T) {
t.Parallel()
_, err := newGotify(map[string]any{
"name": "test",
"url": "example.com",
"token": "token",
})
ExpectError(t, utils.ErrValidationError, err)
})
}

View File

@@ -2,22 +2,24 @@ package notif
import (
"context"
"fmt"
"io"
"net/http"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils"
)
type (
Provider interface {
Name() string
URL() string
Method() string
Token() string
MIMEType() string
utils.CustomValidator
GetName() string
GetURL() string
GetToken() string
GetMethod() string
GetMIMEType() string
MakeBody(logMsg *LogMessage) (io.Reader, error)
makeRespError(resp *http.Response) error
@@ -31,47 +33,29 @@ const (
ProviderWebhook = "webhook"
)
var Providers = map[string]ProviderCreateFunc{
ProviderGotify: newNotifProvider[*GotifyClient],
ProviderWebhook: newNotifProvider[*Webhook],
}
func newNotifProvider[T Provider](cfg map[string]any) (Provider, E.Error) {
var client T
err := U.Deserialize(cfg, &client)
if err != nil {
return nil, err.Subject(client.Name())
}
return client, nil
}
func formatError(p Provider, err error) error {
return fmt.Errorf("%s error: %w", p.Name(), err)
}
func notifyProvider(ctx context.Context, provider Provider, msg *LogMessage) error {
body, err := provider.MakeBody(msg)
if err != nil {
return formatError(provider, err)
return E.PrependSubject(provider.GetName(), err)
}
req, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
provider.URL(),
provider.GetURL(),
body,
)
if err != nil {
return formatError(provider, err)
return E.PrependSubject(provider.GetName(), err)
}
req.Header.Set("Content-Type", provider.MIMEType())
if provider.Token() != "" {
req.Header.Set("Authorization", "Bearer "+provider.Token())
req.Header.Set("Content-Type", provider.GetMIMEType())
if provider.GetToken() != "" {
req.Header.Set("Authorization", "Bearer "+provider.GetToken())
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return formatError(provider, err)
return E.PrependSubject(provider.GetName(), err)
}
defer resp.Body.Close()

View File

@@ -8,19 +8,16 @@ import (
"net/http"
"strings"
"github.com/go-playground/validator/v10"
"github.com/yusing/go-proxy/internal/utils"
E "github.com/yusing/go-proxy/internal/error"
)
type Webhook struct {
N string `json:"name" validate:"required"`
U string `json:"url" validate:"url"`
Template string `json:"template" validate:"omitempty,oneof=discord"`
Payload string `json:"payload" validate:"jsonIfTemplateNotUsed"`
Tok string `json:"token"`
Meth string `json:"method" validate:"oneof=GET POST PUT"`
MIMETyp string `json:"mime_type"`
ColorM string `json:"color_mode" validate:"oneof=hex dec"`
ProviderBase
Template string `json:"template"`
Payload string `json:"payload"`
Method string `json:"method"`
MIMEType string `json:"mime_type"`
ColorMode string `json:"color_mode"`
}
//go:embed templates/discord.json
@@ -30,60 +27,65 @@ var webhookTemplates = map[string]string{
"discord": discordPayload,
}
func DefaultValue() *Webhook {
return &Webhook{
Meth: "POST",
ColorM: "hex",
MIMETyp: "application/json",
func (webhook *Webhook) Validate() E.Error {
if err := webhook.ProviderBase.Validate(); err != nil && !err.Is(ErrMissingToken) {
return err
}
}
func jsonIfTemplateNotUsed(fl validator.FieldLevel) bool {
template := fl.Parent().FieldByName("Template").String()
if template != "" {
return true
}
payload := fl.Field().String()
return json.Valid([]byte(payload))
}
func init() {
utils.RegisterDefaultValueFactory(DefaultValue)
utils.MustRegisterValidation("jsonIfTemplateNotUsed", jsonIfTemplateNotUsed)
}
// Name implements Provider.
func (webhook *Webhook) Name() string {
return webhook.N
}
// Method implements Provider.
func (webhook *Webhook) Method() string {
return webhook.Meth
}
// URL implements Provider.
func (webhook *Webhook) URL() string {
return webhook.U
}
// Token implements Provider.
func (webhook *Webhook) Token() string {
return webhook.Tok
}
// MIMEType implements Provider.
func (webhook *Webhook) MIMEType() string {
return webhook.MIMETyp
}
func (webhook *Webhook) ColorMode() string {
switch webhook.Template {
case "discord":
return "dec"
switch webhook.MIMEType {
case "":
webhook.MIMEType = "application/json"
case "application/json", "application/x-www-form-urlencoded", "text/plain":
default:
return webhook.ColorM
return E.New("invalid mime_type, expect empty, 'application/json', 'application/x-www-form-urlencoded' or 'text/plain'")
}
switch webhook.Template {
case "":
if webhook.MIMEType == "application/json" && !json.Valid([]byte(webhook.Payload)) {
return E.New("invalid payload, expect valid JSON")
}
if webhook.Payload == "" {
return E.New("invalid payload, expect non-empty")
}
case "discord":
webhook.ColorMode = "dec"
webhook.Method = http.MethodPost
webhook.MIMEType = "application/json"
if webhook.Payload == "" {
webhook.Payload = discordPayload
}
default:
return E.New("invalid template, expect empty or 'discord'")
}
switch webhook.Method {
case "":
webhook.Method = http.MethodPost
case http.MethodGet, http.MethodPost, http.MethodPut:
default:
return E.New("invalid method, expect empty, 'GET', 'POST' or 'PUT'")
}
switch webhook.ColorMode {
case "":
webhook.ColorMode = "hex"
case "hex", "dec":
default:
return E.New("invalid color_mode, expect empty, 'hex' or 'dec'")
}
return nil
}
// GetMethod implements Provider.
func (webhook *Webhook) GetMethod() string {
return webhook.Method
}
// GetMIMEType implements Provider.
func (webhook *Webhook) GetMIMEType() string {
return webhook.MIMEType
}
// makeRespError implements Provider.
@@ -108,7 +110,7 @@ func (webhook *Webhook) MakeBody(logMsg *LogMessage) (io.Reader, error) {
return nil, err
}
var color string
if webhook.ColorMode() == "hex" {
if webhook.ColorMode == "hex" {
color = logMsg.Color.HexString()
} else {
color = logMsg.Color.DecString()

View File

@@ -1,121 +0,0 @@
package notif
import (
"encoding/json"
"testing"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestWebhookValidation(t *testing.T) {
t.Parallel()
newWebhook := Providers[ProviderWebhook]
t.Run("valid", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "https://example.com",
"payload": "{}",
})
ExpectNoError(t, err)
})
t.Run("valid template", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "https://example.com",
"template": "discord",
})
ExpectNoError(t, err)
})
t.Run("missing url", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"payload": "{}",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("missing payload", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "https://example.com",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid url", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "example.com",
"payload": "{}",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid payload", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "https://example.com",
"payload": "abcd",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid method", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "https://example.com",
"payload": "{}",
"method": "abcd",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid template", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "https://example.com",
"template": "abcd",
})
ExpectError(t, utils.ErrValidationError, err)
})
}
func TestWebhookBody(t *testing.T) {
t.Parallel()
var webhook Webhook
webhook.Payload = discordPayload
bodyReader, err := webhook.MakeBody(&LogMessage{
Title: "abc",
Extras: map[string]any{
"foo": "bar",
},
})
ExpectNoError(t, err)
var body struct {
Embeds []struct {
Title string `json:"title"`
Fields []struct {
Name string `json:"name"`
Value string `json:"value"`
} `json:"fields"`
} `json:"embeds"`
}
err = json.NewDecoder(bodyReader).Decode(&body)
ExpectNoError(t, err)
ExpectEqual(t, body.Embeds[0].Title, "abc")
fields := body.Embeds[0].Fields
ExpectEqual(t, fields[0].Name, "foo")
ExpectEqual(t, fields[0].Value, "bar")
}