diff --git a/internal/config/config.go b/internal/config/config.go index 770045ad..21734307 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -187,14 +187,17 @@ func (cfg *Config) load() E.Error { return errs.Error() } -func (cfg *Config) initNotification(notifCfgMap types.NotificationConfigMap) (err E.Error) { - if len(notifCfgMap) == 0 { +func (cfg *Config) initNotification(notifCfg []types.NotificationConfig) (err E.Error) { + if len(notifCfg) == 0 { return } errs := E.NewBuilder("notification providers load errors") - for name, notifCfg := range notifCfgMap { - _, err := notif.RegisterProvider(cfg.task.Subtask(name), notifCfg) - errs.Add(err) + for i, notifier := range notifCfg { + _, err := notif.RegisterProvider(cfg.task.Subtask("notifier"), notifier) + if err == nil { + continue + } + errs.Add(err.Subjectf("[%d]", i)) } return errs.Error() } diff --git a/internal/config/types/config.go b/internal/config/types/config.go index 487cdf22..3a40687a 100644 --- a/internal/config/types/config.go +++ b/internal/config/types/config.go @@ -11,10 +11,11 @@ type ( RedirectToHTTPS bool `json:"redirect_to_https" yaml:"redirect_to_https"` } Providers struct { - Files []string `json:"include" yaml:"include"` - Docker map[string]string `json:"docker" yaml:"docker"` - Notification NotificationConfigMap `json:"notification" yaml:"notification"` + Files []string `json:"include" yaml:"include"` + Docker map[string]string `json:"docker" yaml:"docker"` + Notification []NotificationConfig `json:"notification" yaml:"notification"` } + NotificationConfig map[string]any ) func DefaultConfig() *Config { diff --git a/internal/config/types/notif_config.go b/internal/config/types/notif_config.go deleted file mode 100644 index e9214c79..00000000 --- a/internal/config/types/notif_config.go +++ /dev/null @@ -1,5 +0,0 @@ -package types - -import "github.com/yusing/go-proxy/internal/notif" - -type NotificationConfigMap map[string]notif.ProviderConfig diff --git a/internal/notif/color.go b/internal/notif/color.go new file mode 100644 index 00000000..d346759a --- /dev/null +++ b/internal/notif/color.go @@ -0,0 +1,23 @@ +package notif + +import "fmt" + +type Color uint + +const ( + Red Color = 0xff0000 + Green Color = 0x00ff00 + Blue Color = 0x0000ff +) + +func (c Color) HexString() string { + return fmt.Sprintf("#%x", c) +} + +func (c Color) DecString() string { + return fmt.Sprintf("%d", c) +} + +func (c Color) String() string { + return c.HexString() +} diff --git a/internal/notif/dispatcher.go b/internal/notif/dispatcher.go index c6d390d0..148b581f 100644 --- a/internal/notif/dispatcher.go +++ b/internal/notif/dispatcher.go @@ -2,6 +2,7 @@ package notif import ( "github.com/rs/zerolog" + "github.com/yusing/go-proxy/internal/config/types" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" @@ -17,14 +18,20 @@ type ( providers F.Set[Provider] } LogMessage struct { - Level zerolog.Level - Title, Message string + Level zerolog.Level + Title string + Extras map[string]any + Color Color } ) var dispatcher *Dispatcher -var ErrUnknownNotifProvider = E.New("unknown notification provider") +var ( + ErrMissingNotifProvider = E.New("missing notification provider") + ErrInvalidNotifProviderType = E.New("invalid notification provider type") + ErrUnknownNotifProvider = E.New("unknown notification provider") +) const dispatchErr = "notification dispatch error" @@ -45,23 +52,32 @@ func GetDispatcher() *Dispatcher { return dispatcher } -func RegisterProvider(configSubTask task.Task, cfg ProviderConfig) (Provider, error) { - name := configSubTask.Name() - createFunc, ok := Providers[name] +func RegisterProvider(configSubTask task.Task, cfg types.NotificationConfig) (Provider, E.Error) { + providerName, ok := cfg["provider"] if !ok { - return nil, ErrUnknownNotifProvider. - Subject(name). - Withf(strutils.DoYouMean(utils.NearestField(name, Providers))) + return nil, ErrMissingNotifProvider } + switch providerName := providerName.(type) { + case string: + delete(cfg, "provider") + createFunc, ok := Providers[providerName] + if !ok { + return nil, ErrUnknownNotifProvider. + Subject(providerName). + Withf(strutils.DoYouMean(utils.NearestField(providerName, Providers))) + } - provider, err := createFunc(cfg) - if err == nil { - dispatcher.providers.Add(provider) - configSubTask.OnCancel("remove provider", func() { - dispatcher.providers.Remove(provider) - }) + provider, err := createFunc(cfg) + if err == nil { + dispatcher.providers.Add(provider) + configSubTask.OnCancel("remove provider", func() { + dispatcher.providers.Remove(provider) + }) + } + return provider, err + default: + return nil, ErrInvalidNotifProviderType.Subjectf("%T", providerName) } - return provider, err } func (disp *Dispatcher) start() { @@ -83,14 +99,14 @@ func (disp *Dispatcher) dispatch(msg *LogMessage) { errs := E.NewBuilder(dispatchErr) disp.providers.RangeAllParallel(func(p Provider) { - if err := p.Send(task.Context(), msg); err != nil { + if err := notifyProvider(task.Context(), p, msg); err != nil { errs.Add(E.PrependSubject(p.Name(), err)) } }) if errs.HasError() { E.LogError(errs.About(), errs.Error()) } else { - logging.Debug().Str("title", msg.Title).Str("message", msg.Message).Msgf("dispatched notif") + logging.Debug().Str("title", msg.Title).Msgf("dispatched notif") } } @@ -108,10 +124,6 @@ func (disp *Dispatcher) dispatch(msg *LogMessage) { // } // } -func Notify(title, msg string) { - dispatcher.logCh <- &LogMessage{ - Level: zerolog.InfoLevel, - Title: title, - Message: msg, - } +func Notify(msg *LogMessage) { + dispatcher.logCh <- msg } diff --git a/internal/notif/format.go b/internal/notif/format.go new file mode 100644 index 00000000..f78e38a3 --- /dev/null +++ b/internal/notif/format.go @@ -0,0 +1,36 @@ +package notif + +import ( + "bytes" + "encoding/json" + "fmt" +) + +func formatMarkdown(extras map[string]interface{}) string { + msg := bytes.NewBufferString("") + for k, v := range extras { + msg.WriteString("#### ") + msg.WriteString(k) + msg.WriteRune('\n') + msg.WriteString(fmt.Sprintf("%v", v)) + msg.WriteRune('\n') + } + return msg.String() +} + +func formatDiscord(extras map[string]interface{}) (string, error) { + fieldsMap := make([]map[string]any, len(extras)) + i := 0 + for k, extra := range extras { + fieldsMap[i] = map[string]any{ + "name": k, + "value": extra, + } + i++ + } + fields, err := json.Marshal(fieldsMap) + if err != nil { + return "", err + } + return string(fields), nil +} diff --git a/internal/notif/gotify.go b/internal/notif/gotify.go index 0e5de668..fc846896 100644 --- a/internal/notif/gotify.go +++ b/internal/notif/gotify.go @@ -2,57 +2,53 @@ package notif import ( "bytes" - "context" "encoding/json" "fmt" + "io" "net/http" - "net/url" "github.com/gotify/server/v2/model" "github.com/rs/zerolog" - E "github.com/yusing/go-proxy/internal/error" - U "github.com/yusing/go-proxy/internal/utils" ) type ( GotifyClient struct { - GotifyConfig - - url *url.URL - http http.Client + N string `json:"name" validate:"required"` + U string `json:"url" validate:"url"` + Tok string `json:"token" validate:"required"` } - GotifyConfig struct { - URL string `json:"url" yaml:"url"` - Token string `json:"token" yaml:"token"` - } - GotifyMessage model.Message + GotifyMessage model.MessageExternal ) const gotifyMsgEndpoint = "/message" -func newGotifyClient(cfg map[string]any) (Provider, E.Error) { - client := new(GotifyClient) - err := U.Deserialize(cfg, &client.GotifyConfig) - if err != nil { - return nil, err - } - - url, uErr := url.Parse(client.URL) - if uErr != nil { - return nil, E.Errorf("invalid gotify URL %s", client.URL) - } - - client.url = url - return client, err -} - -// Name implements NotifProvider. +// Name implements Provider. func (client *GotifyClient) Name() string { - return "gotify" + return client.N } -// Send implements NotifProvider. -func (client *GotifyClient) Send(ctx context.Context, logMsg *LogMessage) error { +// Method implements Provider. +func (client *GotifyClient) Method() string { + return http.MethodPost +} + +// URL implements Provider. +func (client *GotifyClient) URL() string { + return client.U + gotifyMsgEndpoint +} + +// Token implements Provider. +func (client *GotifyClient) Token() string { + return client.Tok +} + +// MIMEType implements Provider. +func (client *GotifyClient) MIMEType() string { + return "application/json" +} + +// MakeBody implements Provider. +func (client *GotifyClient) MakeBody(logMsg *LogMessage) (io.Reader, error) { var priority int switch logMsg.Level { @@ -66,37 +62,29 @@ func (client *GotifyClient) Send(ctx context.Context, logMsg *LogMessage) error msg := &GotifyMessage{ Title: logMsg.Title, - Message: logMsg.Message, - Priority: priority, + Message: formatMarkdown(logMsg.Extras), + Priority: &priority, + Extras: map[string]interface{}{ + "client::display": map[string]string{ + "contentType": "text/markdown", + }, + }, } data, err := json.Marshal(msg) if err != nil { - return err + return nil, err } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, client.url.String()+gotifyMsgEndpoint, bytes.NewReader(data)) - if err != nil { - return fmt.Errorf("error creating request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+client.Token) - - resp, err := client.http.Do(req) - if err != nil { - return fmt.Errorf("failed to send gotify message: %w", err) - } - - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - var errm model.Error - err = json.NewDecoder(resp.Body).Decode(&errm) - if err != nil { - return fmt.Errorf("gotify status %d, but failed to decode err response: %w", resp.StatusCode, err) - } - return fmt.Errorf("gotify status %d %s: %s", resp.StatusCode, errm.Error, errm.ErrorDescription) - } - return nil + return bytes.NewReader(data), nil +} + +// makeRespError implements Provider. +func (client *GotifyClient) makeRespError(resp *http.Response) error { + var errm model.Error + err := json.NewDecoder(resp.Body).Decode(&errm) + if err != nil { + return fmt.Errorf(ProviderGotify+" status %d, but failed to decode err response: %w", resp.StatusCode, err) + } + return fmt.Errorf(ProviderGotify+" status %d %s: %s", resp.StatusCode, errm.Error, errm.ErrorDescription) } diff --git a/internal/notif/gotify_test.go b/internal/notif/gotify_test.go new file mode 100644 index 00000000..fda36edc --- /dev/null +++ b/internal/notif/gotify_test.go @@ -0,0 +1,52 @@ +package notif + +import ( + "testing" + + "github.com/yusing/go-proxy/internal/utils" + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestGotifyValidation(t *testing.T) { + t.Parallel() + + newGotify := Providers[ProviderGotify] + + t.Run("valid", func(t *testing.T) { + t.Parallel() + _, err := newGotify(map[string]any{ + "name": "test", + "url": "https://example.com", + "token": "token", + }) + ExpectNoError(t, err) + }) + + t.Run("missing url", func(t *testing.T) { + t.Parallel() + _, err := newGotify(map[string]any{ + "name": "test", + "token": "token", + }) + ExpectError(t, utils.ErrValidationError, err) + }) + + t.Run("missing token", func(t *testing.T) { + t.Parallel() + _, err := newGotify(map[string]any{ + "name": "test", + "url": "https://example.com", + }) + ExpectError(t, utils.ErrValidationError, err) + }) + + t.Run("invalid url", func(t *testing.T) { + t.Parallel() + _, err := newGotify(map[string]any{ + "name": "test", + "url": "example.com", + "token": "token", + }) + ExpectError(t, utils.ErrValidationError, err) + }) +} diff --git a/internal/notif/providers.go b/internal/notif/providers.go index 5ab6e9ae..c09c6d1c 100644 --- a/internal/notif/providers.go +++ b/internal/notif/providers.go @@ -2,19 +2,78 @@ package notif import ( "context" + "fmt" + "io" + "net/http" E "github.com/yusing/go-proxy/internal/error" + gphttp "github.com/yusing/go-proxy/internal/net/http" + U "github.com/yusing/go-proxy/internal/utils" ) type ( Provider interface { Name() string - Send(ctx context.Context, logMsg *LogMessage) error + URL() string + Method() string + Token() string + MIMEType() string + MakeBody(logMsg *LogMessage) (io.Reader, error) + + makeRespError(resp *http.Response) error } ProviderCreateFunc func(map[string]any) (Provider, E.Error) ProviderConfig map[string]any ) +const ( + ProviderGotify = "gotify" + ProviderWebhook = "webhook" +) + var Providers = map[string]ProviderCreateFunc{ - "gotify": newGotifyClient, + ProviderGotify: newNotifProvider[*GotifyClient], + ProviderWebhook: newNotifProvider[*Webhook], +} + +func newNotifProvider[T Provider](cfg map[string]any) (Provider, E.Error) { + var client T + err := U.Deserialize(cfg, &client) + if err != nil { + return nil, err.Subject(client.Name()) + } + return client, nil +} + +func notifyProvider(ctx context.Context, provider Provider, msg *LogMessage) error { + body, err := provider.MakeBody(msg) + if err != nil { + return fmt.Errorf("%s error: %w", provider.Name(), err) + } + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + provider.URL(), + body, + ) + if err != nil { + return fmt.Errorf("%s error: %w", provider.Name(), err) + } + + req.Header.Set("Content-Type", provider.MIMEType()) + if provider.Token() != "" { + req.Header.Set("Authorization", "Bearer "+provider.Token()) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("%s error: %w", provider.Name(), err) + } + + defer resp.Body.Close() + + if !gphttp.IsSuccess(resp.StatusCode) { + return provider.makeRespError(resp) + } + return nil } diff --git a/internal/notif/templates/discord.json b/internal/notif/templates/discord.json new file mode 100644 index 00000000..3caccb5a --- /dev/null +++ b/internal/notif/templates/discord.json @@ -0,0 +1,9 @@ +{ + "embeds": [ + { + "title": $title, + "fields": $fields, + "color": "$color" + } + ] +} \ No newline at end of file diff --git a/internal/notif/webhook.go b/internal/notif/webhook.go new file mode 100644 index 00000000..18a1c5a5 --- /dev/null +++ b/internal/notif/webhook.go @@ -0,0 +1,133 @@ +package notif + +import ( + _ "embed" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/go-playground/validator/v10" + "github.com/yusing/go-proxy/internal/utils" +) + +type Webhook struct { + N string `json:"name" validate:"required"` + U string `json:"url" validate:"url"` + Template string `json:"template" validate:"omitempty,oneof=discord"` + Payload string `json:"payload" validate:"jsonIfTemplateNotUsed"` + Tok string `json:"token"` + Meth string `json:"method" validate:"omitempty,oneof=GET POST PUT"` + MIMETyp string `json:"mime_type"` + ColorM string `json:"color_mode" validate:"omitempty,oneof=hex dec"` +} + +//go:embed templates/discord.json +var discordPayload string + +var webhookTemplates = map[string]string{ + "discord": discordPayload, +} + +func jsonIfTemplateNotUsed(fl validator.FieldLevel) bool { + template := fl.Parent().FieldByName("Template").String() + if template != "" { + return true + } + payload := fl.Field().String() + return json.Valid([]byte(payload)) +} + +func init() { + utils.Validator().RegisterValidation("jsonIfTemplateNotUsed", jsonIfTemplateNotUsed) +} + +// Name implements Provider. +func (webhook *Webhook) Name() string { + return webhook.N +} + +// Method implements Provider. +func (webhook *Webhook) Method() string { + if webhook.Meth != "" { + return webhook.Meth + } else { + return http.MethodPost + } +} + +// URL implements Provider. +func (webhook *Webhook) URL() string { + return webhook.U +} + +// Token implements Provider. +func (webhook *Webhook) Token() string { + return webhook.Tok +} + +// MIMEType implements Provider. +func (webhook *Webhook) MIMEType() string { + if webhook.MIMETyp != "" { + return webhook.MIMETyp + } else { + return "application/json" + } +} + +func (Webhook *Webhook) ColorMode() string { + switch Webhook.Template { + case "discord": + return "dec" + default: + return Webhook.ColorM + } +} + +// makeRespError implements Provider. +func (webhook *Webhook) makeRespError(resp *http.Response) error { + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("webhook status %d, failed to read body: %w", resp.StatusCode, err) + } + if len(body) > 0 { + return fmt.Errorf("webhook status %d: %s", resp.StatusCode, body) + } + return fmt.Errorf("webhook status %d", resp.StatusCode) +} + +func (webhook *Webhook) MakeBody(logMsg *LogMessage) (io.Reader, error) { + title, err := json.Marshal(logMsg.Title) + if err != nil { + return nil, err + } + fields, err := formatDiscord(logMsg.Extras) + if err != nil { + return nil, err + } + var color string + if webhook.ColorMode() == "hex" { + color = logMsg.Color.HexString() + } else { + color = logMsg.Color.DecString() + } + message, err := json.Marshal(formatMarkdown(logMsg.Extras)) + if err != nil { + return nil, err + } + plTempl := strings.NewReplacer( + "$title", string(title), + "$message", string(message), + "$fields", string(fields), + "$color", color, + ) + var pl string + if webhook.Template != "" { + pl = webhookTemplates[webhook.Template] + } else { + pl = webhook.Payload + } + pl = plTempl.Replace(pl) + return strings.NewReader(pl), nil +} diff --git a/internal/notif/webhook_test.go b/internal/notif/webhook_test.go new file mode 100644 index 00000000..85fc5884 --- /dev/null +++ b/internal/notif/webhook_test.go @@ -0,0 +1,112 @@ +package notif + +import ( + "encoding/json" + "testing" + + "github.com/yusing/go-proxy/internal/utils" + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestWebhookValidation(t *testing.T) { + t.Parallel() + + newWebhook := Providers[ProviderWebhook] + + t.Run("valid", func(t *testing.T) { + t.Parallel() + _, err := newWebhook(map[string]any{ + "name": "test", + "url": "https://example.com", + "payload": "{}", + }) + ExpectNoError(t, err) + }) + t.Run("valid template", func(t *testing.T) { + t.Parallel() + _, err := newWebhook(map[string]any{ + "name": "test", + "url": "https://example.com", + "template": "discord", + }) + ExpectNoError(t, err) + }) + + t.Run("missing url", func(t *testing.T) { + t.Parallel() + _, err := newWebhook(map[string]any{ + "name": "test", + "payload": "{}", + }) + ExpectError(t, utils.ErrValidationError, err) + }) + + t.Run("missing payload", func(t *testing.T) { + t.Parallel() + _, err := newWebhook(map[string]any{ + "name": "test", + "url": "https://example.com", + }) + ExpectError(t, utils.ErrValidationError, err) + }) + t.Run("invalid url", func(t *testing.T) { + t.Parallel() + _, err := newWebhook(map[string]any{ + "name": "test", + "url": "example.com", + "payload": "{}", + }) + ExpectError(t, utils.ErrValidationError, err) + }) + t.Run("invalid payload", func(t *testing.T) { + t.Parallel() + _, err := newWebhook(map[string]any{ + "name": "test", + "url": "https://example.com", + "payload": "abcd", + }) + ExpectError(t, utils.ErrValidationError, err) + }) + t.Run("invalid method", func(t *testing.T) { + t.Parallel() + _, err := newWebhook(map[string]any{ + "name": "test", + "url": "https://example.com", + "payload": "{}", + "method": "abcd", + }) + ExpectError(t, utils.ErrValidationError, err) + }) + t.Run("invalid template", func(t *testing.T) { + t.Parallel() + _, err := newWebhook(map[string]any{ + "name": "test", + "url": "https://example.com", + "template": "abcd", + }) + ExpectError(t, utils.ErrValidationError, err) + }) +} + +func TestWebhookBody(t *testing.T) { + t.Parallel() + + var webhook Webhook + webhook.Payload = discordPayload + bodyReader, err := webhook.MakeBody(&LogMessage{ + Title: "abc", + Extras: map[string]any{ + "foo": "bar", + }, + }) + ExpectNoError(t, err) + + var body map[string][]map[string]any + err = json.NewDecoder(bodyReader).Decode(&body) + ExpectNoError(t, err) + + ExpectEqual(t, body["embeds"][0]["title"], "abc") + fields := ExpectType[[]map[string]any](t, body["embeds"][0]["fields"]) + ExpectEqual(t, fields[0]["name"], "foo") + ExpectEqual(t, fields[0]["value"], "bar") +}