mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-19 15:31:37 +02:00
v0.26.0
This commit is contained in:
@@ -30,9 +30,11 @@ Internal package with stable core types. Route configuration schema is versioned
|
||||
type Route struct {
|
||||
Alias string // Unique route identifier
|
||||
Scheme Scheme // http, https, h2c, tcp, udp, fileserver
|
||||
Host string // Virtual host
|
||||
Host string // Target host
|
||||
Port Port // Listen and target ports
|
||||
|
||||
Bind string // Bind address for listening (IP address, optional)
|
||||
|
||||
// File serving
|
||||
Root string // Document root
|
||||
SPA bool // Single-page app mode
|
||||
@@ -91,8 +93,8 @@ const (
|
||||
|
||||
```go
|
||||
// Validation and lifecycle
|
||||
func (r *Route) Validate() gperr.Error
|
||||
func (r *Route) Start(parent task.Parent) gperr.Error
|
||||
func (r *Route) Validate() error
|
||||
func (r *Route) Start(parent task.Parent) error
|
||||
func (r *Route) Finish(reason any)
|
||||
func (r *Route) Started() <-chan struct{}
|
||||
|
||||
@@ -117,8 +119,8 @@ func (r *Route) UseHealthCheck() bool
|
||||
```mermaid
|
||||
classDiagram
|
||||
class Route {
|
||||
+Validate() gperr.Error
|
||||
+Start(parent) gperr.Error
|
||||
+Validate() error
|
||||
+Start(parent) error
|
||||
+Finish(reason)
|
||||
+Started() <-chan struct#123;#125;
|
||||
}
|
||||
@@ -196,6 +198,7 @@ type Route struct {
|
||||
Alias string `json:"alias"`
|
||||
Scheme Scheme `json:"scheme"`
|
||||
Host string `json:"host,omitempty"`
|
||||
Bind string `json:"bind,omitempty"` // Listen bind address
|
||||
Port Port `json:"port"`
|
||||
Root string `json:"root,omitempty"`
|
||||
SPA bool `json:"spa,omitempty"`
|
||||
@@ -218,23 +221,26 @@ labels:
|
||||
routes:
|
||||
myapp:
|
||||
scheme: http
|
||||
root: /var/www/myapp
|
||||
spa: true
|
||||
host: myapp.local
|
||||
bind: 192.168.1.100 # Optional: bind to specific address
|
||||
port:
|
||||
proxy: 80
|
||||
target: 3000
|
||||
```
|
||||
|
||||
## Dependency and Integration Map
|
||||
|
||||
| Dependency | Purpose |
|
||||
| -------------------------------- | -------------------------------- |
|
||||
| `internal/route/routes` | Route registry and lookup |
|
||||
| `internal/route/rules` | Request/response rule processing |
|
||||
| `internal/route/stream` | TCP/UDP stream proxying |
|
||||
| `internal/route/provider` | Route discovery and loading |
|
||||
| `internal/health/monitor` | Health checking |
|
||||
| `internal/idlewatcher` | Idle container management |
|
||||
| `internal/logging/accesslog` | Request logging |
|
||||
| `internal/homepage` | Dashboard integration |
|
||||
| `github.com/yusing/goutils/errs` | Error handling |
|
||||
| Dependency | Purpose |
|
||||
| ---------------------------------- | --------------------------------- |
|
||||
| `internal/route/routes/context.go` | Route context helpers (only file) |
|
||||
| `internal/route/rules` | Request/response rule processing |
|
||||
| `internal/route/stream` | TCP/UDP stream proxying |
|
||||
| `internal/route/provider` | Route discovery and loading |
|
||||
| `internal/health/monitor` | Health checking |
|
||||
| `internal/idlewatcher` | Idle container management |
|
||||
| `internal/logging/accesslog` | Request logging |
|
||||
| `internal/homepage` | Dashboard integration |
|
||||
| `github.com/yusing/goutils/errs` | Error handling |
|
||||
|
||||
## Observability
|
||||
|
||||
@@ -305,6 +311,18 @@ route := &route.Route{
|
||||
}
|
||||
```
|
||||
|
||||
### Route with Custom Bind Address
|
||||
|
||||
```go
|
||||
route := &route.Route{
|
||||
Alias: "myapp",
|
||||
Scheme: route.SchemeHTTP,
|
||||
Host: "myapp.local",
|
||||
Bind: "192.168.1.100", // Bind to specific interface
|
||||
Port: route.Port{Listening: 8443, Proxy: 80},
|
||||
}
|
||||
```
|
||||
|
||||
### File Server Route
|
||||
|
||||
```go
|
||||
|
||||
@@ -1,27 +1,36 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"github.com/yusing/godoxy/internal/route/routes"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
|
||||
"github.com/yusing/godoxy/internal/types"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
func checkExists(r types.Route) gperr.Error {
|
||||
// checkExists checks if the route already exists in the entrypoint.
|
||||
//
|
||||
// Context must be passed from the parent task that carries the entrypoint value.
|
||||
func checkExists(ctx context.Context, r types.Route) error {
|
||||
if r.UseLoadBalance() { // skip checking for load balanced routes
|
||||
return nil
|
||||
}
|
||||
ep := entrypoint.FromCtx(ctx)
|
||||
if ep == nil {
|
||||
return fmt.Errorf("entrypoint not found in context")
|
||||
}
|
||||
var (
|
||||
existing types.Route
|
||||
ok bool
|
||||
)
|
||||
switch r := r.(type) {
|
||||
case types.HTTPRoute:
|
||||
existing, ok = routes.HTTP.Get(r.Key())
|
||||
existing, ok = entrypoint.FromCtx(ctx).HTTPRoutes().Get(r.Key())
|
||||
case types.StreamRoute:
|
||||
existing, ok = routes.Stream.Get(r.Key())
|
||||
existing, ok = entrypoint.FromCtx(ctx).StreamRoutes().Get(r.Key())
|
||||
}
|
||||
if ok {
|
||||
return gperr.Errorf("route already exists: from provider %s and %s", existing.ProviderName(), r.ProviderName())
|
||||
return fmt.Errorf("route already exists: from provider %s and %s", existing.ProviderName(), r.ProviderName())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
|
||||
config "github.com/yusing/godoxy/internal/config/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
|
||||
"github.com/yusing/godoxy/internal/health/monitor"
|
||||
"github.com/yusing/godoxy/internal/logging/accesslog"
|
||||
gphttp "github.com/yusing/godoxy/internal/net/gphttp"
|
||||
"github.com/yusing/godoxy/internal/net/gphttp/middleware"
|
||||
"github.com/yusing/godoxy/internal/route/routes"
|
||||
"github.com/yusing/godoxy/internal/types"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
"github.com/yusing/goutils/task"
|
||||
@@ -50,12 +51,12 @@ func handler(root string, spa bool, index string) http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
func NewFileServer(base *Route) (*FileServer, gperr.Error) {
|
||||
func NewFileServer(base *Route) (*FileServer, error) {
|
||||
s := &FileServer{Route: base}
|
||||
|
||||
s.Root = filepath.Clean(s.Root)
|
||||
if !path.IsAbs(s.Root) {
|
||||
return nil, gperr.New("`root` must be an absolute path")
|
||||
if !filepath.IsAbs(s.Root) {
|
||||
return nil, errors.New("`root` must be an absolute path")
|
||||
}
|
||||
|
||||
if s.Index == "" {
|
||||
@@ -77,8 +78,9 @@ func NewFileServer(base *Route) (*FileServer, gperr.Error) {
|
||||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (s *FileServer) Start(parent task.Parent) gperr.Error {
|
||||
func (s *FileServer) Start(parent task.Parent) error {
|
||||
s.task = parent.Subtask("fileserver."+s.Name(), false)
|
||||
s.task.SetValue(monitor.DisplayNameKey{}, s.DisplayName())
|
||||
|
||||
pathPatterns := s.PathPatterns
|
||||
switch {
|
||||
@@ -109,7 +111,7 @@ func (s *FileServer) Start(parent task.Parent) gperr.Error {
|
||||
s.accessLogger, err = accesslog.NewAccessLogger(s.task, s.AccessLog)
|
||||
if err != nil {
|
||||
s.task.Finish(err)
|
||||
return gperr.Wrap(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,20 +122,22 @@ func (s *FileServer) Start(parent task.Parent) gperr.Error {
|
||||
if s.UseHealthCheck() {
|
||||
s.HealthMon = monitor.NewMonitor(s)
|
||||
if err := s.HealthMon.Start(s.task); err != nil {
|
||||
return err
|
||||
log.Warn().EmbedObject(s).Err(err).Msg("health monitor error")
|
||||
s.HealthMon = nil
|
||||
}
|
||||
}
|
||||
|
||||
routes.HTTP.Add(s)
|
||||
if state := config.WorkingState.Load(); state != nil {
|
||||
state.ShortLinkMatcher().AddRoute(s.Alias)
|
||||
ep := entrypoint.FromCtx(parent.Context())
|
||||
if ep == nil {
|
||||
err := errors.New("entrypoint not initialized")
|
||||
s.task.Finish(err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ep.StartAddRoute(s); err != nil {
|
||||
s.task.Finish(err)
|
||||
return err
|
||||
}
|
||||
s.task.OnFinished("remove_route_from_http", func() {
|
||||
routes.HTTP.Del(s)
|
||||
if state := config.WorkingState.Load(); state != nil {
|
||||
state.ShortLinkMatcher().DelRoute(s.Alias)
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ type ProviderImpl interface {
|
||||
fmt.Stringer
|
||||
ShortName() string
|
||||
IsExplicitOnly() bool
|
||||
loadRoutesImpl() (route.Routes, gperr.Error)
|
||||
loadRoutesImpl() (route.Routes, error)
|
||||
NewWatcher() W.Watcher
|
||||
Logger() *zerolog.Logger
|
||||
}
|
||||
@@ -62,8 +62,8 @@ func NewAgentProvider(cfg *agent.AgentConfig) *Provider
|
||||
|
||||
```go
|
||||
func (p *Provider) GetType() provider.Type
|
||||
func (p *Provider) Start(parent task.Parent) gperr.Error
|
||||
func (p *Provider) LoadRoutes() gperr.Error
|
||||
func (p *Provider) Start(parent task.Parent) error
|
||||
func (p *Provider) LoadRoutes() error
|
||||
func (p *Provider) IterRoutes(yield func(string, types.Route) bool)
|
||||
func (p *Provider) GetRoute(alias string) (types.Route, bool)
|
||||
func (p *Provider) FindService(project, service string) (types.Route, bool)
|
||||
@@ -80,8 +80,8 @@ classDiagram
|
||||
+t provider.Type
|
||||
+routes route.Routes
|
||||
+watcher W.Watcher
|
||||
+Start(parent) gperr.Error
|
||||
+LoadRoutes() gperr.Error
|
||||
+Start(parent) error
|
||||
+LoadRoutes() error
|
||||
+IterRoutes(yield)
|
||||
}
|
||||
|
||||
@@ -90,7 +90,7 @@ classDiagram
|
||||
+String() string
|
||||
+ShortName() string
|
||||
+IsExplicitOnly() bool
|
||||
+loadRoutesImpl() (route.Routes, gperr.Error)
|
||||
+loadRoutesImpl() (route.Routes, error)
|
||||
+NewWatcher() W.Watcher
|
||||
+Logger() *zerolog.Logger
|
||||
}
|
||||
@@ -99,20 +99,20 @@ classDiagram
|
||||
+name string
|
||||
+dockerCfg types.DockerProviderConfig
|
||||
+ShortName() string
|
||||
+loadRoutesImpl() (route.Routes, gperr.Error)
|
||||
+loadRoutesImpl() (route.Routes, error)
|
||||
}
|
||||
|
||||
class FileProviderImpl {
|
||||
+filename string
|
||||
+ShortName() string
|
||||
+loadRoutesImpl() (route.Routes, gperr.Error)
|
||||
+loadRoutesImpl() (route.Routes, error)
|
||||
}
|
||||
|
||||
class AgentProviderImpl {
|
||||
+*agent.AgentConfig
|
||||
+docker DockerProviderImpl
|
||||
+ShortName() string
|
||||
+loadRoutesImpl() (route.Routes, gperr.Error)
|
||||
+loadRoutesImpl() (route.Routes, error)
|
||||
}
|
||||
|
||||
Provider --> ProviderImpl : wraps
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"github.com/yusing/godoxy/agent/pkg/agent"
|
||||
"github.com/yusing/godoxy/internal/route"
|
||||
"github.com/yusing/godoxy/internal/watcher"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
type AgentProvider struct {
|
||||
@@ -25,7 +24,7 @@ func (p *AgentProvider) IsExplicitOnly() bool {
|
||||
return p.docker.IsExplicitOnly()
|
||||
}
|
||||
|
||||
func (p *AgentProvider) loadRoutesImpl() (route.Routes, gperr.Error) {
|
||||
func (p *AgentProvider) loadRoutesImpl() (route.Routes, error) {
|
||||
return p.docker.loadRoutesImpl()
|
||||
}
|
||||
|
||||
|
||||
@@ -2,18 +2,44 @@ example: # matching `example.y.z`
|
||||
scheme: http
|
||||
host: 10.0.0.254
|
||||
port: 80
|
||||
bind: 0.0.0.0
|
||||
root: /var/www/example
|
||||
spa: true
|
||||
index: index.html
|
||||
no_tls_verify: true
|
||||
disable_compression: false
|
||||
response_header_timeout: 30s
|
||||
ssl_server_name: "" # empty uses target hostname, "off" disables SNI
|
||||
ssl_trusted_certificate: /etc/ssl/certs/ca-certificates.crt
|
||||
ssl_certificate: /etc/ssl/client.crt
|
||||
ssl_certificate_key: /etc/ssl/client.key
|
||||
ssl_protocols:
|
||||
- tlsv1.2
|
||||
- tlsv1.3
|
||||
path_patterns: # Check https://pkg.go.dev/net/http#hdr-Patterns-ServeMux for syntax
|
||||
- GET / # accept any GET request
|
||||
- POST /auth # for /auth and /auth/* accept only POST
|
||||
- GET /home/{$} # for exactly /home
|
||||
rules:
|
||||
- name: default
|
||||
do: pass
|
||||
- name: block-admin
|
||||
on: path /admin
|
||||
do: error 403 Forbidden
|
||||
rule_file: embed://webui.yml
|
||||
healthcheck:
|
||||
disabled: false
|
||||
use_get: true
|
||||
path: /
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: -1 # -1: immediate fail, 0: use default, >0: retry count
|
||||
load_balance:
|
||||
link: app
|
||||
mode: ip_hash
|
||||
link: app # link to another route alias
|
||||
mode: roundrobin # roundrobin, leastconn, iphash
|
||||
weight: 1
|
||||
sticky: false
|
||||
sticky_max_age: 1h
|
||||
options:
|
||||
header: X-Forwarded-For
|
||||
middlewares:
|
||||
@@ -23,15 +49,19 @@ example: # matching `example.y.z`
|
||||
- 10.0.0.0/8
|
||||
status_code: 403
|
||||
message: IP not allowed
|
||||
hideXForwarded:
|
||||
homepage:
|
||||
show: true
|
||||
name: Example App
|
||||
icon: "@selfhst/adguard-home.png"
|
||||
description: An example app
|
||||
category: example
|
||||
access_log:
|
||||
buffer_size: 100
|
||||
path: /var/log/example.log
|
||||
stdout: false
|
||||
retention:
|
||||
days: 30
|
||||
rotate_interval: 24h
|
||||
format: combined # common, combined, json
|
||||
filters:
|
||||
status_codes:
|
||||
values:
|
||||
@@ -53,14 +83,29 @@ example: # matching `example.y.z`
|
||||
- 192.168.10.0/24
|
||||
fields:
|
||||
headers:
|
||||
default: keep
|
||||
config:
|
||||
foo: redact
|
||||
query:
|
||||
default: drop
|
||||
config:
|
||||
foo: keep
|
||||
cookies:
|
||||
default: redact
|
||||
foo: redact
|
||||
authorization: drop
|
||||
query:
|
||||
default: keep
|
||||
config:
|
||||
foo: keep
|
||||
password: redact
|
||||
cookies:
|
||||
default: drop
|
||||
config:
|
||||
session: keep
|
||||
idlewatcher:
|
||||
idle_timeout: 30m
|
||||
wake_timeout: 30s
|
||||
stop_timeout: 1m
|
||||
stop_method: stop # pause, stop, kill
|
||||
stop_signal: SIGTERM
|
||||
start_endpoint: /api/wake
|
||||
depends_on:
|
||||
- other-service
|
||||
no_loading_page: false
|
||||
docker:
|
||||
container_id: abc123
|
||||
container_name: example-app
|
||||
|
||||
@@ -58,13 +58,13 @@ func (p *DockerProvider) NewWatcher() watcher.Watcher {
|
||||
return watcher.NewDockerWatcher(p.dockerCfg)
|
||||
}
|
||||
|
||||
func (p *DockerProvider) loadRoutesImpl() (route.Routes, gperr.Error) {
|
||||
func (p *DockerProvider) loadRoutesImpl() (route.Routes, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
containers, err := docker.ListContainers(ctx, p.dockerCfg)
|
||||
if err != nil {
|
||||
return nil, gperr.Wrap(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
errs := gperr.NewBuilder("")
|
||||
@@ -74,21 +74,21 @@ func (p *DockerProvider) loadRoutesImpl() (route.Routes, gperr.Error) {
|
||||
container := docker.FromDocker(&c, p.dockerCfg)
|
||||
|
||||
if container.Errors != nil {
|
||||
errs.Add(gperr.PrependSubject(container.ContainerName, container.Errors))
|
||||
errs.AddSubject(container.Errors, container.ContainerName)
|
||||
continue
|
||||
}
|
||||
|
||||
if container.IsHostNetworkMode {
|
||||
err := docker.UpdatePorts(ctx, container)
|
||||
if err != nil {
|
||||
errs.Add(gperr.PrependSubject(container.ContainerName, err))
|
||||
errs.AddSubject(err, container.ContainerName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
newEntries, err := p.routesFromContainerLabels(container)
|
||||
if err != nil {
|
||||
errs.Add(err.Subject(container.ContainerName))
|
||||
errs.AddSubject(err, container.ContainerName)
|
||||
}
|
||||
for k, v := range newEntries {
|
||||
if conflict, ok := routes[k]; ok {
|
||||
@@ -97,7 +97,7 @@ func (p *DockerProvider) loadRoutesImpl() (route.Routes, gperr.Error) {
|
||||
Addf("container %s", container.ContainerName).
|
||||
Addf("conflicting container %s", conflict.Container.ContainerName)
|
||||
if conflict.ShouldExclude() || v.ShouldExclude() {
|
||||
gperr.LogWarn("skipping conflicting route", err)
|
||||
log.Warn().Err(err).Msg("skipping conflicting route")
|
||||
} else {
|
||||
errs.Add(err)
|
||||
}
|
||||
@@ -112,7 +112,7 @@ func (p *DockerProvider) loadRoutesImpl() (route.Routes, gperr.Error) {
|
||||
|
||||
// Returns a list of proxy entries for a container.
|
||||
// Always non-nil.
|
||||
func (p *DockerProvider) routesFromContainerLabels(container *types.Container) (route.Routes, gperr.Error) {
|
||||
func (p *DockerProvider) routesFromContainerLabels(container *types.Container) (route.Routes, error) {
|
||||
if !container.IsExplicit && p.IsExplicitOnly() {
|
||||
return make(route.Routes, 0), nil
|
||||
}
|
||||
@@ -150,7 +150,7 @@ func (p *DockerProvider) routesFromContainerLabels(container *types.Container) (
|
||||
panic(fmt.Errorf("invalid entry map type %T", entryMapAny))
|
||||
}
|
||||
if err := yaml.Unmarshal([]byte(yamlStr), &entryMap); err != nil {
|
||||
errs.Add(gperr.Wrap(err).Subject(alias))
|
||||
errs.AddSubject(err, alias)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -185,7 +185,7 @@ func (p *DockerProvider) routesFromContainerLabels(container *types.Container) (
|
||||
// deserialize map into entry object
|
||||
err := serialization.MapUnmarshalValidate(entryMap, r)
|
||||
if err != nil {
|
||||
errs.Add(err.Subject(alias))
|
||||
errs.AddSubject(err, alias)
|
||||
} else {
|
||||
routes[alias] = r
|
||||
}
|
||||
|
||||
@@ -31,7 +31,6 @@ proxy.app: |
|
||||
description: An example app
|
||||
category: example
|
||||
access_log:
|
||||
buffer_size: 100
|
||||
path: /var/log/example.log
|
||||
filters:
|
||||
status_codes:
|
||||
@@ -92,7 +91,6 @@ proxy.app1.homepage.name: Example App
|
||||
proxy.app1.homepage.icon: "@selfhst/adguard-home.png"
|
||||
proxy.app1.homepage.description: An example app
|
||||
proxy.app1.homepage.category: example
|
||||
proxy.app1.access_log.buffer_size: 100
|
||||
proxy.app1.access_log.path: /var/log/example.log
|
||||
proxy.app1.access_log.filters: |
|
||||
status_codes:
|
||||
|
||||
@@ -81,7 +81,7 @@ func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool
|
||||
func (handler *EventHandler) Add(parent task.Parent, route *route.Route) {
|
||||
err := handler.provider.startRoute(parent, route)
|
||||
if err != nil {
|
||||
handler.errs.Add(err.Subject("add"))
|
||||
handler.errs.AddSubjectf(err, "add")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,12 +93,12 @@ func (handler *EventHandler) Update(parent task.Parent, oldRoute *route.Route, n
|
||||
oldRoute.FinishAndWait("route update")
|
||||
err := handler.provider.startRoute(parent, newRoute)
|
||||
if err != nil {
|
||||
handler.errs.Add(err.Subject("update"))
|
||||
handler.errs.AddSubjectf(err, "update")
|
||||
}
|
||||
}
|
||||
|
||||
func (handler *EventHandler) Log() {
|
||||
if err := handler.errs.Error(); err != nil {
|
||||
handler.provider.Logger().Info().Msg(err.Error())
|
||||
handler.provider.Logger().Error().Msg(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/yusing/godoxy/internal/route"
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
W "github.com/yusing/godoxy/internal/watcher"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
type FileProvider struct {
|
||||
@@ -34,7 +33,7 @@ func FileProviderImpl(filename string) (ProviderImpl, error) {
|
||||
return impl, nil
|
||||
}
|
||||
|
||||
func removeXPrefix(m map[string]any) gperr.Error {
|
||||
func removeXPrefix(m map[string]any) error {
|
||||
for alias := range m {
|
||||
if strings.HasPrefix(alias, "x-") {
|
||||
delete(m, alias)
|
||||
@@ -43,12 +42,12 @@ func removeXPrefix(m map[string]any) gperr.Error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func validate(data []byte) (routes route.Routes, err gperr.Error) {
|
||||
func validate(data []byte) (routes route.Routes, err error) {
|
||||
err = serialization.UnmarshalValidate(data, &routes, yaml.Unmarshal, removeXPrefix)
|
||||
return routes, err
|
||||
}
|
||||
|
||||
func Validate(data []byte) (err gperr.Error) {
|
||||
func Validate(data []byte) (err error) {
|
||||
_, err = validate(data)
|
||||
return err
|
||||
}
|
||||
@@ -69,16 +68,16 @@ func (p *FileProvider) Logger() *zerolog.Logger {
|
||||
return &p.l
|
||||
}
|
||||
|
||||
func (p *FileProvider) loadRoutesImpl() (route.Routes, gperr.Error) {
|
||||
func (p *FileProvider) loadRoutesImpl() (route.Routes, error) {
|
||||
data, err := os.ReadFile(p.path)
|
||||
if err != nil {
|
||||
return nil, gperr.Wrap(err)
|
||||
return nil, err
|
||||
}
|
||||
routes, err := validate(data)
|
||||
if err != nil && len(routes) == 0 {
|
||||
return nil, gperr.Wrap(err)
|
||||
return nil, err
|
||||
}
|
||||
return routes, gperr.Wrap(err)
|
||||
return routes, err
|
||||
}
|
||||
|
||||
func (p *FileProvider) NewWatcher() W.Watcher {
|
||||
|
||||
@@ -15,8 +15,10 @@ import (
|
||||
provider "github.com/yusing/godoxy/internal/route/provider/types"
|
||||
"github.com/yusing/godoxy/internal/types"
|
||||
W "github.com/yusing/godoxy/internal/watcher"
|
||||
"github.com/yusing/godoxy/internal/watcher/events"
|
||||
watcherEvents "github.com/yusing/godoxy/internal/watcher/events"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
"github.com/yusing/goutils/eventqueue"
|
||||
"github.com/yusing/goutils/events"
|
||||
"github.com/yusing/goutils/task"
|
||||
)
|
||||
|
||||
@@ -34,7 +36,7 @@ type (
|
||||
fmt.Stringer
|
||||
ShortName() string
|
||||
IsExplicitOnly() bool
|
||||
loadRoutesImpl() (route.Routes, gperr.Error)
|
||||
loadRoutesImpl() (route.Routes, error)
|
||||
NewWatcher() W.Watcher
|
||||
Logger() *zerolog.Logger
|
||||
}
|
||||
@@ -90,13 +92,13 @@ func (p *Provider) GetType() provider.Type {
|
||||
return p.t
|
||||
}
|
||||
|
||||
// to work with json marshaller.
|
||||
// MarshalText implements encoding.TextMarshaler.
|
||||
func (p *Provider) MarshalText() ([]byte, error) {
|
||||
return []byte(p.String()), nil
|
||||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (p *Provider) Start(parent task.Parent) gperr.Error {
|
||||
func (p *Provider) Start(parent task.Parent) error {
|
||||
errs := gperr.NewGroup("routes error")
|
||||
|
||||
t := parent.Subtask("provider."+p.String(), false)
|
||||
@@ -115,19 +117,29 @@ func (p *Provider) Start(parent task.Parent) gperr.Error {
|
||||
|
||||
err := errs.Wait().Error()
|
||||
|
||||
eventQueue := events.NewEventQueue(
|
||||
t.Subtask("event_queue", false),
|
||||
providerEventFlushInterval,
|
||||
func(events []events.Event) {
|
||||
opts := eventqueue.Options[watcherEvents.Event]{
|
||||
FlushInterval: providerEventFlushInterval,
|
||||
OnFlush: func(evs []watcherEvents.Event) {
|
||||
handler := p.newEventHandler()
|
||||
// routes' lifetime should follow the provider's lifetime
|
||||
handler.Handle(t, events)
|
||||
handler.Handle(t, evs)
|
||||
handler.Log()
|
||||
|
||||
globalEvents := make([]events.Event, len(evs))
|
||||
for i, ev := range evs {
|
||||
globalEvents[i] = events.NewEvent(events.LevelInfo, "provider_event", ev.Action.String(), map[string]any{
|
||||
"provider": p.String(),
|
||||
"type": ev.Type, // file / docker
|
||||
"actor": ev.ActorName, // file path / container name
|
||||
})
|
||||
}
|
||||
events.Global.AddAll(globalEvents)
|
||||
},
|
||||
func(err gperr.Error) {
|
||||
gperr.LogError("event error", err, p.Logger())
|
||||
OnError: func(err error) {
|
||||
p.Logger().Err(err).Msg("event error")
|
||||
},
|
||||
)
|
||||
}
|
||||
eventQueue := eventqueue.New(t.Subtask("event_queue", false), opts)
|
||||
eventQueue.Start(p.watcher.Events(t.Context()))
|
||||
|
||||
if err != nil {
|
||||
@@ -136,7 +148,7 @@ func (p *Provider) Start(parent task.Parent) gperr.Error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) LoadRoutes() (err gperr.Error) {
|
||||
func (p *Provider) LoadRoutes() (err error) {
|
||||
p.routes, err = p.loadRoutes()
|
||||
return err
|
||||
}
|
||||
@@ -188,7 +200,7 @@ func (p *Provider) GetRoute(alias string) (types.Route, bool) {
|
||||
return r.Impl(), true
|
||||
}
|
||||
|
||||
func (p *Provider) loadRoutes() (routes route.Routes, err gperr.Error) {
|
||||
func (p *Provider) loadRoutes() (routes route.Routes, err error) {
|
||||
routes, err = p.loadRoutesImpl()
|
||||
if err != nil && len(routes) == 0 {
|
||||
return route.Routes{}, err
|
||||
@@ -201,7 +213,7 @@ func (p *Provider) loadRoutes() (routes route.Routes, err gperr.Error) {
|
||||
r.Alias = alias
|
||||
r.SetProvider(p)
|
||||
if err := r.Validate(); err != nil {
|
||||
errs.Add(err.Subject(alias))
|
||||
errs.AddSubject(err, alias)
|
||||
delete(routes, alias)
|
||||
continue
|
||||
}
|
||||
@@ -210,11 +222,11 @@ func (p *Provider) loadRoutes() (routes route.Routes, err gperr.Error) {
|
||||
return routes, errs.Error()
|
||||
}
|
||||
|
||||
func (p *Provider) startRoute(parent task.Parent, r *route.Route) gperr.Error {
|
||||
func (p *Provider) startRoute(parent task.Parent, r *route.Route) error {
|
||||
err := r.Start(parent)
|
||||
if err != nil {
|
||||
p.lockDeleteRoute(r.Alias)
|
||||
return err.Subject(r.Alias)
|
||||
return gperr.PrependSubject(err, r.Alias)
|
||||
}
|
||||
|
||||
p.lockAddRoute(r)
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent"
|
||||
"github.com/yusing/godoxy/agent/pkg/agentproxy"
|
||||
config "github.com/yusing/godoxy/internal/config/types"
|
||||
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
|
||||
"github.com/yusing/godoxy/internal/health/monitor"
|
||||
"github.com/yusing/godoxy/internal/idlewatcher"
|
||||
"github.com/yusing/godoxy/internal/logging/accesslog"
|
||||
@@ -14,16 +16,14 @@ import (
|
||||
"github.com/yusing/godoxy/internal/net/gphttp/loadbalancer"
|
||||
"github.com/yusing/godoxy/internal/net/gphttp/middleware"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
"github.com/yusing/godoxy/internal/route/routes"
|
||||
route "github.com/yusing/godoxy/internal/route/types"
|
||||
"github.com/yusing/godoxy/internal/types"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
"github.com/yusing/goutils/http/reverseproxy"
|
||||
"github.com/yusing/goutils/task"
|
||||
"github.com/yusing/goutils/version"
|
||||
)
|
||||
|
||||
type ReveseProxyRoute struct {
|
||||
type ReverseProxyRoute struct {
|
||||
*Route
|
||||
|
||||
loadBalancer *loadbalancer.LoadBalancer
|
||||
@@ -31,11 +31,11 @@ type ReveseProxyRoute struct {
|
||||
rp *reverseproxy.ReverseProxy
|
||||
}
|
||||
|
||||
var _ types.ReverseProxyRoute = (*ReveseProxyRoute)(nil)
|
||||
var _ types.ReverseProxyRoute = (*ReverseProxyRoute)(nil)
|
||||
|
||||
// var globalMux = http.NewServeMux() // TODO: support regex subdomain matching.
|
||||
|
||||
func NewReverseProxyRoute(base *Route) (*ReveseProxyRoute, gperr.Error) {
|
||||
func NewReverseProxyRoute(base *Route) (*ReverseProxyRoute, error) {
|
||||
httpConfig := base.HTTPConfig
|
||||
proxyURL := base.ProxyURL
|
||||
|
||||
@@ -111,7 +111,7 @@ func NewReverseProxyRoute(base *Route) (*ReveseProxyRoute, gperr.Error) {
|
||||
}
|
||||
}
|
||||
|
||||
r := &ReveseProxyRoute{
|
||||
r := &ReverseProxyRoute{
|
||||
Route: base,
|
||||
rp: rp,
|
||||
}
|
||||
@@ -119,20 +119,21 @@ func NewReverseProxyRoute(base *Route) (*ReveseProxyRoute, gperr.Error) {
|
||||
}
|
||||
|
||||
// ReverseProxy implements routes.ReverseProxyRoute.
|
||||
func (r *ReveseProxyRoute) ReverseProxy() *reverseproxy.ReverseProxy {
|
||||
func (r *ReverseProxyRoute) ReverseProxy() *reverseproxy.ReverseProxy {
|
||||
return r.rp
|
||||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error {
|
||||
func (r *ReverseProxyRoute) Start(parent task.Parent) error {
|
||||
r.task = parent.Subtask("http."+r.Name(), false)
|
||||
r.task.SetValue(monitor.DisplayNameKey{}, r.DisplayName())
|
||||
|
||||
switch {
|
||||
case r.UseIdleWatcher():
|
||||
waker, err := idlewatcher.NewWatcher(parent, r, r.IdlewatcherConfig())
|
||||
if err != nil {
|
||||
r.task.Finish(err)
|
||||
return gperr.Wrap(err)
|
||||
return err
|
||||
}
|
||||
r.handler = waker
|
||||
r.HealthMon = waker
|
||||
@@ -149,7 +150,7 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error {
|
||||
r.rp.AccessLogger, err = accesslog.NewAccessLogger(r.task, r.AccessLog)
|
||||
if err != nil {
|
||||
r.task.Finish(err)
|
||||
return gperr.Wrap(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,44 +160,50 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error {
|
||||
|
||||
if r.HealthMon != nil {
|
||||
if err := r.HealthMon.Start(r.task); err != nil {
|
||||
return err
|
||||
// TODO: add to event history
|
||||
log.Warn().Err(err).Msg("health monitor error")
|
||||
r.HealthMon = nil
|
||||
}
|
||||
}
|
||||
|
||||
ep := entrypoint.FromCtx(parent.Context())
|
||||
if ep == nil {
|
||||
err := errors.New("entrypoint not initialized")
|
||||
r.task.Finish(err)
|
||||
return err
|
||||
}
|
||||
|
||||
if r.UseLoadBalance() {
|
||||
r.addToLoadBalancer(parent)
|
||||
} else {
|
||||
routes.HTTP.Add(r)
|
||||
if state := config.WorkingState.Load(); state != nil {
|
||||
state.ShortLinkMatcher().AddRoute(r.Alias)
|
||||
if err := r.addToLoadBalancer(parent, ep); err != nil {
|
||||
r.task.Finish(err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := ep.StartAddRoute(r); err != nil {
|
||||
r.task.Finish(err)
|
||||
return err
|
||||
}
|
||||
r.task.OnCancel("remove_route", func() {
|
||||
routes.HTTP.Del(r)
|
||||
if state := config.WorkingState.Load(); state != nil {
|
||||
state.ShortLinkMatcher().DelRoute(r.Alias)
|
||||
}
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ReveseProxyRoute) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
func (r *ReverseProxyRoute) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
// req.Header.Set("Accept-Encoding", "identity")
|
||||
r.handler.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
var lbLock sync.Mutex
|
||||
|
||||
func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent) {
|
||||
func (r *ReverseProxyRoute) addToLoadBalancer(parent task.Parent, ep entrypoint.Entrypoint) error {
|
||||
var lb *loadbalancer.LoadBalancer
|
||||
cfg := r.LoadBalance
|
||||
lbLock.Lock()
|
||||
defer lbLock.Unlock()
|
||||
|
||||
l, ok := routes.HTTP.Get(cfg.Link)
|
||||
var linked *ReveseProxyRoute
|
||||
l, ok := ep.HTTPRoutes().Get(cfg.Link)
|
||||
var linked *ReverseProxyRoute
|
||||
if ok {
|
||||
lbLock.Unlock()
|
||||
linked = l.(*ReveseProxyRoute)
|
||||
linked = l.(*ReverseProxyRoute) // it must be a reverse proxy route
|
||||
lb = linked.loadBalancer
|
||||
lb.UpdateConfigIfNeeded(cfg)
|
||||
if linked.Homepage.Name == "" {
|
||||
@@ -205,26 +212,24 @@ func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent) {
|
||||
} else {
|
||||
lb = loadbalancer.New(cfg)
|
||||
_ = lb.Start(parent) // always return nil
|
||||
linked = &ReveseProxyRoute{
|
||||
linked = &ReverseProxyRoute{
|
||||
Route: &Route{
|
||||
Alias: cfg.Link,
|
||||
Homepage: r.Homepage,
|
||||
Bind: r.Bind,
|
||||
Metadata: Metadata{
|
||||
LisURL: r.ListenURL(),
|
||||
task: lb.Task(),
|
||||
},
|
||||
},
|
||||
loadBalancer: lb,
|
||||
handler: lb,
|
||||
}
|
||||
linked.SetHealthMonitor(lb)
|
||||
routes.HTTP.AddKey(cfg.Link, linked)
|
||||
if state := config.WorkingState.Load(); state != nil {
|
||||
state.ShortLinkMatcher().AddRoute(cfg.Link)
|
||||
if err := ep.StartAddRoute(linked); err != nil {
|
||||
lb.Finish(err)
|
||||
return err
|
||||
}
|
||||
r.task.OnFinished("remove_loadbalancer_route", func() {
|
||||
routes.HTTP.DelKey(cfg.Link)
|
||||
if state := config.WorkingState.Load(); state != nil {
|
||||
state.ShortLinkMatcher().DelRoute(cfg.Link)
|
||||
}
|
||||
})
|
||||
lbLock.Unlock()
|
||||
}
|
||||
r.loadBalancer = lb
|
||||
|
||||
@@ -233,4 +238,5 @@ func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent) {
|
||||
r.task.OnCancel("lb_remove_server", func() {
|
||||
lb.RemoveServer(server)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
39
internal/route/reverse_proxy_test.go
Normal file
39
internal/route/reverse_proxy_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
route "github.com/yusing/godoxy/internal/route/types"
|
||||
"github.com/yusing/godoxy/internal/types"
|
||||
)
|
||||
|
||||
func TestReverseProxyRoute(t *testing.T) {
|
||||
t.Run("LinkToLoadBalancer", func(t *testing.T) {
|
||||
cfg := Route{
|
||||
Alias: "test",
|
||||
Scheme: route.SchemeHTTP,
|
||||
Host: "example.com",
|
||||
Port: Port{Proxy: 80},
|
||||
LoadBalance: &types.LoadBalancerConfig{
|
||||
Link: "test",
|
||||
},
|
||||
}
|
||||
cfg1 := Route{
|
||||
Alias: "test1",
|
||||
Scheme: route.SchemeHTTP,
|
||||
Host: "example.com",
|
||||
Port: Port{Proxy: 80},
|
||||
LoadBalance: &types.LoadBalancerConfig{
|
||||
Link: "test",
|
||||
},
|
||||
}
|
||||
r, err := NewStartedTestRoute(t, &cfg)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, r)
|
||||
r2, err := NewStartedTestRoute(t, &cfg1)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, r2)
|
||||
})
|
||||
}
|
||||
@@ -14,10 +14,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/godoxy/internal/agentpool"
|
||||
config "github.com/yusing/godoxy/internal/config/types"
|
||||
"github.com/yusing/godoxy/internal/docker"
|
||||
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
|
||||
"github.com/yusing/godoxy/internal/health/monitor"
|
||||
"github.com/yusing/godoxy/internal/homepage"
|
||||
iconlist "github.com/yusing/godoxy/internal/homepage/icons/list"
|
||||
@@ -33,7 +35,6 @@ import (
|
||||
|
||||
"github.com/yusing/godoxy/internal/common"
|
||||
"github.com/yusing/godoxy/internal/logging/accesslog"
|
||||
"github.com/yusing/godoxy/internal/route/routes"
|
||||
"github.com/yusing/godoxy/internal/route/rules"
|
||||
rulepresets "github.com/yusing/godoxy/internal/route/rules/presets"
|
||||
route "github.com/yusing/godoxy/internal/route/types"
|
||||
@@ -46,7 +47,6 @@ type (
|
||||
Host string `json:"host,omitempty"`
|
||||
Port route.Port `json:"port"`
|
||||
|
||||
// for TCP and UDP routes, bind address to listen on
|
||||
Bind string `json:"bind,omitempty" validate:"omitempty,ip_addr" extensions:"x-nullable"`
|
||||
|
||||
Root string `json:"root,omitempty"`
|
||||
@@ -57,7 +57,7 @@ type (
|
||||
PathPatterns []string `json:"path_patterns,omitempty" extensions:"x-nullable"`
|
||||
Rules rules.Rules `json:"rules,omitempty" extensions:"x-nullable"`
|
||||
RuleFile string `json:"rule_file,omitempty" extensions:"x-nullable"`
|
||||
HealthCheck types.HealthCheckConfig `json:"healthcheck,omitempty" extensions:"x-nullable"` // null on load-balancer routes
|
||||
HealthCheck types.HealthCheckConfig `json:"healthcheck,omitzero" extensions:"x-nullable"` // null on load-balancer routes
|
||||
LoadBalance *types.LoadBalancerConfig `json:"load_balance,omitempty" extensions:"x-nullable"`
|
||||
Middlewares map[string]types.LabelMap `json:"middlewares,omitempty" extensions:"x-nullable"`
|
||||
Homepage *homepage.ItemConfig `json:"homepage"`
|
||||
@@ -108,17 +108,17 @@ type (
|
||||
)
|
||||
|
||||
type lockedError struct {
|
||||
err gperr.Error
|
||||
err error
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (le *lockedError) Get() gperr.Error {
|
||||
func (le *lockedError) Get() error {
|
||||
le.lock.Lock()
|
||||
defer le.lock.Unlock()
|
||||
return le.err
|
||||
}
|
||||
|
||||
func (le *lockedError) Set(err gperr.Error) {
|
||||
func (le *lockedError) Set(err error) {
|
||||
le.lock.Lock()
|
||||
defer le.lock.Unlock()
|
||||
le.err = err
|
||||
@@ -131,7 +131,7 @@ func (r Routes) Contains(alias string) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *Route) Validate() gperr.Error {
|
||||
func (r *Route) Validate() error {
|
||||
// wait for alias to be set
|
||||
if r.Alias == "" {
|
||||
return nil
|
||||
@@ -150,13 +150,13 @@ func (r *Route) Validate() gperr.Error {
|
||||
return r.valErr.Get()
|
||||
}
|
||||
|
||||
func (r *Route) validate() gperr.Error {
|
||||
func (r *Route) validate() error {
|
||||
// if strings.HasPrefix(r.Alias, "godoxy") {
|
||||
// log.Debug().Any("route", r).Msg("validating route")
|
||||
// }
|
||||
if r.Agent != "" {
|
||||
if r.Container != nil {
|
||||
return gperr.Errorf("specifying agent is not allowed for docker container routes")
|
||||
return errors.New("specifying agent is not allowed for docker container routes")
|
||||
}
|
||||
var ok bool
|
||||
// by agent address
|
||||
@@ -165,7 +165,7 @@ func (r *Route) validate() gperr.Error {
|
||||
// fallback to get agent by name
|
||||
r.agent, ok = agentpool.GetAgent(r.Agent)
|
||||
if !ok {
|
||||
return gperr.Errorf("agent %s not found", r.Agent)
|
||||
return fmt.Errorf("agent %s not found", r.Agent)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -200,7 +200,11 @@ func (r *Route) validate() gperr.Error {
|
||||
|
||||
if (r.Proxmox == nil || r.Proxmox.Node == "" || r.Proxmox.VMID == nil) && r.Container == nil {
|
||||
wasNotNil := r.Proxmox != nil
|
||||
proxmoxProviders := config.WorkingState.Load().Value().Providers.Proxmox
|
||||
workingState := config.WorkingState.Load()
|
||||
var proxmoxProviders []*proxmox.Config
|
||||
if workingState != nil { // nil in tests
|
||||
proxmoxProviders = workingState.Value().Providers.Proxmox
|
||||
}
|
||||
if len(proxmoxProviders) > 0 {
|
||||
// it's fine if ip is nil
|
||||
hostname := r.Host
|
||||
@@ -208,40 +212,34 @@ func (r *Route) validate() gperr.Error {
|
||||
for _, p := range proxmoxProviders {
|
||||
// First check if hostname, IP, or alias matches a node (node-level route)
|
||||
if nodeName := p.Client().ReverseLookupNode(hostname, ip, r.Alias); nodeName != "" {
|
||||
zero := 0
|
||||
zero := uint64(0)
|
||||
if r.Proxmox == nil {
|
||||
r.Proxmox = &proxmox.NodeConfig{}
|
||||
}
|
||||
r.Proxmox.Node = nodeName
|
||||
r.Proxmox.VMID = &zero
|
||||
r.Proxmox.VMName = ""
|
||||
log.Info().
|
||||
Str("node", nodeName).
|
||||
Msgf("found proxmox node for route %q", r.Alias)
|
||||
log.Info().EmbedObject(r).Msg("found proxmox node")
|
||||
break
|
||||
}
|
||||
|
||||
// Then check if hostname, IP, or alias matches a VM resource
|
||||
resource, _ := p.Client().ReverseLookupResource(ip, hostname, r.Alias)
|
||||
if resource != nil {
|
||||
vmid := int(resource.VMID)
|
||||
vmid := resource.VMID
|
||||
if r.Proxmox == nil {
|
||||
r.Proxmox = &proxmox.NodeConfig{}
|
||||
}
|
||||
r.Proxmox.Node = resource.Node
|
||||
r.Proxmox.VMID = &vmid
|
||||
r.Proxmox.VMName = resource.Name
|
||||
log.Info().
|
||||
Str("node", resource.Node).
|
||||
Int("vmid", int(resource.VMID)).
|
||||
Str("vmname", resource.Name).
|
||||
Msgf("found proxmox resource for route %q", r.Alias)
|
||||
log.Info().EmbedObject(r).Msg("found proxmox resource")
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if wasNotNil && (r.Proxmox.Node == "" || r.Proxmox.VMID == nil) {
|
||||
log.Warn().Msgf("no proxmox node / resource found for route %q", r.Alias)
|
||||
log.Warn().EmbedObject(r).Msg("no proxmox node / resource found")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -260,7 +258,7 @@ func (r *Route) validate() gperr.Error {
|
||||
switch r.Port.Proxy {
|
||||
case common.ProxyHTTPPort, common.ProxyHTTPSPort, common.APIHTTPPort:
|
||||
if r.Scheme.IsReverseProxy() || r.Scheme == route.SchemeTCP {
|
||||
return gperr.Errorf("localhost:%d is reserved for godoxy", r.Port.Proxy)
|
||||
return fmt.Errorf("localhost:%d is reserved for godoxy", r.Port.Proxy)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -271,27 +269,19 @@ func (r *Route) validate() gperr.Error {
|
||||
errs.Add(err)
|
||||
}
|
||||
|
||||
var impl types.Route
|
||||
var err gperr.Error
|
||||
|
||||
switch r.Scheme {
|
||||
case route.SchemeFileServer:
|
||||
r.Host = ""
|
||||
r.Port.Proxy = 0
|
||||
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, "file://"+r.Root)
|
||||
case route.SchemeHTTP, route.SchemeHTTPS, route.SchemeH2C:
|
||||
if r.Port.Listening != 0 {
|
||||
errs.Addf("unexpected listening port for %s scheme", r.Scheme)
|
||||
}
|
||||
if r.ShouldExclude() {
|
||||
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s", r.Scheme, net.JoinHostPort(r.Host, strconv.Itoa(r.Port.Proxy))))
|
||||
case route.SchemeTCP, route.SchemeUDP:
|
||||
if r.ShouldExclude() {
|
||||
// should exclude, we don't care the scheme here.
|
||||
} else {
|
||||
switch r.Scheme {
|
||||
case route.SchemeFileServer:
|
||||
r.Host = ""
|
||||
r.Port.Proxy = 0
|
||||
r.LisURL = gperr.Collect(&errs, nettypes.ParseURL, "https://"+net.JoinHostPort(r.Bind, strconv.Itoa(r.Port.Listening)))
|
||||
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, "file://"+r.Root)
|
||||
case route.SchemeHTTP, route.SchemeHTTPS, route.SchemeH2C:
|
||||
r.LisURL = gperr.Collect(&errs, nettypes.ParseURL, "https://"+net.JoinHostPort(r.Bind, strconv.Itoa(r.Port.Listening)))
|
||||
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s", r.Scheme, net.JoinHostPort(r.Host, strconv.Itoa(r.Port.Proxy))))
|
||||
} else {
|
||||
if r.Bind == "" {
|
||||
r.Bind = "0.0.0.0"
|
||||
}
|
||||
case route.SchemeTCP, route.SchemeUDP:
|
||||
bindIP := net.ParseIP(r.Bind)
|
||||
remoteIP := net.ParseIP(r.Host)
|
||||
toNetwork := func(ip net.IP, scheme route.Scheme) string {
|
||||
@@ -325,6 +315,8 @@ func (r *Route) validate() gperr.Error {
|
||||
return errs.Error()
|
||||
}
|
||||
|
||||
var impl types.Route
|
||||
var err error
|
||||
switch r.Scheme {
|
||||
case route.SchemeFileServer:
|
||||
impl, err = NewFileServer(r)
|
||||
@@ -360,8 +352,8 @@ func (r *Route) validateRules() error {
|
||||
return errors.New("rule preset `webui.yml` not found")
|
||||
}
|
||||
r.Rules = rules
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if r.RuleFile != "" && len(r.Rules) > 0 {
|
||||
@@ -397,7 +389,7 @@ func (r *Route) validateRules() error {
|
||||
}
|
||||
|
||||
func (r *Route) validateProxmox() {
|
||||
l := log.With().Str("route", r.Alias).Logger()
|
||||
l := log.With().EmbedObject(r).Logger()
|
||||
|
||||
nodeName := r.Proxmox.Node
|
||||
vmid := r.Proxmox.VMID
|
||||
@@ -426,7 +418,7 @@ func (r *Route) validateProxmox() {
|
||||
} else {
|
||||
res, err := node.Client().GetResource("lxc", *vmid)
|
||||
if err != nil { // ErrResourceNotFound
|
||||
l.Err(err).Msgf("failed to get resource %d", *vmid)
|
||||
l.Error().Err(err).Msgf("failed to get resource %d", *vmid)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -445,24 +437,22 @@ func (r *Route) validateProxmox() {
|
||||
return
|
||||
}
|
||||
|
||||
l = l.With().Str("container", containerName).Logger()
|
||||
|
||||
l.Info().Msgf("checking if container is running")
|
||||
l.Info().Str("container", containerName).Msg("checking if container is running")
|
||||
running, err := node.LXCIsRunning(ctx, *vmid)
|
||||
if err != nil {
|
||||
l.Err(err).Msgf("failed to check container state")
|
||||
l.Error().Err(err).Msgf("failed to check container state")
|
||||
return
|
||||
}
|
||||
|
||||
if !running {
|
||||
l.Info().Msgf("starting container")
|
||||
l.Info().Msg("starting container")
|
||||
if err := node.LXCAction(ctx, *vmid, proxmox.LXCStart); err != nil {
|
||||
l.Err(err).Msgf("failed to start container")
|
||||
l.Error().Err(err).Msg("failed to start container")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
l.Info().Msgf("finding reachable ip addresses")
|
||||
l.Info().Msg("finding reachable ip addresses")
|
||||
errs := gperr.NewBuilder("failed to find reachable ip addresses")
|
||||
for _, ip := range ips {
|
||||
if err := netutils.PingTCP(ctx, ip, r.Port.Proxy); err != nil {
|
||||
@@ -488,23 +478,23 @@ func (r *Route) Task() *task.Task {
|
||||
return r.task
|
||||
}
|
||||
|
||||
func (r *Route) Start(parent task.Parent) gperr.Error {
|
||||
func (r *Route) Start(parent task.Parent) error {
|
||||
r.onceStart.Do(func() {
|
||||
r.startErr.Set(r.start(parent))
|
||||
})
|
||||
return r.startErr.Get()
|
||||
}
|
||||
|
||||
func (r *Route) start(parent task.Parent) gperr.Error {
|
||||
func (r *Route) start(parent task.Parent) error {
|
||||
if r.impl == nil { // should not happen
|
||||
return gperr.New("route not initialized")
|
||||
return errors.New("route not initialized")
|
||||
}
|
||||
defer close(r.started)
|
||||
|
||||
// skip checking for excluded routes
|
||||
excluded := r.ShouldExclude()
|
||||
if !excluded {
|
||||
if err := checkExists(r); err != nil {
|
||||
if err := checkExists(parent.Context(), r); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -518,15 +508,23 @@ func (r *Route) start(parent task.Parent) gperr.Error {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
r.task = parent.Subtask("excluded."+r.Name(), true)
|
||||
routes.Excluded.Add(r.impl)
|
||||
ep := entrypoint.FromCtx(parent.Context())
|
||||
if ep == nil {
|
||||
return errors.New("entrypoint not initialized")
|
||||
}
|
||||
|
||||
r.task = parent.Subtask("excluded."+r.Name(), false)
|
||||
r.task.SetValue(monitor.DisplayNameKey{}, r.DisplayName())
|
||||
ep.ExcludedRoutes().Add(r.impl)
|
||||
r.task.OnCancel("remove_route_from_excluded", func() {
|
||||
routes.Excluded.Del(r.impl)
|
||||
ep.ExcludedRoutes().Del(r.impl)
|
||||
})
|
||||
if r.UseHealthCheck() {
|
||||
r.HealthMon = monitor.NewMonitor(r.impl)
|
||||
err := r.HealthMon.Start(r.task)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -564,6 +562,10 @@ func (r *Route) ProviderName() string {
|
||||
return r.Provider
|
||||
}
|
||||
|
||||
func (r *Route) ListenURL() *nettypes.URL {
|
||||
return r.LisURL
|
||||
}
|
||||
|
||||
func (r *Route) TargetURL() *nettypes.URL {
|
||||
return r.ProxyURL
|
||||
}
|
||||
@@ -587,10 +589,9 @@ func (r *Route) References() []string {
|
||||
return []string{r.Proxmox.VMName, aliasRef, r.Proxmox.Services[0]}
|
||||
}
|
||||
return []string{r.Proxmox.Services[0], aliasRef}
|
||||
} else {
|
||||
if r.Proxmox.VMName != aliasRef {
|
||||
return []string{r.Proxmox.VMName, aliasRef}
|
||||
}
|
||||
}
|
||||
if r.Proxmox.VMName != aliasRef {
|
||||
return []string{r.Proxmox.VMName, aliasRef}
|
||||
}
|
||||
}
|
||||
return []string{aliasRef}
|
||||
@@ -678,6 +679,44 @@ func (r *Route) DisplayName() string {
|
||||
return r.Homepage.Name
|
||||
}
|
||||
|
||||
func (r *Route) MarshalZerologObject(e *zerolog.Event) {
|
||||
e.Str("alias", r.Alias)
|
||||
switch r := r.impl.(type) {
|
||||
case *ReverseProxyRoute:
|
||||
e.Str("type", "reverse_proxy").
|
||||
Str("scheme", r.Scheme.String()).
|
||||
Str("bind", r.LisURL.Host).
|
||||
Str("target", r.ProxyURL.URL.String())
|
||||
case *FileServer:
|
||||
e.Str("type", "file_server").
|
||||
Str("root", r.Root)
|
||||
case *StreamRoute:
|
||||
e.Str("type", "stream").
|
||||
Str("scheme", r.LisURL.Scheme+"->"+r.ProxyURL.Scheme)
|
||||
if r.stream != nil {
|
||||
// listening port could be zero (random),
|
||||
// use LocalAddr() to get the actual listening host+port.
|
||||
e.Str("bind", r.stream.LocalAddr().String())
|
||||
} else {
|
||||
// not yet started
|
||||
e.Str("bind", r.LisURL.Host)
|
||||
}
|
||||
e.Str("target", r.ProxyURL.URL.String())
|
||||
}
|
||||
if r.Proxmox != nil {
|
||||
e.Str("proxmox", r.Proxmox.Node)
|
||||
if r.Proxmox.VMID != nil {
|
||||
e.Uint64("vmid", *r.Proxmox.VMID)
|
||||
}
|
||||
if r.Proxmox.VMName != "" {
|
||||
e.Str("vmname", r.Proxmox.VMName)
|
||||
}
|
||||
}
|
||||
if r.Container != nil {
|
||||
e.Str("container", r.Container.ContainerName)
|
||||
}
|
||||
}
|
||||
|
||||
// PreferOver implements pool.Preferable to resolve duplicate route keys deterministically.
|
||||
// Preference policy:
|
||||
// - Prefer routes with rules over routes without rules.
|
||||
@@ -689,7 +728,7 @@ func (r *Route) PreferOver(other any) bool {
|
||||
switch v := other.(type) {
|
||||
case *Route:
|
||||
or = v
|
||||
case *ReveseProxyRoute:
|
||||
case *ReverseProxyRoute:
|
||||
or = v.Route
|
||||
case *FileServer:
|
||||
or = v.Route
|
||||
@@ -932,6 +971,13 @@ func (r *Route) Finalize() {
|
||||
}
|
||||
}
|
||||
|
||||
switch r.Scheme {
|
||||
case route.SchemeTCP, route.SchemeUDP:
|
||||
if r.Bind == "" {
|
||||
r.Bind = "0.0.0.0"
|
||||
}
|
||||
}
|
||||
|
||||
r.Port.Listening, r.Port.Proxy = lp, pp
|
||||
|
||||
workingState := config.WorkingState.Load()
|
||||
@@ -942,7 +988,8 @@ func (r *Route) Finalize() {
|
||||
panic("bug: working state is nil")
|
||||
}
|
||||
|
||||
r.HealthCheck.ApplyDefaults(config.WorkingState.Load().Value().Defaults.HealthCheck)
|
||||
// TODO: default value from context
|
||||
r.HealthCheck.ApplyDefaults(workingState.Value().Defaults.HealthCheck)
|
||||
}
|
||||
|
||||
func (r *Route) FinalizeHomepageConfig() {
|
||||
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yusing/godoxy/internal/common"
|
||||
route "github.com/yusing/godoxy/internal/route/types"
|
||||
"github.com/yusing/godoxy/internal/types"
|
||||
expect "github.com/yusing/goutils/testing"
|
||||
)
|
||||
|
||||
func TestRouteValidate(t *testing.T) {
|
||||
@@ -19,20 +19,8 @@ func TestRouteValidate(t *testing.T) {
|
||||
Port: route.Port{Proxy: common.ProxyHTTPPort},
|
||||
}
|
||||
err := r.Validate()
|
||||
expect.HasError(t, err, "Validate should return error for localhost with reserved port")
|
||||
expect.ErrorContains(t, err, "reserved for godoxy")
|
||||
})
|
||||
|
||||
t.Run("ListeningPortWithHTTP", func(t *testing.T) {
|
||||
r := &Route{
|
||||
Alias: "test",
|
||||
Scheme: route.SchemeHTTP,
|
||||
Host: "example.com",
|
||||
Port: route.Port{Proxy: 80, Listening: 1234},
|
||||
}
|
||||
err := r.Validate()
|
||||
expect.HasError(t, err, "Validate should return error for HTTP scheme with listening port")
|
||||
expect.ErrorContains(t, err, "unexpected listening port")
|
||||
require.Error(t, err, "Validate should return error for localhost with reserved port")
|
||||
require.ErrorContains(t, err, "reserved for godoxy")
|
||||
})
|
||||
|
||||
t.Run("DisabledHealthCheckWithLoadBalancer", func(t *testing.T) {
|
||||
@@ -49,8 +37,8 @@ func TestRouteValidate(t *testing.T) {
|
||||
}, // Minimal LoadBalance config with non-empty Link will be checked by UseLoadBalance
|
||||
}
|
||||
err := r.Validate()
|
||||
expect.HasError(t, err, "Validate should return error for disabled healthcheck with loadbalancer")
|
||||
expect.ErrorContains(t, err, "cannot disable healthcheck")
|
||||
require.Error(t, err, "Validate should return error for disabled healthcheck with loadbalancer")
|
||||
require.ErrorContains(t, err, "cannot disable healthcheck")
|
||||
})
|
||||
|
||||
t.Run("FileServerScheme", func(t *testing.T) {
|
||||
@@ -62,8 +50,8 @@ func TestRouteValidate(t *testing.T) {
|
||||
Root: "/tmp", // Root is required for file server
|
||||
}
|
||||
err := r.Validate()
|
||||
expect.NoError(t, err, "Validate should not return error for valid file server route")
|
||||
expect.NotNil(t, r.impl, "Impl should be initialized")
|
||||
require.NoError(t, err, "Validate should not return error for valid file server route")
|
||||
require.NotNil(t, r.impl, "Impl should be initialized")
|
||||
})
|
||||
|
||||
t.Run("HTTPScheme", func(t *testing.T) {
|
||||
@@ -74,8 +62,8 @@ func TestRouteValidate(t *testing.T) {
|
||||
Port: route.Port{Proxy: 80},
|
||||
}
|
||||
err := r.Validate()
|
||||
expect.NoError(t, err, "Validate should not return error for valid HTTP route")
|
||||
expect.NotNil(t, r.impl, "Impl should be initialized")
|
||||
require.NoError(t, err, "Validate should not return error for valid HTTP route")
|
||||
require.NotNil(t, r.impl, "Impl should be initialized")
|
||||
})
|
||||
|
||||
t.Run("TCPScheme", func(t *testing.T) {
|
||||
@@ -86,8 +74,8 @@ func TestRouteValidate(t *testing.T) {
|
||||
Port: route.Port{Proxy: 80, Listening: 8080},
|
||||
}
|
||||
err := r.Validate()
|
||||
expect.NoError(t, err, "Validate should not return error for valid TCP route")
|
||||
expect.NotNil(t, r.impl, "Impl should be initialized")
|
||||
require.NoError(t, err, "Validate should not return error for valid TCP route")
|
||||
require.NotNil(t, r.impl, "Impl should be initialized")
|
||||
})
|
||||
|
||||
t.Run("DockerContainer", func(t *testing.T) {
|
||||
@@ -106,8 +94,8 @@ func TestRouteValidate(t *testing.T) {
|
||||
},
|
||||
}
|
||||
err := r.Validate()
|
||||
expect.NoError(t, err, "Validate should not return error for valid docker container route")
|
||||
expect.NotNil(t, r.ProxyURL, "ProxyURL should be set")
|
||||
require.NoError(t, err, "Validate should not return error for valid docker container route")
|
||||
require.NotNil(t, r.ProxyURL, "ProxyURL should be set")
|
||||
})
|
||||
|
||||
t.Run("InvalidScheme", func(t *testing.T) {
|
||||
@@ -117,7 +105,7 @@ func TestRouteValidate(t *testing.T) {
|
||||
Host: "example.com",
|
||||
Port: route.Port{Proxy: 80},
|
||||
}
|
||||
expect.Panics(t, func() {
|
||||
require.Panics(t, func() {
|
||||
_ = r.Validate()
|
||||
}, "Validate should panic for invalid scheme")
|
||||
})
|
||||
@@ -130,9 +118,9 @@ func TestRouteValidate(t *testing.T) {
|
||||
Port: route.Port{Proxy: 80},
|
||||
}
|
||||
err := r.Validate()
|
||||
expect.NoError(t, err)
|
||||
expect.NotNil(t, r.ProxyURL)
|
||||
expect.NotNil(t, r.HealthCheck)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r.ProxyURL)
|
||||
require.NotNil(t, r.HealthCheck)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -144,7 +132,7 @@ func TestPreferredPort(t *testing.T) {
|
||||
}
|
||||
|
||||
port := preferredPort(ports)
|
||||
expect.Equal(t, port, 3000)
|
||||
require.Equal(t, 3000, port)
|
||||
}
|
||||
|
||||
func TestDockerRouteDisallowAgent(t *testing.T) {
|
||||
@@ -164,8 +152,8 @@ func TestDockerRouteDisallowAgent(t *testing.T) {
|
||||
},
|
||||
}
|
||||
err := r.Validate()
|
||||
expect.HasError(t, err, "Validate should return error for docker route with agent")
|
||||
expect.ErrorContains(t, err, "specifying agent is not allowed for docker container routes")
|
||||
require.Error(t, err, "Validate should return error for docker route with agent")
|
||||
require.ErrorContains(t, err, "specifying agent is not allowed for docker container routes")
|
||||
}
|
||||
|
||||
func TestRouteAgent(t *testing.T) {
|
||||
@@ -177,8 +165,8 @@ func TestRouteAgent(t *testing.T) {
|
||||
Agent: "test-agent",
|
||||
}
|
||||
err := r.Validate()
|
||||
expect.NoError(t, err, "Validate should not return error for valid route with agent")
|
||||
expect.NotNil(t, r.GetAgent(), "GetAgent should return agent")
|
||||
require.NoError(t, err, "Validate should not return error for valid route with agent")
|
||||
require.NotNil(t, r.GetAgent(), "GetAgent should return agent")
|
||||
}
|
||||
|
||||
func TestRouteApplyingHealthCheckDefaults(t *testing.T) {
|
||||
@@ -188,6 +176,106 @@ func TestRouteApplyingHealthCheckDefaults(t *testing.T) {
|
||||
Timeout: 10 * time.Second,
|
||||
})
|
||||
|
||||
expect.Equal(t, hc.Interval, 15*time.Second)
|
||||
expect.Equal(t, hc.Timeout, 10*time.Second)
|
||||
require.Equal(t, 15*time.Second, hc.Interval)
|
||||
require.Equal(t, 10*time.Second, hc.Timeout)
|
||||
}
|
||||
|
||||
func TestRouteBindField(t *testing.T) {
|
||||
t.Run("TCPSchemeWithCustomBind", func(t *testing.T) {
|
||||
r := &Route{
|
||||
Alias: "test-tcp",
|
||||
Scheme: route.SchemeTCP,
|
||||
Host: "192.168.1.100",
|
||||
Port: route.Port{Proxy: 80, Listening: 8080},
|
||||
Bind: "192.168.1.1",
|
||||
}
|
||||
err := r.Validate()
|
||||
require.NoError(t, err, "Validate should not return error for TCP route with custom bind")
|
||||
require.NotNil(t, r.LisURL, "LisURL should be set")
|
||||
require.Equal(t, "tcp4://192.168.1.1:8080", r.LisURL.String(), "LisURL should contain custom bind address")
|
||||
})
|
||||
|
||||
t.Run("UDPSchemeWithCustomBind", func(t *testing.T) {
|
||||
r := &Route{
|
||||
Alias: "test-udp",
|
||||
Scheme: route.SchemeUDP,
|
||||
Host: "10.0.0.1",
|
||||
Port: route.Port{Proxy: 53, Listening: 53},
|
||||
Bind: "10.0.0.254",
|
||||
}
|
||||
err := r.Validate()
|
||||
require.NoError(t, err, "Validate should not return error for UDP route with custom bind")
|
||||
require.NotNil(t, r.LisURL, "LisURL should be set")
|
||||
require.Equal(t, "udp4://10.0.0.254:53", r.LisURL.String(), "LisURL should contain custom bind address")
|
||||
})
|
||||
|
||||
t.Run("HTTPSchemeWithoutBind", func(t *testing.T) {
|
||||
r := &Route{
|
||||
Alias: "test-http",
|
||||
Scheme: route.SchemeHTTP,
|
||||
Host: "example.com",
|
||||
Port: route.Port{Proxy: 80},
|
||||
}
|
||||
err := r.Validate()
|
||||
require.NoError(t, err, "Validate should not return error for HTTP route without bind")
|
||||
require.NotNil(t, r.LisURL, "LisURL should be set")
|
||||
require.Equal(t, "https://:0", r.LisURL.String(), "LisURL should contain bind address")
|
||||
})
|
||||
|
||||
t.Run("HTTPSchemeWithBind", func(t *testing.T) {
|
||||
r := &Route{
|
||||
Alias: "test-http",
|
||||
Scheme: route.SchemeHTTP,
|
||||
Host: "example.com",
|
||||
Port: route.Port{Proxy: 80},
|
||||
Bind: "0.0.0.0",
|
||||
}
|
||||
err := r.Validate()
|
||||
require.NoError(t, err, "Validate should not return error for HTTP route with bind")
|
||||
require.NotNil(t, r.LisURL, "LisURL should be set")
|
||||
require.Equal(t, "https://0.0.0.0:0", r.LisURL.String(), "LisURL should contain bind address")
|
||||
})
|
||||
|
||||
t.Run("HTTPSchemeWithBindAndPort", func(t *testing.T) {
|
||||
r := &Route{
|
||||
Alias: "test-http",
|
||||
Scheme: route.SchemeHTTP,
|
||||
Host: "example.com",
|
||||
Port: route.Port{Listening: 8080, Proxy: 80},
|
||||
Bind: "0.0.0.0",
|
||||
}
|
||||
err := r.Validate()
|
||||
require.NoError(t, err, "Validate should not return error for HTTP route with bind and port")
|
||||
require.NotNil(t, r.LisURL, "LisURL should be set")
|
||||
require.Equal(t, "https://0.0.0.0:8080", r.LisURL.String(), "LisURL should contain bind address and listening port")
|
||||
})
|
||||
|
||||
t.Run("TCPSchemeDefaultsToZeroBind", func(t *testing.T) {
|
||||
r := &Route{
|
||||
Alias: "test-default-bind",
|
||||
Scheme: route.SchemeTCP,
|
||||
Host: "example.com",
|
||||
Port: route.Port{Proxy: 80, Listening: 8080},
|
||||
Bind: "",
|
||||
}
|
||||
err := r.Validate()
|
||||
require.NoError(t, err, "Validate should not return error for TCP route with empty bind")
|
||||
require.Equal(t, "0.0.0.0", r.Bind, "Bind should default to 0.0.0.0 for TCP scheme")
|
||||
require.NotNil(t, r.LisURL, "LisURL should be set")
|
||||
require.Equal(t, "tcp4://0.0.0.0:8080", r.LisURL.String(), "LisURL should use default bind address")
|
||||
})
|
||||
|
||||
t.Run("FileServerSchemeWithBind", func(t *testing.T) {
|
||||
r := &Route{
|
||||
Alias: "test-fileserver",
|
||||
Scheme: route.SchemeFileServer,
|
||||
Port: route.Port{Listening: 9000},
|
||||
Root: "/tmp",
|
||||
Bind: "127.0.0.1",
|
||||
}
|
||||
err := r.Validate()
|
||||
require.NoError(t, err, "Validate should not return error for fileserver route with bind")
|
||||
require.NotNil(t, r.LisURL, "LisURL should be set")
|
||||
require.Equal(t, "https://127.0.0.1:9000", r.LisURL.String(), "LisURL should contain bind address")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,307 +0,0 @@
|
||||
# Route Registry
|
||||
|
||||
Provides centralized route registry with O(1) lookups and route context management for HTTP handlers.
|
||||
|
||||
## Overview
|
||||
|
||||
The `internal/route/routes` package maintains the global route registry for GoDoxy. It provides thread-safe route lookups by alias, route iteration, and utilities for propagating route context through HTTP request handlers.
|
||||
|
||||
### Primary Consumers
|
||||
|
||||
- **HTTP handlers**: Lookup routes and extract request context
|
||||
- **Route providers**: Register and unregister routes
|
||||
- **Health system**: Query route health status
|
||||
- **WebUI**: Display route information
|
||||
|
||||
### Non-goals
|
||||
|
||||
- Does not create or modify routes
|
||||
- Does not handle route validation
|
||||
- Does not implement routing logic (matching)
|
||||
|
||||
### Stability
|
||||
|
||||
Internal package with stable public API.
|
||||
|
||||
## Public API
|
||||
|
||||
### Route Pools
|
||||
|
||||
```go
|
||||
var (
|
||||
HTTP = pool.New[types.HTTPRoute]("http_routes")
|
||||
Stream = pool.New[types.StreamRoute]("stream_routes")
|
||||
Excluded = pool.New[types.Route]("excluded_routes")
|
||||
)
|
||||
```
|
||||
|
||||
Pool methods:
|
||||
|
||||
- `Get(alias string) (T, bool)` - O(1) lookup
|
||||
- `Add(r T)` - Register route
|
||||
- `Del(r T)` - Unregister route
|
||||
- `Size() int` - Route count
|
||||
- `Clear()` - Remove all routes
|
||||
- `Iter` - Channel-based iteration
|
||||
|
||||
### Exported Functions
|
||||
|
||||
```go
|
||||
// Iterate over active routes (HTTP + Stream)
|
||||
func IterActive(yield func(r types.Route) bool)
|
||||
|
||||
// Iterate over all routes (HTTP + Stream + Excluded)
|
||||
func IterAll(yield func(r types.Route) bool)
|
||||
|
||||
// Get route count
|
||||
func NumActiveRoutes() int
|
||||
func NumAllRoutes() int
|
||||
|
||||
// Clear all routes
|
||||
func Clear()
|
||||
|
||||
// Lookup functions
|
||||
func Get(alias string) (types.Route, bool)
|
||||
func GetHTTPRouteOrExact(alias, host string) (types.HTTPRoute, bool)
|
||||
```
|
||||
|
||||
### Route Context
|
||||
|
||||
```go
|
||||
type RouteContext struct {
|
||||
context.Context
|
||||
Route types.HTTPRoute
|
||||
}
|
||||
|
||||
// Attach route to request context (uses unsafe pointer for performance)
|
||||
func WithRouteContext(r *http.Request, route types.HTTPRoute) *http.Request
|
||||
|
||||
// Extract route from request context
|
||||
func TryGetRoute(r *http.Request) types.HTTPRoute
|
||||
```
|
||||
|
||||
### Upstream Information
|
||||
|
||||
```go
|
||||
func TryGetUpstreamName(r *http.Request) string
|
||||
func TryGetUpstreamScheme(r *http.Request) string
|
||||
func TryGetUpstreamHost(r *http.Request) string
|
||||
func TryGetUpstreamPort(r *http.Request) string
|
||||
func TryGetUpstreamHostPort(r *http.Request) string
|
||||
func TryGetUpstreamAddr(r *http.Request) string
|
||||
func TryGetUpstreamURL(r *http.Request) string
|
||||
```
|
||||
|
||||
### Health Information
|
||||
|
||||
```go
|
||||
type HealthInfo struct {
|
||||
HealthInfoWithoutDetail
|
||||
Detail string
|
||||
}
|
||||
|
||||
type HealthInfoWithoutDetail struct {
|
||||
Status types.HealthStatus
|
||||
Uptime time.Duration
|
||||
Latency time.Duration
|
||||
}
|
||||
|
||||
func GetHealthInfo() map[string]HealthInfo
|
||||
func GetHealthInfoWithoutDetail() map[string]HealthInfoWithoutDetail
|
||||
func GetHealthInfoSimple() map[string]types.HealthStatus
|
||||
```
|
||||
|
||||
### Provider Grouping
|
||||
|
||||
```go
|
||||
func ByProvider() map[string][]types.Route
|
||||
```
|
||||
|
||||
## Proxmox Integration
|
||||
|
||||
Routes can be automatically linked to Proxmox nodes or LXC containers through reverse lookup during validation.
|
||||
|
||||
### Node-Level Routes
|
||||
|
||||
Routes can be linked to a Proxmox node directly (VMID = 0) when the route's hostname, IP, or alias matches a node name or IP:
|
||||
|
||||
```go
|
||||
// Route linked to Proxmox node (no specific VM)
|
||||
route.Proxmox = &proxmox.NodeConfig{
|
||||
Node: "pve-node-01",
|
||||
VMID: 0, // node-level, no container
|
||||
VMName: "",
|
||||
}
|
||||
```
|
||||
|
||||
### Container-Level Routes
|
||||
|
||||
Routes are linked to LXC containers when they match a VM resource by hostname, IP, or alias:
|
||||
|
||||
```go
|
||||
// Route linked to LXC container
|
||||
route.Proxmox = &proxmox.NodeConfig{
|
||||
Node: "pve-node-01",
|
||||
VMID: 100,
|
||||
VMName: "my-container",
|
||||
}
|
||||
```
|
||||
|
||||
### Lookup Priority
|
||||
|
||||
1. **Node match** - If hostname, IP, or alias matches a Proxmox node
|
||||
2. **VM match** - If hostname, IP, or alias matches a VM resource
|
||||
|
||||
Node-level routes skip container control logic (start/check IPs) and can be used to proxy node services directly.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
```mermaid
|
||||
classDiagram
|
||||
class HTTP
|
||||
class Stream
|
||||
class Excluded
|
||||
class RouteContext
|
||||
|
||||
HTTP : +Get(alias) T
|
||||
HTTP : +Add(r)
|
||||
HTTP : +Del(r)
|
||||
HTTP : +Size() int
|
||||
HTTP : +Iter chan
|
||||
|
||||
Stream : +Get(alias) T
|
||||
Stream : +Add(r)
|
||||
Stream : +Del(r)
|
||||
|
||||
Excluded : +Get(alias) T
|
||||
Excluded : +Add(r)
|
||||
Excluded : +Del(r)
|
||||
```
|
||||
|
||||
### Route Lookup Flow
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Lookup Request] --> B{HTTP Pool}
|
||||
B -->|Found| C[Return Route]
|
||||
B -->|Not Found| D{Stream Pool}
|
||||
D -->|Found| C
|
||||
D -->|Not Found| E[Return nil]
|
||||
```
|
||||
|
||||
### Context Propagation
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant H as HTTP Handler
|
||||
participant R as Registry
|
||||
participant C as RouteContext
|
||||
|
||||
H->>R: WithRouteContext(req, route)
|
||||
R->>C: Attach route via unsafe pointer
|
||||
C-->>H: Modified request
|
||||
|
||||
H->>R: TryGetRoute(req)
|
||||
R->>C: Extract route from context
|
||||
C-->>R: Route
|
||||
R-->>H: Route
|
||||
```
|
||||
|
||||
## Dependency and Integration Map
|
||||
|
||||
| Dependency | Purpose |
|
||||
| -------------------------------- | ---------------------------------- |
|
||||
| `internal/types` | Route and health type definitions |
|
||||
| `internal/proxmox` | Proxmox node/container integration |
|
||||
| `github.com/yusing/goutils/pool` | Thread-safe pool implementation |
|
||||
|
||||
## Observability
|
||||
|
||||
### Logs
|
||||
|
||||
Registry operations logged at DEBUG level:
|
||||
|
||||
- Route add/remove
|
||||
- Pool iteration
|
||||
- Context operations
|
||||
|
||||
### Performance
|
||||
|
||||
- `WithRouteContext` uses `unsafe.Pointer` to avoid request cloning
|
||||
- Route lookups are O(1) using internal maps
|
||||
- Iteration uses channels for memory efficiency
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- Route context propagation is internal to the process
|
||||
- No sensitive data exposed in context keys
|
||||
- Routes are validated before registration
|
||||
|
||||
## Failure Modes and Recovery
|
||||
|
||||
| Failure | Behavior | Recovery |
|
||||
| ---------------------------------------- | ------------------------------ | -------------------- |
|
||||
| Route not found | Returns (nil, false) | Verify route alias |
|
||||
| Context extraction on non-route request | Returns nil | Check request origin |
|
||||
| Concurrent modification during iteration | Handled by pool implementation | N/A |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Route Lookup
|
||||
|
||||
```go
|
||||
route, ok := routes.Get("myapp")
|
||||
if !ok {
|
||||
return fmt.Errorf("route not found")
|
||||
}
|
||||
```
|
||||
|
||||
### Iterating Over All Routes
|
||||
|
||||
```go
|
||||
for r := range routes.IterActive {
|
||||
log.Printf("Route: %s", r.Name())
|
||||
}
|
||||
```
|
||||
|
||||
### Getting Health Status
|
||||
|
||||
```go
|
||||
healthMap := routes.GetHealthInfo()
|
||||
for name, health := range healthMap {
|
||||
log.Printf("Route %s: %s (uptime: %v)", name, health.Status, health.Uptime)
|
||||
}
|
||||
```
|
||||
|
||||
### Using Route Context in Handler
|
||||
|
||||
```go
|
||||
func MyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
route := routes.TryGetRoute(r)
|
||||
if route == nil {
|
||||
http.Error(w, "Route not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
upstreamHost := routes.TryGetUpstreamHost(r)
|
||||
log.Printf("Proxying to: %s", upstreamHost)
|
||||
}
|
||||
```
|
||||
|
||||
### Grouping Routes by Provider
|
||||
|
||||
```go
|
||||
byProvider := routes.ByProvider()
|
||||
for providerName, routeList := range byProvider {
|
||||
log.Printf("Provider %s: %d routes", providerName, len(routeList))
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Notes
|
||||
|
||||
- Unit tests for pool thread safety
|
||||
- Context propagation tests
|
||||
- Health info aggregation tests
|
||||
- Provider grouping tests
|
||||
@@ -1,103 +0,0 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/yusing/godoxy/internal/types"
|
||||
)
|
||||
|
||||
type HealthInfo struct {
|
||||
HealthInfoWithoutDetail
|
||||
Detail string `json:"detail"`
|
||||
} // @name HealthInfo
|
||||
|
||||
type HealthInfoWithoutDetail struct {
|
||||
Status types.HealthStatus `json:"status" swaggertype:"string" enums:"healthy,unhealthy,napping,starting,error,unknown"`
|
||||
Uptime time.Duration `json:"uptime" swaggertype:"number"` // uptime in milliseconds
|
||||
Latency time.Duration `json:"latency" swaggertype:"number"` // latency in microseconds
|
||||
} // @name HealthInfoWithoutDetail
|
||||
|
||||
type HealthMap = map[string]types.HealthStatusString // @name HealthMap
|
||||
|
||||
// GetHealthInfo returns a map of route name to health info.
|
||||
//
|
||||
// The health info is for all routes, including excluded routes.
|
||||
func GetHealthInfo() map[string]HealthInfo {
|
||||
healthMap := make(map[string]HealthInfo, NumAllRoutes())
|
||||
for r := range IterAll {
|
||||
healthMap[r.Name()] = getHealthInfo(r)
|
||||
}
|
||||
return healthMap
|
||||
}
|
||||
|
||||
// GetHealthInfoWithoutDetail returns a map of route name to health info without detail.
|
||||
//
|
||||
// The health info is for all routes, including excluded routes.
|
||||
func GetHealthInfoWithoutDetail() map[string]HealthInfoWithoutDetail {
|
||||
healthMap := make(map[string]HealthInfoWithoutDetail, NumAllRoutes())
|
||||
for r := range IterAll {
|
||||
healthMap[r.Name()] = getHealthInfoWithoutDetail(r)
|
||||
}
|
||||
return healthMap
|
||||
}
|
||||
|
||||
func GetHealthInfoSimple() map[string]types.HealthStatus {
|
||||
healthMap := make(map[string]types.HealthStatus, NumAllRoutes())
|
||||
for r := range IterAll {
|
||||
healthMap[r.Name()] = getHealthInfoSimple(r)
|
||||
}
|
||||
return healthMap
|
||||
}
|
||||
|
||||
func getHealthInfo(r types.Route) HealthInfo {
|
||||
mon := r.HealthMonitor()
|
||||
if mon == nil {
|
||||
return HealthInfo{
|
||||
HealthInfoWithoutDetail: HealthInfoWithoutDetail{
|
||||
Status: types.StatusUnknown,
|
||||
},
|
||||
Detail: "n/a",
|
||||
}
|
||||
}
|
||||
return HealthInfo{
|
||||
HealthInfoWithoutDetail: HealthInfoWithoutDetail{
|
||||
Status: mon.Status(),
|
||||
Uptime: mon.Uptime(),
|
||||
Latency: mon.Latency(),
|
||||
},
|
||||
Detail: mon.Detail(),
|
||||
}
|
||||
}
|
||||
|
||||
func getHealthInfoWithoutDetail(r types.Route) HealthInfoWithoutDetail {
|
||||
mon := r.HealthMonitor()
|
||||
if mon == nil {
|
||||
return HealthInfoWithoutDetail{
|
||||
Status: types.StatusUnknown,
|
||||
}
|
||||
}
|
||||
return HealthInfoWithoutDetail{
|
||||
Status: mon.Status(),
|
||||
Uptime: mon.Uptime(),
|
||||
Latency: mon.Latency(),
|
||||
}
|
||||
}
|
||||
|
||||
func getHealthInfoSimple(r types.Route) types.HealthStatus {
|
||||
mon := r.HealthMonitor()
|
||||
if mon == nil {
|
||||
return types.StatusUnknown
|
||||
}
|
||||
return mon.Status()
|
||||
}
|
||||
|
||||
// ByProvider returns a map of provider name to routes.
|
||||
//
|
||||
// The routes are all routes, including excluded routes.
|
||||
func ByProvider() map[string][]types.Route {
|
||||
rts := make(map[string][]types.Route)
|
||||
for r := range IterAll {
|
||||
rts[r.ProviderName()] = append(rts[r.ProviderName()], r)
|
||||
}
|
||||
return rts
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/yusing/godoxy/internal/types"
|
||||
"github.com/yusing/goutils/pool"
|
||||
)
|
||||
|
||||
var (
|
||||
HTTP = pool.New[types.HTTPRoute]("http_routes")
|
||||
Stream = pool.New[types.StreamRoute]("stream_routes")
|
||||
|
||||
Excluded = pool.New[types.Route]("excluded_routes")
|
||||
)
|
||||
|
||||
func IterActive(yield func(r types.Route) bool) {
|
||||
for _, r := range HTTP.Iter {
|
||||
if !yield(r) {
|
||||
break
|
||||
}
|
||||
}
|
||||
for _, r := range Stream.Iter {
|
||||
if !yield(r) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func IterAll(yield func(r types.Route) bool) {
|
||||
for _, r := range HTTP.Iter {
|
||||
if !yield(r) {
|
||||
break
|
||||
}
|
||||
}
|
||||
for _, r := range Stream.Iter {
|
||||
if !yield(r) {
|
||||
break
|
||||
}
|
||||
}
|
||||
for _, r := range Excluded.Iter {
|
||||
if !yield(r) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NumActiveRoutes() int {
|
||||
return HTTP.Size() + Stream.Size()
|
||||
}
|
||||
|
||||
func NumAllRoutes() int {
|
||||
return HTTP.Size() + Stream.Size() + Excluded.Size()
|
||||
}
|
||||
|
||||
func Clear() {
|
||||
HTTP.Clear()
|
||||
Stream.Clear()
|
||||
Excluded.Clear()
|
||||
}
|
||||
|
||||
func GetHTTPRouteOrExact(alias, host string) (types.HTTPRoute, bool) {
|
||||
r, ok := HTTP.Get(alias)
|
||||
if ok {
|
||||
return r, true
|
||||
}
|
||||
// try find with exact match
|
||||
return HTTP.Get(host)
|
||||
}
|
||||
|
||||
// Get returns the route with the given alias.
|
||||
//
|
||||
// It does not return excluded routes.
|
||||
func Get(alias string) (types.Route, bool) {
|
||||
if r, ok := HTTP.Get(alias); ok {
|
||||
return r, true
|
||||
}
|
||||
if r, ok := Stream.Get(alias); ok {
|
||||
return r, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// GetIncludeExcluded returns the route with the given alias, including excluded routes.
|
||||
func GetIncludeExcluded(alias string) (types.Route, bool) {
|
||||
if r, ok := HTTP.Get(alias); ok {
|
||||
return r, true
|
||||
}
|
||||
if r, ok := Stream.Get(alias); ok {
|
||||
return r, true
|
||||
}
|
||||
return Excluded.Get(alias)
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
|
||||
"github.com/yusing/godoxy/internal/logging"
|
||||
gphttp "github.com/yusing/godoxy/internal/net/gphttp"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
@@ -71,10 +72,11 @@ var commands = map[string]struct {
|
||||
description: makeLines("Require HTTP authentication for incoming requests"),
|
||||
args: map[string]string{},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 0 {
|
||||
return nil, ErrExpectNoArg
|
||||
}
|
||||
//nolint:nilnil
|
||||
return nil, nil
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
@@ -102,17 +104,17 @@ var commands = map[string]struct {
|
||||
"to": "the path to rewrite to, must start with /",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
}
|
||||
path1, err1 := validateURLPath(args[:1])
|
||||
path2, err2 := validateURLPath(args[1:])
|
||||
if err1 != nil {
|
||||
err1 = gperr.Errorf("from: %w", err1)
|
||||
err1 = gperr.PrependSubject(err1, "from")
|
||||
}
|
||||
if err2 != nil {
|
||||
err2 = gperr.Errorf("to: %w", err2)
|
||||
err2 = gperr.PrependSubject(err2, "to")
|
||||
}
|
||||
if err1 != nil || err2 != nil {
|
||||
return nil, gperr.Join(err1, err2)
|
||||
@@ -188,7 +190,7 @@ var commands = map[string]struct {
|
||||
"route": "the route to route to",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -197,9 +199,10 @@ var commands = map[string]struct {
|
||||
build: func(args any) CommandHandler {
|
||||
route := args.(string)
|
||||
return TerminatingCommand(func(w http.ResponseWriter, req *http.Request) error {
|
||||
r, ok := routes.HTTP.Get(route)
|
||||
ep := entrypoint.FromCtx(req.Context())
|
||||
r, ok := ep.HTTPRoutes().Get(route)
|
||||
if !ok {
|
||||
excluded, has := routes.Excluded.Get(route)
|
||||
excluded, has := ep.ExcludedRoutes().Get(route)
|
||||
if has {
|
||||
r, ok = excluded.(types.HTTPRoute)
|
||||
}
|
||||
@@ -225,7 +228,7 @@ var commands = map[string]struct {
|
||||
"text": "the error message to return",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
}
|
||||
@@ -265,7 +268,7 @@ var commands = map[string]struct {
|
||||
"realm": "the authentication realm",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) == 1 {
|
||||
return args[0], nil
|
||||
}
|
||||
@@ -327,12 +330,12 @@ var commands = map[string]struct {
|
||||
helpExample(CommandSet, "header", "User-Agent", "godoxy"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"target": fmt.Sprintf("the target to set, can be %s", strings.Join(AllFields, ", ")),
|
||||
"target": "the target to set, can be " + strings.Join(AllFields, ", "),
|
||||
"field": "the field to set",
|
||||
"value": "the value to set",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
return validateModField(ModFieldSet, args)
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
@@ -347,12 +350,12 @@ var commands = map[string]struct {
|
||||
helpExample(CommandAdd, "header", "X-Foo", "bar"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"target": fmt.Sprintf("the target to add, can be %s", strings.Join(AllFields, ", ")),
|
||||
"target": "the target to add, can be " + strings.Join(AllFields, ", "),
|
||||
"field": "the field to add",
|
||||
"value": "the value to add",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
return validateModField(ModFieldAdd, args)
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
@@ -367,11 +370,11 @@ var commands = map[string]struct {
|
||||
helpExample(CommandRemove, "header", "User-Agent"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"target": fmt.Sprintf("the target to remove, can be %s", strings.Join(AllFields, ", ")),
|
||||
"target": "the target to remove, can be " + strings.Join(AllFields, ", "),
|
||||
"field": "the field to remove",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
return validateModField(ModFieldRemove, args)
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
@@ -396,7 +399,7 @@ var commands = map[string]struct {
|
||||
"template": "the template to log",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 3 {
|
||||
return nil, ErrExpectThreeArgs
|
||||
}
|
||||
@@ -453,7 +456,7 @@ var commands = map[string]struct {
|
||||
"body": "the body of the notification",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 4 {
|
||||
return nil, ErrExpectFourArgs
|
||||
}
|
||||
@@ -509,8 +512,10 @@ var commands = map[string]struct {
|
||||
},
|
||||
}
|
||||
|
||||
type onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateString]
|
||||
type onNotifyArgs = Tuple4[zerolog.Level, string, templateString, templateString]
|
||||
type (
|
||||
onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateString]
|
||||
onNotifyArgs = Tuple4[zerolog.Level, string, templateString, templateString]
|
||||
)
|
||||
|
||||
// Parse implements strutils.Parser.
|
||||
func (cmd *Command) Parse(v string) error {
|
||||
@@ -541,7 +546,7 @@ func (cmd *Command) Parse(v string) error {
|
||||
validArgs, err := builder.validate(args)
|
||||
if err != nil {
|
||||
// Only attach help for the directive that failed, avoid bringing in unrelated KV errors
|
||||
return err.Subject(directive).With(builder.help.Error())
|
||||
return gperr.PrependSubject(err, directive).With(builder.help.Error())
|
||||
}
|
||||
|
||||
handler := builder.build(validArgs)
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
// mockUpstream creates a simple upstream handler for testing
|
||||
@@ -32,7 +31,7 @@ func mockUpstreamWithHeaders(status int, body string, headers http.Header) http.
|
||||
}
|
||||
}
|
||||
|
||||
func parseRules(data string, target *Rules) gperr.Error {
|
||||
func parseRules(data string, target *Rules) error {
|
||||
_, err := serialization.ConvertString(data, reflect.ValueOf(target))
|
||||
return err
|
||||
}
|
||||
@@ -54,7 +53,7 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/users", nil)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/users", nil)
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -71,7 +70,7 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLogCommand_StdoutAndStderr(t *testing.T) {
|
||||
upstream := mockUpstream(200, "success")
|
||||
upstream := mockUpstream(http.StatusOK, "success")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
@@ -86,7 +85,7 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
@@ -97,7 +96,7 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLogCommand_DifferentLogLevels(t *testing.T) {
|
||||
upstream := mockUpstream(404, "not found")
|
||||
upstream := mockUpstream(http.StatusNotFound, "not found")
|
||||
|
||||
infoFile := TestRandomFileName()
|
||||
warnFile := TestRandomFileName()
|
||||
@@ -141,7 +140,7 @@ func TestLogCommand_TemplateVariables(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Custom-Header", "custom-value")
|
||||
w.Header().Set("Content-Length", "42")
|
||||
w.WriteHeader(201)
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
w.Write([]byte("created"))
|
||||
})
|
||||
|
||||
@@ -177,13 +176,13 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/error":
|
||||
w.WriteHeader(500)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("internal server error"))
|
||||
case "/notfound":
|
||||
w.WriteHeader(404)
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Write([]byte("not found"))
|
||||
default:
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
}
|
||||
})
|
||||
@@ -207,22 +206,22 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test success request
|
||||
req1 := httptest.NewRequest("GET", "/success", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/success", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
|
||||
// Test not found request
|
||||
req2 := httptest.NewRequest("GET", "/notfound", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/notfound", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
assert.Equal(t, 404, w2.Code)
|
||||
assert.Equal(t, http.StatusNotFound, w2.Code)
|
||||
|
||||
// Test server error request
|
||||
req3 := httptest.NewRequest("POST", "/error", nil)
|
||||
req3 := httptest.NewRequest(http.MethodPost, "/error", nil)
|
||||
w3 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w3, req3)
|
||||
assert.Equal(t, 500, w3.Code)
|
||||
assert.Equal(t, http.StatusInternalServerError, w3.Code)
|
||||
|
||||
// Verify success log
|
||||
successContent := TestFileContent(successFile)
|
||||
@@ -239,7 +238,7 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLogCommand_MultipleLogEntries(t *testing.T) {
|
||||
upstream := mockUpstream(200, "response")
|
||||
upstream := mockUpstream(http.StatusOK, "response")
|
||||
|
||||
tempFile := TestRandomFileName()
|
||||
|
||||
@@ -267,7 +266,7 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) {
|
||||
req := httptest.NewRequest(reqInfo.method, reqInfo.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
// Verify all requests were logged
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
ioutils "github.com/yusing/goutils/io"
|
||||
)
|
||||
@@ -228,7 +227,7 @@ var modFields = map[string]struct {
|
||||
"template": "the body template",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -273,7 +272,7 @@ var modFields = map[string]struct {
|
||||
"template": "the response body template",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -301,7 +300,7 @@ var modFields = map[string]struct {
|
||||
"code": "the status code",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ func TestFieldHandler_Header(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -126,8 +126,8 @@ func TestFieldHandler_ResponseHeader(t *testing.T) {
|
||||
verify: func(w *httptest.ResponseRecorder) {
|
||||
values := w.Header()["X-Response-Test"]
|
||||
require.Len(t, values, 2)
|
||||
assert.Equal(t, values[0], "existing-value")
|
||||
assert.Equal(t, values[1], "additional-value")
|
||||
assert.Equal(t, "existing-value", values[0])
|
||||
assert.Equal(t, "additional-value", values[1])
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -143,7 +143,7 @@ func TestFieldHandler_ResponseHeader(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
if tt.setup != nil {
|
||||
tt.setup(w)
|
||||
@@ -232,7 +232,7 @@ func TestFieldHandler_Query(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -330,7 +330,7 @@ func TestFieldHandler_Cookie(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -396,7 +396,7 @@ func TestFieldHandler_Body(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -440,7 +440,7 @@ func TestFieldHandler_ResponseBody(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -494,7 +494,7 @@ func TestFieldHandler_StatusCode(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
var cmd Command
|
||||
|
||||
@@ -3,7 +3,7 @@ package rules
|
||||
import (
|
||||
"testing"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func TestErrorFormat(t *testing.T) {
|
||||
@@ -19,5 +19,5 @@ func TestErrorFormat(t *testing.T) {
|
||||
do: set invalid_command
|
||||
- do: set resp_body "{{ .Request.Method {{ .Request.URL.Path }}"
|
||||
`, &rules)
|
||||
gperr.LogError("error", err)
|
||||
log.Err(err).Msg("error")
|
||||
}
|
||||
|
||||
@@ -131,7 +131,7 @@ Generate help string as error, e.g.
|
||||
from: the path to rewrite, must start with /
|
||||
to: the path to rewrite to, must start with /
|
||||
*/
|
||||
func (h *Help) Error() gperr.Error {
|
||||
func (h *Help) Error() error {
|
||||
var lines gperr.MultilineError
|
||||
|
||||
lines.Adds(ansi.WithANSI(h.command, ansi.HighlightGreen))
|
||||
|
||||
@@ -17,16 +17,14 @@ import (
|
||||
"github.com/yusing/godoxy/internal/route"
|
||||
"github.com/yusing/godoxy/internal/route/routes"
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
. "github.com/yusing/godoxy/internal/route/rules"
|
||||
)
|
||||
|
||||
// mockUpstream creates a simple upstream handler for testing
|
||||
func mockUpstream(status int, body string) http.HandlerFunc {
|
||||
func mockUpstream(body string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(body))
|
||||
}
|
||||
}
|
||||
@@ -44,7 +42,7 @@ func mockRoute(alias string) *route.FileServer {
|
||||
return &route.FileServer{Route: &route.Route{Alias: alias}}
|
||||
}
|
||||
|
||||
func parseRules(data string, target *Rules) gperr.Error {
|
||||
func parseRules(data string, target *Rules) error {
|
||||
_, err := serialization.ConvertString(strings.TrimSpace(data), reflect.ValueOf(target))
|
||||
return err
|
||||
}
|
||||
@@ -52,7 +50,7 @@ func parseRules(data string, target *Rules) gperr.Error {
|
||||
func TestHTTPFlow_BasicPreRules(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Custom-Header", r.Header.Get("X-Custom-Header"))
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("upstream response"))
|
||||
})
|
||||
|
||||
@@ -66,18 +64,18 @@ func TestHTTPFlow_BasicPreRules(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "upstream response", w.Body.String())
|
||||
assert.Equal(t, "test-value", w.Header().Get("X-Custom-Header"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_BypassRule(t *testing.T) {
|
||||
upstream := mockUpstream(200, "upstream response")
|
||||
upstream := mockUpstream("upstream response")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
@@ -92,17 +90,17 @@ func TestHTTPFlow_BypassRule(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/bypass", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/bypass", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "upstream response", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_TerminatingCommand(t *testing.T) {
|
||||
upstream := mockUpstream(200, "should not be called")
|
||||
upstream := mockUpstream("should not be called")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
@@ -117,18 +115,18 @@ func TestHTTPFlow_TerminatingCommand(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/error", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/error", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 403, w.Code)
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
assert.Equal(t, "Forbidden\n", w.Body.String())
|
||||
assert.Empty(t, w.Header().Get("X-Header"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RedirectFlow(t *testing.T) {
|
||||
upstream := mockUpstream(200, "should not be called")
|
||||
upstream := mockUpstream("should not be called")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
@@ -140,18 +138,18 @@ func TestHTTPFlow_RedirectFlow(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/old-path", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/old-path", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 307, w.Code) // TemporaryRedirect
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, w.Code) // TemporaryRedirect
|
||||
assert.Equal(t, "/new-path", w.Header().Get("Location"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RewriteFlow(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("path: " + r.URL.Path))
|
||||
})
|
||||
|
||||
@@ -165,18 +163,18 @@ func TestHTTPFlow_RewriteFlow(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/users", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/users", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "path: /v1/users", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_MultiplePreRules(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("upstream: " + r.Header.Get("X-Request-Id")))
|
||||
})
|
||||
|
||||
@@ -193,18 +191,18 @@ func TestHTTPFlow_MultiplePreRules(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "upstream: req-123", w.Body.String())
|
||||
assert.Equal(t, "token-456", req.Header.Get("X-Auth-Token"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_PostResponseRule(t *testing.T) {
|
||||
upstream := mockUpstreamWithHeaders(200, "success", http.Header{
|
||||
upstream := mockUpstreamWithHeaders(http.StatusOK, "success", http.Header{
|
||||
"X-Upstream": []string{"upstream-value"},
|
||||
})
|
||||
|
||||
@@ -220,12 +218,12 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "success", w.Body.String())
|
||||
assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream"))
|
||||
|
||||
@@ -238,10 +236,10 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) {
|
||||
func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/success" {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
} else {
|
||||
w.WriteHeader(404)
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Write([]byte("not found"))
|
||||
}
|
||||
})
|
||||
@@ -261,18 +259,18 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test successful request (should not log)
|
||||
req1 := httptest.NewRequest("GET", "/success", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/success", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
|
||||
// Test error request (should log)
|
||||
req2 := httptest.NewRequest("GET", "/notfound", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/notfound", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 404, w2.Code)
|
||||
assert.Equal(t, http.StatusNotFound, w2.Code)
|
||||
|
||||
// Check log file
|
||||
content := TestFileContent(tempFile)
|
||||
@@ -284,7 +282,7 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
|
||||
|
||||
func TestHTTPFlow_ConditionalRules(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("hello " + r.Header.Get("X-Username")))
|
||||
})
|
||||
|
||||
@@ -305,19 +303,19 @@ func TestHTTPFlow_ConditionalRules(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test with Authorization header
|
||||
req1 := httptest.NewRequest("GET", "/", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req1.Header.Set("Authorization", "Bearer token")
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
assert.Equal(t, "hello authenticated-user", w1.Body.String())
|
||||
assert.Equal(t, "authenticated-user", w1.Header().Get("X-Username"))
|
||||
|
||||
// Test without Authorization header
|
||||
req2 := httptest.NewRequest("GET", "/", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Equal(t, "hello anonymous", w2.Body.String())
|
||||
assert.Equal(t, "anonymous", w2.Header().Get("X-Username"))
|
||||
}
|
||||
@@ -327,13 +325,13 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
|
||||
// Simulate different responses based on path
|
||||
if r.URL.Path == "/protected" {
|
||||
if r.Header.Get("X-Auth") != "valid" {
|
||||
w.WriteHeader(401)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte("unauthorized"))
|
||||
return
|
||||
}
|
||||
}
|
||||
w.Header().Set("X-Response-Time", "100ms")
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
})
|
||||
|
||||
@@ -361,32 +359,32 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test successful request
|
||||
req1 := httptest.NewRequest("GET", "/public", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/public", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
assert.Equal(t, "success", w1.Body.String())
|
||||
assert.Equal(t, "random_uuid", w1.Header().Get("X-Correlation-Id"))
|
||||
assert.Equal(t, "100ms", w1.Header().Get("X-Response-Time"))
|
||||
|
||||
// Test unauthorized protected request
|
||||
req2 := httptest.NewRequest("GET", "/protected", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 401, w2.Code)
|
||||
assert.Equal(t, w2.Body.String(), "Unauthorized\n")
|
||||
assert.Equal(t, http.StatusUnauthorized, w2.Code)
|
||||
assert.Equal(t, "Unauthorized\n", w2.Body.String())
|
||||
|
||||
// Test authorized protected request
|
||||
req3 := httptest.NewRequest("GET", "/protected", nil)
|
||||
req3 := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req3.SetBasicAuth("user", "pass")
|
||||
w3 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w3, req3)
|
||||
|
||||
// This should fail because our simple upstream expects X-Auth: valid header
|
||||
// but the basic auth requirement should add the appropriate header
|
||||
assert.Equal(t, 401, w3.Code)
|
||||
assert.Equal(t, http.StatusUnauthorized, w3.Code)
|
||||
|
||||
// Check log files
|
||||
logContent := TestFileContent(logFile)
|
||||
@@ -405,7 +403,7 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHTTPFlow_DefaultRule(t *testing.T) {
|
||||
upstream := mockUpstream(200, "upstream response")
|
||||
upstream := mockUpstream("upstream response")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
@@ -420,20 +418,20 @@ func TestHTTPFlow_DefaultRule(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test default rule
|
||||
req1 := httptest.NewRequest("GET", "/regular", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/regular", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
assert.Equal(t, "true", w1.Header().Get("X-Default-Applied"))
|
||||
assert.Empty(t, w1.Header().Get("X-Special-Handled"))
|
||||
|
||||
// Test special rule + default rule
|
||||
req2 := httptest.NewRequest("GET", "/special", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/special", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Equal(t, "true", w2.Header().Get("X-Default-Applied"))
|
||||
assert.Equal(t, "true", w2.Header().Get("X-Special-Handled"))
|
||||
}
|
||||
@@ -443,7 +441,7 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) {
|
||||
// Echo back a header
|
||||
headerValue := r.Header.Get("X-Test-Header")
|
||||
w.Header().Set("X-Echoed-Header", headerValue)
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("header echoed"))
|
||||
})
|
||||
|
||||
@@ -461,14 +459,14 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Secret", "secret-value")
|
||||
req.Header.Set("X-Test-Header", "original-value")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "modified-value", w.Header().Get("X-Echoed-Header"))
|
||||
assert.Equal(t, "custom-value", w.Header().Get("X-Custom-Header"))
|
||||
// Ensure the secret header was removed and not passed to upstream
|
||||
@@ -478,7 +476,7 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) {
|
||||
func TestHTTPFlow_QueryParameterHandling(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("query: " + query.Get("param")))
|
||||
})
|
||||
|
||||
@@ -492,25 +490,23 @@ func TestHTTPFlow_QueryParameterHandling(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/path?param=original", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/path?param=original", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
// The set command should have modified the query parameter
|
||||
assert.Equal(t, "query: added-value", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ServeCommand(t *testing.T) {
|
||||
// Create a temporary directory with test files
|
||||
tempDir, err := os.MkdirTemp("", "test-serve-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create test files directly in the temp directory
|
||||
testFile := filepath.Join(tempDir, "index.html")
|
||||
err = os.WriteFile(testFile, []byte("<h1>Test Page</h1>"), 0644)
|
||||
err := os.WriteFile(testFile, []byte("<h1>Test Page</h1>"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
var rules Rules
|
||||
@@ -521,7 +517,7 @@ func TestHTTPFlow_ServeCommand(t *testing.T) {
|
||||
`, tempDir), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(mockUpstream(200, "should not be called"))
|
||||
handler := rules.BuildHandler(mockUpstream("should not be called"))
|
||||
|
||||
// Test serving a file - serve command serves files relative to the root directory
|
||||
// The path /files/index.html gets mapped to tempDir + "/files/index.html"
|
||||
@@ -534,7 +530,7 @@ func TestHTTPFlow_ServeCommand(t *testing.T) {
|
||||
err = os.WriteFile(filesIndexFile, []byte("<h1>Test Page</h1>"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
req1 := httptest.NewRequest("GET", "/files/index.html", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/files/index.html", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
@@ -543,18 +539,18 @@ func TestHTTPFlow_ServeCommand(t *testing.T) {
|
||||
assert.NotEqual(t, "should not be called", w1.Body.String())
|
||||
|
||||
// Test file not found
|
||||
req2 := httptest.NewRequest("GET", "/files/nonexistent.html", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/files/nonexistent.html", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 404, w2.Code)
|
||||
assert.Equal(t, http.StatusNotFound, w2.Code)
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ProxyCommand(t *testing.T) {
|
||||
// Create a mock upstream server
|
||||
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Upstream-Header", "upstream-value")
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("upstream response"))
|
||||
}))
|
||||
defer upstreamServer.Close()
|
||||
@@ -567,15 +563,15 @@ func TestHTTPFlow_ProxyCommand(t *testing.T) {
|
||||
`, upstreamServer.URL), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(mockUpstream(200, "should not be called"))
|
||||
handler := rules.BuildHandler(mockUpstream("should not be called"))
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
// The proxy command should forward the request to the upstream server
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "upstream response", w.Body.String())
|
||||
assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream-Header"))
|
||||
}
|
||||
@@ -586,7 +582,7 @@ func TestHTTPFlow_NotifyCommand(t *testing.T) {
|
||||
|
||||
func TestHTTPFlow_FormConditions(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("form processed"))
|
||||
})
|
||||
|
||||
@@ -605,28 +601,28 @@ func TestHTTPFlow_FormConditions(t *testing.T) {
|
||||
|
||||
// Test form condition
|
||||
formData := url.Values{"username": {"john_doe"}}
|
||||
req1 := httptest.NewRequest("POST", "/", strings.NewReader(formData.Encode()))
|
||||
req1 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(formData.Encode()))
|
||||
req1.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
assert.Equal(t, "john_doe", w1.Header().Get("X-Username"))
|
||||
|
||||
// Test postform condition
|
||||
postFormData := url.Values{"email": {"john@example.com"}}
|
||||
req2 := httptest.NewRequest("POST", "/", strings.NewReader(postFormData.Encode()))
|
||||
req2 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(postFormData.Encode()))
|
||||
req2.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Equal(t, "john@example.com", w2.Header().Get("X-Email"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RemoteConditions(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("remote processed"))
|
||||
})
|
||||
|
||||
@@ -644,27 +640,27 @@ func TestHTTPFlow_RemoteConditions(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test localhost condition
|
||||
req1 := httptest.NewRequest("GET", "/", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req1.RemoteAddr = "127.0.0.1:12345"
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
assert.Equal(t, "local", w1.Header().Get("X-Access"))
|
||||
|
||||
// Test private network block
|
||||
req2 := httptest.NewRequest("GET", "/", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req2.RemoteAddr = "192.168.1.100:12345"
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 403, w2.Code)
|
||||
assert.Equal(t, http.StatusForbidden, w2.Code)
|
||||
assert.Equal(t, "Private network blocked\n", w2.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_BasicAuthConditions(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("auth processed"))
|
||||
})
|
||||
|
||||
@@ -688,27 +684,27 @@ func TestHTTPFlow_BasicAuthConditions(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test admin user
|
||||
req1 := httptest.NewRequest("GET", "/", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req1.SetBasicAuth("admin", "adminpass")
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
assert.Equal(t, "admin", w1.Header().Get("X-Auth-Status"))
|
||||
|
||||
// Test guest user
|
||||
req2 := httptest.NewRequest("GET", "/", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req2.SetBasicAuth("guest", "guestpass")
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Equal(t, "guest", w2.Header().Get("X-Auth-Status"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RouteConditions(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("route processed"))
|
||||
})
|
||||
|
||||
@@ -726,29 +722,29 @@ func TestHTTPFlow_RouteConditions(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test API route
|
||||
req1 := httptest.NewRequest("GET", "/", nil)
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req1 = routes.WithRouteContext(req1, mockRoute("backend"))
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
assert.Equal(t, "backend", w1.Header().Get("X-Route"))
|
||||
|
||||
// Test admin route
|
||||
req2 := httptest.NewRequest("GET", "/", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req2 = routes.WithRouteContext(req2, mockRoute("frontend"))
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Equal(t, "frontend", w2.Header().Get("X-Route"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ResponseStatusConditions(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(405)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
w.Write([]byte("method not allowed"))
|
||||
})
|
||||
|
||||
@@ -763,18 +759,18 @@ func TestHTTPFlow_ResponseStatusConditions(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 405, w.Code)
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
|
||||
assert.Equal(t, "error\n", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Response-Header", "response header")
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("processed"))
|
||||
})
|
||||
|
||||
@@ -789,11 +785,11 @@ func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 405, w.Code)
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
|
||||
assert.Equal(t, "error\n", w.Body.String())
|
||||
})
|
||||
t.Run("with_value", func(t *testing.T) {
|
||||
@@ -807,11 +803,11 @@ func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 405, w.Code)
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
|
||||
assert.Equal(t, "error\n", w.Body.String())
|
||||
})
|
||||
|
||||
@@ -826,18 +822,18 @@ func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "processed", w.Body.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("complex processed"))
|
||||
})
|
||||
|
||||
@@ -868,26 +864,26 @@ func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test admin API (should match first rule)
|
||||
req1 := httptest.NewRequest("POST", "/api/admin/users", nil)
|
||||
req1 := httptest.NewRequest(http.MethodPost, "/api/admin/users", nil)
|
||||
req1.Header.Set("Authorization", "Bearer token")
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
assert.Equal(t, "admin", w1.Header().Get("X-Access-Level"))
|
||||
assert.Equal(t, "v1", w1.Header()["X-API-Version"][0])
|
||||
|
||||
// Test user API (should match second rule)
|
||||
req2 := httptest.NewRequest("GET", "/api/users/profile", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/api/users/profile", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Equal(t, "user", w2.Header().Get("X-Access-Level"))
|
||||
assert.Equal(t, "v1", w2.Header()["X-API-Version"][0])
|
||||
|
||||
// Test public API (should match third rule)
|
||||
req3 := httptest.NewRequest("GET", "/api/public/info", nil)
|
||||
req3 := httptest.NewRequest(http.MethodGet, "/api/public/info", nil)
|
||||
w3 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w3, req3)
|
||||
|
||||
@@ -898,7 +894,7 @@ func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
|
||||
|
||||
func TestHTTPFlow_ResponseModifier(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("original response"))
|
||||
})
|
||||
|
||||
@@ -913,12 +909,12 @@ func TestHTTPFlow_ResponseModifier(t *testing.T) {
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "true", w.Header().Get("X-Modified"))
|
||||
assert.Equal(t, "Modified: GET /test\n", w.Body.String())
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/yusing/godoxy/internal/common"
|
||||
"github.com/yusing/godoxy/internal/logging/accesslog"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
type noopWriteCloser struct {
|
||||
@@ -31,7 +30,7 @@ var (
|
||||
testFilesLock sync.Mutex
|
||||
)
|
||||
|
||||
func openFile(path string) (io.WriteCloser, gperr.Error) {
|
||||
func openFile(path string) (io.WriteCloser, error) {
|
||||
switch path {
|
||||
case "/dev/stdout":
|
||||
return stdout, nil
|
||||
|
||||
@@ -41,6 +41,7 @@ const (
|
||||
OnRoute = "route"
|
||||
|
||||
// on response
|
||||
|
||||
OnResponseHeader = "resp_header"
|
||||
OnStatus = "status"
|
||||
)
|
||||
@@ -59,10 +60,11 @@ var checkers = map[string]struct {
|
||||
),
|
||||
args: map[string]string{},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 0 {
|
||||
return nil, ErrExpectNoArg
|
||||
}
|
||||
//nolint:nilnil
|
||||
return nil, nil
|
||||
},
|
||||
builder: func(args any) CheckFunc { return func(w http.ResponseWriter, r *http.Request) bool { return false } }, // this should never be called
|
||||
@@ -251,7 +253,7 @@ var checkers = map[string]struct {
|
||||
"proto": "the http protocol (http, https, h3)",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
validate: func(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -581,7 +583,7 @@ func (on *RuleOn) Parse(v string) error {
|
||||
}
|
||||
parsed, isResp, err := parseOn(rule)
|
||||
if err != nil {
|
||||
errs.Add(err.Subjectf("line %d", i+1))
|
||||
errs.AddSubjectf(err, "line %d", i+1)
|
||||
continue
|
||||
}
|
||||
if isResp {
|
||||
@@ -603,7 +605,7 @@ func (on *RuleOn) MarshalText() ([]byte, error) {
|
||||
return []byte(on.String()), nil
|
||||
}
|
||||
|
||||
func parseOn(line string) (Checker, bool, gperr.Error) {
|
||||
func parseOn(line string) (Checker, bool, error) {
|
||||
ors := splitPipe(line)
|
||||
|
||||
if len(ors) > 1 {
|
||||
@@ -645,7 +647,7 @@ func parseOn(line string) (Checker, bool, gperr.Error) {
|
||||
|
||||
validArgs, err := checker.validate(args)
|
||||
if err != nil {
|
||||
return nil, false, err.With(checker.help.Error())
|
||||
return nil, false, gperr.Wrap(err).With(checker.help.Error())
|
||||
}
|
||||
|
||||
checkFunc := checker.builder(validArgs)
|
||||
|
||||
@@ -31,7 +31,7 @@ var quoteChars = [256]bool{
|
||||
// error 403 "Forbidden 'foo' 'bar'"
|
||||
// error 403 Forbidden\ \"foo\"\ \"bar\".
|
||||
// error 403 "Message: ${CLOUDFLARE_API_KEY}"
|
||||
func parse(v string) (subject string, args []string, err gperr.Error) {
|
||||
func parse(v string) (subject string, args []string, err error) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, len(v)))
|
||||
|
||||
escaped := false
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
expect "github.com/yusing/goutils/testing"
|
||||
)
|
||||
|
||||
@@ -15,7 +13,6 @@ func TestParser(t *testing.T) {
|
||||
input string
|
||||
subject string
|
||||
args []string
|
||||
wantErr gperr.Error
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
@@ -93,10 +90,6 @@ func TestParser(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
subject, args, err := parse(tt.input)
|
||||
if tt.wantErr != nil {
|
||||
expect.ErrorIs(t, tt.wantErr, err)
|
||||
return
|
||||
}
|
||||
// t.Log(subject, args, err)
|
||||
expect.NoError(t, err)
|
||||
expect.Equal(t, subject, tt.subject)
|
||||
@@ -105,12 +98,8 @@ func TestParser(t *testing.T) {
|
||||
}
|
||||
t.Run("env substitution", func(t *testing.T) {
|
||||
// Set up test environment variables
|
||||
os.Setenv("CLOUDFLARE_API_KEY", "test-api-key-123")
|
||||
os.Setenv("DOMAIN", "example.com")
|
||||
defer func() {
|
||||
os.Unsetenv("CLOUDFLARE_API_KEY")
|
||||
os.Unsetenv("DOMAIN")
|
||||
}()
|
||||
t.Setenv("CLOUDFLARE_API_KEY", "test-api-key-123")
|
||||
t.Setenv("DOMAIN", "example.com")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/godoxy/internal/route/rules"
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
//go:embed *.yml
|
||||
@@ -35,12 +34,12 @@ func initPresets() {
|
||||
var rules rules.Rules
|
||||
content, err := fs.ReadFile(file.Name())
|
||||
if err != nil {
|
||||
gperr.LogError("failed to read rule preset", err)
|
||||
log.Err(err).Msg("failed to read rule preset")
|
||||
continue
|
||||
}
|
||||
_, err = serialization.ConvertString(string(content), reflect.ValueOf(&rules))
|
||||
if err != nil {
|
||||
gperr.LogError("failed to unmarshal rule preset", err)
|
||||
log.Err(err).Msg("failed to unmarshal rule preset")
|
||||
continue
|
||||
}
|
||||
rulePresets[file.Name()] = rules
|
||||
|
||||
@@ -3,12 +3,19 @@
|
||||
do: pass
|
||||
- name: protected
|
||||
on: |
|
||||
!path regex("(_next/static|_next/image|favicon.ico).*")
|
||||
!path glob("/api/v1/auth/*")
|
||||
!path glob("/auth/*")
|
||||
!path regex("[A-Za-z0-9_-]+\.(svg|png|jpg|jpeg|gif|ico|webp|woff2?|eot|ttf|otf|txt)(\?.+)?")
|
||||
!path /icon0.svg
|
||||
!path /favicon.ico
|
||||
!path /apple-icon.png
|
||||
!path glob("/web-app-manifest-*x*.png")
|
||||
!path regex("\/assets\/(chunks\/)?[a-zA-Z0-9\-_]+\.(css|js|woff2)")
|
||||
!path regex("\/assets\/workbox-window\.prod\.es5-[a-zA-Z0-9]+\.js")
|
||||
!path regex("/workbox-[a-zA-Z0-9]+\.js")
|
||||
!path /api/v1/version
|
||||
!path /manifest.json
|
||||
!path /sw.js
|
||||
!path /registerSW.js
|
||||
do: require_auth
|
||||
- name: proxy to backend
|
||||
on: path glob("/api/v1/*")
|
||||
|
||||
26
internal/route/rules/presets/webui_dev.yml
Normal file
26
internal/route/rules/presets/webui_dev.yml
Normal file
@@ -0,0 +1,26 @@
|
||||
- name: login page
|
||||
on: path /login
|
||||
do: pass
|
||||
- name: protected
|
||||
on: |
|
||||
!path glob("/@tanstack-start/*")
|
||||
!path glob("/@vite-plugin-pwa/*")
|
||||
!path glob("/__tsd/*")
|
||||
!path /@react-refresh
|
||||
!path /@vite/client
|
||||
!path regex("/\?token=[a-zA-Z0-9-_]+")
|
||||
!path glob("/@id/*")
|
||||
!path glob("/api/v1/auth/*")
|
||||
!path glob("/auth/*")
|
||||
!path regex("([A-Za-z0-9_\-/]+)+\.(css|ts|js|mjs|svg|png|jpg|jpeg|gif|ico|webp|woff2?|eot|ttf|otf|txt)(\?.*)?")
|
||||
!path /api/v1/version
|
||||
!path /manifest.json
|
||||
do: require_auth
|
||||
- name: proxy to backend
|
||||
on: path glob("/api/v1/*")
|
||||
do: proxy http://${API_ADDR}/
|
||||
- name: proxy to auth api
|
||||
on: path glob("/auth/*")
|
||||
do: |
|
||||
rewrite /auth /api/v1/auth
|
||||
proxy http://${API_ADDR}/
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
"github.com/rs/zerolog/log"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
@@ -58,7 +57,7 @@ func (rule *Rule) IsResponseRule() bool {
|
||||
return rule.On.IsResponseChecker() || rule.Do.IsResponseHandler()
|
||||
}
|
||||
|
||||
func (rules Rules) Validate() gperr.Error {
|
||||
func (rules Rules) Validate() error {
|
||||
var defaultRulesFound []int
|
||||
for i, rule := range rules {
|
||||
if rule.Name == "default" || rule.On.raw == OnDefault {
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
type (
|
||||
ValidateFunc func(args []string) (any, gperr.Error)
|
||||
ValidateFunc func(args []string) (any, error)
|
||||
Tuple[T1, T2 any] struct {
|
||||
First T1
|
||||
Second T2
|
||||
@@ -62,7 +62,7 @@ func (t *Tuple4[T1, T2, T3, T4]) String() string {
|
||||
}
|
||||
|
||||
// validateSingleMatcher returns Matcher with the matcher validated.
|
||||
func validateSingleMatcher(args []string) (any, gperr.Error) {
|
||||
func validateSingleMatcher(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -70,7 +70,7 @@ func validateSingleMatcher(args []string) (any, gperr.Error) {
|
||||
}
|
||||
|
||||
// toKVOptionalVMatcher returns *MapValueMatcher that value is optional.
|
||||
func toKVOptionalVMatcher(args []string) (any, gperr.Error) {
|
||||
func toKVOptionalVMatcher(args []string) (any, error) {
|
||||
switch len(args) {
|
||||
case 1:
|
||||
return &MapValueMatcher{args[0], nil}, nil
|
||||
@@ -85,7 +85,7 @@ func toKVOptionalVMatcher(args []string) (any, gperr.Error) {
|
||||
}
|
||||
}
|
||||
|
||||
func toKeyValueTemplate(args []string) (any, gperr.Error) {
|
||||
func toKeyValueTemplate(args []string) (any, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
}
|
||||
@@ -98,7 +98,7 @@ func toKeyValueTemplate(args []string) (any, gperr.Error) {
|
||||
}
|
||||
|
||||
// validateURL returns types.URL with the URL validated.
|
||||
func validateURL(args []string) (any, gperr.Error) {
|
||||
func validateURL(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -134,7 +134,7 @@ func validateAbsoluteURL(args []string) (any, gperr.Error) {
|
||||
}
|
||||
|
||||
// validateCIDR returns types.CIDR with the CIDR validated.
|
||||
func validateCIDR(args []string) (any, gperr.Error) {
|
||||
func validateCIDR(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -149,7 +149,7 @@ func validateCIDR(args []string) (any, gperr.Error) {
|
||||
}
|
||||
|
||||
// validateURLPath returns string with the path validated.
|
||||
func validateURLPath(args []string) (any, gperr.Error) {
|
||||
func validateURLPath(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -166,7 +166,7 @@ func validateURLPath(args []string) (any, gperr.Error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func validateURLPathMatcher(args []string) (any, gperr.Error) {
|
||||
func validateURLPathMatcher(args []string) (any, error) {
|
||||
path, err := validateURLPath(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -175,7 +175,7 @@ func validateURLPathMatcher(args []string) (any, gperr.Error) {
|
||||
}
|
||||
|
||||
// validateFSPath returns string with the path validated.
|
||||
func validateFSPath(args []string) (any, gperr.Error) {
|
||||
func validateFSPath(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -187,7 +187,7 @@ func validateFSPath(args []string) (any, gperr.Error) {
|
||||
}
|
||||
|
||||
// validateMethod returns string with the method validated.
|
||||
func validateMethod(args []string) (any, gperr.Error) {
|
||||
func validateMethod(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -218,7 +218,7 @@ func validateStatusCode(status string) (int, error) {
|
||||
// - 3xx
|
||||
// - 4xx
|
||||
// - 5xx
|
||||
func validateStatusRange(args []string) (any, gperr.Error) {
|
||||
func validateStatusRange(args []string) (any, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -250,7 +250,7 @@ func validateStatusRange(args []string) (any, gperr.Error) {
|
||||
}
|
||||
|
||||
// validateUserBCryptPassword returns *HashedCrendential with the password validated.
|
||||
func validateUserBCryptPassword(args []string) (any, gperr.Error) {
|
||||
func validateUserBCryptPassword(args []string) (any, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
}
|
||||
@@ -258,7 +258,7 @@ func validateUserBCryptPassword(args []string) (any, gperr.Error) {
|
||||
}
|
||||
|
||||
// validateModField returns CommandHandler with the field validated.
|
||||
func validateModField(mod FieldModifier, args []string) (CommandHandler, gperr.Error) {
|
||||
func validateModField(mod FieldModifier, args []string) (CommandHandler, error) {
|
||||
if len(args) == 0 {
|
||||
return nil, ErrExpectTwoOrThreeArgs
|
||||
}
|
||||
@@ -275,7 +275,7 @@ func validateModField(mod FieldModifier, args []string) (CommandHandler, gperr.E
|
||||
}
|
||||
validArgs, err := setField.validate(args[1:])
|
||||
if err != nil {
|
||||
return nil, err.With(setField.help.Error())
|
||||
return nil, gperr.Wrap(err).With(setField.help.Error())
|
||||
}
|
||||
modder := setField.builder(validArgs)
|
||||
switch mod {
|
||||
@@ -299,7 +299,7 @@ func validateModField(mod FieldModifier, args []string) (CommandHandler, gperr.E
|
||||
return set, nil
|
||||
}
|
||||
|
||||
func validateTemplate(tmplStr string, newline bool) (templateString, gperr.Error) {
|
||||
func validateTemplate(tmplStr string, newline bool) (templateString, error) {
|
||||
if newline && !strings.HasSuffix(tmplStr, "\n") {
|
||||
tmplStr += "\n"
|
||||
}
|
||||
@@ -310,22 +310,15 @@ func validateTemplate(tmplStr string, newline bool) (templateString, gperr.Error
|
||||
|
||||
err := ValidateVars(tmplStr)
|
||||
if err != nil {
|
||||
return templateString{}, gperr.Wrap(err)
|
||||
return templateString{}, err
|
||||
}
|
||||
return templateString{tmplStr, true}, nil
|
||||
}
|
||||
|
||||
func validateLevel(level string) (zerolog.Level, gperr.Error) {
|
||||
func validateLevel(level string) (zerolog.Level, error) {
|
||||
l, err := zerolog.ParseLevel(level)
|
||||
if err != nil {
|
||||
return zerolog.NoLevel, ErrInvalidArguments.With(err)
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// func validateNotifProvider(provider string) gperr.Error {
|
||||
// if !notif.HasProvider(provider) {
|
||||
// return ErrInvalidArguments.Subject(provider)
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
|
||||
@@ -2,6 +2,7 @@ package rules
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
@@ -11,9 +12,9 @@ import (
|
||||
|
||||
func BenchmarkExpandVars(b *testing.B) {
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
testResponseModifier.WriteHeader(200)
|
||||
testResponseModifier.WriteHeader(http.StatusOK)
|
||||
testResponseModifier.Write([]byte("Hello, world!"))
|
||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||
testRequest := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
testRequest.Header.Set("User-Agent", "test-agent/1.0")
|
||||
testRequest.Header.Set("X-Custom", "value1,value2")
|
||||
testRequest.ContentLength = 12345
|
||||
|
||||
@@ -203,7 +203,7 @@ func TestExpandVars(t *testing.T) {
|
||||
postFormData.Add("postmulti", "first")
|
||||
postFormData.Add("postmulti", "second")
|
||||
|
||||
testRequest := httptest.NewRequest("POST", "https://example.com:8080/api/users?param1=value1¶m2=value2#fragment", strings.NewReader(postFormData.Encode()))
|
||||
testRequest := httptest.NewRequest(http.MethodPost, "https://example.com:8080/api/users?param1=value1¶m2=value2#fragment", strings.NewReader(postFormData.Encode()))
|
||||
testRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
testRequest.Header.Set("User-Agent", "test-agent/1.0")
|
||||
testRequest.Header.Add("X-Custom", "value1")
|
||||
@@ -218,7 +218,7 @@ func TestExpandVars(t *testing.T) {
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
testResponseModifier.Header().Set("Content-Type", "text/html")
|
||||
testResponseModifier.Header().Set("X-Custom-Resp", "resp-value")
|
||||
testResponseModifier.WriteHeader(200)
|
||||
testResponseModifier.WriteHeader(http.StatusOK)
|
||||
// set content length to 9876 by writing 9876 'a' bytes
|
||||
testResponseModifier.Write(bytes.Repeat([]byte("a"), 9876))
|
||||
|
||||
@@ -498,12 +498,12 @@ func TestExpandVars(t *testing.T) {
|
||||
|
||||
func TestExpandVars_Integration(t *testing.T) {
|
||||
t.Run("complex log format", func(t *testing.T) {
|
||||
testRequest := httptest.NewRequest("GET", "https://api.example.com/users/123?sort=asc", nil)
|
||||
testRequest := httptest.NewRequest(http.MethodGet, "https://api.example.com/users/123?sort=asc", nil)
|
||||
testRequest.Header.Set("User-Agent", "curl/7.68.0")
|
||||
testRequest.RemoteAddr = "10.0.0.1:54321"
|
||||
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
testResponseModifier.WriteHeader(200)
|
||||
testResponseModifier.WriteHeader(http.StatusOK)
|
||||
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest,
|
||||
@@ -515,7 +515,7 @@ func TestExpandVars_Integration(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("with query parameters", func(t *testing.T) {
|
||||
testRequest := httptest.NewRequest("GET", "http://example.com/search?q=test&page=1", nil)
|
||||
testRequest := httptest.NewRequest(http.MethodGet, "http://example.com/search?q=test&page=1", nil)
|
||||
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
@@ -529,12 +529,12 @@ func TestExpandVars_Integration(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("response headers", func(t *testing.T) {
|
||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||
testRequest := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
testResponseModifier.Header().Set("Cache-Control", "no-cache")
|
||||
testResponseModifier.Header().Set("X-Rate-Limit", "100")
|
||||
testResponseModifier.WriteHeader(200)
|
||||
testResponseModifier.WriteHeader(http.StatusOK)
|
||||
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest,
|
||||
@@ -554,7 +554,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "http scheme",
|
||||
request: httptest.NewRequest("GET", "http://example.com/", nil),
|
||||
request: httptest.NewRequest(http.MethodGet, "http://example.com/", nil),
|
||||
expected: "http",
|
||||
},
|
||||
{
|
||||
@@ -581,7 +581,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) {
|
||||
|
||||
func TestExpandVars_UpstreamVariables(t *testing.T) {
|
||||
// Upstream variables require context from routes package
|
||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||
testRequest := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
@@ -607,7 +607,7 @@ func TestExpandVars_UpstreamVariables(t *testing.T) {
|
||||
|
||||
func TestExpandVars_NoHostPort(t *testing.T) {
|
||||
// Test request without port in Host header
|
||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||
testRequest := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
testRequest.Host = "example.com" // No port
|
||||
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
@@ -623,13 +623,13 @@ func TestExpandVars_NoHostPort(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$req_port", &out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", out.String())
|
||||
require.Empty(t, out.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpandVars_NoRemotePort(t *testing.T) {
|
||||
// Test request without port in RemoteAddr
|
||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||
testRequest := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
testRequest.RemoteAddr = "192.168.1.1" // No port
|
||||
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
@@ -638,19 +638,19 @@ func TestExpandVars_NoRemotePort(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$remote_host", &out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", out.String())
|
||||
require.Empty(t, out.String())
|
||||
})
|
||||
|
||||
t.Run("remote_port without port", func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$remote_port", &out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", out.String())
|
||||
require.Empty(t, out.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpandVars_WhitespaceHandling(t *testing.T) {
|
||||
testRequest := httptest.NewRequest("GET", "/test", nil)
|
||||
testRequest := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
var out strings.Builder
|
||||
|
||||
@@ -2,19 +2,18 @@ package route
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
|
||||
"github.com/yusing/godoxy/internal/health/monitor"
|
||||
"github.com/yusing/godoxy/internal/idlewatcher"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
"github.com/yusing/godoxy/internal/route/routes"
|
||||
"github.com/yusing/godoxy/internal/route/stream"
|
||||
"github.com/yusing/godoxy/internal/types"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
"github.com/yusing/goutils/task"
|
||||
)
|
||||
|
||||
@@ -22,11 +21,11 @@ import (
|
||||
type StreamRoute struct {
|
||||
*Route
|
||||
stream nettypes.Stream
|
||||
|
||||
l zerolog.Logger
|
||||
}
|
||||
|
||||
func NewStreamRoute(base *Route) (types.Route, gperr.Error) {
|
||||
var _ types.StreamRoute = (*StreamRoute)(nil)
|
||||
|
||||
func NewStreamRoute(base *Route) (types.Route, error) {
|
||||
// TODO: support non-coherent scheme
|
||||
return &StreamRoute{Route: base}, nil
|
||||
}
|
||||
@@ -36,25 +35,26 @@ func (r *StreamRoute) Stream() nettypes.Stream {
|
||||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
|
||||
func (r *StreamRoute) Start(parent task.Parent) error {
|
||||
if r.LisURL == nil {
|
||||
return gperr.Errorf("listen URL is not set")
|
||||
return errors.New("listen URL is not set")
|
||||
}
|
||||
|
||||
stream, err := r.initStream()
|
||||
if err != nil {
|
||||
return gperr.Wrap(err)
|
||||
return err
|
||||
}
|
||||
r.stream = stream
|
||||
|
||||
r.task = parent.Subtask("stream."+r.Name(), !r.ShouldExclude())
|
||||
r.task.SetValue(monitor.DisplayNameKey{}, r.DisplayName())
|
||||
|
||||
switch {
|
||||
case r.UseIdleWatcher():
|
||||
waker, err := idlewatcher.NewWatcher(parent, r, r.IdlewatcherConfig())
|
||||
if err != nil {
|
||||
r.task.Finish(err)
|
||||
return gperr.Wrap(err, "idlewatcher error")
|
||||
return fmt.Errorf("idlewatcher error: %w", err)
|
||||
}
|
||||
r.stream = waker
|
||||
r.HealthMon = waker
|
||||
@@ -64,32 +64,26 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
|
||||
|
||||
if r.HealthMon != nil {
|
||||
if err := r.HealthMon.Start(r.task); err != nil {
|
||||
gperr.LogWarn("health monitor error", err, &r.l)
|
||||
log.Warn().Err(err).Msg("health monitor error")
|
||||
r.HealthMon = nil
|
||||
}
|
||||
}
|
||||
|
||||
r.ListenAndServe(r.task.Context(), nil, nil)
|
||||
r.l = log.With().
|
||||
Str("type", r.LisURL.Scheme+"->"+r.ProxyURL.Scheme).
|
||||
Str("name", r.Name()).
|
||||
Stringer("rurl", r.ProxyURL).
|
||||
Stringer("laddr", r.LocalAddr()).Logger()
|
||||
r.l.Info().Msg("stream started")
|
||||
|
||||
r.task.OnCancel("close_stream", func() {
|
||||
r.stream.Close()
|
||||
r.l.Info().Msg("stream closed")
|
||||
})
|
||||
|
||||
routes.Stream.Add(r)
|
||||
r.task.OnCancel("remove_route_from_stream", func() {
|
||||
routes.Stream.Del(r)
|
||||
})
|
||||
ep := entrypoint.FromCtx(parent.Context())
|
||||
if ep == nil {
|
||||
err := errors.New("entrypoint not initialized")
|
||||
r.task.Finish(err)
|
||||
return err
|
||||
}
|
||||
if err := ep.StartAddRoute(r); err != nil {
|
||||
r.task.Finish(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *StreamRoute) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
||||
r.stream.ListenAndServe(ctx, preDial, onRead)
|
||||
func (r *StreamRoute) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) error {
|
||||
return r.stream.ListenAndServe(ctx, preDial, onRead)
|
||||
}
|
||||
|
||||
func (r *StreamRoute) Close() error {
|
||||
|
||||
@@ -63,10 +63,9 @@ func NewUDPUDPStream(network, listenAddr, dstAddr string) (nettypes.Stream, erro
|
||||
|
||||
```go
|
||||
type Stream interface {
|
||||
ListenAndServe(ctx context.Context, preDial, onRead HookFunc)
|
||||
ListenAndServe(ctx context.Context, preDial, onRead HookFunc) error
|
||||
Close() error
|
||||
LocalAddr() net.Addr
|
||||
zerolog.LogObjectMarshaler
|
||||
}
|
||||
|
||||
type HookFunc func(ctx context.Context) error
|
||||
|
||||
@@ -8,5 +8,5 @@ import (
|
||||
)
|
||||
|
||||
func logDebugf(stream zerolog.LogObjectMarshaler, format string, v ...any) {
|
||||
log.Debug().Object("stream", stream).Msgf(format, v...)
|
||||
log.Debug().EmbedObject(stream).Msgf(format, v...)
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ func logErr(stream zerolog.LogObjectMarshaler, err error, msg string) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
log.Err(err).Object("stream", stream).Msg(msg)
|
||||
log.Err(err).EmbedObject(stream).Msg(msg)
|
||||
}
|
||||
|
||||
func logErrf(stream zerolog.LogObjectMarshaler, err error, format string, v ...any) {
|
||||
@@ -37,5 +37,5 @@ func logErrf(stream zerolog.LogObjectMarshaler, err error, format string, v ...a
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
log.Err(err).Object("stream", stream).Msgf(format, v...)
|
||||
log.Err(err).EmbedObject(stream).Msgf(format, v...)
|
||||
}
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
"github.com/pires/go-proxyproto"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/godoxy/internal/acl"
|
||||
acl "github.com/yusing/godoxy/internal/acl/types"
|
||||
"github.com/yusing/godoxy/internal/agentpool"
|
||||
"github.com/yusing/godoxy/internal/entrypoint"
|
||||
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
ioutils "github.com/yusing/goutils/io"
|
||||
"go.uber.org/atomic"
|
||||
@@ -43,26 +43,29 @@ func NewTCPTCPStream(network, dstNetwork, listenAddr, dstAddr string, agent *age
|
||||
return &TCPTCPStream{network: network, dstNetwork: dstNetwork, laddr: laddr, dst: dst, agent: agent}, nil
|
||||
}
|
||||
|
||||
func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
||||
func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) error {
|
||||
var err error
|
||||
s.listener, err = net.ListenTCP(s.network, s.laddr)
|
||||
if err != nil {
|
||||
logErr(s, err, "failed to listen")
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
if acl, ok := ctx.Value(acl.ContextKey{}).(*acl.Config); ok {
|
||||
log.Debug().Str("listener", s.listener.Addr().String()).Msg("wrapping listener with ACL")
|
||||
s.listener = acl.WrapTCP(s.listener)
|
||||
if ep := entrypoint.FromCtx(ctx); ep != nil {
|
||||
if proxyProto := ep.SupportProxyProtocol(); proxyProto {
|
||||
log.Debug().EmbedObject(s).Msg("wrapping listener with proxy protocol")
|
||||
s.listener = &proxyproto.Listener{Listener: s.listener}
|
||||
}
|
||||
}
|
||||
|
||||
if proxyProto := entrypoint.ActiveConfig.Load().SupportProxyProtocol; proxyProto {
|
||||
s.listener = &proxyproto.Listener{Listener: s.listener}
|
||||
if aclCfg := acl.FromCtx(ctx); aclCfg != nil {
|
||||
log.Debug().EmbedObject(s).Msg("wrapping listener with ACL")
|
||||
s.listener = aclCfg.WrapTCP(s.listener)
|
||||
}
|
||||
|
||||
s.preDial = preDial
|
||||
s.onRead = onRead
|
||||
go s.listen(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TCPTCPStream) Close() error {
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/godoxy/internal/acl"
|
||||
acl "github.com/yusing/godoxy/internal/acl/types"
|
||||
"github.com/yusing/godoxy/internal/agentpool"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
"github.com/yusing/goutils/synk"
|
||||
@@ -75,21 +75,21 @@ func NewUDPUDPStream(network, dstNetwork, listenAddr, dstAddr string, agent *age
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *UDPUDPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
||||
func (s *UDPUDPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) error {
|
||||
l, err := net.ListenUDP(s.network, s.laddr)
|
||||
if err != nil {
|
||||
logErr(s, err, "failed to listen")
|
||||
return
|
||||
return err
|
||||
}
|
||||
s.listener = l
|
||||
if acl, ok := ctx.Value(acl.ContextKey{}).(*acl.Config); ok {
|
||||
log.Debug().Str("listener", s.listener.LocalAddr().String()).Msg("wrapping listener with ACL")
|
||||
s.listener = acl.WrapUDP(s.listener)
|
||||
if aclCfg := acl.FromCtx(ctx); aclCfg != nil {
|
||||
log.Debug().EmbedObject(s).Msg("wrapping listener with ACL")
|
||||
s.listener = aclCfg.WrapUDP(l)
|
||||
}
|
||||
s.preDial = preDial
|
||||
s.onRead = onRead
|
||||
go s.listen(ctx)
|
||||
go s.cleanUp(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UDPUDPStream) Close() error {
|
||||
|
||||
32
internal/route/test_route.go
Normal file
32
internal/route/test_route.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/godoxy/internal/entrypoint"
|
||||
epctx "github.com/yusing/godoxy/internal/entrypoint/types"
|
||||
"github.com/yusing/godoxy/internal/types"
|
||||
"github.com/yusing/goutils/task"
|
||||
)
|
||||
|
||||
func NewStartedTestRoute(t testing.TB, base *Route) (types.Route, error) {
|
||||
t.Helper()
|
||||
|
||||
task := task.GetTestTask(t)
|
||||
if ep := epctx.FromCtx(task.Context()); ep == nil {
|
||||
ep = entrypoint.NewEntrypoint(task, nil)
|
||||
epctx.SetCtx(task, ep)
|
||||
}
|
||||
|
||||
err := base.Validate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = base.Start(task)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return base.impl, nil
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package route
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -25,7 +26,8 @@ type HTTPConfig struct {
|
||||
}
|
||||
|
||||
// BuildTLSConfig creates a TLS configuration based on the HTTP config options.
|
||||
func (cfg *HTTPConfig) BuildTLSConfig(targetURL *url.URL) (*tls.Config, gperr.Error) {
|
||||
func (cfg *HTTPConfig) BuildTLSConfig(targetURL *url.URL) (*tls.Config, error) {
|
||||
//nolint:gosec
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
// Handle InsecureSkipVerify (legacy NoTLSVerify option)
|
||||
@@ -54,15 +56,12 @@ func (cfg *HTTPConfig) BuildTLSConfig(targetURL *url.URL) (*tls.Config, gperr.Er
|
||||
if cfg.SSLTrustedCertificate != "" {
|
||||
caCertData, err := os.ReadFile(cfg.SSLTrustedCertificate)
|
||||
if err != nil {
|
||||
return nil, gperr.New("failed to read trusted certificate file").
|
||||
Subject(cfg.SSLTrustedCertificate).
|
||||
With(err)
|
||||
return nil, gperr.PrependSubject(err, cfg.SSLTrustedCertificate)
|
||||
}
|
||||
|
||||
caCertPool := x509.NewCertPool()
|
||||
if !caCertPool.AppendCertsFromPEM(caCertData) {
|
||||
return nil, gperr.New("failed to parse trusted certificates").
|
||||
Subject(cfg.SSLTrustedCertificate)
|
||||
return nil, gperr.PrependSubject(errors.New("failed to parse trusted certificates"), cfg.SSLTrustedCertificate)
|
||||
}
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
}
|
||||
@@ -70,16 +69,16 @@ func (cfg *HTTPConfig) BuildTLSConfig(targetURL *url.URL) (*tls.Config, gperr.Er
|
||||
// Handle ssl_certificate and ssl_certificate_key (client certificates)
|
||||
if cfg.SSLCertificate != "" {
|
||||
if cfg.SSLCertificateKey == "" {
|
||||
return nil, gperr.New("ssl_certificate_key is required when ssl_certificate is specified")
|
||||
return nil, errors.New("ssl_certificate_key is required when ssl_certificate is specified")
|
||||
}
|
||||
|
||||
clientCert, err := tls.LoadX509KeyPair(cfg.SSLCertificate, cfg.SSLCertificateKey)
|
||||
if err != nil {
|
||||
return nil, gperr.New("failed to load client certificate").
|
||||
Subject(cfg.SSLCertificate).
|
||||
With(err)
|
||||
return nil, gperr.PrependSubject(err, cfg.SSLCertificate)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{clientCert}
|
||||
} else if cfg.SSLCertificateKey != "" {
|
||||
return nil, errors.New("ssl_certificate is required when ssl_certificate_key is specified")
|
||||
}
|
||||
|
||||
// Handle ssl_protocols (TLS versions)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
@@ -13,8 +14,8 @@ type Port struct {
|
||||
} // @name Port
|
||||
|
||||
var (
|
||||
ErrInvalidPortSyntax = gperr.New("invalid port syntax, expect [listening_port:]target_port")
|
||||
ErrPortOutOfRange = gperr.New("port out of range")
|
||||
ErrInvalidPortSyntax = errors.New("invalid port syntax, expect [listening_port:]target_port")
|
||||
ErrPortOutOfRange = errors.New("port out of range")
|
||||
)
|
||||
|
||||
// Parse implements strutils.Parser.
|
||||
@@ -30,7 +31,7 @@ func (p *Port) Parse(v string) (err error) {
|
||||
p.Proxy, err2 = strconv.Atoi(parts[1])
|
||||
err = gperr.Join(err, err2)
|
||||
default:
|
||||
return ErrInvalidPortSyntax.Subject(v)
|
||||
return gperr.PrependSubject(ErrInvalidPortSyntax, v)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -38,11 +39,11 @@ func (p *Port) Parse(v string) (err error) {
|
||||
}
|
||||
|
||||
if p.Listening < MinPort || p.Listening > MaxPort {
|
||||
return ErrPortOutOfRange.Subjectf("%d", p.Listening)
|
||||
return gperr.PrependSubject(ErrPortOutOfRange, strconv.Itoa(p.Listening))
|
||||
}
|
||||
|
||||
if p.Proxy < MinPort || p.Proxy > MaxPort {
|
||||
return ErrPortOutOfRange.Subjectf("%d", p.Proxy)
|
||||
return gperr.PrependSubject(ErrPortOutOfRange, strconv.Itoa(p.Proxy))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -2,6 +2,7 @@ package route
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
@@ -9,7 +10,7 @@ import (
|
||||
|
||||
type Scheme uint8
|
||||
|
||||
var ErrInvalidScheme = gperr.New("invalid scheme")
|
||||
var ErrInvalidScheme = errors.New("invalid scheme")
|
||||
|
||||
const (
|
||||
SchemeHTTP Scheme = 1 << iota
|
||||
@@ -79,7 +80,7 @@ func (s *Scheme) Parse(v string) error {
|
||||
case schemeStrFileServer:
|
||||
*s = SchemeFileServer
|
||||
default:
|
||||
return ErrInvalidScheme.Subject(v)
|
||||
return gperr.PrependSubject(ErrInvalidScheme, v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user