Compare commits

...

49 Commits
0.6 ... 0.7.3

Author SHA1 Message Date
yusing
3bf520541b fixed stats websocket endpoint when no match_domains configured 2024-11-03 06:04:35 +08:00
yusing
a531896bd6 updated docker compose example 2024-11-03 05:27:27 +08:00
yusing
e005b42d18 readme indentation 2024-11-03 04:50:40 +08:00
yusing
1f6573b6da readme update 2024-11-03 04:43:17 +08:00
yusing
73af381c4c updated auto and manual setup guide 2024-11-03 04:38:09 +08:00
yusing
625bf4dfdc improved HTTP performance, especially when match_domains are used; api json string fix 2024-11-02 07:14:03 +08:00
yusing
46b4090629 moved env to .env.example, update setup method 2024-11-02 03:36:55 +08:00
yusing
91e012987e added option for jwt token ttl 2024-11-02 03:21:47 +08:00
yusing
a86d316d07 refactor and typo fixes 2024-11-02 03:14:47 +08:00
yusing
76454df5e6 remove useless dummytask 2024-11-02 03:14:25 +08:00
yusing
67b6e40f85 remove unused code 2024-11-02 03:04:15 +08:00
yusing
9889b5a8d3 correcting startup behavior 2024-11-02 03:00:14 +08:00
yusing
00fc75b61b update setup default branch 2024-10-31 03:07:03 +08:00
yusing
4ee93a1351 added logout endpoint 2024-10-31 00:58:35 +08:00
yusing
669d13b89a fixed crash on some base64 jwt secret 2024-10-30 11:09:52 +08:00
yusing
5fa86b5eb7 fixed auth internal server error 2024-10-30 10:19:30 +08:00
yusing
369cdf8c4f fixed config reload 2024-10-30 06:52:18 +08:00
yusing
0397f69853 fixed notification not being sent 2024-10-30 06:44:10 +08:00
yusing
81177926ff implemented login and jwt auth 2024-10-30 06:25:32 +08:00
yusing
e5bbb18414 migrated from logrus to zerolog, improved error formatting, fixed concurrent map write, fixed crash on rapid page refresh for idle containers, fixed infinite recursion on gotfiy error, fixed websocket connection problem when using idlewatcher 2024-10-29 11:34:58 +08:00
yusing
cfa74d69ae fixed out of range error on gotify message being sent 2024-10-22 16:46:12 +08:00
yusing
bee26f43d4 initial gotify support 2024-10-22 05:38:09 +08:00
yusing
a3ab32e9ab reload no longer be skipped when there're errors 2024-10-20 15:05:04 +08:00
yusing
c847fe4747 fixed schema and json tag, hide http://:0 2024-10-20 11:04:44 +08:00
yusing
a278711421 fixed loadbalancer with idlewatcher, fixed reload issue 2024-10-20 09:46:02 +08:00
yusing
01ffe0d97c simplified error messages 2024-10-19 14:25:28 +08:00
yusing
bd732dfa0a fixed homepage item not showing 2024-10-19 01:30:34 +08:00
yusing
8b8e1773e8 fixed loadbalanced routes with same alias cause conflict 2024-10-19 01:20:08 +08:00
yusing
b296fb2965 fixed healthcheck failed to disable and nil dereference 2024-10-19 00:13:55 +08:00
yusing
53557e38b6 Fixed a few issues:
- Incorrect name being shown on dashboard "Proxies page"
- Apps being shown when homepage.show is false
- Load balanced routes are shown on homepage instead of the load balancer
- Route with idlewatcher will now be removed on container destroy
- Idlewatcher panic
- Performance improvement
- Idlewatcher infinitely loading
- Reload stucked / not working properly
- Streams stuck on shutdown / reload
- etc...
Added:
- support idlewatcher for loadbalanced routes
- partial implementation for stream type idlewatcher
Issues:
- graceful shutdown
2024-10-18 16:47:01 +08:00
yusing
c0c61709ca fixed idlewatcher panic, dashboard app name, route not removing on container destroy 2024-10-16 00:57:10 +08:00
yusing
56b778f19c fixed json output for ls-routes and its API and homepage api 2024-10-15 16:23:46 +08:00
yusing
f4d532598c Improved healthcheck, idlewatcher support for loadbalanced routes, bug fixes 2024-10-15 15:34:27 +08:00
yusing
53fa28ae77 graceful shutdown and ref count related 2024-10-14 10:31:27 +08:00
yusing
f38b3abdbc improved health check 2024-10-14 10:02:53 +08:00
yusing
99207ae606 routes in loadbalance pool no longer listed in ls-route and its API, the loadbalancer is listed instead. improved context handling and grateful shutdown 2024-10-14 09:28:54 +08:00
yusing
d3b8cb8cba fixed possible memory leak for UDP stream 2024-10-14 08:57:46 +08:00
yusing
51c6eb4597 fixed problems related to reload and port selection 2024-10-14 08:55:49 +08:00
yusing
d47b672aa5 refactored some stuff, added healthcheck support, fixed 'include file' reload not showing in log 2024-10-12 13:56:38 +08:00
yusing
64e30f59e8 added missing json tags 2024-10-11 10:38:38 +08:00
yusing
cef7b3d396 updated tests for new behavior 2024-10-11 10:00:10 +08:00
yusing
7184c9cfe9 correcting some behaviors for $DOCKER_HOST, now uses container's private IP instead of localhost 2024-10-11 09:13:38 +08:00
yusing
da04a0dff4 added golangci-linting, refactor, simplified error msgs and fixed some error handling 2024-10-10 11:52:09 +08:00
yusing
d91b66ae87 performance improvement and small fix on loadbalancer 2024-10-09 18:10:51 +08:00
yusing
5c40f4aa84 added round_robin, least_conn and ip_hash load balance support, small refactoring 2024-10-09 10:39:07 +08:00
yusing
1797896fa6 fixed typos and formatting, fixed loading page not being shown (idlewaker) 2024-10-08 13:15:23 +08:00
yusing
d1c9e18c97 improved idlewatcher support for API-like services, fixed idlewaker proxying to zero port 2024-10-07 18:50:51 +08:00
yusing
ef83ed0596 improved idlewatcher and content type matching, update CI 2024-10-07 17:41:08 +08:00
yusing
d89155a6ee idlewatcher fixed idlewatcher incorrect respond haviour, keep url path 2024-10-07 16:44:18 +08:00
203 changed files with 8497 additions and 4359 deletions

22
.env.example Normal file
View File

@@ -0,0 +1,22 @@
# set timezone to get correct log timestamp
TZ=ETC/UTC
# generate secret with `openssl rand -base64 32`
GOPROXY_API_JWT_SECRET=
# the JWT token time-to-live
GOPROXY_API_JWT_TOKEN_TTL=1h
# API/WebUI login credentials
GOPROXY_API_USER=admin
GOPROXY_API_PASSWORD=password
# Proxy listening address
GOPROXY_HTTP_ADDR=:80
GOPROXY_HTTPS_ADDR=:443
# API listening address
GOPROXY_API_ADDR=127.0.0.1:8888
# Debug mode
GOPROXY_DEBUG=false

View File

@@ -11,7 +11,7 @@ env:
jobs:
build:
name: Build multi-platform Docker image
runs-on: self-hosted
runs-on: ubuntu-22.04
permissions:
contents: read
@@ -85,7 +85,7 @@ jobs:
if-no-files-found: error
retention-days: 1
merge:
runs-on: self-hosted
runs-on: ubuntu-22.04
needs:
- build
permissions:

3
.gitignore vendored
View File

@@ -1,6 +1,8 @@
compose.yml
*.compose.yml
config
certs
config*/
certs*/
bin/
@@ -22,3 +24,4 @@ todo.md
.aider*
mtrace.json
.env
test.Dockerfile

137
.golangci.yml Normal file
View File

@@ -0,0 +1,137 @@
run:
timeout: 10m
linters-settings:
govet:
enable-all: true
disable:
- shadow
- fieldalignment
gocyclo:
min-complexity: 14
goconst:
min-len: 3
min-occurrences: 4
misspell:
locale: US
funlen:
lines: -1
statements: 120
forbidigo:
forbid:
- ^print(ln)?$
godox:
keywords:
- FIXME
tagalign:
align: false
sort: true
order:
- description
- json
- toml
- yaml
- yml
- label
- label-slice-as-struct
- file
- kv
- export
stylecheck:
dot-import-whitelist:
- github.com/yusing/go-proxy/internal/utils/testing # go tests only
- github.com/yusing/go-proxy/internal/api/v1/utils # api only
revive:
rules:
- name: struct-tag
- name: blank-imports
- name: context-as-argument
- name: context-keys-type
- name: error-return
- name: error-strings
- name: error-naming
- name: exported
disabled: true
- name: if-return
- name: increment-decrement
- name: var-naming
- name: var-declaration
- name: package-comments
disabled: true
- name: range
- name: receiver-naming
- name: time-naming
- name: unexported-return
- name: indent-error-flow
- name: errorf
- name: empty-block
- name: superfluous-else
- name: unused-parameter
disabled: true
- name: unreachable-code
- name: redefines-builtin-id
gomoddirectives:
replace-allow-list:
- github.com/abbot/go-http-auth
- github.com/gorilla/mux
- github.com/mailgun/minheap
- github.com/mailgun/multibuf
- github.com/jaguilar/vt100
- github.com/cucumber/godog
- github.com/http-wasm/http-wasm-host-go
testifylint:
disable:
- suite-dont-use-pkg
- require-error
- go-require
staticcheck:
checks:
- all
- -SA1019
errcheck:
exclude-functions:
- fmt.Fprintln
linters:
enable-all: true
disable:
- execinquery # deprecated
- gomnd # deprecated
- sqlclosecheck # not relevant (SQL)
- rowserrcheck # not relevant (SQL)
- cyclop # duplicate of gocyclo
- depguard # Not relevant
- nakedret # Too strict
- lll # Not relevant
- gocyclo # FIXME must be fixed
- gocognit # Too strict
- nestif # Too many false-positive.
- prealloc # Too many false-positive.
- makezero # Not relevant
- dupl # Too strict
- gci # I don't care
- gosec # Too strict
- gochecknoinits
- gochecknoglobals
- wsl # Too strict
- nlreturn # Not relevant
- mnd # Too strict
- testpackage # Too strict
- tparallel # Not relevant
- paralleltest # Not relevant
- exhaustive # Not relevant
- exhaustruct # Not relevant
- err113 # Too strict
- wrapcheck # Too strict
- noctx # Too strict
- bodyclose # too many false-positive
- forcetypeassert # Too strict
- tagliatelle # Too strict
- varnamelen # Not relevant
- nilnil # Not relevant
- ireturn # Not relevant
- contextcheck # too many false-positive
- containedctx # too many false-positive
- maintidx # kind of duplicate of gocyclo
- nonamedreturns # Too strict
- gosmopolitan # not relevant
- exportloopref # Not relevant since go1.22

9
.trunk/.gitignore vendored Normal file
View File

@@ -0,0 +1,9 @@
*out
*logs
*actions
*notifications
*tools
plugins
user_trunk.yaml
user.yaml
tmp

41
.trunk/trunk.yaml Normal file
View File

@@ -0,0 +1,41 @@
# This file controls the behavior of Trunk: https://docs.trunk.io/cli
# To learn more about the format of this file, see https://docs.trunk.io/reference/trunk-yaml
version: 0.1
cli:
version: 1.22.6
# Trunk provides extensibility via plugins. (https://docs.trunk.io/plugins)
plugins:
sources:
- id: trunk
ref: v1.6.3
uri: https://github.com/trunk-io/plugins
# Many linters and tools depend on runtimes - configure them here. (https://docs.trunk.io/runtimes)
runtimes:
enabled:
- node@18.12.1
- python@3.10.8
- go@1.23.2
# This is the section where you manage your linters. (https://docs.trunk.io/check/configuration)
lint:
enabled:
- hadolint@2.12.0
- actionlint@1.7.3
- checkov@3.2.257
- git-diff-check
- gofmt@1.20.4
- golangci-lint@1.61.0
- markdownlint@0.42.0
- osv-scanner@1.9.0
- oxipng@9.1.2
- prettier@3.3.3
- shellcheck@0.10.0
- shfmt@3.6.0
- trufflehog@3.82.7
- yamllint@1.35.1
actions:
disabled:
- trunk-announce
- trunk-check-pre-push
- trunk-fmt-pre-commit
enabled:
- trunk-upgrade-available

View File

@@ -28,14 +28,20 @@ get:
go get -u ./cmd && go mod tidy
debug:
make build && sudo GOPROXY_DEBUG=1 bin/go-proxy
make build
sudo GOPROXY_DEBUG=1 bin/go-proxy
debug-trace:
make build
sudo GOPROXY_DEBUG=1 GOPROXY_TRACE=1 bin/go-proxy
profile:
GODEBUG=gctrace=1 make build
sudo GOPROXY_DEBUG=1 bin/go-proxy
mtrace:
bin/go-proxy debug-ls-mtrace > mtrace.json
run-test:
make build && sudo GOPROXY_TEST=1 bin/go-proxy
run:
make build && sudo bin/go-proxy
@@ -49,7 +55,7 @@ repush:
git push gitlab dev --force
rapid-crash:
sudo docker run --restart=always --name test_crash debian:bookworm-slim /bin/cat &&\
sudo docker run --restart=always --name test_crash -p 80 debian:bookworm-slim /bin/cat &&\
sleep 3 &&\
sudo docker rm -f test_crash
@@ -58,4 +64,7 @@ debug-list-containers:
ci-test:
mkdir -p /tmp/artifacts
act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)"
act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)"
cloc:
cloc --not-match-f '_test.go$$' cmd internal pkg

View File

@@ -24,6 +24,8 @@ _Join our [Discord](https://discord.gg/umReR62nRd) for help and discussions_
- [Key Features](#key-features)
- [Getting Started](#getting-started)
- [Setup](#setup)
- [Manual Setup](#manual-setup)
- [Folder structrue](#folder-structrue)
- [Use JSON Schema in VSCode](#use-json-schema-in-vscode)
- [Screenshots](#screenshots)
- [idlesleeper](#idlesleeper)
@@ -59,10 +61,12 @@ _Join our [Discord](https://discord.gg/umReR62nRd) for help and discussions_
docker pull ghcr.io/yusing/go-proxy:latest
```
2. Create new directory, `cd` into it, then run setup
2. Create new directory, `cd` into it, then run setup, or [set up manually](#manual-setup)
```shell
docker run --rm -v .:/setup ghcr.io/yusing/go-proxy /app/go-proxy setup
# Then set the JWT secret
sed -i "s|GOPROXY_API_JWT_SECRET=.*|GOPROXY_API_JWT_SECRET=$(openssl rand -base64 32)|g" .env
```
3. Setup DNS Records point to machine which runs `go-proxy`, e.g.
@@ -83,6 +87,43 @@ _Join our [Discord](https://discord.gg/umReR62nRd) for help and discussions_
[🔼Back to top](#table-of-content)
### Manual Setup
1. Make `config` directory then grab `config.example.yml` into `config/config.yml`
`mkdir -p config && wget https://raw.githubusercontent.com/yusing/go-proxy/v0.7/config.example.yml -O config/config.yml`
2. Grab `.env.example` into `.env`
`wget https://raw.githubusercontent.com/yusing/go-proxy/v0.7/.env.example -O .env`
3. Grab `compose.example.yml` into `compose.yml`
`wget https://raw.githubusercontent.com/yusing/go-proxy/v0.7/compose.example.yml -O compose.yml`
4. Set the JWT secret
`sed -i "s|GOPROXY_API_JWT_SECRET=.*|GOPROXY_API_JWT_SECRET=$(openssl rand -base64 32)|g" .env`
5. Start the container `docker compose up -d`
### Folder structrue
```shell
├── certs
│ ├── cert.crt
│ └── priv.key
├── compose.yml
├── config
│ ├── config.yml
│ ├── middlewares
│ │ ├── middleware1.yml
│ │ ├── middleware2.yml
│ ├── provider1.yml
│ └── provider2.yml
└── .env
```
### Use JSON Schema in VSCode
Copy [`.vscode/settings.example.json`](.vscode/settings.example.json) to `.vscode/settings.json` and modify it to fit your needs

View File

@@ -1,76 +1,79 @@
package main
import (
"context"
"encoding/json"
"io"
"log"
"net/http"
"os"
"os/signal"
"reflect"
"runtime"
"strings"
"sync"
"syscall"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal"
"github.com/yusing/go-proxy/internal/api"
"github.com/yusing/go-proxy/internal/api/v1/query"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/docker/idlewatcher"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/http/middleware"
R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/server"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/pkg"
)
func main() {
args := common.GetArgs()
if args.Command == common.CommandSetup {
switch args.Command {
case common.CommandSetup:
internal.Setup()
return
}
l := logrus.WithField("module", "main")
onShutdown := F.NewSlice[func()]()
if common.IsDebug {
logrus.SetLevel(logrus.DebugLevel)
}
if args.Command != common.CommandStart {
logrus.SetOutput(io.Discard)
} else {
logrus.SetFormatter(&logrus.TextFormatter{
DisableSorting: true,
FullTimestamp: true,
ForceColors: true,
TimestampFormat: "01-02 15:04:05",
})
logrus.Infof("go-proxy version %s", pkg.GetVersion())
}
if args.Command == common.CommandReload {
case common.CommandReload:
if err := query.ReloadServer(); err != nil {
E.LogFatal("server reload error", err)
}
logging.Info().Msg("ok")
return
case common.CommandListIcons:
icons, err := internal.ListAvailableIcons()
if err != nil {
log.Fatal(err)
}
log.Print("ok")
printJSON(icons)
return
case common.CommandListRoutes:
routes, err := query.ListRoutes()
if err != nil {
log.Printf("failed to connect to api server: %s", err)
log.Printf("falling back to config file")
printJSON(config.RoutesByAlias())
} else {
printJSON(routes)
}
return
case common.CommandDebugListMTrace:
trace, err := query.ListMiddlewareTraces()
if err != nil {
log.Fatal(err)
}
printJSON(trace)
return
}
// exit if only validate config
if args.Command == common.CommandStart {
logging.Info().Msgf("go-proxy version %s", pkg.GetVersion())
logging.Trace().Msg("trace enabled")
// logging.AddHook(notif.GetDispatcher())
} else {
logging.DiscardLogger()
}
if args.Command == common.CommandValidate {
data, err := os.ReadFile(common.ConfigPath)
if err == nil {
err = config.Validate(data).Error()
err = config.Validate(data)
}
if err != nil {
log.Fatal("config error: ", err)
@@ -85,68 +88,39 @@ func main() {
middleware.LoadComposeFiles()
if err := config.Load(); err != nil {
logrus.Warn(err)
var cfg *config.Config
var err E.Error
if cfg, err = config.Load(); err != nil {
E.LogWarn("errors in config", err)
}
cfg := config.GetInstance()
switch args.Command {
case common.CommandListConfigs:
printJSON(cfg.Value())
return
case common.CommandListRoutes:
routes, err := query.ListRoutes()
if err != nil {
log.Printf("failed to connect to api server: %s", err)
log.Printf("falling back to config file")
printJSON(cfg.RoutesByAlias())
} else {
printJSON(routes)
}
return
case common.CommandListIcons:
icons, err := internal.ListAvailableIcons()
if err != nil {
log.Fatal(err)
}
printJSON(icons)
printJSON(config.Value())
return
case common.CommandDebugListEntries:
printJSON(cfg.DumpEntries())
printJSON(config.DumpEntries())
return
case common.CommandDebugListProviders:
printJSON(cfg.DumpProviders())
printJSON(config.DumpProviders())
return
case common.CommandDebugListMTrace:
trace, err := query.ListMiddlewareTraces()
if err != nil {
log.Fatal(err)
}
printJSON(trace)
}
cfg.StartProxyProviders()
cfg.WatchChanges()
onShutdown.Add(docker.CloseAllClients)
onShutdown.Add(cfg.Dispose)
config.WatchChanges()
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT)
signal.Notify(sig, syscall.SIGTERM)
signal.Notify(sig, syscall.SIGHUP)
autocert := cfg.GetAutoCertProvider()
autocert := config.GetAutoCertProvider()
if autocert != nil {
ctx, cancel := context.WithCancel(context.Background())
if err := autocert.Setup(ctx); err != nil {
l.Fatal(err)
} else {
onShutdown.Add(cancel)
if err := autocert.Setup(); err != nil {
E.LogFatal("autocert setup error", err)
}
} else {
l.Info("autocert not configured")
logging.Info().Msg("autocert not configured")
}
proxyServer := server.InitProxyServer(server.Options{
@@ -155,75 +129,40 @@ func main() {
HTTPAddr: common.ProxyHTTPAddr,
HTTPSAddr: common.ProxyHTTPSAddr,
Handler: http.HandlerFunc(R.ProxyHandler),
RedirectToHTTPS: cfg.Value().RedirectToHTTPS,
RedirectToHTTPS: config.Value().RedirectToHTTPS,
})
apiServer := server.InitAPIServer(server.Options{
Name: "api",
CertProvider: autocert,
HTTPAddr: common.APIHTTPAddr,
Handler: api.NewHandler(cfg),
RedirectToHTTPS: cfg.Value().RedirectToHTTPS,
Handler: api.NewHandler(),
RedirectToHTTPS: config.Value().RedirectToHTTPS,
})
proxyServer.Start()
apiServer.Start()
onShutdown.Add(proxyServer.Stop)
onShutdown.Add(apiServer.Stop)
go idlewatcher.Start()
onShutdown.Add(idlewatcher.Stop)
// wait for signal
<-sig
// grafully shutdown
logrus.Info("shutting down")
done := make(chan struct{}, 1)
var wg sync.WaitGroup
wg.Add(onShutdown.Size())
onShutdown.ForEach(func(f func()) {
go func() {
l.Debugf("waiting for %s to complete...", funcName(f))
f()
l.Debugf("%s done", funcName(f))
wg.Done()
}()
})
go func() {
wg.Wait()
close(done)
}()
timeout := time.After(time.Duration(cfg.Value().TimeoutShutdown) * time.Second)
select {
case <-done:
logrus.Info("shutdown complete")
case <-timeout:
logrus.Info("timeout waiting for shutdown")
onShutdown.ForEach(func(f func()) {
l.Warnf("%s() is still running", funcName(f))
})
}
logging.Info().Msg("shutting down")
task.CancelGlobalContext()
task.GlobalContextWait(time.Second * time.Duration(config.Value().TimeoutShutdown))
}
func prepareDirectory(dir string) {
if _, err := os.Stat(dir); os.IsNotExist(err) {
if err = os.MkdirAll(dir, 0755); err != nil {
logrus.Fatalf("failed to create directory %s: %v", dir, err)
if err = os.MkdirAll(dir, 0o755); err != nil {
logging.Fatal().Msgf("failed to create directory %s: %v", dir, err)
}
}
}
func funcName(f func()) string {
parts := strings.Split(runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name(), "/go-proxy/")
return parts[len(parts)-1]
}
func printJSON(obj any) {
j, err := E.Check(json.MarshalIndent(obj, "", " "))
j, err := json.MarshalIndent(obj, "", " ")
if err != nil {
logrus.Fatal(err)
logging.Fatal().Err(err).Send()
}
rawLogger := log.New(os.Stdout, "", 0)
rawLogger.Printf("%s", j) // raw output for convenience using "jq"

View File

@@ -4,31 +4,26 @@ services:
container_name: go-proxy-frontend
restart: unless-stopped
network_mode: host
env_file: .env
depends_on:
- app
# if you also want to proxy the WebUI and access it via gp.y.z
# labels:
# - proxy.aliases=gp
# - proxy.gp.port=3000
# Make sure the value is same as `GOPROXY_API_ADDR` below (if you have changed it)
#
# environment:
# GOPROXY_API_ADDR: 127.0.0.1:8888
# modify below to fit your needs
labels:
proxy.aliases: gp
proxy.#1.port: 3000
proxy.#1.middlewares.cidr_whitelist.status_code: 403
proxy.#1.middlewares.cidr_whitelist.message: IP not allowed
proxy.#1.middlewares.cidr_whitelist.allow: |
- 127.0.0.1
- 10.0.0.0/8
- 192.168.0.0/16
- 172.16.0.0/12
app:
image: ghcr.io/yusing/go-proxy:latest
container_name: go-proxy
restart: always
network_mode: host
environment:
# (Optional) change this to your timezone to get correct log timestamp
TZ: ETC/UTC
# Change these if you need
#
# GOPROXY_HTTP_ADDR: :80
# GOPROXY_HTTPS_ADDR: :443
# GOPROXY_API_ADDR: 127.0.0.1:8888
env_file: .env
volumes:
- /var/run/docker.sock:/var/run/docker.sock
- ./config:/app/config

19
go.mod
View File

@@ -8,9 +8,11 @@ require (
github.com/docker/docker v27.3.1+incompatible
github.com/fsnotify/fsnotify v1.7.0
github.com/go-acme/lego/v4 v4.19.2
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/gotify/server/v2 v2.5.0
github.com/puzpuzpuz/xsync/v3 v3.4.0
github.com/rs/zerolog v1.33.0
github.com/santhosh-tekuri/jsonschema v1.2.4
github.com/sirupsen/logrus v1.9.3
golang.org/x/net v0.30.0
golang.org/x/text v0.19.0
gopkg.in/yaml.v3 v3.0.1
@@ -19,7 +21,7 @@ require (
require (
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cloudflare/cloudflare-go v0.106.0 // indirect
github.com/cloudflare/cloudflare-go v0.108.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/go-connections v0.5.0 // indirect
@@ -32,6 +34,8 @@ require (
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/miekg/dns v1.1.62 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/term v0.5.0 // indirect
@@ -40,13 +44,14 @@ require (
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/ovh/go-ovh v1.6.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0 // indirect
go.opentelemetry.io/otel v1.30.0 // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect
go.opentelemetry.io/otel v1.31.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 // indirect
go.opentelemetry.io/otel/metric v1.30.0 // indirect
go.opentelemetry.io/otel/metric v1.31.0 // indirect
go.opentelemetry.io/otel/sdk v1.30.0 // indirect
go.opentelemetry.io/otel/trace v1.30.0 // indirect
go.opentelemetry.io/otel/trace v1.31.0 // indirect
golang.org/x/crypto v0.28.0 // indirect
golang.org/x/mod v0.21.0 // indirect
golang.org/x/oauth2 v0.23.0 // indirect

42
go.sum
View File

@@ -4,12 +4,13 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cloudflare/cloudflare-go v0.106.0 h1:q41gC5Wc1nfi0D1ZhSHokWcd9mGMbqC7RE7qiP+qE00=
github.com/cloudflare/cloudflare-go v0.106.0/go.mod h1:pfUQ4PIG4ISI0/Mmc21Bp86UnFU0ktmPf3iTgbSL+cM=
github.com/cloudflare/cloudflare-go v0.108.0 h1:C4Skfjd8I8X3uEOGmQUT4/iGyZcWdkIU7HwvMoLkEE0=
github.com/cloudflare/cloudflare-go v0.108.0/go.mod h1:m492eNahT/9MsN7Ppnoge8AaI7QhVFtEgVm3I9HJFeU=
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -40,8 +41,11 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
@@ -49,6 +53,8 @@ github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gotify/server/v2 v2.5.0 h1:tJd+a5bb17X52f0EV2KxqLuyjQFKmVK1+t/iNUkP16Y=
github.com/gotify/server/v2 v2.5.0/go.mod h1:DKPMQI/FZ69iKbZvrOL6VWwRaoB9O+HDvJWVd/kiGbc=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 h1:asbCHRVmodnJTuQ3qamDwqVOIjwqUPTYmYuemVOx+Ys=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0/go.mod h1:ggCgvZ2r7uOoQjOyu2Y1NhHmEPPzzuhWgcza5M1Ji1I=
github.com/jarcoal/httpmock v1.3.0 h1:2RJ8GP0IIaWwcC9Fp2BmVi8Kog3v2Hn7VXM3fTd+nuc=
@@ -59,6 +65,12 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g=
github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
@@ -84,8 +96,11 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH
github.com/puzpuzpuz/xsync/v3 v3.4.0 h1:DuVBAdXuGFHv8adVXjWWZ63pJq+NRXOWVXlKDBZ+mJ4=
github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis=
github.com/santhosh-tekuri/jsonschema v1.2.4/go.mod h1:TEAUOeZSmIxTTuHatJzrvARHiuO9LYd+cIxzgEHCQI4=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
@@ -96,20 +111,20 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0 h1:ZIg3ZT/aQ7AfKqdwp7ECpOK6vHqquXXuyTjIO8ZdmPs=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0/go.mod h1:DQAwmETtZV00skUwgD6+0U89g80NKsJE3DCKeLLPQMI=
go.opentelemetry.io/otel v1.30.0 h1:F2t8sK4qf1fAmY9ua4ohFS/K+FUuOPemHUIXHtktrts=
go.opentelemetry.io/otel v1.30.0/go.mod h1:tFw4Br9b7fOS+uEao81PJjVMjW/5fvNCbpsDIXqP0pc=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 h1:UP6IpuHFkUgOQL9FFQFrZ+5LiwhhYRbi7VZSIx6Nj5s=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0/go.mod h1:qxuZLtbq5QDtdeSHsS7bcf6EH6uO6jUAgk764zd3rhM=
go.opentelemetry.io/otel v1.31.0 h1:NsJcKPIW0D0H3NgzPDHmo0WW6SptzPdqg/L1zsIm2hY=
go.opentelemetry.io/otel v1.31.0/go.mod h1:O0C14Yl9FgkjqcCZAsE053C13OaddMYr/hz6clDkEJE=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.30.0 h1:lsInsfvhVIfOI6qHVyysXMNDnjO9Npvl7tlDPJFBVd4=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.30.0/go.mod h1:KQsVNh4OjgjTG0G6EiNi1jVpnaeeKsKMRwbLN+f1+8M=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 h1:umZgi92IyxfXd/l4kaDhnKgY8rnN/cZcF1LKc6I8OQ8=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0/go.mod h1:4lVs6obhSVRb1EW5FhOuBTyiQhtRtAnnva9vD3yRfq8=
go.opentelemetry.io/otel/metric v1.30.0 h1:4xNulvn9gjzo4hjg+wzIKG7iNFEaBMX00Qd4QIZs7+w=
go.opentelemetry.io/otel/metric v1.30.0/go.mod h1:aXTfST94tswhWEb+5QjlSqG+cZlmyXy/u8jFpor3WqQ=
go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozRPcF2fE=
go.opentelemetry.io/otel/metric v1.31.0/go.mod h1:C3dEloVbLuYoX41KpmAhOqNriGbA+qqH6PQ5E5mUfnY=
go.opentelemetry.io/otel/sdk v1.30.0 h1:cHdik6irO49R5IysVhdn8oaiR9m8XluDaJAs4DfOrYE=
go.opentelemetry.io/otel/sdk v1.30.0/go.mod h1:p14X4Ok8S+sygzblytT1nqG98QG2KYKv++HE0LY/mhg=
go.opentelemetry.io/otel/trace v1.30.0 h1:7UBkkYzeg3C7kQX8VAidWh2biiQbtAKjyIML8dQ9wmc=
go.opentelemetry.io/otel/trace v1.30.0/go.mod h1:5EyKqTzzmyqB9bwtCCq6pDLktPK6fmGf/Dph+8VI02o=
go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys=
go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A=
go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
@@ -138,6 +153,9 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@@ -2,13 +2,13 @@ package api
import (
"fmt"
"net"
"net/http"
v1 "github.com/yusing/go-proxy/internal/api/v1"
"github.com/yusing/go-proxy/internal/api/v1/error_page"
"github.com/yusing/go-proxy/internal/api/v1/auth"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
)
type ServeMux struct{ *http.ServeMux }
@@ -21,43 +21,38 @@ func (mux ServeMux) HandleFunc(method, endpoint string, handler http.HandlerFunc
mux.ServeMux.HandleFunc(fmt.Sprintf("%s %s", method, endpoint), checkHost(handler))
}
func NewHandler(cfg *config.Config) http.Handler {
func NewHandler() http.Handler {
mux := NewServeMux()
mux.HandleFunc("GET", "/v1", v1.Index)
mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
mux.HandleFunc("GET", "/v1/checkhealth", wrap(cfg, v1.CheckHealth))
mux.HandleFunc("HEAD", "/v1/checkhealth", wrap(cfg, v1.CheckHealth))
mux.HandleFunc("POST", "/v1/reload", wrap(cfg, v1.Reload))
mux.HandleFunc("GET", "/v1/list", wrap(cfg, v1.List))
mux.HandleFunc("GET", "/v1/list/{what}", wrap(cfg, v1.List))
mux.HandleFunc("GET", "/v1/file", v1.GetFileContent)
mux.HandleFunc("GET", "/v1/file/{filename...}", v1.GetFileContent)
mux.HandleFunc("POST", "/v1/file/{filename...}", v1.SetFileContent)
mux.HandleFunc("PUT", "/v1/file/{filename...}", v1.SetFileContent)
mux.HandleFunc("GET", "/v1/stats", wrap(cfg, v1.Stats))
mux.HandleFunc("GET", "/v1/stats/ws", wrap(cfg, v1.StatsWS))
mux.HandleFunc("GET", "/v1/error_page", error_page.GetHandleFunc())
mux.HandleFunc("POST", "/v1/login", auth.LoginHandler)
mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler)
mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler)
mux.HandleFunc("POST", "/v1/reload", v1.Reload)
mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(v1.List))
mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(v1.List))
mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(v1.List))
mux.HandleFunc("GET", "/v1/file", auth.RequireAuth(v1.GetFileContent))
mux.HandleFunc("GET", "/v1/file/{filename...}", auth.RequireAuth(v1.GetFileContent))
mux.HandleFunc("POST", "/v1/file/{filename...}", auth.RequireAuth(v1.SetFileContent))
mux.HandleFunc("PUT", "/v1/file/{filename...}", auth.RequireAuth(v1.SetFileContent))
mux.HandleFunc("GET", "/v1/stats", v1.Stats)
mux.HandleFunc("GET", "/v1/stats/ws", v1.StatsWS)
return mux
}
// allow only requests to API server with host matching common.APIHTTPAddr
// allow only requests to API server with localhost.
func checkHost(f http.HandlerFunc) http.HandlerFunc {
if common.IsDebug {
return f
}
return func(w http.ResponseWriter, r *http.Request) {
if r.Host != common.APIHTTPAddr {
Logger.Warnf("invalid request to API server with host: %s, expect %s", r.Host, common.APIHTTPAddr)
w.WriteHeader(http.StatusForbidden)
w.Write([]byte("invalid request"))
host, _, _ := net.SplitHostPort(r.RemoteAddr)
if host != "127.0.0.1" && host != "localhost" && host != "[::1]" {
LogWarn(r).Msgf("blocked API request from %s", host)
http.Error(w, "forbidden", http.StatusForbidden)
return
}
f(w, r)
}
}
func wrap(cfg *config.Config, f func(cfg *config.Config, w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
f(cfg, w, r)
}
}

View File

@@ -0,0 +1,135 @@
package auth
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/golang-jwt/jwt/v5"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type (
Credentials struct {
Username string `json:"username"`
Password string `json:"password"`
}
Claims struct {
Username string `json:"username"`
jwt.RegisteredClaims
}
)
var (
ErrInvalidUsername = E.New("invalid username")
ErrInvalidPassword = E.New("invalid password")
)
func validatePassword(cred *Credentials) error {
if cred.Username != common.APIUser {
return ErrInvalidUsername.Subject(cred.Username)
}
if !bytes.Equal(common.HashPassword(cred.Password), common.APIPasswordHash) {
return ErrInvalidPassword.Subject(cred.Password)
}
return nil
}
func LoginHandler(w http.ResponseWriter, r *http.Request) {
var creds Credentials
err := json.NewDecoder(r.Body).Decode(&creds)
if err != nil {
U.HandleErr(w, r, err, http.StatusBadRequest)
return
}
if err := validatePassword(&creds); err != nil {
U.HandleErr(w, r, err, http.StatusUnauthorized)
return
}
expiresAt := time.Now().Add(common.APIJWTTokenTTL)
claim := &Claims{
Username: creds.Username,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS512, claim)
tokenStr, err := token.SignedString(common.APIJWTSecret)
if err != nil {
U.HandleErr(w, r, err)
return
}
http.SetCookie(w, &http.Cookie{
Name: "token",
Value: tokenStr,
Expires: expiresAt,
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
Path: "/",
})
w.WriteHeader(http.StatusOK)
}
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: "token",
Value: "",
Expires: time.Unix(0, 0),
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
Path: "/",
})
w.Header().Set("location", "/login")
w.WriteHeader(http.StatusTemporaryRedirect)
}
func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
if common.IsDebugSkipAuth {
return next
}
return func(w http.ResponseWriter, r *http.Request) {
if checkToken(w, r) {
next(w, r)
}
}
}
func checkToken(w http.ResponseWriter, r *http.Request) (ok bool) {
tokenCookie, err := r.Cookie("token")
if err != nil {
U.HandleErr(w, r, E.PrependSubject("token", err), http.StatusUnauthorized)
return false
}
var claims Claims
token, err := jwt.ParseWithClaims(tokenCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return common.APIJWTSecret, nil
})
switch {
case err != nil:
break
case !token.Valid:
err = E.New("invalid token")
case claims.Username != common.APIUser:
err = E.New("username mismatch").Subject(claims.Username)
case claims.ExpiresAt.Before(time.Now()):
err = E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time))
}
if err != nil {
U.HandleErr(w, r, err, http.StatusForbidden)
return false
}
return true
}

View File

@@ -1,42 +0,0 @@
package v1
import (
"fmt"
"net/http"
"strings"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/config"
R "github.com/yusing/go-proxy/internal/route"
)
func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
target := r.FormValue("target")
if target == "" {
U.HandleErr(w, r, U.ErrMissingKey("target"), http.StatusBadRequest)
return
}
var ok bool
route := cfg.FindRoute(target)
switch {
case route == nil:
U.HandleErr(w, r, U.ErrNotFound("target", target), http.StatusNotFound)
return
case route.Type() == R.RouteTypeReverseProxy:
ok = IsSiteHealthy(route.URL().String())
case route.Type() == R.RouteTypeStream:
entry := route.Entry()
ok = IsStreamHealthy(
strings.Split(entry.Scheme, ":")[1], // target scheme
fmt.Sprintf("%s:%v", entry.Host, strings.Split(entry.Port, ":")[1]),
)
}
if ok {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusRequestTimeout)
}
}

View File

@@ -1,25 +0,0 @@
package error_page
import "net/http"
func GetHandleFunc() http.HandlerFunc {
setup()
return serveHTTP
}
func serveHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if r.URL.Path == "/" {
http.Error(w, "invalid path", http.StatusNotFound)
return
}
content, ok := fileContentMap.Load(r.URL.Path)
if !ok {
http.Error(w, "404 not found", http.StatusNotFound)
return
}
w.Write(content)
}

View File

@@ -11,7 +11,7 @@ import (
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/proxy/provider"
"github.com/yusing/go-proxy/internal/route/provider"
)
func GetFileContent(w http.ResponseWriter, r *http.Request) {
@@ -24,7 +24,7 @@ func GetFileContent(w http.ResponseWriter, r *http.Request) {
U.HandleErr(w, r, err)
return
}
w.Write(content)
U.WriteBody(w, content)
}
func SetFileContent(w http.ResponseWriter, r *http.Request) {
@@ -39,19 +39,20 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) {
return
}
var validateErr E.NestedError
var valErr E.Error
if filename == common.ConfigFileName {
validateErr = config.Validate(content)
valErr = config.Validate(content)
} else if !strings.HasPrefix(filename, path.Base(common.MiddlewareComposeBasePath)) {
validateErr = provider.Validate(content)
valErr = provider.Validate(content)
}
// no validation for include files
if validateErr != nil {
U.RespondJson(w, validateErr.JSONObject(), http.StatusBadRequest)
if valErr != nil {
U.RespondJSON(w, r, valErr, http.StatusBadRequest)
return
}
err = os.WriteFile(path.Join(common.ConfigBasePath, filename), content, 0644)
err = os.WriteFile(path.Join(common.ConfigBasePath, filename), content, 0o644)
if err != nil {
U.HandleErr(w, r, err)
return

View File

@@ -1,34 +0,0 @@
package v1
import (
"net"
"net/http"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
)
func IsSiteHealthy(url string) bool {
// try HEAD first
// if HEAD is not allowed, try GET
resp, err := U.Head(url)
if resp != nil {
resp.Body.Close()
}
if err != nil && resp != nil && resp.StatusCode == http.StatusMethodNotAllowed {
_, err = U.Get(url)
}
if resp != nil {
resp.Body.Close()
}
return err == nil
}
func IsStreamHealthy(scheme, address string) bool {
conn, err := net.DialTimeout(scheme, address, common.DialTimeout)
if err != nil {
return false
}
conn.Close()
return true
}

View File

@@ -1,7 +1,11 @@
package v1
import "net/http"
import (
"net/http"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
)
func Index(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("API ready"))
WriteBody(w, []byte("API ready"))
}

View File

@@ -8,54 +8,69 @@ import (
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
"github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
)
const (
ListRoutes = "routes"
ListConfigFiles = "config_files"
ListMiddlewares = "middlewares"
ListMiddlewareTrace = "middleware_trace"
ListMatchDomains = "match_domains"
ListHomepageConfig = "homepage_config"
ListRoute = "route"
ListRoutes = "routes"
ListConfigFiles = "config_files"
ListMiddlewares = "middlewares"
ListMiddlewareTraces = "middleware_trace"
ListMatchDomains = "match_domains"
ListHomepageConfig = "homepage_config"
ListTasks = "tasks"
)
func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
func List(w http.ResponseWriter, r *http.Request) {
what := r.PathValue("what")
if what == "" {
what = ListRoutes
}
which := r.PathValue("which")
switch what {
case ListRoute:
if route := listRoute(which); route == nil {
http.Error(w, "not found", http.StatusNotFound)
return
} else {
U.RespondJSON(w, r, route)
}
case ListRoutes:
listRoutes(cfg, w, r)
U.RespondJSON(w, r, config.RoutesByAlias(route.RouteType(r.FormValue("type"))))
case ListConfigFiles:
listConfigFiles(w, r)
case ListMiddlewares:
listMiddlewares(w, r)
case ListMiddlewareTrace:
listMiddlewareTrace(w, r)
U.RespondJSON(w, r, middleware.All())
case ListMiddlewareTraces:
U.RespondJSON(w, r, middleware.GetAllTrace())
case ListMatchDomains:
listMatchDomains(cfg, w, r)
U.RespondJSON(w, r, config.Value().MatchDomains)
case ListHomepageConfig:
listHomepageConfig(cfg, w, r)
U.RespondJSON(w, r, config.HomepageConfig())
case ListTasks:
U.RespondJSON(w, r, task.DebugTaskMap())
default:
U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest)
}
}
func listRoutes(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
routes := cfg.RoutesByAlias()
typeFilter := r.FormValue("type")
if typeFilter != "" {
for k, v := range routes {
if v["type"] != typeFilter {
delete(routes, k)
}
}
func listRoute(which string) any {
if which == "" {
which = "all"
}
U.HandleErr(w, r, U.RespondJson(w, routes))
if which == "all" {
return config.RoutesByAlias()
}
routes := config.RoutesByAlias()
route, ok := routes[which]
if !ok {
return nil
}
return route
}
func listConfigFiles(w http.ResponseWriter, r *http.Request) {
@@ -67,21 +82,5 @@ func listConfigFiles(w http.ResponseWriter, r *http.Request) {
for i := range files {
files[i] = strings.TrimPrefix(files[i], common.ConfigBasePath+"/")
}
U.HandleErr(w, r, U.RespondJson(w, files))
}
func listMiddlewareTrace(w http.ResponseWriter, r *http.Request) {
U.HandleErr(w, r, U.RespondJson(w, middleware.GetAllTrace()))
}
func listMiddlewares(w http.ResponseWriter, r *http.Request) {
U.HandleErr(w, r, U.RespondJson(w, middleware.All()))
}
func listMatchDomains(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
U.HandleErr(w, r, U.RespondJson(w, cfg.Value().MatchDomains))
}
func listHomepageConfig(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
U.HandleErr(w, r, U.RespondJson(w, cfg.HomepageConfig()))
U.RespondJSON(w, r, files)
}

View File

@@ -13,57 +13,52 @@ import (
"github.com/yusing/go-proxy/internal/net/http/middleware"
)
func ReloadServer() E.NestedError {
resp, err := U.Post(fmt.Sprintf("%s/v1/reload", common.APIHTTPURL), "", nil)
func ReloadServer() E.Error {
resp, err := U.Post(common.APIHTTPURL+"/v1/reload", "", nil)
if err != nil {
return E.From(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
failure := E.Failure("server reload").Extraf("status code: %v", resp.StatusCode)
b, err := io.ReadAll(resp.Body)
failure := E.Errorf("server reload status %v", resp.StatusCode)
body, err := io.ReadAll(resp.Body)
if err != nil {
return failure.Extraf("unable to read response body: %s", err)
return failure.With(err)
}
reloadErr, ok := E.FromJSON(b)
if ok {
return E.Join("reload success, but server returned error", reloadErr)
}
return failure.Extraf("unable to read response body")
reloadErr := string(body)
return failure.Withf(reloadErr)
}
return nil
}
func ListRoutes() (map[string]map[string]any, E.NestedError) {
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListRoutes))
func List[T any](what string) (_ T, outErr E.Error) {
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, what))
if err != nil {
return nil, E.From(err)
outErr = E.From(err)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, E.Failure("list routes").Extraf("status code: %v", resp.StatusCode)
outErr = E.Errorf("list %s: failed, status %v", what, resp.StatusCode)
return
}
var routes map[string]map[string]any
err = json.NewDecoder(resp.Body).Decode(&routes)
var res T
err = json.NewDecoder(resp.Body).Decode(&res)
if err != nil {
return nil, E.From(err)
outErr = E.From(err)
return
}
return routes, nil
return res, nil
}
func ListMiddlewareTraces() (middleware.Traces, E.NestedError) {
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListMiddlewareTrace))
if err != nil {
return nil, E.From(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, E.Failure("list middleware trace").Extraf("status code: %v", resp.StatusCode)
}
var traces middleware.Traces
err = json.NewDecoder(resp.Body).Decode(&traces)
if err != nil {
return nil, E.From(err)
}
return traces, nil
func ListRoutes() (map[string]map[string]any, E.Error) {
return List[map[string]map[string]any](v1.ListRoutes)
}
func ListMiddlewareTraces() (middleware.Traces, E.Error) {
return List[middleware.Traces](v1.ListMiddlewareTraces)
}
func DebugListTasks() (map[string]any, E.Error) {
return List[map[string]any](v1.ListTasks)
}

View File

@@ -7,10 +7,10 @@ import (
"github.com/yusing/go-proxy/internal/config"
)
func Reload(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
if err := cfg.Reload(); err != nil {
U.RespondJson(w, err.JSONObject(), http.StatusInternalServerError)
} else {
w.WriteHeader(http.StatusOK)
func Reload(w http.ResponseWriter, r *http.Request) {
if err := config.Reload(); err != nil {
U.HandleErr(w, r, err)
return
}
U.WriteBody(w, []byte("OK"))
}

View File

@@ -5,33 +5,35 @@ import (
"net/http"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
"github.com/yusing/go-proxy/internal/server"
"github.com/yusing/go-proxy/internal/utils"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
func Stats(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
U.HandleErr(w, r, U.RespondJson(w, getStats(cfg)))
func Stats(w http.ResponseWriter, r *http.Request) {
U.RespondJSON(w, r, getStats())
}
func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
originPats := make([]string, len(cfg.Value().MatchDomains)+len(localAddresses))
func StatsWS(w http.ResponseWriter, r *http.Request) {
var originPats []string
if len(originPats) == 0 {
U.Logger.Warnf("no match domains configured, accepting websocket request from all origins")
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
if len(config.Value().MatchDomains) == 0 {
U.LogWarn(r).Msg("no match domains configured, accepting websocket API request from all origins")
originPats = []string{"*"}
} else {
for i, domain := range cfg.Value().MatchDomains {
originPats[i] = "*." + domain
originPats = make([]string, len(config.Value().MatchDomains))
for i, domain := range config.Value().MatchDomains {
originPats[i] = "*" + domain
}
originPats = append(originPats, localAddresses...)
}
U.LogInfo(r).Msgf("websocket API request from origins: %s", originPats)
if common.IsDebug {
originPats = []string{"*"}
}
@@ -39,9 +41,10 @@ func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
OriginPatterns: originPats,
})
if err != nil {
U.Logger.Errorf("/stats/ws failed to upgrade websocket: %s", err)
U.LogError(r).Err(err).Msg("failed to upgrade websocket")
return
}
/* trunk-ignore(golangci-lint/errcheck) */
defer conn.CloseNow()
ctx, cancel := context.WithCancel(context.Background())
@@ -51,17 +54,17 @@ func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
defer ticker.Stop()
for range ticker.C {
stats := getStats(cfg)
stats := getStats()
if err := wsjson.Write(ctx, conn, stats); err != nil {
U.Logger.Errorf("/stats/ws failed to write JSON: %s", err)
U.LogError(r).Msg("failed to write JSON")
return
}
}
}
func getStats(cfg *config.Config) map[string]any {
func getStats() map[string]any {
return map[string]any{
"proxies": cfg.Statistics(),
"uptime": utils.FormatDuration(server.GetProxyServer().Uptime()),
"proxies": config.Statistics(),
"uptime": strutils.FormatDuration(server.GetProxyServer().Uptime()),
}
}

View File

@@ -1,37 +1,36 @@
package utils
import (
"errors"
"fmt"
"net/http"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error"
)
var Logger = logrus.WithField("module", "api")
// HandleErr logs the error and returns an HTTP error response to the client.
// If code is specified, it will be used as the HTTP status code; otherwise,
// http.StatusInternalServerError is used.
//
// The error is only logged but not returned to the client.
func HandleErr(w http.ResponseWriter, r *http.Request, origErr error, code ...int) {
if origErr == nil {
return
}
err := E.From(origErr).Subjectf("%s %s", r.Method, r.URL)
Logger.Error(err)
LogError(r).Msg(origErr.Error())
statusCode := http.StatusInternalServerError
if len(code) > 0 {
http.Error(w, err.String(), code[0])
return
statusCode = code[0]
}
http.Error(w, err.String(), http.StatusInternalServerError)
http.Error(w, http.StatusText(statusCode), statusCode)
}
func ErrMissingKey(k string) error {
return errors.New("missing key '" + k + "' in query or request body")
return E.New("missing key '" + k + "' in query or request body")
}
func ErrInvalidKey(k string) error {
return errors.New("invalid key '" + k + "' in query or request body")
return E.New("invalid key '" + k + "' in query or request body")
}
func ErrNotFound(k, v string) error {
return fmt.Errorf("key %q with value %q not found", k, v)
return E.Errorf("key %q with value %q not found", k, v)
}

View File

@@ -8,20 +8,21 @@ import (
"github.com/yusing/go-proxy/internal/common"
)
var HTTPClient = &http.Client{
Timeout: common.ConnectionTimeout,
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DisableKeepAlives: true,
ForceAttemptHTTP2: true,
DialContext: (&net.Dialer{
Timeout: common.DialTimeout,
KeepAlive: common.KeepAlive, // this is different from DisableKeepAlives
}).DialContext,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
var (
httpClient = &http.Client{
Timeout: common.ConnectionTimeout,
Transport: &http.Transport{
DisableKeepAlives: true,
ForceAttemptHTTP2: false,
DialContext: (&net.Dialer{
Timeout: common.DialTimeout,
KeepAlive: common.KeepAlive, // this is different from DisableKeepAlives
}).DialContext,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
var Get = HTTPClient.Get
var Post = HTTPClient.Post
var Head = HTTPClient.Head
Get = httpClient.Get
Post = httpClient.Post
Head = httpClient.Head
)

View File

@@ -0,0 +1,18 @@
package utils
import (
"net/http"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/logging"
)
func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
return logging.WithLevel(level).Str("module", "api").
Str("method", r.Method).
Str("path", r.RequestURI)
}
func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) }
func LogWarn(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.WarnLevel) }
func LogInfo(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.InfoLevel) }

View File

@@ -2,19 +2,42 @@ package utils
import (
"encoding/json"
"fmt"
"net/http"
"github.com/yusing/go-proxy/internal/logging"
)
func RespondJson(w http.ResponseWriter, data any, code ...int) error {
func WriteBody(w http.ResponseWriter, body []byte) {
if _, err := w.Write(body); err != nil {
HandleErr(w, nil, err)
}
}
func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) (canProceed bool) {
if len(code) > 0 {
w.WriteHeader(code[0])
}
w.Header().Set("Content-Type", "application/json")
j, err := json.MarshalIndent(data, "", " ")
if err != nil {
return err
} else {
w.Write(j)
var j []byte
var err error
switch data := data.(type) {
case string:
j = []byte(fmt.Sprintf("%q", data))
case []byte:
j = data
default:
j, err = json.MarshalIndent(data, "", " ")
if err != nil {
logging.Panic().Err(err).Msg("failed to marshal json")
return false
}
}
return nil
_, err = w.Write(j)
if err != nil {
HandleErr(w, r, err)
return false
}
return true
}

View File

@@ -3,9 +3,10 @@ package v1
import (
"net/http"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/pkg"
)
func GetVersion(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(pkg.GetVersion()))
WriteBody(w, []byte(pkg.GetVersion()))
}

View File

@@ -8,12 +8,21 @@ import (
"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/lego"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/config/types"
)
type Config types.AutoCertConfig
var (
ErrMissingDomain = E.New("missing field 'domains'")
ErrMissingEmail = E.New("missing field 'email'")
ErrMissingProvider = E.New("missing field 'provider'")
ErrUnknownProvider = E.New("unknown provider")
)
func NewConfig(cfg *types.AutoCertConfig) *Config {
if cfg.CertPath == "" {
cfg.CertPath = CertFileDefault
@@ -27,35 +36,36 @@ func NewConfig(cfg *types.AutoCertConfig) *Config {
return (*Config)(cfg)
}
func (cfg *Config) GetProvider() (provider *Provider, res E.NestedError) {
b := E.NewBuilder("unable to initialize autocert")
defer b.To(&res)
func (cfg *Config) GetProvider() (*Provider, E.Error) {
b := E.NewBuilder("autocert errors")
if cfg.Provider != ProviderLocal {
if len(cfg.Domains) == 0 {
b.Addf("%s", "no domains specified")
b.Add(ErrMissingDomain)
}
if cfg.Provider == "" {
b.Addf("%s", "no provider specified")
b.Add(ErrMissingProvider)
}
if cfg.Email == "" {
b.Addf("%s", "no email specified")
b.Add(ErrMissingEmail)
}
// check if provider is implemented
_, ok := providersGenMap[cfg.Provider]
if !ok {
b.Addf("unknown provider: %q", cfg.Provider)
b.Add(ErrUnknownProvider.
Subject(cfg.Provider).
Withf(strutils.DoYouMean(utils.NearestField(cfg.Provider, providersGenMap))))
}
}
if b.HasError() {
return
return nil, b.Error()
}
privKey, err := E.Check(ecdsa.GenerateKey(elliptic.P256(), rand.Reader))
if err.HasError() {
b.Add(E.FailWith("generate private key", err))
return
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
b.Addf("generate private key: %w", err)
return nil, b.Error()
}
user := &User{
@@ -66,11 +76,9 @@ func (cfg *Config) GetProvider() (provider *Provider, res E.NestedError) {
legoCfg := lego.NewConfig(user)
legoCfg.Certificate.KeyType = certcrypto.RSA2048
provider = &Provider{
return &Provider{
cfg: cfg,
user: user,
legoCfg: legoCfg,
}
return
}, nil
}

View File

@@ -1,13 +1,10 @@
package autocert
import (
"errors"
"github.com/go-acme/lego/v4/providers/dns/clouddns"
"github.com/go-acme/lego/v4/providers/dns/cloudflare"
"github.com/go-acme/lego/v4/providers/dns/duckdns"
"github.com/go-acme/lego/v4/providers/dns/ovh"
"github.com/sirupsen/logrus"
)
const (
@@ -32,9 +29,3 @@ var providersGenMap = map[string]ProviderGenerator{
ProviderDuckdns: providerGenerator(duckdns.NewDefaultConfig, duckdns.NewDNSProviderConfig),
ProviderOVH: providerGenerator(ovh.NewDefaultConfig, ovh.NewDNSProviderConfig),
}
var (
ErrGetCertFailure = errors.New("get certificate failed")
)
var logger = logrus.WithField("module", "autocert")

View File

@@ -0,0 +1,5 @@
package autocert
import "github.com/yusing/go-proxy/internal/logging"
var logger = logging.With().Str("module", "autocert").Logger()

View File

@@ -1,9 +1,9 @@
package autocert
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"os"
"path"
"reflect"
@@ -14,24 +14,29 @@ import (
"github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/registration"
"github.com/yusing/go-proxy/internal/config/types"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Provider struct {
cfg *Config
user *User
legoCfg *lego.Config
client *lego.Client
type (
Provider struct {
cfg *Config
user *User
legoCfg *lego.Config
client *lego.Client
tlsCert *tls.Certificate
certExpiries CertExpiries
}
tlsCert *tls.Certificate
certExpiries CertExpiries
}
ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.Error)
type ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.NestedError)
type CertExpiries map[string]time.Time
CertExpiries map[string]time.Time
)
var ErrGetCertFailure = errors.New("get certificate failed")
func (p *Provider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
if p.tlsCert == nil {
@@ -56,25 +61,20 @@ func (p *Provider) GetExpiries() CertExpiries {
return p.certExpiries
}
func (p *Provider) ObtainCert() (res E.NestedError) {
b := E.NewBuilder("failed to obtain certificate")
defer b.To(&res)
func (p *Provider) ObtainCert() E.Error {
if p.cfg.Provider == ProviderLocal {
return nil
}
if p.client == nil {
if err := p.initClient(); err.HasError() {
b.Add(E.FailWith("init autocert client", err))
return
if err := p.initClient(); err != nil {
return err
}
}
if p.user.Registration == nil {
if err := p.registerACME(); err.HasError() {
b.Add(E.FailWith("register ACME", err))
return
if err := p.registerACME(); err != nil {
return E.From(err)
}
}
@@ -83,27 +83,23 @@ func (p *Provider) ObtainCert() (res E.NestedError) {
Domains: p.cfg.Domains,
Bundle: true,
}
cert, err := E.Check(client.Certificate.Obtain(req))
if err.HasError() {
b.Add(err)
return
cert, err := client.Certificate.Obtain(req)
if err != nil {
return E.From(err)
}
if err = p.saveCert(cert); err.HasError() {
b.Add(E.FailWith("save certificate", err))
return
if err = p.saveCert(cert); err != nil {
return E.From(err)
}
tlsCert, err := E.Check(tls.X509KeyPair(cert.Certificate, cert.PrivateKey))
if err.HasError() {
b.Add(E.FailWith("parse obtained certificate", err))
return
tlsCert, err := tls.X509KeyPair(cert.Certificate, cert.PrivateKey)
if err != nil {
return E.From(err)
}
expiries, err := getCertExpiries(&tlsCert)
if err.HasError() {
b.Add(E.FailWith("get certificate expiry", err))
return
if err != nil {
return E.From(err)
}
p.tlsCert = &tlsCert
p.certExpiries = expiries
@@ -111,22 +107,23 @@ func (p *Provider) ObtainCert() (res E.NestedError) {
return nil
}
func (p *Provider) LoadCert() E.NestedError {
cert, err := E.Check(tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath))
if err.HasError() {
return err
func (p *Provider) LoadCert() E.Error {
cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)
if err != nil {
return E.Errorf("load SSL certificate: %w", err)
}
expiries, err := getCertExpiries(&cert)
if err.HasError() {
return err
if err != nil {
return E.Errorf("parse SSL certificate: %w", err)
}
p.tlsCert = &cert
p.certExpiries = expiries
logger.Infof("next renewal in %v", U.FormatDuration(time.Until(p.ShouldRenewOn())))
logger.Info().Msgf("next renewal in %v", strutils.FormatDuration(time.Until(p.ShouldRenewOn())))
return p.renewIfNeeded()
}
// ShouldRenewOn returns the time at which the certificate should be renewed.
func (p *Provider) ShouldRenewOn() time.Time {
for _, expiry := range p.certExpiries {
return expiry.AddDate(0, -1, 0) // 1 month before
@@ -135,55 +132,55 @@ func (p *Provider) ShouldRenewOn() time.Time {
panic("no certificate available")
}
func (p *Provider) ScheduleRenewal(ctx context.Context) {
func (p *Provider) ScheduleRenewal() {
if p.GetName() == ProviderLocal {
return
}
logger.Debug("started renewal scheduler")
defer logger.Debug("renewal scheduler stopped")
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C: // check every 5 seconds
if err := p.renewIfNeeded(); err.HasError() {
logger.Warn(err)
go func() {
task := task.GlobalTask("cert renew scheduler")
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
defer task.Finish("cert renew scheduler stopped")
for {
select {
case <-task.Context().Done():
return
case <-ticker.C: // check every 5 seconds
if err := p.renewIfNeeded(); err != nil {
E.LogWarn("cert renew failed", err, &logger)
}
}
}
}
}()
}
func (p *Provider) initClient() E.NestedError {
legoClient, err := E.Check(lego.NewClient(p.legoCfg))
if err.HasError() {
return E.FailWith("create lego client", err)
func (p *Provider) initClient() E.Error {
legoClient, err := lego.NewClient(p.legoCfg)
if err != nil {
return E.From(err)
}
legoProvider, err := providersGenMap[p.cfg.Provider](p.cfg.Options)
if err.HasError() {
return E.FailWith("create lego provider", err)
generator := providersGenMap[p.cfg.Provider]
legoProvider, pErr := generator(p.cfg.Options)
if pErr != nil {
return pErr
}
err = E.From(legoClient.Challenge.SetDNS01Provider(legoProvider))
if err.HasError() {
return E.FailWith("set challenge provider", err)
err = legoClient.Challenge.SetDNS01Provider(legoProvider)
if err != nil {
return E.From(err)
}
p.client = legoClient
return nil
}
func (p *Provider) registerACME() E.NestedError {
func (p *Provider) registerACME() error {
if p.user.Registration != nil {
return nil
}
reg, err := E.Check(p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}))
if err.HasError() {
reg, err := p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
if err != nil {
return err
}
p.user.Registration = reg
@@ -191,26 +188,27 @@ func (p *Provider) registerACME() E.NestedError {
return nil
}
func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError {
//* This should have been done in setup
//* but double check is always a good choice
func (p *Provider) saveCert(cert *certificate.Resource) error {
/* This should have been done in setup
but double check is always a good choice.*/
_, err := os.Stat(path.Dir(p.cfg.CertPath))
if err != nil {
if os.IsNotExist(err) {
if err = os.MkdirAll(path.Dir(p.cfg.CertPath), 0o755); err != nil {
return E.FailWith("create cert directory", err)
return err
}
} else {
return E.FailWith("stat cert directory", err)
return err
}
}
err = os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw-------
if err != nil {
return E.FailWith("write key file", err)
return err
}
err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0o644) // -rw-r--r--
if err != nil {
return E.FailWith("write cert file", err)
return err
}
return nil
}
@@ -232,39 +230,36 @@ func (p *Provider) certState() CertState {
sort.Strings(certDomains)
if !reflect.DeepEqual(certDomains, wantedDomains) {
logger.Debugf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains)
logger.Info().Msgf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains)
return CertStateMismatch
}
return CertStateValid
}
func (p *Provider) renewIfNeeded() E.NestedError {
func (p *Provider) renewIfNeeded() E.Error {
if p.cfg.Provider == ProviderLocal {
return nil
}
switch p.certState() {
case CertStateExpired:
logger.Info("certs expired, renewing")
logger.Info().Msg("certs expired, renewing")
case CertStateMismatch:
logger.Info("cert domains mismatch with config, renewing")
logger.Info().Msg("cert domains mismatch with config, renewing")
default:
return nil
}
if err := p.ObtainCert(); err.HasError() {
return E.FailWith("renew certificate", err)
}
return nil
return p.ObtainCert()
}
func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) {
func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) {
r := make(CertExpiries, len(cert.Certificate))
for _, cert := range cert.Certificate {
x509Cert, err := E.Check(x509.ParseCertificate(cert))
if err.HasError() {
return nil, E.FailWith("parse certificate", err)
x509Cert, err := x509.ParseCertificate(cert)
if err != nil {
return nil, err
}
if x509Cert.IsCA {
continue
@@ -281,16 +276,13 @@ func providerGenerator[CT any, PT challenge.Provider](
defaultCfg func() *CT,
newProvider func(*CT) (PT, error),
) ProviderGenerator {
return func(opt types.AutocertProviderOpt) (challenge.Provider, E.NestedError) {
return func(opt types.AutocertProviderOpt) (challenge.Provider, E.Error) {
cfg := defaultCfg()
err := U.Deserialize(opt, cfg)
if err.HasError() {
if err != nil {
return nil, err
}
p, err := E.Check(newProvider(cfg))
if err.HasError() {
return nil, err
}
return p, nil
p, pErr := newProvider(cfg)
return p, E.From(pErr)
}
}

View File

@@ -45,6 +45,6 @@ oauth2_config:
testYaml = testYaml[1:] // remove first \n
opt := make(map[string]any)
ExpectNoError(t, yaml.Unmarshal([]byte(testYaml), opt))
ExpectNoError(t, U.Deserialize(opt, cfg).Error())
ExpectNoError(t, U.Deserialize(opt, cfg))
ExpectDeepEqual(t, cfg, cfgExpected)
}

View File

@@ -1,27 +1,26 @@
package autocert
import (
"context"
"os"
E "github.com/yusing/go-proxy/internal/error"
)
func (p *Provider) Setup(ctx context.Context) (err E.NestedError) {
func (p *Provider) Setup() (err E.Error) {
if err = p.LoadCert(); err != nil {
if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist
return err
}
logger.Debug("obtaining cert due to error loading cert")
logger.Debug().Msg("obtaining cert due to error loading cert")
if err = p.ObtainCert(); err != nil {
return err
}
}
go p.ScheduleRenewal(ctx)
p.ScheduleRenewal()
for _, expiry := range p.GetExpiries() {
logger.Infof("certificate expire on %s", expiry)
logger.Info().Msg("certificate expire on " + expiry.String())
break
}

View File

@@ -1,8 +1,9 @@
package autocert
import (
"github.com/go-acme/lego/v4/registration"
"crypto"
"github.com/go-acme/lego/v4/registration"
)
type User struct {
@@ -19,4 +20,4 @@ func (u *User) GetRegistration() *registration.Resource {
}
func (u *User) GetPrivateKey() crypto.PrivateKey {
return u.key
}
}

View File

@@ -3,8 +3,7 @@ package common
import (
"flag"
"fmt"
"github.com/sirupsen/logrus"
"log"
)
type Args struct {
@@ -42,7 +41,7 @@ func GetArgs() Args {
flag.Parse()
args.Command = flag.Arg(0)
if err := validateArg(args.Command); err != nil {
logrus.Fatal(err)
log.Fatalf("invalid command: %s", err)
}
return args
}
@@ -53,5 +52,5 @@ func validateArg(arg string) error {
return nil
}
}
return fmt.Errorf("invalid command: %s", arg)
return fmt.Errorf("invalid command %q", arg)
}

View File

@@ -13,43 +13,44 @@ const (
// file, folder structure
const (
DotEnvPath = ".env"
DotEnvExamplePath = ".env.example"
ConfigBasePath = "config"
ConfigFileName = "config.yml"
ConfigExampleFileName = "config.example.yml"
ConfigPath = ConfigBasePath + "/" + ConfigFileName
MiddlewareComposeBasePath = ConfigBasePath + "/middlewares"
)
JWTKeyPath = ConfigBasePath + "/jwt.key"
MiddlewareComposeBasePath = ConfigBasePath + "/middlewares"
const (
SchemaBasePath = "schema"
ConfigSchemaPath = SchemaBasePath + "/config.schema.json"
FileProviderSchemaPath = SchemaBasePath + "/providers.schema.json"
)
const (
ComposeFileName = "compose.yml"
ComposeExampleFileName = "compose.example.yml"
)
const (
ErrorPagesBasePath = "error_pages"
)
var (
RequiredDirectories = []string{
ConfigBasePath,
SchemaBasePath,
ErrorPagesBasePath,
MiddlewareComposeBasePath,
}
)
var RequiredDirectories = []string{
ConfigBasePath,
SchemaBasePath,
ErrorPagesBasePath,
MiddlewareComposeBasePath,
}
const DockerHostFromEnv = "$DOCKER_HOST"
const (
IdleTimeoutDefault = "0"
HealthCheckIntervalDefault = 5 * time.Second
HealthCheckTimeoutDefault = 5 * time.Second
WakeTimeoutDefault = "30s"
StopTimeoutDefault = "10s"
StopMethodDefault = "stop"
)
const HeaderCheckRedirect = "X-Goproxy-Check-Redirect"

31
internal/common/crypto.go Normal file
View File

@@ -0,0 +1,31 @@
package common
import (
"crypto/rand"
"crypto/sha512"
"encoding/base64"
"github.com/rs/zerolog/log"
)
func HashPassword(pwd string) []byte {
h := sha512.New()
h.Write([]byte(pwd))
return h.Sum(nil)
}
func generateJWTKey(size int) string {
bytes := make([]byte, size)
if _, err := rand.Read(bytes); err != nil {
log.Panic().Err(err).Msg("failed to generate jwt key")
}
return base64.StdEncoding.EncodeToString(bytes)
}
func decodeJWTKey(key string) []byte {
bytes, err := base64.StdEncoding.DecodeString(key)
if err != nil {
log.Panic().Err(err).Msg("failed to decode jwt key")
}
return bytes
}

View File

@@ -2,19 +2,21 @@ package common
import (
"fmt"
"log"
"net"
"os"
"strconv"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/rs/zerolog/log"
)
var (
NoSchemaValidation = GetEnvBool("GOPROXY_NO_SCHEMA_VALIDATION", true)
IsTest = GetEnvBool("GOPROXY_TEST", false) || strings.HasSuffix(os.Args[0], ".test")
IsDebug = GetEnvBool("GOPROXY_DEBUG", IsTest)
IsDebugSkipAuth = GetEnvBool("GOPROXY_DEBUG_SKIP_AUTH", false)
IsTrace = GetEnvBool("GOPROXY_TRACE", false) && IsDebug
ProxyHTTPAddr,
ProxyHTTPHost,
@@ -30,6 +32,11 @@ var (
APIHTTPHost,
APIHTTPPort,
APIHTTPURL = GetAddrEnv("GOPROXY_API_ADDR", "127.0.0.1:8888", "http")
APIJWTSecret = decodeJWTKey(GetEnv("GOPROXY_API_JWT_SECRET", generateJWTKey(32)))
APIJWTTokenTTL = GetDurationEnv("GOPROXY_API_JWT_TOKEN_TTL", time.Hour)
APIUser = GetEnv("GOPROXY_API_USER", "admin")
APIPasswordHash = HashPassword(GetEnv("GOPROXY_API_PASSWORD", "password"))
)
func GetEnvBool(key string, defaultValue bool) bool {
@@ -39,7 +46,7 @@ func GetEnvBool(key string, defaultValue bool) bool {
}
b, err := strconv.ParseBool(value)
if err != nil {
log.Fatalf("Invalid boolean value: %s", value)
log.Fatal().Msgf("env %s: invalid boolean value: %s", key, value)
}
return b
}
@@ -56,7 +63,7 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL str
addr = GetEnv(key, defaultValue)
host, port, err := net.SplitHostPort(addr)
if err != nil {
logrus.Fatalf("Invalid address: %s", addr)
log.Fatal().Msgf("env %s: invalid address: %s", key, addr)
}
if host == "" {
host = "localhost"
@@ -64,3 +71,15 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL str
fullURL = fmt.Sprintf("%s://%s:%s", scheme, host, port)
return
}
func GetDurationEnv(key string, defaultValue time.Duration) time.Duration {
value, ok := os.LookupEnv(key)
if !ok || value == "" {
return defaultValue
}
d, err := time.ParseDuration(value)
if err != nil {
log.Fatal().Msgf("env %s: invalid duration value: %s", key, value)
}
return d
}

View File

@@ -1,236 +1,249 @@
package config
import (
"context"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config/types"
E "github.com/yusing/go-proxy/internal/error"
PR "github.com/yusing/go-proxy/internal/proxy/provider"
R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/route"
proxy "github.com/yusing/go-proxy/internal/route/provider"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
W "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events"
"gopkg.in/yaml.v3"
)
type Config struct {
value *types.Config
proxyProviders F.Map[string, *PR.Provider]
providers F.Map[string, *proxy.Provider]
autocertProvider *autocert.Provider
l logrus.FieldLogger
watcher W.Watcher
watcherCtx context.Context
watcherCancel context.CancelFunc
reloadReq chan struct{}
task task.Task
}
var instance *Config
var (
instance *Config
cfgWatcher watcher.Watcher
logger = logging.With().Str("module", "config").Logger()
reloadMu sync.Mutex
)
const configEventFlushInterval = 500 * time.Millisecond
const (
cfgRenameWarn = `Config file renamed, not reloading.
Make sure you rename it back before next time you start.`
cfgDeleteWarn = `Config file deleted, not reloading.
You may run "ls-config" to show or dump the current config.`
)
func GetInstance() *Config {
return instance
}
func Load() E.NestedError {
if instance != nil {
return nil
func newConfig() *Config {
return &Config{
value: types.DefaultConfig(),
providers: F.NewMapOf[string, *proxy.Provider](),
task: task.GlobalTask("config"),
}
instance = &Config{
value: types.DefaultConfig(),
proxyProviders: F.NewMapOf[string, *PR.Provider](),
l: logrus.WithField("module", "config"),
watcher: W.NewConfigFileWatcher(common.ConfigFileName),
reloadReq: make(chan struct{}, 1),
}
return instance.load()
}
func Validate(data []byte) E.NestedError {
func Load() (*Config, E.Error) {
if instance != nil {
return instance, nil
}
instance = newConfig()
cfgWatcher = watcher.NewConfigFileWatcher(common.ConfigFileName)
return instance, instance.load()
}
func Validate(data []byte) E.Error {
return U.ValidateYaml(U.GetSchema(common.ConfigSchemaPath), data)
}
func MatchDomains() []string {
if instance == nil {
logrus.Panic("config has not been loaded, please check if there is any errors")
}
return instance.value.MatchDomains
}
func (cfg *Config) Value() types.Config {
if cfg == nil {
logrus.Panic("config has not been loaded, please check if there is any errors")
}
return *cfg.value
func WatchChanges() {
task := task.GlobalTask("Config watcher")
eventQueue := events.NewEventQueue(
task,
configEventFlushInterval,
OnConfigChange,
func(err E.Error) {
E.LogError("config reload error", err, &logger)
},
)
eventQueue.Start(cfgWatcher.Events(task.Context()))
}
func (cfg *Config) GetAutoCertProvider() *autocert.Provider {
if instance == nil {
logrus.Panic("config has not been loaded, please check if there is any errors")
func OnConfigChange(flushTask task.Task, ev []events.Event) {
defer flushTask.Finish("config reload complete")
// no matter how many events during the interval
// just reload once and check the last event
switch ev[len(ev)-1].Action {
case events.ActionFileRenamed:
logger.Warn().Msg(cfgRenameWarn)
return
case events.ActionFileDeleted:
logger.Warn().Msg(cfgDeleteWarn)
return
}
if err := Reload(); err != nil {
// recovered in event queue
panic(err)
}
return cfg.autocertProvider
}
func (cfg *Config) Dispose() {
if cfg.watcherCancel != nil {
cfg.watcherCancel()
cfg.l.Debug("stopped watcher")
func Reload() E.Error {
// avoid race between config change and API reload request
reloadMu.Lock()
defer reloadMu.Unlock()
newCfg := newConfig()
err := newCfg.load()
if err != nil {
return err
}
cfg.stopProviders()
// cancel all current subtasks -> wait
// -> replace config -> start new subtasks
instance.task.Finish("config changed")
instance.task.Wait()
*instance = *newCfg
instance.StartProxyProviders()
return nil
}
func (cfg *Config) Reload() (err E.NestedError) {
cfg.stopProviders()
err = cfg.load()
cfg.StartProxyProviders()
return
func Value() types.Config {
return *instance.value
}
func GetAutoCertProvider() *autocert.Provider {
return instance.autocertProvider
}
func (cfg *Config) Task() task.Task {
return cfg.task
}
func (cfg *Config) StartProxyProviders() {
cfg.controlProviders("start", (*PR.Provider).StartAllRoutes)
}
func (cfg *Config) WatchChanges() {
cfg.watcherCtx, cfg.watcherCancel = context.WithCancel(context.Background())
go func() {
for {
select {
case <-cfg.watcherCtx.Done():
return
case <-cfg.reloadReq:
if err := cfg.Reload(); err.HasError() {
cfg.l.Error(err)
}
}
}
}()
go func() {
eventCh, errCh := cfg.watcher.Events(cfg.watcherCtx)
for {
select {
case <-cfg.watcherCtx.Done():
return
case event := <-eventCh:
if event.Action == events.ActionFileDeleted || event.Action == events.ActionFileRenamed {
cfg.l.Error("config file deleted or renamed, ignoring...")
continue
} else {
cfg.reloadReq <- struct{}{}
}
case err := <-errCh:
cfg.l.Error(err)
continue
}
}
}()
}
func (cfg *Config) forEachRoute(do func(alias string, r R.Route, p *PR.Provider)) {
cfg.proxyProviders.RangeAll(func(_ string, p *PR.Provider) {
p.RangeRoutes(func(a string, r R.Route) {
do(a, r, p)
errs := cfg.providers.CollectErrorsParallel(
func(_ string, p *proxy.Provider) error {
subtask := cfg.task.Subtask(p.String())
return p.Start(subtask)
})
})
if err := E.Join(errs...); err != nil {
E.LogError("route provider errors", err, &logger)
}
}
func (cfg *Config) load() (res E.NestedError) {
b := E.NewBuilder("errors loading config")
defer b.To(&res)
func (cfg *Config) load() E.Error {
const errMsg = "config load error"
cfg.l.Debug("loading config")
defer cfg.l.Debug("loaded config")
data, err := E.Check(os.ReadFile(common.ConfigPath))
if err.HasError() {
b.Add(E.FailWith("read config", err))
logrus.Fatal(b.Build())
data, err := os.ReadFile(common.ConfigPath)
if err != nil {
E.LogFatal(errMsg, err, &logger)
}
if !common.NoSchemaValidation {
if err = Validate(data); err.HasError() {
b.Add(E.FailWith("schema validation", err))
logrus.Fatal(b.Build())
if err := Validate(data); err != nil {
E.LogFatal(errMsg, err, &logger)
}
}
model := types.DefaultConfig()
if err := E.From(yaml.Unmarshal(data, model)); err.HasError() {
b.Add(E.FailWith("parse config", err))
logrus.Fatal(b.Build())
if err := E.From(yaml.Unmarshal(data, model)); err != nil {
E.LogFatal(errMsg, err, &logger)
}
// errors are non fatal below
b.Add(cfg.initAutoCert(&model.AutoCert))
b.Add(cfg.loadProviders(&model.Providers))
errs := E.NewBuilder(errMsg)
errs.Add(cfg.initNotification(model.Providers.Notification))
errs.Add(cfg.initAutoCert(&model.AutoCert))
errs.Add(cfg.loadRouteProviders(&model.Providers))
cfg.value = model
R.SetFindMuxDomains(model.MatchDomains)
return
for i, domain := range model.MatchDomains {
if !strings.HasPrefix(domain, ".") {
model.MatchDomains[i] = "." + domain
}
}
route.SetFindMuxDomains(model.MatchDomains)
return errs.Error()
}
func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.NestedError) {
func (cfg *Config) initNotification(notifCfgMap types.NotificationConfigMap) (err E.Error) {
if len(notifCfgMap) == 0 {
return
}
errs := E.NewBuilder("notification providers load errors")
for name, notifCfg := range notifCfgMap {
_, err := notif.RegisterProvider(cfg.task.Subtask(name), notifCfg)
errs.Add(err)
}
return errs.Error()
}
func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) {
if cfg.autocertProvider != nil {
return
}
cfg.l.Debug("initializing autocert")
defer cfg.l.Debug("initialized autocert")
cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider()
if err.HasError() {
err = E.FailWith("autocert provider", err)
}
return
}
func (cfg *Config) loadProviders(providers *types.ProxyProviders) (res E.NestedError) {
cfg.l.Debug("loading providers")
defer cfg.l.Debug("loaded providers")
func (cfg *Config) loadRouteProviders(providers *types.Providers) E.Error {
subtask := cfg.task.Subtask("load route providers")
defer subtask.Finish("done")
b := E.NewBuilder("errors loading providers")
defer b.To(&res)
errs := E.NewBuilder("route provider errors")
results := E.NewBuilder("loaded route providers")
lenLongestName := 0
for _, filename := range providers.Files {
p, err := PR.NewFileProvider(filename)
p, err := proxy.NewFileProvider(filename)
if err != nil {
b.Add(err.Subject(filename))
errs.Add(E.PrependSubject(filename, err))
continue
}
cfg.proxyProviders.Store(p.GetName(), p)
b.Add(p.LoadRoutes().Subject(filename))
cfg.providers.Store(p.GetName(), p)
if len(p.GetName()) > lenLongestName {
lenLongestName = len(p.GetName())
}
}
for name, dockerHost := range providers.Docker {
p, err := PR.NewDockerProvider(name, dockerHost)
p, err := proxy.NewDockerProvider(name, dockerHost)
if err != nil {
b.Add(err.Subjectf("%s (%s)", name, dockerHost))
errs.Add(E.PrependSubject(name, err))
continue
}
cfg.proxyProviders.Store(p.GetName(), p)
b.Add(p.LoadRoutes().Subject(dockerHost))
}
return
}
func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) {
errors := E.NewBuilder("errors in %s these providers", action)
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) {
if err := do(p); err.HasError() {
errors.Add(err.Subject(p))
cfg.providers.Store(p.GetName(), p)
if len(p.GetName()) > lenLongestName {
lenLongestName = len(p.GetName())
}
})
if err := errors.Build(); err.HasError() {
cfg.l.Error(err)
}
}
func (cfg *Config) stopProviders() {
cfg.controlProviders("stop routes", (*PR.Provider).StopAllRoutes)
cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) {
if err := p.LoadRoutes(); err != nil {
errs.Add(err.Subject(p.String()))
}
results.Addf("%-"+strconv.Itoa(lenLongestName)+"s %d routes", p.GetName(), p.NumRoutes())
})
logger.Info().Msg(results.String())
return errs.Error()
}

View File

@@ -5,34 +5,36 @@ import (
"strings"
"github.com/yusing/go-proxy/internal/common"
H "github.com/yusing/go-proxy/internal/homepage"
PR "github.com/yusing/go-proxy/internal/proxy/provider"
R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/types"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/homepage"
"github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/route"
proxy "github.com/yusing/go-proxy/internal/route/provider"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
func (cfg *Config) DumpEntries() map[string]*types.RawEntry {
entries := make(map[string]*types.RawEntry)
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
entries[alias] = r.Entry()
func DumpEntries() map[string]*entry.RawEntry {
entries := make(map[string]*entry.RawEntry)
instance.providers.RangeAll(func(_ string, p *proxy.Provider) {
p.RangeRoutes(func(alias string, r *route.Route) {
entries[alias] = r.Entry
})
})
return entries
}
func (cfg *Config) DumpProviders() map[string]*PR.Provider {
entries := make(map[string]*PR.Provider)
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) {
func DumpProviders() map[string]*proxy.Provider {
entries := make(map[string]*proxy.Provider)
instance.providers.RangeAll(func(name string, p *proxy.Provider) {
entries[name] = p
})
return entries
}
func (cfg *Config) HomepageConfig() H.HomePageConfig {
func HomepageConfig() homepage.Config {
var proto, port string
domains := cfg.value.MatchDomains
cert, _ := cfg.autocertProvider.GetCert(nil)
domains := instance.value.MatchDomains
cert, _ := instance.autocertProvider.GetCert(nil)
if cert != nil {
proto = "https"
port = common.ProxyHTTPSPort
@@ -41,31 +43,25 @@ func (cfg *Config) HomepageConfig() H.HomePageConfig {
port = common.ProxyHTTPPort
}
hpCfg := H.NewHomePageConfig()
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
if !r.Started() {
return
}
entry := r.Entry()
if entry.Homepage == nil {
entry.Homepage = &H.HomePageItem{
Show: r.Entry().IsExplicit || !p.IsExplicitOnly(),
}
}
item := entry.Homepage
if !item.Show && !item.IsEmpty() {
hpCfg := homepage.NewHomePageConfig()
route.GetReverseProxies().RangeAll(func(alias string, r *route.HTTPRoute) {
en := r.Raw
item := en.Homepage
if item == nil {
item = new(homepage.Item)
item.Show = true
}
if !item.Show || r.Type() != R.RouteTypeReverseProxy {
if !item.IsEmpty() {
item.Show = true
}
if !item.Show {
return
}
if item.Name == "" {
item.Name = U.Title(
item.Name = strutils.Title(
strings.ReplaceAll(
strings.ReplaceAll(alias, "-", " "),
"_", " ",
@@ -73,16 +69,22 @@ func (cfg *Config) HomepageConfig() H.HomePageConfig {
)
}
if p.GetType() == PR.ProviderTypeDocker {
switch {
case entry.IsDocker(r):
if item.Category == "" {
item.Category = "Docker"
}
item.SourceType = string(PR.ProviderTypeDocker)
} else if p.GetType() == PR.ProviderTypeFile {
item.SourceType = string(proxy.ProviderTypeDocker)
case entry.UseLoadBalance(r):
if item.Category == "" {
item.Category = "Load-balanced"
}
item.SourceType = "loadbalancer"
default:
if item.Category == "" {
item.Category = "Others"
}
item.SourceType = string(PR.ProviderTypeFile)
item.SourceType = string(proxy.ProviderTypeFile)
}
if item.URL == "" {
@@ -90,39 +92,39 @@ func (cfg *Config) HomepageConfig() H.HomePageConfig {
item.URL = fmt.Sprintf("%s://%s.%s:%s", proto, strings.ToLower(alias), domains[0], port)
}
}
item.AltURL = r.URL().String()
item.AltURL = r.TargetURL().String()
hpCfg.Add(item)
})
return hpCfg
}
func (cfg *Config) RoutesByAlias() map[string]U.SerializedObject {
routes := make(map[string]U.SerializedObject)
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
if !r.Started() {
return
func RoutesByAlias(typeFilter ...route.RouteType) map[string]any {
routes := make(map[string]any)
if len(typeFilter) == 0 || typeFilter[0] == "" {
typeFilter = []route.RouteType{route.RouteTypeReverseProxy, route.RouteTypeStream}
}
for _, t := range typeFilter {
switch t {
case route.RouteTypeReverseProxy:
route.GetReverseProxies().RangeAll(func(alias string, r *route.HTTPRoute) {
routes[alias] = r
})
case route.RouteTypeStream:
route.GetStreamProxies().RangeAll(func(alias string, r *route.StreamRoute) {
routes[alias] = r
})
}
obj, err := U.Serialize(r)
if err.HasError() {
cfg.l.Error(err)
return
}
obj["provider"] = p.GetName()
obj["type"] = string(r.Type())
obj["started"] = r.Started()
obj["raw"] = r.Entry()
routes[alias] = obj
})
}
return routes
}
func (cfg *Config) Statistics() map[string]any {
func Statistics() map[string]any {
nTotalStreams := 0
nTotalRPs := 0
providerStats := make(map[string]PR.ProviderStats)
providerStats := make(map[string]proxy.ProviderStats)
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) {
instance.providers.RangeAll(func(name string, p *proxy.Provider) {
providerStats[name] = p.Statistics()
})
@@ -138,9 +140,9 @@ func (cfg *Config) Statistics() map[string]any {
}
}
func (cfg *Config) FindRoute(alias string) R.Route {
return F.MapFind(cfg.proxyProviders,
func(p *PR.Provider) (R.Route, bool) {
func FindRoute(alias string) *route.Route {
return F.MapFind(instance.providers,
func(p *proxy.Provider) (*route.Route, bool) {
if route, ok := p.GetRoute(alias); ok {
return route, true
}

View File

@@ -0,0 +1,13 @@
package types
type (
AutoCertConfig struct {
Email string `json:"email,omitempty" yaml:"email"`
Domains []string `json:"domains,omitempty" yaml:",flow"`
CertPath string `json:"cert_path,omitempty" yaml:"cert_path"`
KeyPath string `json:"key_path,omitempty" yaml:"key_path"`
Provider string `json:"provider,omitempty" yaml:"provider"`
Options AutocertProviderOpt `json:"options,omitempty" yaml:",flow"`
}
AutocertProviderOpt map[string]any
)

View File

@@ -0,0 +1,25 @@
package types
type (
Config struct {
Providers Providers `json:"providers" yaml:",flow"`
AutoCert AutoCertConfig `json:"autocert" yaml:",flow"`
ExplicitOnly bool `json:"explicit_only" yaml:"explicit_only"`
MatchDomains []string `json:"match_domains" yaml:"match_domains"`
TimeoutShutdown int `json:"timeout_shutdown" yaml:"timeout_shutdown"`
RedirectToHTTPS bool `json:"redirect_to_https" yaml:"redirect_to_https"`
}
Providers struct {
Files []string `json:"include" yaml:"include"`
Docker map[string]string `json:"docker" yaml:"docker"`
Notification NotificationConfigMap `json:"notification" yaml:"notification"`
}
)
func DefaultConfig() *Config {
return &Config{
Providers: Providers{},
TimeoutShutdown: 3,
RedirectToHTTPS: false,
}
}

View File

@@ -0,0 +1,5 @@
package types
import "github.com/yusing/go-proxy/internal/notif"
type NotificationConfigMap map[string]notif.ProviderConfig

View File

@@ -1,65 +1,61 @@
package docker
import (
"errors"
"net/http"
"sync"
"sync/atomic"
"github.com/docker/cli/cli/connhelper"
"github.com/docker/docker/client"
"github.com/sirupsen/logrus"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type Client struct {
key string
refCount *atomic.Int32
*client.Client
type (
Client = *SharedClient
SharedClient struct {
*client.Client
l logrus.FieldLogger
key string
refCount *U.RefCount
l zerolog.Logger
}
)
var (
clientMap F.Map[string, Client] = F.NewMapOf[string, Client]()
clientMapMu sync.Mutex
clientOptEnvHost = []client.Opt{
client.WithHostFromEnv(),
client.WithAPIVersionNegotiation(),
}
)
func init() {
task.GlobalTask("close docker clients").OnFinished("", func() {
clientMap.RangeAllParallel(func(_ string, c Client) {
if c.Connected() {
c.Client.Close()
}
})
})
}
func ParseDockerHostname(host string) (string, E.NestedError) {
switch host {
case common.DockerHostFromEnv, "":
return "localhost", nil
}
url, err := E.Check(client.ParseHostURL(host))
if err != nil {
return "", E.Invalid("host", host).With(err)
}
return url.Hostname(), nil
func (c *SharedClient) Connected() bool {
return c != nil && c.Client != nil
}
func (c Client) DaemonHostname() string {
// DaemonHost should always return a valid host
hostname, _ := ParseDockerHostname(c.DaemonHost())
return hostname
}
func (c Client) Connected() bool {
return c.Client != nil
}
// if the client is still referenced, this is no-op
func (c *Client) Close() error {
if c.refCount.Add(-1) > 0 {
return nil
// if the client is still referenced, this is no-op.
func (c *SharedClient) Close() {
if c.Connected() {
c.refCount.Sub()
}
clientMap.Delete(c.key)
client := c.Client
c.Client = nil
c.l.Debugf("client closed")
if client != nil {
return client.Close()
}
return nil
}
// ConnectClient creates a new Docker client connection to the specified host.
@@ -72,13 +68,13 @@ func (c *Client) Close() error {
// Returns:
// - Client: the Docker client connection.
// - error: an error if the connection failed.
func ConnectClient(host string) (Client, E.NestedError) {
func ConnectClient(host string) (Client, error) {
clientMapMu.Lock()
defer clientMapMu.Unlock()
// check if client exists
if client, ok := clientMap.Load(host); ok {
client.refCount.Add(1)
client.refCount.Add()
return client, nil
}
@@ -86,12 +82,14 @@ func ConnectClient(host string) (Client, E.NestedError) {
var opt []client.Opt
switch host {
case "":
return nil, errors.New("empty docker host")
case common.DockerHostFromEnv:
opt = clientOptEnvHost
default:
helper, err := E.Check(connhelper.GetConnectionHelper(host))
if err.HasError() {
return Client{}, E.UnexpectedError(err.Error())
helper, err := connhelper.GetConnectionHelper(host)
if err != nil {
logging.Panic().Err(err).Msg("failed to get connection helper")
}
if helper != nil {
httpClient := &http.Client{
@@ -113,41 +111,29 @@ func ConnectClient(host string) (Client, E.NestedError) {
}
}
client, err := E.Check(client.NewClientWithOpts(opt...))
if err.HasError() {
return Client{}, err
client, err := client.NewClientWithOpts(opt...)
if err != nil {
return nil, err
}
c := Client{
c := &SharedClient{
Client: client,
key: host,
refCount: &atomic.Int32{},
l: logger.WithField("docker_client", client.DaemonHost()),
refCount: U.NewRefCounter(),
l: logger.With().Str("address", client.DaemonHost()).Logger(),
}
c.refCount.Add(1)
c.l.Debugf("client connected")
c.l.Trace().Msg("client connected")
clientMap.Store(host, c)
go func() {
<-c.refCount.Zero()
clientMap.Delete(c.key)
if c.Connected() {
c.Client.Close()
c.l.Trace().Msg("client closed")
}
}()
return c, nil
}
func CloseAllClients() {
clientMap.RangeAll(func(_ string, c Client) {
c.Client.Close()
})
clientMap.Clear()
logger.Debug("closed all clients")
}
var (
clientMap F.Map[string, Client] = F.NewMapOf[string, Client]()
clientMapMu sync.Mutex
clientOptEnvHost = []client.Opt{
client.WithHostFromEnv(),
client.WithAPIVersionNegotiation(),
}
logger = logrus.WithField("module", "docker")
)

View File

@@ -1,54 +0,0 @@
package docker
import (
"context"
"time"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
E "github.com/yusing/go-proxy/internal/error"
)
type ClientInfo struct {
Client Client
Containers []types.Container
}
var listOptions = container.ListOptions{
// Filters: filters.NewArgs(
// filters.Arg("health", "healthy"),
// filters.Arg("health", "none"),
// filters.Arg("health", "starting"),
// ),
All: true,
}
func GetClientInfo(clientHost string, getContainer bool) (*ClientInfo, E.NestedError) {
dockerClient, err := ConnectClient(clientHost)
if err.HasError() {
return nil, E.FailWith("connect to docker", err)
}
defer dockerClient.Close()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
var containers []types.Container
if getContainer {
containers, err = E.Check(dockerClient.ContainerList(ctx, listOptions))
if err.HasError() {
return nil, E.FailWith("list containers", err)
}
}
return &ClientInfo{
Client: dockerClient,
Containers: containers,
}, nil
}
func IsErrConnectionFailed(err error) bool {
return client.IsErrConnectionFailed(err)
}

View File

@@ -1,143 +1,144 @@
package docker
import (
"fmt"
"net/url"
"strconv"
"strings"
"github.com/docker/docker/api/types"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Container struct {
*types.Container
*ProxyProperties
}
type (
PortMapping = map[string]types.Port
Container struct {
_ U.NoCopy
func FromDocker(c *types.Container, dockerHost string) (res Container) {
res.Container = c
isExplicit := c.Labels[LabelAliases] != ""
res.ProxyProperties = &ProxyProperties{
DockerHost: dockerHost,
ContainerName: res.getName(),
ContainerID: c.ID,
ImageName: res.getImageName(),
PublicPortMapping: res.getPublicPortMapping(),
PrivatePortMapping: res.getPrivatePortMapping(),
NetworkMode: c.HostConfig.NetworkMode,
Aliases: res.getAliases(),
IsExcluded: U.ParseBool(res.getDeleteLabel(LabelExclude)),
IsExplicit: isExplicit,
IsDatabase: res.isDatabase(),
IdleTimeout: res.getDeleteLabel(LabelIdleTimeout),
WakeTimeout: res.getDeleteLabel(LabelWakeTimeout),
StopMethod: res.getDeleteLabel(LabelStopMethod),
StopTimeout: res.getDeleteLabel(LabelStopTimeout),
StopSignal: res.getDeleteLabel(LabelStopSignal),
Running: c.Status == "running" || c.State == "running",
DockerHost string `json:"docker_host" yaml:"-"`
ContainerName string `json:"container_name" yaml:"-"`
ContainerID string `json:"container_id" yaml:"-"`
ImageName string `json:"image_name" yaml:"-"`
Labels map[string]string `json:"labels" yaml:"-"`
PublicPortMapping PortMapping `json:"public_ports" yaml:"-"` // non-zero publicPort:types.Port
PrivatePortMapping PortMapping `json:"private_ports" yaml:"-"` // privatePort:types.Port
PublicIP string `json:"public_ip" yaml:"-"`
PrivateIP string `json:"private_ip" yaml:"-"`
NetworkMode string `json:"network_mode" yaml:"-"`
Aliases []string `json:"aliases" yaml:"-"`
IsExcluded bool `json:"is_excluded" yaml:"-"`
IsExplicit bool `json:"is_explicit" yaml:"-"`
IsDatabase bool `json:"is_database" yaml:"-"`
IdleTimeout string `json:"idle_timeout,omitempty" yaml:"-"`
WakeTimeout string `json:"wake_timeout,omitempty" yaml:"-"`
StopMethod string `json:"stop_method,omitempty" yaml:"-"`
StopTimeout string `json:"stop_timeout,omitempty" yaml:"-"` // stop_method = "stop" only
StopSignal string `json:"stop_signal,omitempty" yaml:"-"` // stop_method = "stop" | "kill" only
Running bool `json:"running" yaml:"-"`
}
)
var DummyContainer = new(Container)
func FromDocker(c *types.Container, dockerHost string) (res *Container) {
isExplicit := c.Labels[LabelAliases] != ""
helper := containerHelper{c}
res = &Container{
DockerHost: dockerHost,
ContainerName: helper.getName(),
ContainerID: c.ID,
ImageName: helper.getImageName(),
Labels: c.Labels,
PublicPortMapping: helper.getPublicPortMapping(),
PrivatePortMapping: helper.getPrivatePortMapping(),
NetworkMode: c.HostConfig.NetworkMode,
Aliases: helper.getAliases(),
IsExcluded: strutils.ParseBool(helper.getDeleteLabel(LabelExclude)),
IsExplicit: isExplicit,
IsDatabase: helper.isDatabase(),
IdleTimeout: helper.getDeleteLabel(LabelIdleTimeout),
WakeTimeout: helper.getDeleteLabel(LabelWakeTimeout),
StopMethod: helper.getDeleteLabel(LabelStopMethod),
StopTimeout: helper.getDeleteLabel(LabelStopTimeout),
StopSignal: helper.getDeleteLabel(LabelStopSignal),
Running: c.Status == "running" || c.State == "running",
}
res.setPrivateIP(helper)
res.setPublicIP()
return
}
func FromJson(json types.ContainerJSON, dockerHost string) Container {
func FromJSON(json types.ContainerJSON, dockerHost string) *Container {
ports := make([]types.Port, 0)
for k, bindings := range json.NetworkSettings.Ports {
privPortStr, proto := k.Port(), k.Proto()
privPort, _ := strconv.ParseUint(privPortStr, 10, 16)
ports = append(ports, types.Port{
PrivatePort: uint16(privPort),
Type: proto,
})
for _, v := range bindings {
pubPort, _ := strconv.ParseUint(v.HostPort, 10, 16)
privPort, _ := strconv.ParseUint(k.Port(), 10, 16)
ports = append(ports, types.Port{
IP: v.HostIP,
PublicPort: uint16(pubPort),
PrivatePort: uint16(privPort),
Type: proto,
})
}
}
cont := FromDocker(&types.Container{
ID: json.ID,
Names: []string{json.Name},
Names: []string{strings.TrimPrefix(json.Name, "/")},
Image: json.Image,
Ports: ports,
Labels: json.Config.Labels,
State: json.State.Status,
Status: json.State.Status,
Mounts: json.Mounts,
NetworkSettings: &types.SummaryNetworkSettings{
Networks: json.NetworkSettings.Networks,
},
}, dockerHost)
cont.NetworkMode = string(json.HostConfig.NetworkMode)
return cont
}
func (c Container) getDeleteLabel(label string) string {
if l, ok := c.Labels[label]; ok {
delete(c.Labels, label)
return l
func (c *Container) setPublicIP() {
if !c.Running {
return
}
return ""
}
func (c Container) getAliases() []string {
if l := c.getDeleteLabel(LabelAliases); l != "" {
return U.CommaSeperatedList(l)
} else {
return []string{c.getName()}
if strings.HasPrefix(c.DockerHost, "unix://") {
c.PublicIP = "127.0.0.1"
return
}
url, err := url.Parse(c.DockerHost)
if err != nil {
logger.Err(err).Msgf("invalid docker host %q, falling back to 127.0.0.1", c.DockerHost)
c.PublicIP = "127.0.0.1"
return
}
c.PublicIP = url.Hostname()
}
func (c Container) getName() string {
return strings.TrimPrefix(c.Names[0], "/")
}
func (c Container) getImageName() string {
colonSep := strings.Split(c.Image, ":")
slashSep := strings.Split(colonSep[0], "/")
return slashSep[len(slashSep)-1]
}
func (c Container) getPublicPortMapping() PortMapping {
res := make(PortMapping)
for _, v := range c.Ports {
if v.PublicPort == 0 {
func (c *Container) setPrivateIP(helper containerHelper) {
if !strings.HasPrefix(c.DockerHost, "unix://") {
return
}
if helper.NetworkSettings == nil {
return
}
for _, v := range helper.NetworkSettings.Networks {
if v.IPAddress == "" {
continue
}
res[fmt.Sprint(v.PublicPort)] = v
c.PrivateIP = v.IPAddress
return
}
return res
}
func (c Container) getPrivatePortMapping() PortMapping {
res := make(PortMapping)
for _, v := range c.Ports {
res[fmt.Sprint(v.PrivatePort)] = v
}
return res
}
var databaseMPs = map[string]struct{}{
"/var/lib/postgresql/data": {},
"/var/lib/mysql": {},
"/var/lib/mongodb": {},
"/var/lib/mariadb": {},
"/var/lib/memcached": {},
"/var/lib/rabbitmq": {},
}
var databasePrivPorts = map[uint16]struct{}{
5432: {}, // postgres
3306: {}, // mysql, mariadb
6379: {}, // redis
11211: {}, // memcached
27017: {}, // mongodb
}
func (c Container) isDatabase() bool {
for _, m := range c.Container.Mounts {
if _, ok := databaseMPs[m.Destination]; ok {
return true
}
}
for _, v := range c.Ports {
if _, ok := databasePrivPorts[v.PrivatePort]; ok {
return true
}
}
return false
}

View File

@@ -0,0 +1,90 @@
package docker
import (
"strings"
"github.com/docker/docker/api/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type containerHelper struct {
*types.Container
}
// getDeleteLabel gets the value of a label and then deletes it from the container.
// If the label does not exist, an empty string is returned.
func (c containerHelper) getDeleteLabel(label string) string {
if l, ok := c.Labels[label]; ok {
delete(c.Labels, label)
return l
}
return ""
}
func (c containerHelper) getAliases() []string {
if l := c.getDeleteLabel(LabelAliases); l != "" {
return strutils.CommaSeperatedList(l)
}
return []string{c.getName()}
}
func (c containerHelper) getName() string {
return strings.TrimPrefix(c.Names[0], "/")
}
func (c containerHelper) getImageName() string {
colonSep := strings.Split(c.Image, ":")
slashSep := strings.Split(colonSep[0], "/")
return slashSep[len(slashSep)-1]
}
func (c containerHelper) getPublicPortMapping() PortMapping {
res := make(PortMapping)
for _, v := range c.Ports {
if v.PublicPort == 0 {
continue
}
res[strutils.PortString(v.PublicPort)] = v
}
return res
}
func (c containerHelper) getPrivatePortMapping() PortMapping {
res := make(PortMapping)
for _, v := range c.Ports {
res[strutils.PortString(v.PrivatePort)] = v
}
return res
}
var databaseMPs = map[string]struct{}{
"/var/lib/postgresql/data": {},
"/var/lib/mysql": {},
"/var/lib/mongodb": {},
"/var/lib/mariadb": {},
"/var/lib/memcached": {},
"/var/lib/rabbitmq": {},
}
var databasePrivPorts = map[uint16]struct{}{
5432: {}, // postgres
3306: {}, // mysql, mariadb
6379: {}, // redis
11211: {}, // memcached
27017: {}, // mongodb
}
func (c containerHelper) isDatabase() bool {
for _, m := range c.Mounts {
if _, ok := databaseMPs[m.Destination]; ok {
return true
}
}
for _, v := range c.Ports {
if _, ok := databasePrivPorts[v.PrivatePort]; ok {
return true
}
}
return false
}

View File

@@ -3,9 +3,10 @@ package idlewatcher
import (
"bytes"
_ "embed"
"fmt"
"strings"
"text/template"
"github.com/yusing/go-proxy/internal/common"
)
type templateData struct {
@@ -18,18 +19,15 @@ type templateData struct {
var loadingPage []byte
var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(loadingPage)))
const headerCheckRedirect = "X-GoProxy-Check-Redirect"
func (w *watcher) makeRespBody(format string, args ...any) []byte {
msg := fmt.Sprintf(format, args...)
func (w *Watcher) makeLoadingPageBody() []byte {
msg := w.ContainerName + " is starting..."
data := new(templateData)
data.CheckRedirectHeader = headerCheckRedirect
data.CheckRedirectHeader = common.HeaderCheckRedirect
data.Title = w.ContainerName
data.Message = strings.ReplaceAll(msg, "\n", "<br>")
data.Message = strings.ReplaceAll(data.Message, " ", "&ensp;")
data.Message = strings.ReplaceAll(msg, " ", "&ensp;")
buf := bytes.NewBuffer(make([]byte, 128)) // more than enough
buf := bytes.NewBuffer(make([]byte, len(loadingPage)+len(data.Title)+len(data.Message)+len(common.HeaderCheckRedirect)))
err := loadingPageTmpl.Execute(buf, data)
if err != nil { // should never happen in production
panic(err)

View File

@@ -0,0 +1,103 @@
package types
import (
"errors"
"time"
"github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
)
type (
Config struct {
IdleTimeout time.Duration `json:"idle_timeout,omitempty"`
WakeTimeout time.Duration `json:"wake_timeout,omitempty"`
StopTimeout int `json:"stop_timeout,omitempty"` // docker api takes integer seconds for timeout argument
StopMethod StopMethod `json:"stop_method,omitempty"`
StopSignal Signal `json:"stop_signal,omitempty"`
DockerHost string `json:"docker_host,omitempty"`
ContainerName string `json:"container_name,omitempty"`
ContainerID string `json:"container_id,omitempty"`
ContainerRunning bool `json:"container_running,omitempty"`
}
StopMethod string
Signal string
)
const (
StopMethodPause StopMethod = "pause"
StopMethodStop StopMethod = "stop"
StopMethodKill StopMethod = "kill"
)
func ValidateConfig(cont *docker.Container) (*Config, E.Error) {
if cont == nil {
return nil, nil
}
if cont.IdleTimeout == "" {
return &Config{
DockerHost: cont.DockerHost,
ContainerName: cont.ContainerName,
ContainerID: cont.ContainerID,
ContainerRunning: cont.Running,
}, nil
}
errs := E.NewBuilder("invalid idlewatcher config")
idleTimeout := E.Collect(errs, validateDurationPostitive, cont.IdleTimeout)
wakeTimeout := E.Collect(errs, validateDurationPostitive, cont.WakeTimeout)
stopTimeout := E.Collect(errs, validateDurationPostitive, cont.StopTimeout)
stopMethod := E.Collect(errs, validateStopMethod, cont.StopMethod)
signal := E.Collect(errs, validateSignal, cont.StopSignal)
if errs.HasError() {
return nil, errs.Error()
}
return &Config{
IdleTimeout: idleTimeout,
WakeTimeout: wakeTimeout,
StopTimeout: int(stopTimeout.Seconds()),
StopMethod: stopMethod,
StopSignal: signal,
DockerHost: cont.DockerHost,
ContainerName: cont.ContainerName,
ContainerID: cont.ContainerID,
ContainerRunning: cont.Running,
}, nil
}
func validateDurationPostitive(value string) (time.Duration, error) {
d, err := time.ParseDuration(value)
if err != nil {
return 0, err
}
if d < 0 {
return 0, errors.New("duration must be positive")
}
return d, nil
}
func validateSignal(s string) (Signal, error) {
switch s {
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
"INT", "TERM", "HUP", "QUIT":
return Signal(s), nil
}
return "", errors.New("invalid signal " + s)
}
func validateStopMethod(s string) (StopMethod, error) {
sm := StopMethod(s)
switch sm {
case StopMethodPause, StopMethodStop, StopMethodKill:
return sm, nil
default:
return "", errors.New("invalid stop method " + s)
}
}

View File

@@ -0,0 +1,14 @@
package types
import (
"net/http"
net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type Waker interface {
health.HealthMonitor
http.Handler
net.Stream
}

View File

@@ -1,101 +1,134 @@
package idlewatcher
import (
"context"
"crypto/tls"
"net/http"
"sync/atomic"
"time"
. "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type Waker struct {
*watcher
type waker struct {
_ U.NoCopy
client *http.Client
rp *gphttp.ReverseProxy
stream net.Stream
hc health.HealthChecker
ready atomic.Bool
}
func NewWaker(w *watcher, rp *gphttp.ReverseProxy) *Waker {
tr := &http.Transport{}
if w.NoTLSVerify {
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
const (
idleWakerCheckInterval = 100 * time.Millisecond
idleWakerCheckTimeout = time.Second
)
// TODO: support stream
func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) {
hcCfg := entry.HealthCheckConfig()
hcCfg.Timeout = idleWakerCheckTimeout
waker := &waker{
rp: rp,
stream: stream,
}
return &Waker{
watcher: w,
client: &http.Client{
Timeout: 1 * time.Second,
Transport: tr,
},
rp: rp,
watcher, err := registerWatcher(providerSubTask, entry, waker)
if err != nil {
return nil, E.Errorf("register watcher: %w", err)
}
switch {
case rp != nil:
waker.hc = health.NewHTTPHealthChecker(entry.TargetURL(), hcCfg, rp.Transport)
case stream != nil:
waker.hc = health.NewRawHealthChecker(entry.TargetURL(), hcCfg)
default:
panic("both nil")
}
return watcher, nil
}
// lifetime should follow route provider.
func NewHTTPWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) {
return newWaker(providerSubTask, entry, rp, nil)
}
func NewStreamWaker(providerSubTask task.Task, entry entry.Entry, stream net.Stream) (Waker, E.Error) {
return newWaker(providerSubTask, entry, nil, stream)
}
// Start implements health.HealthMonitor.
func (w *Watcher) Start(routeSubTask task.Task) E.Error {
routeSubTask.Finish("ignored")
w.task.OnCancel("stop route", func() {
routeSubTask.Parent().Finish(w.task.FinishCause())
})
return nil
}
// Finish implements health.HealthMonitor.
func (w *Watcher) Finish(reason any) {
if w.stream != nil {
w.stream.Close()
}
}
func (w *Waker) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
w.wake(w.rp.ServeHTTP, rw, r)
// Name implements health.HealthMonitor.
func (w *Watcher) Name() string {
return w.String()
}
func (w *Waker) wake(next http.HandlerFunc, rw http.ResponseWriter, r *http.Request) {
// pass through if container is ready
// String implements health.HealthMonitor.
func (w *Watcher) String() string {
return w.ContainerName
}
// Uptime implements health.HealthMonitor.
func (w *Watcher) Uptime() time.Duration {
return 0
}
// Status implements health.HealthMonitor.
func (w *Watcher) Status() health.Status {
if !w.ContainerRunning {
return health.StatusNapping
}
if w.ready.Load() {
next(rw, r)
return
return health.StatusHealthy
}
ctx, cancel := context.WithTimeout(r.Context(), w.WakeTimeout)
defer cancel()
if r.Header.Get(headerCheckRedirect) == "" {
// Send a loading response to the client
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.Write(w.makeRespBody("%s waking up...", w.ContainerName))
return
}
// wake the container and reset idle timer
// also wait for another wake request
w.wakeCh <- struct{}{}
if <-w.wakeDone != nil {
http.Error(rw, "Error sending wake request", http.StatusInternalServerError)
return
}
// maybe another request came in while we were waiting for the wake
if w.ready.Load() {
next(rw, r)
return
}
for {
select {
case <-ctx.Done():
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
return
default:
}
wakeReq, err := http.NewRequestWithContext(
ctx,
http.MethodHead,
w.URL.String(),
nil,
)
if err != nil {
w.l.Errorf("new request err to %s: %s", r.URL, err)
http.Error(rw, "Internal server error", http.StatusInternalServerError)
return
}
// we don't care about the response
_, err = w.client.Do(wakeReq)
if err == nil {
w.ready.Store(true)
rw.WriteHeader(http.StatusOK)
return
}
// retry until the container is ready or timeout
time.Sleep(100 * time.Millisecond)
healthy, _, err := w.hc.CheckHealth()
switch {
case err != nil:
w.ready.Store(false)
return health.StatusError
case healthy:
w.ready.Store(true)
return health.StatusHealthy
default:
return health.StatusStarting
}
}
// MarshalJSON implements health.HealthMonitor.
func (w *Watcher) MarshalJSON() ([]byte, error) {
var url net.URL
if w.hc.URL().Port() != "0" {
url = w.hc.URL()
}
return (&health.JSONRepresentation{
Name: w.Name(),
Status: w.Status(),
Config: w.hc.Config(),
URL: url,
}).MarshalJSON()
}

View File

@@ -0,0 +1,108 @@
package idlewatcher
import (
"context"
"errors"
"net/http"
"strconv"
"time"
"github.com/yusing/go-proxy/internal/common"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/watcher/health"
)
// ServeHTTP implements http.Handler.
func (w *Watcher) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
shouldNext := w.wakeFromHTTP(rw, r)
if !shouldNext {
return
}
select {
case <-r.Context().Done():
return
default:
w.rp.ServeHTTP(rw, r)
}
}
func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldNext bool) {
w.resetIdleTimer()
// pass through if container is already ready
if w.ready.Load() {
return true
}
if r.Body != nil {
defer r.Body.Close()
}
accept := gphttp.GetAccept(r.Header)
acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty())
isCheckRedirect := r.Header.Get(common.HeaderCheckRedirect) != ""
if !isCheckRedirect && acceptHTML {
// Send a loading response to the client
body := w.makeLoadingPageBody()
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.Header().Set("Content-Length", strconv.Itoa(len(body)))
rw.Header().Add("Cache-Control", "no-cache")
rw.Header().Add("Cache-Control", "no-store")
rw.Header().Add("Cache-Control", "must-revalidate")
rw.Header().Add("Connection", "close")
if _, err := rw.Write(body); err != nil {
w.Err(err).Msg("error writing http response")
}
return false
}
ctx, cancel := context.WithTimeoutCause(r.Context(), w.WakeTimeout, errors.New("wake timeout"))
defer cancel()
checkCanceled := func() (canceled bool) {
select {
case <-ctx.Done():
w.WakeDebug().Str("cause", context.Cause(ctx).Error()).Msg("canceled")
return true
case <-w.task.Context().Done():
w.WakeDebug().Str("cause", w.task.FinishCause().Error()).Msg("canceled")
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return true
default:
return false
}
}
if checkCanceled() {
return false
}
w.WakeTrace().Msg("signal received")
err := w.wakeIfStopped()
if err != nil {
w.WakeError(err)
http.Error(rw, "Error waking container", http.StatusInternalServerError)
return false
}
for {
if checkCanceled() {
return false
}
if w.Status() == health.StatusHealthy {
w.resetIdleTimer()
if isCheckRedirect {
w.Debug().Msgf("redirecting to %s ...", w.hc.URL())
rw.WriteHeader(http.StatusOK)
return false
}
w.Debug().Msgf("passing through to %s ...", w.hc.URL())
return true
}
// retry until the container is ready or timeout
time.Sleep(idleWakerCheckInterval)
}
}

View File

@@ -0,0 +1,90 @@
package idlewatcher
import (
"context"
"errors"
"fmt"
"net"
"time"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/watcher/health"
)
// Setup implements types.Stream.
func (w *Watcher) Addr() net.Addr {
return w.stream.Addr()
}
// Setup implements types.Stream.
func (w *Watcher) Setup() error {
return w.stream.Setup()
}
// Accept implements types.Stream.
func (w *Watcher) Accept() (conn types.StreamConn, err error) {
conn, err = w.stream.Accept()
if err != nil {
return
}
if wakeErr := w.wakeFromStream(); wakeErr != nil {
w.WakeError(wakeErr)
}
return
}
// Handle implements types.Stream.
func (w *Watcher) Handle(conn types.StreamConn) error {
if err := w.wakeFromStream(); err != nil {
return err
}
return w.stream.Handle(conn)
}
// Close implements types.Stream.
func (w *Watcher) Close() error {
return w.stream.Close()
}
func (w *Watcher) wakeFromStream() error {
w.resetIdleTimer()
// pass through if container is already ready
if w.ready.Load() {
return nil
}
w.WakeDebug().Msg("wake signal received")
wakeErr := w.wakeIfStopped()
if wakeErr != nil {
wakeErr = fmt.Errorf("%s failed: %w", w.String(), wakeErr)
w.WakeError(wakeErr)
return wakeErr
}
ctx, cancel := context.WithTimeoutCause(w.task.Context(), w.WakeTimeout, errors.New("wake timeout"))
defer cancel()
for {
select {
case <-w.task.Context().Done():
cause := w.task.FinishCause()
w.WakeDebug().Str("cause", cause.Error()).Msg("canceled")
return cause
case <-ctx.Done():
cause := context.Cause(ctx)
w.WakeDebug().Str("cause", cause.Error()).Msg("timeout")
return cause
default:
}
if w.Status() == health.StatusHealthy {
w.resetIdleTimer()
w.Debug().Msg("container is ready, passing through to " + w.hc.URL().String())
return nil
}
// retry until the container is ready or timeout
time.Sleep(idleWakerCheckInterval)
}
}

View File

@@ -2,269 +2,292 @@ package idlewatcher
import (
"context"
"errors"
"sync"
"sync/atomic"
"time"
"github.com/docker/docker/api/types/container"
"github.com/sirupsen/logrus"
"github.com/rs/zerolog"
D "github.com/yusing/go-proxy/internal/docker"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error"
P "github.com/yusing/go-proxy/internal/proxy"
PT "github.com/yusing/go-proxy/internal/proxy/fields"
W "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events"
)
type (
watcher struct {
*P.ReverseProxyEntry
Watcher struct {
_ U.NoCopy
client D.Client
zerolog.Logger
ready atomic.Bool // whether the site is ready to accept connection
*idlewatcher.Config
*waker
client D.Client
stopByMethod StopCallback // send a docker command w.r.t. `stop_method`
wakeCh chan struct{}
wakeDone chan E.NestedError
ctx context.Context
cancel context.CancelFunc
refCount *sync.WaitGroup
l logrus.FieldLogger
ticker *time.Ticker
task task.Task
}
WakeDone <-chan error
WakeFunc func() WakeDone
StopCallback func() E.NestedError
StopCallback func() error
)
var (
mainLoopCtx context.Context
mainLoopCancel context.CancelFunc
mainLoopWg sync.WaitGroup
watcherMap = make(map[string]*watcher)
watcherMap = F.NewMapOf[string, *Watcher]()
watcherMapMu sync.Mutex
newWatcherCh = make(chan *watcher)
logger = logrus.WithField("module", "idle_watcher")
logger = logging.With().Str("module", "idle_watcher").Logger()
)
func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) {
failure := E.Failure("idle_watcher register")
const dockerReqTimeout = 3 * time.Second
if entry.IdleTimeout == 0 {
return nil, failure.With(E.Invalid("idle_timeout", 0))
func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, error) {
cfg := entry.IdlewatcherConfig()
if cfg.IdleTimeout == 0 {
panic("should not reach here")
}
watcherMapMu.Lock()
defer watcherMapMu.Unlock()
key := entry.ContainerID
key := cfg.ContainerID
if w, ok := watcherMap[key]; ok {
w.refCount.Add(1)
w.ReverseProxyEntry = entry
if w, ok := watcherMap.Load(key); ok {
w.Config = cfg
w.waker = waker
w.resetIdleTimer()
providerSubtask.Finish("used existing watcher")
return w, nil
}
client, err := D.ConnectClient(entry.DockerHost)
if err.HasError() {
return nil, failure.With(err)
client, err := D.ConnectClient(cfg.DockerHost)
if err != nil {
return nil, err
}
w := &watcher{
ReverseProxyEntry: entry,
client: client,
refCount: &sync.WaitGroup{},
wakeCh: make(chan struct{}),
wakeDone: make(chan E.NestedError),
l: logger.WithField("container", entry.ContainerName),
w := &Watcher{
Logger: logger.With().Str("name", cfg.ContainerName).Logger(),
Config: cfg,
waker: waker,
client: client,
task: providerSubtask,
ticker: time.NewTicker(cfg.IdleTimeout),
}
w.refCount.Add(1)
w.stopByMethod = w.getStopCallback()
watcherMap[key] = w
watcherMap.Store(key, w)
go func() {
newWatcherCh <- w
cause := w.watchUntilDestroy()
watcherMap.Delete(w.ContainerID)
w.ticker.Stop()
w.client.Close()
w.task.Finish(cause)
}()
return w, nil
}
func (w *watcher) Unregister() {
w.refCount.Add(-1)
// WakeDebug logs a debug message related to waking the container.
func (w *Watcher) WakeDebug() *zerolog.Event {
return w.Debug().Str("action", "wake")
}
func Start() {
logger.Debug("started")
defer logger.Debug("stopped")
mainLoopCtx, mainLoopCancel = context.WithCancel(context.Background())
for {
select {
case <-mainLoopCtx.Done():
return
case w := <-newWatcherCh:
w.l.Debug("registered")
mainLoopWg.Add(1)
go func() {
w.watchUntilCancel()
w.refCount.Wait() // wait for 0 ref count
w.client.Close()
delete(watcherMap, w.ContainerID)
w.l.Debug("unregistered")
mainLoopWg.Done()
}()
}
}
func (w *Watcher) WakeTrace() *zerolog.Event {
return w.Trace().Str("action", "wake")
}
func Stop() {
mainLoopCancel()
mainLoopWg.Wait()
func (w *Watcher) WakeError(err error) {
w.Err(err).Str("action", "wake").Msg("error")
}
func (w *watcher) containerStop() error {
return w.client.ContainerStop(w.ctx, w.ContainerID, container.StopOptions{
func (w *Watcher) LogReason(action, reason string) {
w.Info().Str("reason", reason).Msg(action)
}
func (w *Watcher) containerStop(ctx context.Context) error {
return w.client.ContainerStop(ctx, w.ContainerID, container.StopOptions{
Signal: string(w.StopSignal),
Timeout: &w.StopTimeout})
Timeout: &w.StopTimeout,
})
}
func (w *watcher) containerPause() error {
return w.client.ContainerPause(w.ctx, w.ContainerID)
func (w *Watcher) containerPause(ctx context.Context) error {
return w.client.ContainerPause(ctx, w.ContainerID)
}
func (w *watcher) containerKill() error {
return w.client.ContainerKill(w.ctx, w.ContainerID, string(w.StopSignal))
func (w *Watcher) containerKill(ctx context.Context) error {
return w.client.ContainerKill(ctx, w.ContainerID, string(w.StopSignal))
}
func (w *watcher) containerUnpause() error {
return w.client.ContainerUnpause(w.ctx, w.ContainerID)
func (w *Watcher) containerUnpause(ctx context.Context) error {
return w.client.ContainerUnpause(ctx, w.ContainerID)
}
func (w *watcher) containerStart() error {
return w.client.ContainerStart(w.ctx, w.ContainerID, container.StartOptions{})
func (w *Watcher) containerStart(ctx context.Context) error {
return w.client.ContainerStart(ctx, w.ContainerID, container.StartOptions{})
}
func (w *watcher) containerStatus() (string, E.NestedError) {
json, err := w.client.ContainerInspect(w.ctx, w.ContainerID)
func (w *Watcher) containerStatus() (string, error) {
if !w.client.Connected() {
return "", errors.New("docker client not connected")
}
ctx, cancel := context.WithTimeoutCause(w.task.Context(), dockerReqTimeout, errors.New("docker request timeout"))
defer cancel()
json, err := w.client.ContainerInspect(ctx, w.ContainerID)
if err != nil {
return "", E.FailWith("inspect container", err)
return "", err
}
return json.State.Status, nil
}
func (w *watcher) wakeIfStopped() E.NestedError {
if w.ready.Load() || w.ContainerRunning {
func (w *Watcher) wakeIfStopped() error {
if w.ContainerRunning {
return nil
}
status, err := w.containerStatus()
if err.HasError() {
if err != nil {
return err
}
// "created", "running", "paused", "restarting", "removing", "exited", or "dead"
ctx, cancel := context.WithTimeout(w.task.Context(), w.WakeTimeout)
defer cancel()
// !Hard coded here since theres no constants from Docker API
switch status {
case "exited", "dead":
return E.From(w.containerStart())
return w.containerStart(ctx)
case "paused":
return E.From(w.containerUnpause())
return w.containerUnpause(ctx)
case "running":
return nil
default:
return E.Unexpected("container state", status)
panic("should not reach here")
}
}
func (w *watcher) getStopCallback() StopCallback {
var cb func() error
func (w *Watcher) getStopCallback() StopCallback {
var cb func(context.Context) error
switch w.StopMethod {
case PT.StopMethodPause:
case idlewatcher.StopMethodPause:
cb = w.containerPause
case PT.StopMethodStop:
case idlewatcher.StopMethodStop:
cb = w.containerStop
case PT.StopMethodKill:
case idlewatcher.StopMethodKill:
cb = w.containerKill
default:
panic("should not reach here")
}
return func() E.NestedError {
status, err := w.containerStatus()
if err.HasError() {
return err
}
if status != "running" {
return nil
}
return E.From(cb())
return func() error {
ctx, cancel := context.WithTimeout(w.task.Context(), time.Duration(w.StopTimeout)*time.Second)
defer cancel()
return cb(ctx)
}
}
func (w *watcher) watchUntilCancel() {
defer close(w.wakeCh)
func (w *Watcher) resetIdleTimer() {
w.Trace().Msg("reset idle timer")
w.ticker.Reset(w.IdleTimeout)
}
w.ctx, w.cancel = context.WithCancel(context.Background())
dockerWatcher := W.NewDockerWatcherWithClient(w.client)
dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.ctx, W.DockerListOptions{
Filters: W.NewDockerFilter(
W.DockerFilterContainer,
W.DockerrFilterContainer(w.ContainerID),
W.DockerFilterStart,
W.DockerFilterStop,
W.DockerFilterDie,
W.DockerFilterKill,
W.DockerFilterPause,
W.DockerFilterUnpause,
func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask task.Task, eventCh <-chan events.Event, errCh <-chan E.Error) {
eventTask = w.task.Subtask("docker event watcher")
eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), watcher.DockerListOptions{
Filters: watcher.NewDockerFilter(
watcher.DockerFilterContainer,
watcher.DockerFilterContainerNameID(w.ContainerID),
watcher.DockerFilterStart,
watcher.DockerFilterStop,
watcher.DockerFilterDie,
watcher.DockerFilterKill,
watcher.DockerFilterDestroy,
watcher.DockerFilterPause,
watcher.DockerFilterUnpause,
),
})
return
}
ticker := time.NewTicker(w.IdleTimeout)
defer ticker.Stop()
// watchUntilDestroy waits for the container to be created, started, or unpaused,
// and then reset the idle timer.
//
// When the container is stopped, paused,
// or killed, the idle timer is stopped and the ContainerRunning flag is set to false.
//
// When the idle timer fires, the container is stopped according to the
// stop method.
//
// it exits only if the context is canceled, the container is destroyed,
// errors occurred on docker client, or route provider died (mainly caused by config reload).
func (w *Watcher) watchUntilDestroy() (returnCause error) {
dockerWatcher := watcher.NewDockerWatcherWithClient(w.client)
eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher)
defer eventTask.Finish("stopped")
for {
select {
case <-mainLoopCtx.Done():
w.cancel()
case <-w.ctx.Done():
w.l.Debug("stopped")
return
case <-w.task.Context().Done():
return w.task.FinishCause()
case err := <-dockerEventErrCh:
if err != nil && err.IsNot(context.Canceled) {
w.l.Error(E.FailWith("docker watcher", err))
if !err.Is(context.Canceled) {
E.LogError("idlewatcher error", err, &w.Logger)
}
return err
case e := <-dockerEventCh:
switch {
case e.Action == events.ActionContainerDestroy:
w.ContainerRunning = false
w.ready.Store(false)
w.LogReason("watcher stopped", "container destroyed")
return errors.New("container destroyed")
// create / start / unpause
case e.Action.IsContainerWake():
ticker.Reset(w.IdleTimeout)
w.l.Info(e)
default: // stop / pause / kill
ticker.Stop()
w.ContainerRunning = true
w.resetIdleTimer()
w.Info().Msg("awaken")
case e.Action.IsContainerSleep(): // stop / pause / kil
w.ContainerRunning = false
w.ready.Store(false)
w.l.Info(e)
w.ticker.Stop()
default:
w.Error().Msg("unexpected docker event: " + e.String())
}
case <-ticker.C:
w.l.Debug("idle timeout")
ticker.Stop()
if err := w.stopByMethod(); err != nil && err.IsNot(context.Canceled) {
w.l.Error(E.FailWith("stop", err).Extraf("stop method: %s", w.StopMethod))
// container name changed should also change the container id
if w.ContainerName != e.ActorName {
w.Debug().Msgf("renamed %s -> %s", w.ContainerName, e.ActorName)
w.ContainerName = e.ActorName
}
case <-w.wakeCh:
w.l.Debug("wake signal received")
ticker.Reset(w.IdleTimeout)
err := w.wakeIfStopped()
if err != nil {
w.l.Error(E.FailWith("wake", err))
if w.ContainerID != e.ActorID {
w.Debug().Msgf("id changed %s -> %s", w.ContainerID, e.ActorID)
w.ContainerID = e.ActorID
// recreate event stream
eventTask.Finish("recreate event stream")
eventTask, dockerEventCh, dockerEventErrCh = w.getEventCh(dockerWatcher)
}
case <-w.ticker.C:
w.ticker.Stop()
if w.ContainerRunning {
err := w.stopByMethod()
switch {
case errors.Is(err, context.Canceled):
continue
case err != nil:
w.Err(err).Msgf("container stop with method %q failed", w.StopMethod)
default:
w.LogReason("container stopped", "idle timeout")
}
}
w.wakeDone <- err
}
}
}

View File

@@ -2,18 +2,28 @@ package docker
import (
"context"
"errors"
"time"
E "github.com/yusing/go-proxy/internal/error"
)
func (c Client) Inspect(containerID string) (Container, E.NestedError) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
func Inspect(dockerHost string, containerID string) (*Container, error) {
client, err := ConnectClient(dockerHost)
defer client.Close()
if err != nil {
return nil, err
}
return client.Inspect(containerID)
}
func (c Client) Inspect(containerID string) (*Container, error) {
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker container inspect timeout"))
defer cancel()
json, err := c.ContainerInspect(ctx, containerID)
if err != nil {
return Container{}, E.From(err)
return nil, err
}
return FromJson(json, c.key), nil
return FromJSON(json, c.key), nil
}

View File

@@ -24,6 +24,11 @@ type (
NestedLabelMap map[string]U.SerializedObject
)
var (
ErrApplyToNil = E.New("label value is nil")
ErrFieldNotExist = E.New("field does not exist")
)
func (l *Label) String() string {
if l.Attribute == "" {
return l.Namespace + "." + l.Target
@@ -39,22 +44,22 @@ func (l *Label) String() string {
//
// Returns:
// - error: an error if the field does not exist.
func ApplyLabel[T any](obj *T, l *Label) E.NestedError {
func ApplyLabel[T any](obj *T, l *Label) E.Error {
if obj == nil {
return E.Invalid("nil object", l)
return ErrApplyToNil.Subject(l.String())
}
switch nestedLabel := l.Value.(type) {
case *Label:
var field reflect.Value
objType := reflect.TypeFor[T]()
for i := 0; i < reflect.TypeFor[T]().NumField(); i++ {
for i := range reflect.TypeFor[T]().NumField() {
if objType.Field(i).Tag.Get("yaml") == l.Attribute {
field = reflect.ValueOf(obj).Elem().Field(i)
break
}
}
if !field.IsValid() {
return E.NotExist("field", l.Attribute)
return ErrFieldNotExist.Subject(l.Attribute).Subject(l.String())
}
dst, ok := field.Interface().(NestedLabelMap)
if !ok {
@@ -65,7 +70,11 @@ func ApplyLabel[T any](obj *T, l *Label) E.NestedError {
} else {
field = field.Addr()
}
return U.Deserialize(U.SerializedObject{nestedLabel.Namespace: nestedLabel.Value}, field.Interface())
err := U.Deserialize(U.SerializedObject{nestedLabel.Namespace: nestedLabel.Value}, field.Interface())
if err != nil {
return err.Subject(l.String())
}
return nil
}
if dst == nil {
field.Set(reflect.MakeMap(reflect.TypeFor[NestedLabelMap]()))
@@ -77,18 +86,22 @@ func ApplyLabel[T any](obj *T, l *Label) E.NestedError {
dst[nestedLabel.Namespace][nestedLabel.Attribute] = nestedLabel.Value
return nil
default:
return U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj)
err := U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj)
if err != nil {
return err.Subject(l.String())
}
return nil
}
}
func ParseLabel(label string, value string) (*Label, E.NestedError) {
func ParseLabel(label string, value string) *Label {
parts := strings.Split(label, ".")
if len(parts) < 2 {
return &Label{
Namespace: label,
Value: value,
}, nil
}
}
l := &Label{
@@ -104,12 +117,9 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) {
l.Attribute = parts[2]
default:
l.Attribute = parts[2]
nestedLabel, err := ParseLabel(strings.Join(parts[3:], "."), value)
if err.HasError() {
return nil, err
}
nestedLabel := ParseLabel(strings.Join(parts[3:], "."), value)
l.Value = nestedLabel
}
return l, nil
return l
}

View File

@@ -8,17 +8,20 @@ import (
. "github.com/yusing/go-proxy/internal/utils/testing"
)
const (
mName = "middleware1"
mAttr = "prop1"
v = "value1"
)
func makeLabel(ns, name, attr string) string {
return fmt.Sprintf("%s.%s.%s", ns, name, attr)
}
func TestNestedLabel(t *testing.T) {
mName := "middleware1"
mAttr := "prop1"
v := "value1"
pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
ExpectNoError(t, err.Error())
sGot := ExpectType[*Label](t, pl.Value)
lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
sGot := ExpectType[*Label](t, lbl.Value)
ExpectFalse(t, sGot == nil)
ExpectEqual(t, sGot.Namespace, mName)
ExpectEqual(t, sGot.Attribute, mAttr)
@@ -28,13 +31,9 @@ func TestApplyNestedLabel(t *testing.T) {
entry := new(struct {
Middlewares NestedLabelMap `yaml:"middlewares"`
})
mName := "middleware1"
mAttr := "prop1"
v := "value1"
pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
ExpectNoError(t, err.Error())
err = ApplyLabel(entry, pl)
ExpectNoError(t, err.Error())
lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
err := ApplyLabel(entry, lbl)
ExpectNoError(t, err)
middleware1, ok := entry.Middlewares[mName]
ExpectTrue(t, ok)
got := ExpectType[string](t, middleware1[mAttr])
@@ -42,10 +41,6 @@ func TestApplyNestedLabel(t *testing.T) {
}
func TestApplyNestedLabelExisting(t *testing.T) {
mName := "middleware1"
mAttr := "prop1"
v := "value1"
checkAttr := "prop2"
checkV := "value2"
entry := new(struct {
@@ -55,10 +50,9 @@ func TestApplyNestedLabelExisting(t *testing.T) {
entry.Middlewares[mName] = make(U.SerializedObject)
entry.Middlewares[mName][checkAttr] = checkV
pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
ExpectNoError(t, err.Error())
err = ApplyLabel(entry, pl)
ExpectNoError(t, err.Error())
lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
err := ApplyLabel(entry, lbl)
ExpectNoError(t, err)
middleware1, ok := entry.Middlewares[mName]
ExpectTrue(t, ok)
got := ExpectType[string](t, middleware1[mAttr])
@@ -71,19 +65,15 @@ func TestApplyNestedLabelExisting(t *testing.T) {
}
func TestApplyNestedLabelNoAttr(t *testing.T) {
mName := "middleware1"
v := "value1"
entry := new(struct {
Middlewares NestedLabelMap `yaml:"middlewares"`
})
entry.Middlewares = make(NestedLabelMap)
entry.Middlewares[mName] = make(U.SerializedObject)
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", "middlewares", mName)), v)
ExpectNoError(t, err.Error())
err = ApplyLabel(entry, pl)
ExpectNoError(t, err.Error())
lbl := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", "middlewares", mName)), v)
err := ApplyLabel(entry, lbl)
ExpectNoError(t, err)
_, ok := entry.Middlewares[mName]
ExpectTrue(t, ok)
}

View File

@@ -0,0 +1,44 @@
package docker
import (
"context"
"errors"
"time"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
)
var listOptions = container.ListOptions{
// created|restarting|running|removing|paused|exited|dead
// Filters: filters.NewArgs(
// filters.Arg("status", "created"),
// filters.Arg("status", "restarting"),
// filters.Arg("status", "running"),
// filters.Arg("status", "paused"),
// filters.Arg("status", "exited"),
// ),
All: true,
}
func ListContainers(clientHost string) ([]types.Container, error) {
dockerClient, err := ConnectClient(clientHost)
if err != nil {
return nil, err
}
defer dockerClient.Close()
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("list containers timeout"))
defer cancel()
containers, err := dockerClient.ContainerList(ctx, listOptions)
if err != nil {
return nil, err
}
return containers, nil
}
func IsErrConnectionFailed(err error) bool {
return client.IsErrConnectionFailed(err)
}

View File

@@ -0,0 +1,7 @@
package docker
import (
"github.com/yusing/go-proxy/internal/logging"
)
var logger = logging.With().Str("module", "docker").Logger()

View File

@@ -1,25 +0,0 @@
package docker
import "github.com/docker/docker/api/types"
type PortMapping = map[string]types.Port
type ProxyProperties struct {
DockerHost string `yaml:"-" json:"docker_host"`
ContainerName string `yaml:"-" json:"container_name"`
ContainerID string `yaml:"-" json:"container_id"`
ImageName string `yaml:"-" json:"image_name"`
PublicPortMapping PortMapping `yaml:"-" json:"public_port_mapping"` // non-zero publicPort:types.Port
PrivatePortMapping PortMapping `yaml:"-" json:"private_port_mapping"` // privatePort:types.Port
NetworkMode string `yaml:"-" json:"network_mode"`
Aliases []string `yaml:"-" json:"aliases"`
IsExcluded bool `yaml:"-" json:"is_excluded"`
IsExplicit bool `yaml:"-" json:"is_explicit"`
IsDatabase bool `yaml:"-" json:"is_database"`
IdleTimeout string `yaml:"-" json:"idle_timeout"`
WakeTimeout string `yaml:"-" json:"wake_timeout"`
StopMethod string `yaml:"-" json:"stop_method"`
StopTimeout string `yaml:"-" json:"stop_timeout"` // stop_method = "stop" only
StopSignal string `yaml:"-" json:"stop_signal"` // stop_method = "stop" | "kill" only
Running bool `yaml:"-" json:"running"`
}

46
internal/error/base.go Normal file
View File

@@ -0,0 +1,46 @@
package error
import (
"errors"
"fmt"
)
// baseError is an immutable wrapper around an error.
type baseError struct {
Err error `json:"err"`
}
func (err *baseError) Unwrap() error {
return err.Err
}
func (err *baseError) Is(other error) bool {
if other, ok := other.(*baseError); ok {
return errors.Is(err.Err, other.Err)
}
return errors.Is(err.Err, other)
}
func (err baseError) Subject(subject string) Error {
err.Err = PrependSubject(subject, err.Err)
return &err
}
func (err *baseError) Subjectf(format string, args ...any) Error {
if len(args) > 0 {
return err.Subject(fmt.Sprintf(format, args...))
}
return err.Subject(format)
}
func (err baseError) With(extra error) Error {
return &nestedError{&err, []error{extra}}
}
func (err baseError) Withf(format string, args ...any) Error {
return &nestedError{&err, []error{fmt.Errorf(format, args...)}}
}
func (err *baseError) Error() string {
return err.Err.Error()
}

View File

@@ -2,69 +2,123 @@ package error
import (
"fmt"
"strings"
"sync"
)
type Builder struct {
*builder
}
type builder struct {
message string
errors []NestedError
about string
errs []error
sync.Mutex
}
func NewBuilder(format string, args ...any) Builder {
return Builder{&builder{message: fmt.Sprintf(format, args...)}}
func NewBuilder(about string) *Builder {
return &Builder{about: about}
}
// adding nil / nil is no-op,
// you may safely pass expressions returning error to it
func (b Builder) Add(err NestedError) Builder {
if err != nil {
func (b *Builder) About() string {
if !b.HasError() {
return ""
}
return b.about
}
//go:inline
func (b *Builder) HasError() bool {
return len(b.errs) > 0
}
func (b *Builder) error() Error {
if !b.HasError() {
return nil
}
return &nestedError{Err: New(b.about), Extras: b.errs}
}
func (b *Builder) Error() Error {
if len(b.errs) == 1 {
return From(b.errs[0])
}
return b.error()
}
func (b *Builder) String() string {
err := b.error()
if err == nil {
return ""
}
return err.Error()
}
// Add adds an error to the Builder.
//
// adding nil is no-op.
func (b *Builder) Add(err error) *Builder {
if err == nil {
return b
}
b.Lock()
defer b.Unlock()
switch err := err.(type) {
case *baseError:
b.errs = append(b.errs, err.Err)
case *nestedError:
if err.Err == nil {
b.errs = append(b.errs, err.Extras...)
} else {
b.errs = append(b.errs, err)
}
default:
b.errs = append(b.errs, err)
}
return b
}
func (b *Builder) Adds(err string) *Builder {
b.Lock()
defer b.Unlock()
b.errs = append(b.errs, newError(err))
return b
}
func (b *Builder) Addf(format string, args ...any) *Builder {
if len(args) > 0 {
b.Lock()
b.errors = append(b.errors, err)
b.Unlock()
defer b.Unlock()
b.errs = append(b.errs, fmt.Errorf(format, args...))
} else {
b.Adds(format)
}
return b
}
func (b *Builder) AddFrom(other *Builder, flatten bool) *Builder {
if other == nil || !other.HasError() {
return b
}
b.Lock()
defer b.Unlock()
if flatten {
b.errs = append(b.errs, other.errs...)
} else {
b.errs = append(b.errs, other.error())
}
return b
}
func (b Builder) AddE(err error) Builder {
return b.Add(From(err))
}
func (b *Builder) AddRange(errs ...error) *Builder {
b.Lock()
defer b.Unlock()
func (b Builder) Addf(format string, args ...any) Builder {
return b.Add(errorf(format, args...))
}
// Build builds a NestedError based on the errors collected in the Builder.
//
// If there are no errors in the Builder, it returns a Nil() NestedError.
// Otherwise, it returns a NestedError with the message and the errors collected.
//
// Returns:
// - NestedError: the built NestedError.
func (b Builder) Build() NestedError {
if len(b.errors) == 0 {
return nil
} else if len(b.errors) == 1 && !strings.ContainsRune(b.message, ' ') {
return b.errors[0].Subject(b.message)
for _, err := range errs {
if err != nil {
b.errs = append(b.errs, err)
}
}
return Join(b.message, b.errors...)
}
func (b Builder) To(ptr *NestedError) {
if ptr == nil {
return
} else if *ptr == nil {
*ptr = b.Build()
} else {
(*ptr).With(b.Build())
}
}
func (b Builder) HasError() bool {
return len(b.errors) > 0
return b
}

View File

@@ -1,53 +1,55 @@
package error
package error_test
import (
"context"
"errors"
"io"
"testing"
. "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestBuilderEmpty(t *testing.T) {
eb := NewBuilder("qwer")
ExpectTrue(t, eb.Build() == nil)
ExpectTrue(t, eb.Build().NoError())
eb := NewBuilder("foo")
ExpectTrue(t, errors.Is(eb.Error(), nil))
ExpectFalse(t, eb.HasError())
}
func TestBuilderAddNil(t *testing.T) {
eb := NewBuilder("asdf")
var err NestedError
eb := NewBuilder("foo")
var err Error
for range 3 {
eb.Add(nil)
}
for range 3 {
eb.Add(err)
}
ExpectTrue(t, eb.Build() == nil)
ExpectTrue(t, eb.Build().NoError())
eb.AddRange(nil, nil, err)
ExpectFalse(t, eb.HasError())
ExpectTrue(t, eb.Error() == nil)
}
func TestBuilderIs(t *testing.T) {
eb := NewBuilder("foo")
eb.Add(context.Canceled)
eb.Add(io.ErrShortBuffer)
ExpectTrue(t, eb.HasError())
ExpectError(t, io.ErrShortBuffer, eb.Error())
ExpectError(t, context.Canceled, eb.Error())
}
func TestBuilderNested(t *testing.T) {
eb := NewBuilder("error occurred")
eb.Add(Failure("Action 1").With(Invalid("Inner", "1")).With(Invalid("Inner", "2")))
eb.Add(Failure("Action 2").With(Invalid("Inner", "3")))
eb := NewBuilder("action failed")
eb.Add(New("Action 1").Withf("Inner: 1").Withf("Inner: 2"))
eb.Add(New("Action 2").Withf("Inner: 3"))
got := eb.Build().String()
expected1 :=
(`error occurred:
- Action 1 failed:
- invalid Inner: 1
- invalid Inner: 2
- Action 2 failed:
- invalid Inner: 3`)
expected2 :=
(`error occurred:
- Action 1 failed:
- invalid Inner: "1"
- invalid Inner: "2"
- Action 2 failed:
- invalid Inner: "3"`)
if got != expected1 && got != expected2 {
t.Errorf("expected \n%s, got \n%s", expected1, got)
}
got := eb.String()
expected := `action failed
• Action 1
• Inner: 1
Inner: 2
• Action 2
• Inner: 3`
ExpectEqual(t, got, expected)
}

View File

@@ -1,296 +1,31 @@
package error
import (
"encoding/json"
"errors"
"fmt"
"strings"
)
type Error interface {
error
type (
NestedError = *nestedError
nestedError struct {
subject string
err error
extras []nestedError
}
jsonNestedError struct {
Subject string
Err string
Extras []jsonNestedError
}
)
func From(err error) NestedError {
if IsNil(err) {
return nil
}
return &nestedError{err: err}
// Is is a wrapper for errors.Is when there is no sub-error.
//
// When there are sub-errors, they will also be checked.
Is(other error) bool
// With appends a sub-error to the error.
With(extra error) Error
// Withf is a wrapper for With(fmt.Errorf(format, args...)).
Withf(format string, args ...any) Error
// Subject prepends the given subject with a colon and space to the error message.
//
// If there is already a subject in the error message, the subject will be
// prepended to the existing subject with " > ".
//
// Subject empty string is ignored.
Subject(subject string) Error
// Subjectf is a wrapper for Subject(fmt.Sprintf(format, args...)).
Subjectf(format string, args ...any) Error
}
func FromJSON(data []byte) (NestedError, bool) {
var j jsonNestedError
if err := json.Unmarshal(data, &j); err != nil {
return nil, false
}
if j.Err == "" {
return nil, false
}
extras := make([]nestedError, len(j.Extras))
for i, e := range j.Extras {
extra, ok := fromJSONObject(e)
if !ok {
return nil, false
}
extras[i] = *extra
}
return &nestedError{
subject: j.Subject,
err: errors.New(j.Err),
extras: extras,
}, true
}
// Check is a helper function that
// convert (T, error) to (T, NestedError).
func Check[T any](obj T, err error) (T, NestedError) {
return obj, From(err)
}
func Join(message string, err ...NestedError) NestedError {
extras := make([]nestedError, len(err))
nErr := 0
for i, e := range err {
if e == nil {
continue
}
extras[i] = *e
nErr += 1
}
if nErr == 0 {
return nil
}
return &nestedError{
err: errors.New(message),
extras: extras,
}
}
func JoinE(message string, err ...error) NestedError {
b := NewBuilder(message)
for _, e := range err {
b.AddE(e)
}
return b.Build()
}
func IsNil(err error) bool {
return err == nil
}
func IsNotNil(err error) bool {
return err != nil
}
func (ne NestedError) String() string {
var buf strings.Builder
ne.writeToSB(&buf, 0, "")
return buf.String()
}
func (ne NestedError) Is(err error) bool {
if ne == nil {
return err == nil
}
// return errors.Is(ne.err, err)
if errors.Is(ne.err, err) {
return true
}
for _, e := range ne.extras {
if e.Is(err) {
return true
}
}
return false
}
func (ne NestedError) IsNot(err error) bool {
return !ne.Is(err)
}
func (ne NestedError) Error() error {
if ne == nil {
return nil
}
return ne.buildError(0, "")
}
func (ne NestedError) With(s any) NestedError {
if ne == nil {
return ne
}
var msg string
switch ss := s.(type) {
case nil:
return ne
case NestedError:
return ne.withError(ss)
case error:
return ne.withError(From(ss))
case string:
msg = ss
case fmt.Stringer:
return ne.appendMsg(ss.String())
default:
return ne.appendMsg(fmt.Sprint(s))
}
return ne.withError(From(errors.New(msg)))
}
func (ne NestedError) Extraf(format string, args ...any) NestedError {
return ne.With(errorf(format, args...))
}
func (ne NestedError) Subject(s any) NestedError {
if ne == nil {
return ne
}
var subject string
switch ss := s.(type) {
case string:
subject = ss
case fmt.Stringer:
subject = ss.String()
default:
subject = fmt.Sprint(s)
}
if ne.subject == "" {
ne.subject = subject
} else if !strings.ContainsRune(subject, ' ') || strings.ContainsRune(ne.subject, '.') {
ne.subject = fmt.Sprintf("%s.%s", subject, ne.subject)
} else {
ne.subject = fmt.Sprintf("%s > %s", subject, ne.subject)
}
return ne
}
func (ne NestedError) Subjectf(format string, args ...any) NestedError {
if ne == nil {
return ne
}
if strings.Contains(format, "%q") {
panic("Subjectf format should not contain %q")
}
if strings.Contains(format, "%w") {
panic("Subjectf format should not contain %w")
}
return ne.Subject(fmt.Sprintf(format, args...))
}
func (ne NestedError) JSONObject() jsonNestedError {
extras := make([]jsonNestedError, len(ne.extras))
for i, e := range ne.extras {
extras[i] = e.JSONObject()
}
return jsonNestedError{
Subject: ne.subject,
Err: ne.err.Error(),
Extras: extras,
}
}
func (ne NestedError) JSON() []byte {
b, _ := json.MarshalIndent(ne.JSONObject(), "", " ")
return b
}
func (ne NestedError) NoError() bool {
return ne == nil
}
func (ne NestedError) HasError() bool {
return ne != nil
}
func errorf(format string, args ...any) NestedError {
return From(fmt.Errorf(format, args...))
}
func fromJSONObject(obj jsonNestedError) (NestedError, bool) {
data, err := json.Marshal(obj)
if err != nil {
return nil, false
}
return FromJSON(data)
}
func (ne NestedError) withError(err NestedError) NestedError {
if ne != nil && err != nil {
ne.extras = append(ne.extras, *err)
}
return ne
}
func (ne NestedError) appendMsg(msg string) NestedError {
if ne == nil {
return nil
}
ne.err = fmt.Errorf("%w %s", ne.err, msg)
return ne
}
func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) {
for i := 0; i < level; i++ {
sb.WriteString(" ")
}
sb.WriteString(prefix)
if ne.NoError() {
sb.WriteString("nil")
return
}
sb.WriteString(ne.err.Error())
if ne.subject != "" {
sb.WriteString(fmt.Sprintf(" for %q", ne.subject))
}
if len(ne.extras) > 0 {
sb.WriteRune(':')
for _, extra := range ne.extras {
sb.WriteRune('\n')
extra.writeToSB(sb, level+1, "- ")
}
}
}
func (ne NestedError) buildError(level int, prefix string) error {
var res error
var sb strings.Builder
for i := 0; i < level; i++ {
sb.WriteString(" ")
}
sb.WriteString(prefix)
if ne.NoError() {
sb.WriteString("nil")
return errors.New(sb.String())
}
res = fmt.Errorf("%s%w", sb.String(), ne.err)
sb.Reset()
if ne.subject != "" {
sb.WriteString(fmt.Sprintf(" for %q", ne.subject))
}
if len(ne.extras) > 0 {
sb.WriteRune(':')
res = fmt.Errorf("%w%s", res, sb.String())
for _, extra := range ne.extras {
res = errors.Join(res, extra.buildError(level+1, "- "))
}
} else {
res = fmt.Errorf("%w%s", res, sb.String())
}
return res
// this makes JSON marshaling work,
// as the builtin one doesn't.
type errStr string
func (err errStr) Error() string {
return string(err)
}

View File

@@ -1,109 +1,157 @@
package error_test
package error
import (
"errors"
"strings"
"testing"
. "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestBaseString(t *testing.T) {
ExpectEqual(t, New("error").Error(), "error")
}
func TestBaseWithSubject(t *testing.T) {
err := New("error")
withSubject := err.Subject("foo")
withSubjectf := err.Subjectf("%s %s", "foo", "bar")
ExpectError(t, err, withSubject)
ExpectStrEqual(t, withSubject.Error(), "foo: error")
ExpectTrue(t, withSubject.Is(err))
ExpectError(t, err, withSubjectf)
ExpectStrEqual(t, withSubjectf.Error(), "foo bar: error")
ExpectTrue(t, withSubjectf.Is(err))
}
func TestBaseWithExtra(t *testing.T) {
err := New("error")
extra := New("bar").Subject("baz")
withExtra := err.With(extra)
ExpectTrue(t, withExtra.Is(extra))
ExpectTrue(t, withExtra.Is(err))
ExpectTrue(t, errors.Is(withExtra, extra))
ExpectTrue(t, errors.Is(withExtra, err))
ExpectTrue(t, strings.Contains(withExtra.Error(), err.Error()))
ExpectTrue(t, strings.Contains(withExtra.Error(), extra.Error()))
ExpectTrue(t, strings.Contains(withExtra.Error(), "baz"))
}
func TestBaseUnwrap(t *testing.T) {
err := errors.New("err")
wrapped := From(err)
ExpectError(t, err, errors.Unwrap(wrapped))
}
func TestNestedUnwrap(t *testing.T) {
err := errors.New("err")
err2 := New("err2")
wrapped := From(err).Subject("foo").With(err2.Subject("bar"))
unwrapper, ok := wrapped.(interface{ Unwrap() []error })
ExpectTrue(t, ok)
ExpectError(t, err, wrapped)
ExpectError(t, err2, wrapped)
ExpectEqual(t, len(unwrapper.Unwrap()), 2)
}
func TestErrorIs(t *testing.T) {
ExpectTrue(t, Failure("foo").Is(ErrFailure))
ExpectTrue(t, Failure("foo").With("bar").Is(ErrFailure))
ExpectFalse(t, Failure("foo").With("bar").Is(ErrInvalid))
ExpectFalse(t, Failure("foo").With("bar").With("baz").Is(ErrInvalid))
from := errors.New("error")
err := From(from)
ExpectError(t, from, err)
ExpectTrue(t, Invalid("foo", "bar").Is(ErrInvalid))
ExpectFalse(t, Invalid("foo", "bar").Is(ErrFailure))
ExpectTrue(t, err.Is(from))
ExpectFalse(t, err.Is(New("error")))
ExpectFalse(t, Invalid("foo", "bar").Is(nil))
ExpectTrue(t, errors.Is(Failure("foo").Error(), ErrFailure))
ExpectTrue(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrInvalid))
ExpectTrue(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrFailure))
ExpectFalse(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrNotExists))
ExpectTrue(t, errors.Is(err.Subject("foo"), from))
ExpectTrue(t, errors.Is(err.Withf("foo"), from))
ExpectTrue(t, errors.Is(err.Subject("foo").Withf("bar"), from))
}
func TestErrorNestedIs(t *testing.T) {
var err NestedError
ExpectTrue(t, err.Is(nil))
func TestErrorImmutability(t *testing.T) {
err := New("err")
err2 := New("err2")
err = Failure("some reason")
ExpectTrue(t, err.Is(ErrFailure))
ExpectFalse(t, err.Is(ErrDuplicated))
for range 3 {
// t.Logf("%d: %v %T %s", i, errors.Unwrap(err), err, err)
err.Subject("foo")
ExpectFalse(t, strings.Contains(err.Error(), "foo"))
err.With(Duplicated("something", ""))
ExpectTrue(t, err.Is(ErrFailure))
ExpectTrue(t, err.Is(ErrDuplicated))
ExpectFalse(t, err.Is(ErrInvalid))
}
err.With(err2)
ExpectFalse(t, strings.Contains(err.Error(), "extra"))
ExpectFalse(t, err.Is(err2))
func TestIsNil(t *testing.T) {
var err NestedError
ExpectTrue(t, err.Is(nil))
ExpectFalse(t, err.HasError())
ExpectTrue(t, err == nil)
ExpectTrue(t, err.NoError())
eb := NewBuilder("")
returnNil := func() error {
return eb.Build().Error()
err = err.Subject("bar").Withf("baz")
ExpectTrue(t, err != nil)
}
ExpectTrue(t, IsNil(returnNil()))
ExpectTrue(t, returnNil() == nil)
ExpectTrue(t, (err.
Subject("any").
With("something").
Extraf("foo %s", "bar")) == nil)
}
func TestErrorSimple(t *testing.T) {
ne := Failure("foo bar")
ExpectEqual(t, ne.String(), "foo bar failed")
ne = ne.Subject("baz")
ExpectEqual(t, ne.String(), "foo bar failed for \"baz\"")
}
func TestErrorWith(t *testing.T) {
ne := Failure("foo").With("bar").With("baz")
ExpectEqual(t, ne.String(), "foo failed:\n - bar\n - baz")
err1 := New("err1")
err2 := New("err2")
err3 := err1.With(err2)
ExpectTrue(t, err3.Is(err1))
ExpectTrue(t, err3.Is(err2))
err2.Subject("foo")
ExpectTrue(t, err3.Is(err1))
ExpectTrue(t, err3.Is(err2))
// check if err3 is affected by err2.Subject
ExpectFalse(t, strings.Contains(err3.Error(), "foo"))
}
func TestErrorNested(t *testing.T) {
inner := Failure("inner").
With("1").
With("1")
inner2 := Failure("inner2").
func TestErrorStringSimple(t *testing.T) {
errFailure := New("generic failure")
ne := errFailure.Subject("foo bar")
ExpectStrEqual(t, ne.Error(), "foo bar: generic failure")
ne = ne.Subject("baz")
ExpectStrEqual(t, ne.Error(), "baz > foo bar: generic failure")
}
func TestErrorStringNested(t *testing.T) {
errFailure := New("generic failure")
inner := errFailure.Subject("inner").
Withf("1").
Withf("1")
inner2 := errFailure.Subject("inner2").
Subject("action 2").
With("2").
With("2")
inner3 := Failure("inner3").
Withf("2").
Withf("2")
inner3 := errFailure.Subject("inner3").
Subject("action 3").
With("3").
With("3")
ne := Failure("foo").
With("bar").
With("baz").
Withf("3").
Withf("3")
ne := errFailure.
Subject("foo").
Withf("bar").
Withf("baz").
With(inner).
With(inner.With(inner2.With(inner3)))
want :=
`foo failed:
- bar
- baz
- inner failed:
- 1
- 1
- inner failed:
- 1
- 1
- inner2 failed for "action 2":
- 2
- 2
- inner3 failed for "action 3":
- 3
- 3`
ExpectEqual(t, ne.String(), want)
ExpectEqual(t, ne.Error().Error(), want)
want := `foo: generic failure
• bar
baz
• inner: generic failure
• 1
1
• inner: generic failure
• 1
1
• action 2 > inner2: generic failure
• 2
2
• action 3 > inner3: generic failure
• 3
3`
ExpectStrEqual(t, ne.Error(), want)
}

View File

@@ -1,77 +0,0 @@
package error
import (
stderrors "errors"
"reflect"
)
var (
ErrFailure = stderrors.New("failed")
ErrInvalid = stderrors.New("invalid")
ErrUnsupported = stderrors.New("unsupported")
ErrUnexpected = stderrors.New("unexpected")
ErrNotExists = stderrors.New("does not exist")
ErrMissing = stderrors.New("missing")
ErrDuplicated = stderrors.New("duplicated")
ErrOutOfRange = stderrors.New("out of range")
ErrTypeError = stderrors.New("type error")
ErrTypeMismatch = stderrors.New("type mismatch")
)
const fmtSubjectWhat = "%w %v: %q"
func Failure(what string) NestedError {
return errorf("%s %w", what, ErrFailure)
}
func FailedWhy(what string, why string) NestedError {
return Failure(what).With(why)
}
func FailWith(what string, err any) NestedError {
return Failure(what).With(err)
}
func Invalid(subject, what any) NestedError {
return errorf(fmtSubjectWhat, ErrInvalid, subject, what)
}
func Unsupported(subject, what any) NestedError {
return errorf(fmtSubjectWhat, ErrUnsupported, subject, what)
}
func Unexpected(subject, what any) NestedError {
return errorf(fmtSubjectWhat, ErrUnexpected, subject, what)
}
func UnexpectedError(err error) NestedError {
return errorf("%w error: %w", ErrUnexpected, err)
}
func NotExist(subject, what any) NestedError {
return errorf("%v %w: %v", subject, ErrNotExists, what)
}
func Missing(subject any) NestedError {
return errorf("%w %v", ErrMissing, subject)
}
func Duplicated(subject, what any) NestedError {
return errorf("%w %v: %v", ErrDuplicated, subject, what)
}
func OutOfRange(subject any, value any) NestedError {
return errorf("%v %w: %v", subject, ErrOutOfRange, value)
}
func TypeError(subject any, from, to reflect.Type) NestedError {
return errorf("%v %w: %s -> %s\n", subject, ErrTypeError, from, to)
}
func TypeError2(subject any, from, to reflect.Value) NestedError {
return TypeError(subject, from.Type(), to.Type())
}
func TypeMismatch[Expect any](value any) NestedError {
return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value)
}

43
internal/error/log.go Normal file
View File

@@ -0,0 +1,43 @@
package error
import (
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/logging"
)
func getLogger(logger ...*zerolog.Logger) *zerolog.Logger {
if len(logger) > 0 {
return logger[0]
}
return logging.GetLogger()
}
//go:inline
func LogFatal(msg string, err error, logger ...*zerolog.Logger) {
getLogger(logger...).Fatal().Msg(err.Error())
}
//go:inline
func LogError(msg string, err error, logger ...*zerolog.Logger) {
getLogger(logger...).Error().Msg(err.Error())
}
//go:inline
func LogWarn(msg string, err error, logger ...*zerolog.Logger) {
getLogger(logger...).Warn().Msg(err.Error())
}
//go:inline
func LogPanic(msg string, err error, logger ...*zerolog.Logger) {
getLogger(logger...).Panic().Msg(err.Error())
}
//go:inline
func LogInfo(msg string, err error, logger ...*zerolog.Logger) {
getLogger(logger...).Info().Msg(err.Error())
}
//go:inline
func LogDebug(msg string, err error, logger ...*zerolog.Logger) {
getLogger(logger...).Debug().Msg(err.Error())
}

View File

@@ -0,0 +1,120 @@
package error
import (
"errors"
"fmt"
"strings"
)
type nestedError struct {
Err error `json:"err"`
Extras []error `json:"extras"`
}
func (err nestedError) Subject(subject string) Error {
if err.Err == nil {
err.Err = newError(subject)
} else {
err.Err = PrependSubject(subject, err.Err)
}
return &err
}
func (err *nestedError) Subjectf(format string, args ...any) Error {
if len(args) > 0 {
return err.Subject(fmt.Sprintf(format, args...))
}
return err.Subject(format)
}
func (err nestedError) With(extra error) Error {
if extra != nil {
err.Extras = append(err.Extras, extra)
}
return &err
}
func (err nestedError) Withf(format string, args ...any) Error {
if len(args) > 0 {
err.Extras = append(err.Extras, fmt.Errorf(format, args...))
} else {
err.Extras = append(err.Extras, newError(format))
}
return &err
}
func (err *nestedError) Unwrap() []error {
if err.Err == nil {
if len(err.Extras) == 0 {
return nil
}
return err.Extras
}
return append([]error{err.Err}, err.Extras...)
}
func (err *nestedError) Is(other error) bool {
if errors.Is(err.Err, other) {
return true
}
for _, e := range err.Extras {
if errors.Is(e, other) {
return true
}
}
return false
}
func (err *nestedError) Error() string {
return buildError(err, 0)
}
//go:inline
func makeLine(err string, level int) string {
const bulletPrefix = "• "
const spaces = " "
if level == 0 {
return err
}
return spaces[:2*level] + bulletPrefix + err
}
func makeLines(errs []error, level int) []string {
if len(errs) == 0 {
return nil
}
lines := make([]string, 0, len(errs))
for _, err := range errs {
switch err := err.(type) {
case *nestedError:
if err.Err != nil {
lines = append(lines, makeLine(err.Err.Error(), level))
}
if extras := makeLines(err.Extras, level+1); len(extras) > 0 {
lines = append(lines, extras...)
}
default:
lines = append(lines, makeLine(err.Error(), level))
}
}
return lines
}
func buildError(err error, level int) string {
switch err := err.(type) {
case nil:
return makeLine("<nil>", level)
case *nestedError:
lines := make([]string, 0, 1+len(err.Extras))
if err.Err != nil {
lines = append(lines, makeLine(err.Err.Error(), level))
}
if extras := makeLines(err.Extras, level+1); len(extras) > 0 {
lines = append(lines, extras...)
}
return strings.Join(lines, "\n")
default:
return makeLine(err.Error(), level)
}
}

52
internal/error/subject.go Normal file
View File

@@ -0,0 +1,52 @@
package error
import (
"strings"
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
)
type withSubject struct {
Subject string `json:"subject"`
Err error `json:"err"`
}
const subjectSep = " > "
func highlight(subject string) string {
return ansi.HighlightRed + subject + ansi.Reset
}
func PrependSubject(subject string, err error) error {
switch err := err.(type) {
case nil:
return nil
case *withSubject:
return err.Prepend(subject)
case Error:
return err.Subject(subject)
default:
return &withSubject{subject, err}
}
}
func (err withSubject) Prepend(subject string) *withSubject {
if subject != "" {
err.Subject = subject + subjectSep + err.Subject
}
return &err
}
func (err *withSubject) Is(other error) bool {
return err.Err == other
}
func (err *withSubject) Unwrap() error {
return err.Err
}
func (err *withSubject) Error() string {
subjects := strings.Split(err.Subject, subjectSep)
subjects[len(subjects)-1] = highlight(subjects[len(subjects)-1])
return strings.Join(subjects, subjectSep) + ": " + err.Err.Error()
}

68
internal/error/utils.go Normal file
View File

@@ -0,0 +1,68 @@
package error
import (
"fmt"
)
func newError(message string) error {
return errStr(message)
}
func New(message string) Error {
if message == "" {
return nil
}
return &baseError{newError(message)}
}
func Errorf(format string, args ...any) Error {
return &baseError{fmt.Errorf(format, args...)}
}
func From(err error) Error {
if err == nil {
return nil
}
if err, ok := err.(Error); ok {
return err
}
return &baseError{err}
}
func Must[T any](v T, err error) T {
if err != nil {
LogPanic("must failed", err)
}
return v
}
func Join(errors ...error) Error {
n := 0
for _, err := range errors {
if err != nil {
n++
}
}
if n == 0 {
return nil
}
errs := make([]error, 0, n)
for _, err := range errors {
if err != nil {
errs = append(errs, err)
}
}
return &nestedError{Extras: errs}
}
func Collect[T any, Err error, Arg any, Func func(Arg) (T, Err)](eb *Builder, fn Func, arg Arg) T {
result, err := fn(arg)
eb.Add(err)
return result
}
func Collect2[T any, Err error, Arg1 any, Arg2 any, Func func(Arg1, Arg2) (T, Err)](eb *Builder, fn Func, arg1 Arg1, arg2 Arg2) T {
result, err := fn(arg1, arg2)
eb.Add(err)
return result
}

View File

@@ -1,24 +1,24 @@
package homepage
type (
HomePageConfig map[string]HomePageCategory
HomePageCategory []*HomePageItem
Config map[string]Category
Category []*Item
HomePageItem struct {
Show bool `yaml:"show" json:"show"`
Name string `yaml:"name" json:"name"`
Icon string `yaml:"icon" json:"icon"`
URL string `yaml:"url" json:"url"` // alias + domain
Category string `yaml:"category" json:"category"`
Description string `yaml:"description" json:"description"`
WidgetConfig map[string]any `yaml:",flow" json:"widget_config"`
Item struct {
Show bool `json:"show" yaml:"show"`
Name string `json:"name" yaml:"name"`
Icon string `json:"icon" yaml:"icon"`
URL string `json:"url" yaml:"url"` // alias + domain
Category string `json:"category" yaml:"category"`
Description string `json:"description" yaml:"description"`
WidgetConfig map[string]any `json:"widget_config" yaml:",flow"`
SourceType string `yaml:"-" json:"source_type"`
AltURL string `yaml:"-" json:"alt_url"` // original proxy target
SourceType string `json:"source_type" yaml:"-"`
AltURL string `json:"alt_url" yaml:"-"` // original proxy target
}
)
func (item *HomePageItem) IsEmpty() bool {
func (item *Item) IsEmpty() bool {
return item == nil || (item.Name == "" &&
item.Icon == "" &&
item.URL == "" &&
@@ -27,17 +27,17 @@ func (item *HomePageItem) IsEmpty() bool {
len(item.WidgetConfig) == 0)
}
func NewHomePageConfig() HomePageConfig {
return HomePageConfig(make(map[string]HomePageCategory))
func NewHomePageConfig() Config {
return Config(make(map[string]Category))
}
func (c *HomePageConfig) Clear() {
*c = make(HomePageConfig)
func (c *Config) Clear() {
*c = make(Config)
}
func (c HomePageConfig) Add(item *HomePageItem) {
func (c Config) Add(item *Item) {
if c[item.Category] == nil {
c[item.Category] = make(HomePageCategory, 0)
c[item.Category] = make(Category, 0)
}
c[item.Category] = append(c[item.Category], item)
}

View File

@@ -4,12 +4,11 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"time"
"log"
"github.com/yusing/go-proxy/internal/utils"
)
@@ -21,8 +20,10 @@ type GitHubContents struct { //! keep this, may reuse in future
Size int `json:"size"`
}
const iconsCachePath = "/tmp/icons_cache.json"
const updateInterval = 1 * time.Hour
const (
iconsCachePath = "/tmp/icons_cache.json"
updateInterval = 1 * time.Hour
)
func ListAvailableIcons() ([]string, error) {
owner := "walkxcode"
@@ -30,13 +31,14 @@ func ListAvailableIcons() ([]string, error) {
ref := "main"
var lastUpdate time.Time
var icons = make([]string, 0)
icons := make([]string, 0)
info, err := os.Stat(iconsCachePath)
if err == nil {
lastUpdate = info.ModTime().Local()
}
if time.Since(lastUpdate) < updateInterval {
err := utils.LoadJson(iconsCachePath, &icons)
err := utils.LoadJSON(iconsCachePath, &icons)
if err == nil {
return icons, nil
}
@@ -51,7 +53,7 @@ func ListAvailableIcons() ([]string, error) {
icons = append(icons, content.Path)
}
}
err = utils.SaveJson(iconsCachePath, &icons, 0o644).Error()
err = utils.SaveJSON(iconsCachePath, &icons, 0o644)
if err != nil {
log.Print("error saving cache", err)
}
@@ -59,7 +61,7 @@ func ListAvailableIcons() ([]string, error) {
}
func getRepoContents(client *http.Client, owner string, repo string, ref string, path string) ([]GitHubContents, error) {
req, err := http.NewRequest("GET", fmt.Sprintf("https://api.github.com/repos/%s/%s/contents/%s?ref=%s", owner, repo, path, ref), nil)
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://api.github.com/repos/%s/%s/contents/%s?ref=%s", owner, repo, path, ref), nil)
if err != nil {
return nil, err
}

View File

@@ -0,0 +1,70 @@
package logging
import (
"os"
"strings"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/common"
)
var logger zerolog.Logger
func init() {
var timeFmt string
var level zerolog.Level
var exclude []string
switch {
case common.IsTrace:
timeFmt = "04:05"
level = zerolog.TraceLevel
case common.IsDebug:
timeFmt = "01-02 15:04"
level = zerolog.DebugLevel
default:
timeFmt = "01-02 15:04"
level = zerolog.InfoLevel
exclude = []string{"module"}
}
prefixLength := len(timeFmt) + 5 // level takes 3 + 2 spaces
prefix := strings.Repeat(" ", prefixLength)
logger = zerolog.New(
zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: timeFmt,
FieldsExclude: exclude,
FormatMessage: func(msgI interface{}) string { // pad spaces for each line
msg := msgI.(string)
lines := strings.Split(msg, "\n")
if len(lines) == 1 {
return msg
}
for i := 1; i < len(lines); i++ {
lines[i] = prefix + lines[i]
}
return strings.Join(lines, "\n")
},
},
).Level(level).With().Timestamp().Logger()
}
func DiscardLogger() { logger = zerolog.Nop() }
func AddHook(h zerolog.Hook) { logger = logger.Hook(h) }
func GetLogger() *zerolog.Logger { return &logger }
func With() zerolog.Context { return logger.With() }
func WithLevel(level zerolog.Level) *zerolog.Event { return logger.WithLevel(level) }
func Info() *zerolog.Event { return logger.Info() }
func Warn() *zerolog.Event { return logger.Warn() }
func Error() *zerolog.Event { return logger.Error() }
func Err(err error) *zerolog.Event { return logger.Err(err) }
func Debug() *zerolog.Event { return logger.Debug() }
func Fatal() *zerolog.Event { return logger.Fatal() }
func Panic() *zerolog.Event { return logger.Panic() }
func Trace() *zerolog.Event { return logger.Trace() }

View File

@@ -16,13 +16,12 @@ var (
Proxy: http.ProxyFromEnvironment,
DialContext: defaultDialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
MaxIdleConnsPerHost: 100,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
DefaultTransportNoTLS = func() *http.Transport {
var clone = DefaultTransport.Clone()
clone := DefaultTransport.Clone()
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
return clone
}()

View File

@@ -5,7 +5,10 @@ import (
"net/http"
)
type ContentType string
type (
ContentType string
AcceptContentType []ContentType
)
func GetContentType(h http.Header) ContentType {
ct := h.Get("Content-Type")
@@ -19,6 +22,18 @@ func GetContentType(h http.Header) ContentType {
return ContentType(ct)
}
func GetAccept(h http.Header) AcceptContentType {
var accepts []ContentType
for _, v := range h["Accept"] {
ct, _, err := mime.ParseMediaType(v)
if err != nil {
continue
}
accepts = append(accepts, ContentType(ct))
}
return accepts
}
func (ct ContentType) IsHTML() bool {
return ct == "text/html" || ct == "application/xhtml+xml"
}
@@ -30,3 +45,34 @@ func (ct ContentType) IsJSON() bool {
func (ct ContentType) IsPlainText() bool {
return ct == "text/plain"
}
func (act AcceptContentType) IsEmpty() bool {
return len(act) == 0
}
func (act AcceptContentType) AcceptHTML() bool {
for _, v := range act {
if v.IsHTML() || v == "text/*" || v == "*/*" {
return true
}
}
return false
}
func (act AcceptContentType) AcceptJSON() bool {
for _, v := range act {
if v.IsJSON() || v == "*/*" {
return true
}
}
return false
}
func (act AcceptContentType) AcceptPlainText() bool {
for _, v := range act {
if v.IsPlainText() || v == "text/*" || v == "*/*" {
return true
}
}
return false
}

View File

@@ -0,0 +1,41 @@
package http
import (
"net/http"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestContentTypes(t *testing.T) {
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsHTML())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/html; charset=utf-8"}}).IsHTML())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/xhtml+xml"}}).IsHTML())
ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/plain"}}).IsHTML())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/json"}}).IsJSON())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/json; charset=utf-8"}}).IsJSON())
ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsJSON())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/plain"}}).IsPlainText())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/plain; charset=utf-8"}}).IsPlainText())
ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsPlainText())
}
func TestAcceptContentTypes(t *testing.T) {
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain"}}).AcceptPlainText())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain; charset=utf-8"}}).AcceptPlainText())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain"}}).AcceptHTML())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"application/json"}}).AcceptJSON())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptPlainText())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptHTML())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptJSON())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptPlainText())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptHTML())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/plain"}}).AcceptHTML())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/plain; charset=utf-8"}}).AcceptHTML())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/html"}}).AcceptPlainText())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/html"}}).AcceptJSON())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptJSON())
}

View File

@@ -0,0 +1,15 @@
package http
import "net/http"
type DummyResponseWriter struct{}
func (w DummyResponseWriter) Header() http.Header {
return make(http.Header)
}
func (w DummyResponseWriter) Write([]byte) (_ int, _ error) {
return
}
func (w DummyResponseWriter) WriteHeader(int) {}

View File

@@ -0,0 +1,91 @@
package loadbalancer
import (
"hash/fnv"
"net"
"net/http"
"sync"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware"
)
type ipHash struct {
*LoadBalancer
realIP *middleware.Middleware
pool servers
mu sync.Mutex
}
func (lb *LoadBalancer) newIPHash() impl {
impl := &ipHash{LoadBalancer: lb}
if len(lb.Options) == 0 {
return impl
}
var err E.Error
impl.realIP, err = middleware.NewRealIP(lb.Options)
if err != nil {
E.LogError("invalid real_ip options, ignoring", err, &impl.Logger)
}
return impl
}
func (impl *ipHash) OnAddServer(srv *Server) {
impl.mu.Lock()
defer impl.mu.Unlock()
for i, s := range impl.pool {
if s == srv {
return
}
if s == nil {
impl.pool[i] = srv
return
}
}
impl.pool = append(impl.pool, srv)
}
func (impl *ipHash) OnRemoveServer(srv *Server) {
impl.mu.Lock()
defer impl.mu.Unlock()
for i, s := range impl.pool {
if s == srv {
impl.pool[i] = nil
return
}
}
}
func (impl *ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) {
if impl.realIP != nil {
impl.realIP.ModifyRequest(impl.serveHTTP, rw, r)
} else {
impl.serveHTTP(rw, r)
}
}
func (impl *ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(rw, "Internal error", http.StatusInternalServerError)
impl.Err(err).Msg("invalid remote address " + r.RemoteAddr)
return
}
idx := hashIP(ip) % uint32(len(impl.pool))
srv := impl.pool[idx]
if srv == nil || srv.Status().Bad() {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
}
srv.ServeHTTP(rw, r)
}
func hashIP(ip string) uint32 {
h := fnv.New32a()
h.Write([]byte(ip))
return h.Sum32()
}

View File

@@ -0,0 +1,53 @@
package loadbalancer
import (
"net/http"
"sync/atomic"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type leastConn struct {
*LoadBalancer
nConn F.Map[*Server, *atomic.Int64]
}
func (lb *LoadBalancer) newLeastConn() impl {
return &leastConn{
LoadBalancer: lb,
nConn: F.NewMapOf[*Server, *atomic.Int64](),
}
}
func (impl *leastConn) OnAddServer(srv *Server) {
impl.nConn.Store(srv, new(atomic.Int64))
}
func (impl *leastConn) OnRemoveServer(srv *Server) {
impl.nConn.Delete(srv)
}
func (impl *leastConn) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
srv := srvs[0]
minConn, ok := impl.nConn.Load(srv)
if !ok {
impl.Error().Msgf("[BUG] server %s not found", srv.Name)
http.Error(rw, "Internal error", http.StatusInternalServerError)
}
for i := 1; i < len(srvs); i++ {
nConn, ok := impl.nConn.Load(srvs[i])
if !ok {
impl.Error().Msgf("[BUG] server %s not found", srv.Name)
http.Error(rw, "Internal error", http.StatusInternalServerError)
}
if nConn.Load() < minConn.Load() {
minConn = nConn
srv = srvs[i]
}
}
minConn.Add(1)
srv.ServeHTTP(rw, r)
minConn.Add(-1)
}

View File

@@ -0,0 +1,297 @@
package loadbalancer
import (
"context"
"net/http"
"sync"
"time"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/common"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/health"
)
// TODO: stats of each server.
// TODO: support weighted mode.
type (
impl interface {
ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request)
OnAddServer(srv *Server)
OnRemoveServer(srv *Server)
}
Config struct {
Link string `json:"link" yaml:"link"`
Mode Mode `json:"mode" yaml:"mode"`
Weight weightType `json:"weight" yaml:"weight"`
Options middleware.OptionsRaw `json:"options,omitempty" yaml:"options,omitempty"`
}
LoadBalancer struct {
zerolog.Logger
impl
*Config
task task.Task
pool Pool
poolMu sync.Mutex
sumWeight weightType
startTime time.Time
}
weightType uint16
)
const maxWeight weightType = 100
func New(cfg *Config) *LoadBalancer {
lb := &LoadBalancer{
Logger: logger.With().Str("name", cfg.Link).Logger(),
Config: new(Config),
pool: newPool(),
}
lb.UpdateConfigIfNeeded(cfg)
return lb
}
// Start implements task.TaskStarter.
func (lb *LoadBalancer) Start(routeSubtask task.Task) E.Error {
lb.startTime = time.Now()
lb.task = routeSubtask
lb.task.OnFinished("loadbalancer cleanup", func() {
if lb.impl != nil {
lb.pool.RangeAll(func(k string, v *Server) {
lb.impl.OnRemoveServer(v)
})
}
lb.pool.Clear()
})
return nil
}
// Finish implements task.TaskFinisher.
func (lb *LoadBalancer) Finish(reason any) {
lb.task.Finish(reason)
}
func (lb *LoadBalancer) updateImpl() {
switch lb.Mode {
case Unset, RoundRobin:
lb.impl = lb.newRoundRobin()
case LeastConn:
lb.impl = lb.newLeastConn()
case IPHash:
lb.impl = lb.newIPHash()
default: // should happen in test only
lb.impl = lb.newRoundRobin()
}
lb.pool.RangeAll(func(_ string, srv *Server) {
lb.impl.OnAddServer(srv)
})
}
func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
if cfg != nil {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
lb.Link = cfg.Link
if lb.Mode == Unset && cfg.Mode != Unset {
lb.Mode = cfg.Mode
if !lb.Mode.ValidateUpdate() {
lb.Error().Msgf("invalid mode %q, fallback to %q", cfg.Mode, lb.Mode)
}
lb.updateImpl()
}
if len(lb.Options) == 0 && len(cfg.Options) > 0 {
lb.Options = cfg.Options
}
}
if lb.impl == nil {
lb.updateImpl()
}
}
func (lb *LoadBalancer) AddServer(srv *Server) {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
if lb.pool.Has(srv.Name) {
old, _ := lb.pool.Load(srv.Name)
lb.sumWeight -= old.Weight
lb.impl.OnRemoveServer(old)
}
lb.pool.Store(srv.Name, srv)
lb.sumWeight += srv.Weight
lb.rebalance()
lb.impl.OnAddServer(srv)
lb.Debug().
Str("action", "add").
Str("server", srv.Name).
Msgf("%d servers available", lb.pool.Size())
}
func (lb *LoadBalancer) RemoveServer(srv *Server) {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
if !lb.pool.Has(srv.Name) {
return
}
lb.pool.Delete(srv.Name)
lb.sumWeight -= srv.Weight
lb.rebalance()
lb.impl.OnRemoveServer(srv)
lb.Debug().
Str("action", "remove").
Str("server", srv.Name).
Msgf("%d servers left", lb.pool.Size())
if lb.pool.Size() == 0 {
lb.task.Finish("no server left")
return
}
}
func (lb *LoadBalancer) rebalance() {
if lb.sumWeight == maxWeight {
return
}
if lb.pool.Size() == 0 {
return
}
if lb.sumWeight == 0 { // distribute evenly
weightEach := maxWeight / weightType(lb.pool.Size())
remainder := maxWeight % weightType(lb.pool.Size())
lb.pool.RangeAll(func(_ string, s *Server) {
s.Weight = weightEach
lb.sumWeight += weightEach
if remainder > 0 {
s.Weight++
remainder--
}
})
return
}
// scale evenly
scaleFactor := float64(maxWeight) / float64(lb.sumWeight)
lb.sumWeight = 0
lb.pool.RangeAll(func(_ string, s *Server) {
s.Weight = weightType(float64(s.Weight) * scaleFactor)
lb.sumWeight += s.Weight
})
delta := maxWeight - lb.sumWeight
if delta == 0 {
return
}
lb.pool.Range(func(_ string, s *Server) bool {
if delta == 0 {
return false
}
if delta > 0 {
s.Weight++
lb.sumWeight++
delta--
} else {
s.Weight--
lb.sumWeight--
delta++
}
return true
})
}
func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
srvs := lb.availServers()
if len(srvs) == 0 {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return
}
if r.Header.Get(common.HeaderCheckRedirect) != "" {
ctx, cancel := context.WithTimeout(r.Context(), 1*time.Second)
defer cancel()
// send dummy request to wake all servers
var dummyRW gphttp.DummyResponseWriter
for _, srv := range srvs {
// wake only if server implements Waker
_, ok := srv.handler.(idlewatcher.Waker)
if !ok {
continue
}
wakeReq := r.Clone(ctx)
srv.ServeHTTP(dummyRW, wakeReq)
}
}
lb.impl.ServeHTTP(srvs, rw, r)
}
func (lb *LoadBalancer) Uptime() time.Duration {
return time.Since(lb.startTime)
}
// MarshalJSON implements health.HealthMonitor.
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
extra := make(map[string]any)
lb.pool.RangeAll(func(k string, v *Server) {
extra[v.Name] = v.healthMon
})
return (&health.JSONRepresentation{
Name: lb.Name(),
Status: lb.Status(),
Started: lb.startTime,
Uptime: lb.Uptime(),
Extra: map[string]any{
"config": lb.Config,
"pool": extra,
},
}).MarshalJSON()
}
// Name implements health.HealthMonitor.
func (lb *LoadBalancer) Name() string {
return lb.Link
}
// Status implements health.HealthMonitor.
func (lb *LoadBalancer) Status() health.Status {
if lb.pool.Size() == 0 {
return health.StatusUnknown
}
if len(lb.availServers()) == 0 {
return health.StatusUnhealthy
}
return health.StatusHealthy
}
// String implements health.HealthMonitor.
func (lb *LoadBalancer) String() string {
return lb.Name()
}
func (lb *LoadBalancer) availServers() []*Server {
avail := make([]*Server, 0, lb.pool.Size())
lb.pool.RangeAll(func(_ string, srv *Server) {
if srv.Status().Good() {
avail = append(avail, srv)
}
})
return avail
}

View File

@@ -0,0 +1,43 @@
package loadbalancer
import (
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRebalance(t *testing.T) {
t.Parallel()
t.Run("zero", func(t *testing.T) {
lb := New(new(Config))
for range 10 {
lb.AddServer(&Server{})
}
lb.rebalance()
ExpectEqual(t, lb.sumWeight, maxWeight)
})
t.Run("less", func(t *testing.T) {
lb := New(new(Config))
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb.rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight)
})
t.Run("more", func(t *testing.T) {
lb := New(new(Config))
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .4)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb.rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight)
})
}

View File

@@ -0,0 +1,5 @@
package loadbalancer
import "github.com/yusing/go-proxy/internal/logging"
var logger = logging.With().Str("module", "load_balancer").Logger()

View File

@@ -0,0 +1,32 @@
package loadbalancer
import (
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Mode string
const (
Unset Mode = ""
RoundRobin Mode = "roundrobin"
LeastConn Mode = "leastconn"
IPHash Mode = "iphash"
)
func (mode *Mode) ValidateUpdate() bool {
switch strutils.ToLowerNoSnake(string(*mode)) {
case "":
return true
case string(RoundRobin):
*mode = RoundRobin
return true
case string(LeastConn):
*mode = LeastConn
return true
case string(IPHash):
*mode = IPHash
return true
}
*mode = RoundRobin
return false
}

View File

@@ -0,0 +1,22 @@
package loadbalancer
import (
"net/http"
"sync/atomic"
)
type roundRobin struct {
index atomic.Uint32
}
func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} }
func (lb *roundRobin) OnAddServer(srv *Server) {}
func (lb *roundRobin) OnRemoveServer(srv *Server) {}
func (lb *roundRobin) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
index := lb.index.Add(1) % uint32(len(srvs))
srvs[index].ServeHTTP(rw, r)
if lb.index.Load() >= 2*uint32(len(srvs)) {
lb.index.Store(0)
}
}

View File

@@ -0,0 +1,55 @@
package loadbalancer
import (
"net/http"
"time"
"github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type (
Server struct {
_ U.NoCopy
Name string
URL types.URL
Weight weightType
handler http.Handler
healthMon health.HealthMonitor
}
servers = []*Server
Pool = F.Map[string, *Server]
)
var newPool = F.NewMap[Pool]
func NewServer(name string, url types.URL, weight weightType, handler http.Handler, healthMon health.HealthMonitor) *Server {
srv := &Server{
Name: name,
URL: url,
Weight: weight,
handler: handler,
healthMon: healthMon,
}
return srv
}
func (srv *Server) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
srv.handler.ServeHTTP(rw, r)
}
func (srv *Server) String() string {
return srv.Name
}
func (srv *Server) Status() health.Status {
return srv.healthMon.Status()
}
func (srv *Server) Uptime() time.Duration {
return srv.healthMon.Uptime()
}

View File

@@ -0,0 +1,5 @@
package http
import "github.com/yusing/go-proxy/internal/logging"
var logger = logging.With().Str("module", "http").Logger()

View File

@@ -5,7 +5,7 @@ import (
"net/http"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/net/types"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
@@ -15,9 +15,9 @@ type cidrWhitelist struct {
}
type cidrWhitelistOpts struct {
Allow []*types.CIDR
StatusCode int
Message string
Allow []*types.CIDR `json:"allow"`
StatusCode int `json:"statusCode"`
Message string `json:"message"`
cachedAddr F.Map[string, bool] // cache for trusted IPs
}
@@ -35,7 +35,7 @@ var cidrWhitelistDefaults = func() *cidrWhitelistOpts {
}
}
func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.NestedError) {
func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
wl := new(cidrWhitelist)
wl.m = &Middleware{
impl: wl,
@@ -47,7 +47,7 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.NestedError) {
return nil, err
}
if len(wl.cidrWhitelistOpts.Allow) == 0 {
return nil, E.Missing("allow range")
return nil, E.New("no allowed CIDRs")
}
return wl.m, nil
}

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"testing"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@@ -13,10 +14,9 @@ var testCIDRWhitelistCompose []byte
var deny, accept *Middleware
func TestCIDRWhitelist(t *testing.T) {
mids, err := BuildMiddlewaresFromYAML(testCIDRWhitelistCompose)
if err != nil {
panic(err)
}
errs := E.NewBuilder("")
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
ExpectNoError(t, errs.Error())
deny = mids["deny@file"]
accept = mids["accept@file"]
if deny == nil || accept == nil {
@@ -26,7 +26,7 @@ func TestCIDRWhitelist(t *testing.T) {
t.Run("deny", func(t *testing.T) {
for range 10 {
result, err := newMiddlewareTest(deny, nil)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults().StatusCode)
ExpectEqual(t, string(result.Data), cidrWhitelistDefaults().Message)
}
@@ -35,7 +35,7 @@ func TestCIDRWhitelist(t *testing.T) {
t.Run("accept", func(t *testing.T) {
for range 10 {
result, err := newMiddlewareTest(accept, nil)
ExpectNoError(t, err.Error())
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
}
})

View File

@@ -10,10 +10,10 @@ import (
"sync"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
const (
@@ -26,14 +26,14 @@ const (
var (
cfCIDRsLastUpdate time.Time
cfCIDRsMu sync.Mutex
cfCIDRsLogger = logrus.WithField("middleware", "CloudflareRealIP")
cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger()
)
var CloudflareRealIP = &realIP{
m: &Middleware{withOptions: NewCloudflareRealIP},
}
func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.NestedError) {
func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) {
cri := new(realIP)
cri.m = &Middleware{
impl: cri,
@@ -80,13 +80,13 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
)
if err != nil {
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)
cfCIDRsLogger.Errorf("failed to update cloudflare range: %s, retry in %s", err, cfCIDRsUpdateRetryInterval)
cfCIDRsLogger.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
return nil
}
}
cfCIDRsLastUpdate = time.Now()
cfCIDRsLogger.Debugf("cloudflare CIDR range updated")
cfCIDRsLogger.Info().Msg("cloudflare CIDR range updated")
return
}
@@ -109,9 +109,9 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error {
_, cidr, err := net.ParseCIDR(line)
if err != nil {
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
} else {
cfCIDRs = append(cfCIDRs, (*types.CIDR)(cidr))
}
cfCIDRs = append(cfCIDRs, (*types.CIDR)(cidr))
}
return nil

View File

@@ -2,43 +2,52 @@ package middleware
import (
"bytes"
"fmt"
"io"
"net/http"
"path/filepath"
"strconv"
"strings"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/api/v1/error_page"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
)
var CustomErrorPage = &Middleware{
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
if !ServeStaticErrorPageFile(w, r) {
next(w, r)
}
},
modifyResponse: func(resp *Response) error {
var CustomErrorPage *Middleware
func init() {
CustomErrorPage = customErrorPage()
}
func customErrorPage() *Middleware {
m := &Middleware{
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
if !ServeStaticErrorPageFile(w, r) {
next(w, r)
}
},
}
m.modifyResponse = func(resp *Response) error {
// only handles non-success status code and html/plain content type
contentType := gphttp.GetContentType(resp.Header)
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
errorPage, ok := error_page.GetErrorPageByStatus(resp.StatusCode)
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
if ok {
errPageLogger.Debugf("error page for status %d loaded", resp.StatusCode)
CustomErrorPage.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
/* trunk-ignore(golangci-lint/errcheck) */
io.Copy(io.Discard, resp.Body) // drain the original body
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
resp.ContentLength = int64(len(errorPage))
resp.Header.Set("Content-Length", fmt.Sprint(len(errorPage)))
resp.Header.Set("Content-Length", strconv.Itoa(len(errorPage)))
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
} else {
errPageLogger.Errorf("unable to load error page for status %d", resp.StatusCode)
CustomErrorPage.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
}
return nil
}
return nil
},
}
return m
}
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
@@ -48,27 +57,27 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
}
if strings.HasPrefix(path, gphttp.StaticFilePathPrefix) {
filename := path[len(gphttp.StaticFilePathPrefix):]
file, ok := error_page.GetStaticFile(filename)
file, ok := errorpage.GetStaticFile(filename)
if !ok {
errPageLogger.Errorf("unable to load resource %s", filename)
logger.Error().Msg("unable to load resource " + filename)
return false
} else {
ext := filepath.Ext(filename)
switch ext {
case ".html":
w.Header().Set("Content-Type", "text/html; charset=utf-8")
case ".js":
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
case ".css":
w.Header().Set("Content-Type", "text/css; charset=utf-8")
default:
errPageLogger.Errorf("unexpected file type %q for %s", ext, filename)
}
w.Write(file)
return true
}
ext := filepath.Ext(filename)
switch ext {
case ".html":
w.Header().Set("Content-Type", "text/html; charset=utf-8")
case ".js":
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
case ".css":
w.Header().Set("Content-Type", "text/css; charset=utf-8")
default:
logger.Error().Msgf("unexpected file type %q for %s", ext, filename)
}
if _, err := w.Write(file); err != nil {
logger.Err(err).Msg("unable to write resource " + filename)
http.Error(w, "Error page failure", http.StatusInternalServerError)
}
return true
}
return false
}
var errPageLogger = logrus.WithField("middleware", "error_page")

View File

@@ -1,14 +1,15 @@
package error_page
package errorpage
import (
"context"
"fmt"
"os"
"path"
"sync"
api "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
W "github.com/yusing/go-proxy/internal/watcher"
@@ -17,19 +18,34 @@ import (
const errPagesBasePath = common.ErrorPagesBasePath
var setup = sync.OnceFunc(func() {
dirWatcher = W.NewDirectoryWatcher(context.Background(), errPagesBasePath)
var (
setupMu sync.Mutex
dirWatcher W.Watcher
fileContentMap = F.NewMapOf[string, []byte]()
)
func setup() {
setupMu.Lock()
defer setupMu.Unlock()
if dirWatcher != nil {
return
}
task := task.GlobalTask("error page")
dirWatcher = W.NewDirectoryWatcher(task.Subtask("dir watcher"), errPagesBasePath)
loadContent()
go watchDir()
})
go watchDir(task)
}
func GetStaticFile(filename string) ([]byte, bool) {
setup()
return fileContentMap.Load(filename)
}
// try <statusCode>.html -> 404.html -> not ok
// try <statusCode>.html -> 404.html -> not ok.
func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
content, ok = fileContentMap.Load(fmt.Sprintf("%d.html", statusCode))
content, ok = GetStaticFile(fmt.Sprintf("%d.html", statusCode))
if !ok && statusCode != 404 {
return fileContentMap.Load("404.html")
}
@@ -39,7 +55,7 @@ func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
func loadContent() {
files, err := U.ListFiles(errPagesBasePath, 0)
if err != nil {
api.Logger.Error(err)
logger.Err(err).Msg("failed to list error page resources")
return
}
for _, file := range files {
@@ -48,19 +64,21 @@ func loadContent() {
}
content, err := os.ReadFile(file)
if err != nil {
api.Logger.Errorf("failed to read error page resource %s: %s", file, err)
logger.Warn().Err(err).Msgf("failed to read error page resource %s", file)
continue
}
file = path.Base(file)
api.Logger.Infof("error page resource %s loaded", file)
logging.Info().Msgf("error page resource %s loaded", file)
fileContentMap.Store(file, content)
}
}
func watchDir() {
eventCh, errCh := dirWatcher.Events(context.Background())
func watchDir(task task.Task) {
eventCh, errCh := dirWatcher.Events(task.Context())
for {
select {
case <-task.Context().Done():
return
case event, ok := <-eventCh:
if !ok {
return
@@ -72,17 +90,14 @@ func watchDir() {
loadContent()
case events.ActionFileDeleted:
fileContentMap.Delete(filename)
api.Logger.Infof("error page resource %s deleted", filename)
logger.Warn().Msgf("error page resource %s deleted", filename)
case events.ActionFileRenamed:
api.Logger.Infof("error page resource %s deleted", filename)
logger.Warn().Msgf("error page resource %s deleted", filename)
fileContentMap.Delete(filename)
loadContent()
}
case err := <-errCh:
api.Logger.Errorf("error watching error page directory: %s", err)
E.LogError("error watching error page directory", err, &logger)
}
}
}
var dirWatcher W.Watcher
var fileContentMap = F.NewMapOf[string, []byte]()

View File

@@ -0,0 +1,5 @@
package errorpage
import "github.com/yusing/go-proxy/internal/logging"
var logger = logging.With().Str("module", "errorpage").Logger()

View File

@@ -24,11 +24,12 @@ type (
client http.Client
}
forwardAuthOpts struct {
Address string
TrustForwardHeader bool
AuthResponseHeaders []string
AddAuthCookiesToResponse []string
transport http.RoundTripper
Address string `json:"address"`
TrustForwardHeader bool `json:"trustForwardHeader"`
AuthResponseHeaders []string `json:"authResponseHeaders"`
AddAuthCookiesToResponse []string `json:"addAuthCookiesToResponse"`
transport http.RoundTripper
}
)
@@ -36,16 +37,14 @@ var ForwardAuth = &forwardAuth{
m: &Middleware{withOptions: NewForwardAuthfunc},
}
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) {
fa := new(forwardAuth)
fa.forwardAuthOpts = new(forwardAuthOpts)
err := Deserialize(optsRaw, fa.forwardAuthOpts)
if err != nil {
if err := Deserialize(optsRaw, fa.forwardAuthOpts); err != nil {
return nil, err
}
_, err = E.Check(url.Parse(fa.Address))
if err != nil {
return nil, E.Invalid("address", fa.Address)
if _, err := url.Parse(fa.Address); err != nil {
return nil, E.From(err)
}
fa.m = &Middleware{

View File

@@ -0,0 +1,5 @@
package middleware
import "github.com/yusing/go-proxy/internal/logging"
var logger = logging.With().Str("module", "middleware").Logger()

View File

@@ -5,13 +5,14 @@ import (
"errors"
"net/http"
"github.com/rs/zerolog"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
U "github.com/yusing/go-proxy/internal/utils"
)
type (
Error = E.NestedError
Error = E.Error
ReverseProxy = gphttp.ReverseProxy
ProxyRequest = gphttp.ProxyRequest
@@ -24,12 +25,16 @@ type (
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
RewriteFunc func(req *Request)
ModifyResponseFunc func(resp *Response) error
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.NestedError)
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error)
OptionsRaw = map[string]any
Options any
Middleware struct {
_ U.NoCopy
zerolog.Logger
name string
before BeforeFunc // runs before ReverseProxy.ServeHTTP
@@ -75,48 +80,57 @@ func (m *Middleware) MarshalJSON() ([]byte, error) {
}, "", " ")
}
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
if len(optsRaw) != 0 && m.withOptions != nil {
if mWithOpt, err := m.withOptions(optsRaw); err != nil {
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) {
if m.withOptions != nil {
m, err := m.withOptions(optsRaw)
if err != nil {
return nil, err
} else {
return mWithOpt, nil
}
m.Logger = logger.With().Str("name", m.name).Logger()
return m, nil
}
// WithOptionsClone is called only once
// set withOptions and labelParser will not be used after that
return &Middleware{
m.name,
m.before,
m.modifyResponse,
nil,
m.impl,
m.parent,
m.children,
false,
Logger: logger.With().Str("name", m.name).Logger(),
name: m.name,
before: m.before,
modifyResponse: m.modifyResponse,
impl: m.impl,
parent: m.parent,
children: m.children,
}, nil
}
// TODO: check conflict or duplicates
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (res E.NestedError) {
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w ResponseWriter, r *Request) {
if m.before != nil {
m.before(next, w, r)
}
}
func (m *Middleware) ModifyResponse(resp *Response) error {
if m.modifyResponse != nil {
return m.modifyResponse(resp)
}
return nil
}
// TODO: check conflict or duplicates.
func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
middlewares := make([]*Middleware, 0, len(middlewaresMap))
invalidM := E.NewBuilder("invalid middlewares")
invalidOpts := E.NewBuilder("invalid options")
defer func() {
invalidM.Add(invalidOpts.Build())
invalidM.To(&res)
}()
errs := E.NewBuilder("middlewares compile error")
invalidOpts := E.NewBuilder("options compile error")
for name, opts := range middlewaresMap {
m, ok := Get(name)
if !ok {
invalidM.Add(E.NotExist("middleware", name))
m, err := Get(name)
if err != nil {
errs.Add(err)
continue
}
m, err := m.WithOptionsClone(opts)
m, err = m.WithOptionsClone(opts)
if err != nil {
invalidOpts.Add(err.Subject(name))
continue
@@ -124,10 +138,18 @@ func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[strin
middlewares = append(middlewares, m)
}
if invalidM.HasError() {
if invalidOpts.HasError() {
errs.Add(invalidOpts.Error())
}
return middlewares, errs.Error()
}
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {
var middlewares []*Middleware
middlewares, err = createMiddlewares(middlewaresMap)
if err != nil {
return
}
patchReverseProxy(rpName, rp, middlewares)
return
}

View File

@@ -4,67 +4,63 @@ import (
"fmt"
"net/http"
"os"
"path"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"gopkg.in/yaml.v3"
)
func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.NestedError) {
func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]*Middleware {
fileContent, err := os.ReadFile(filePath)
if err != nil {
return nil, E.FailWith("read middleware compose file", err)
eb.Add(err)
return nil
}
return BuildMiddlewaresFromYAML(fileContent)
return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
}
func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.NestedError) {
b := E.NewBuilder("middlewares compile errors")
defer b.To(&outErr)
func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[string]*Middleware {
var rawMap map[string][]map[string]any
err := yaml.Unmarshal(data, &rawMap)
if err != nil {
b.Add(E.FailWith("yaml unmarshal", err))
return
eb.Add(err)
return nil
}
middlewares = make(map[string]*Middleware)
middlewares := make(map[string]*Middleware)
for name, defs := range rawMap {
chainErr := E.NewBuilder(name)
chainErr := E.NewBuilder("")
chain := make([]*Middleware, 0, len(defs))
for i, def := range defs {
if def["use"] == nil || def["use"] == "" {
chainErr.Add(E.Missing("use").Subjectf(".%d", i))
chainErr.Addf("item %d: missing field 'use'", i)
continue
}
baseName := def["use"].(string)
base, ok := Get(baseName)
if !ok {
base, ok = middlewares[baseName]
if !ok {
chainErr.Add(E.NotExist("middleware", baseName).Subjectf(".%d", i))
continue
}
base, err := Get(baseName)
if err != nil {
chainErr.Add(err.Subjectf("%s[%d]", name, i))
continue
}
delete(def, "use")
m, err := base.WithOptionsClone(def)
m.name = fmt.Sprintf("%s[%d]", name, i)
if err != nil {
chainErr.Add(err.Subjectf("item%d", i))
chainErr.Add(err.Subjectf("%s[%d]", name, i))
continue
}
m.name = fmt.Sprintf("%s[%d]", name, i)
chain = append(chain, m)
}
if chainErr.HasError() {
b.Add(chainErr.Build())
eb.Add(chainErr.Error().Subject(source))
} else {
middlewares[name+"@file"] = BuildMiddlewareFromChain(name, chain)
}
}
return
return middlewares
}
// TODO: check conflict or duplicates
// TODO: check conflict or duplicates.
func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
m := &Middleware{name: name, children: chain}
@@ -86,11 +82,13 @@ func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
}
if len(modResps) > 0 {
m.modifyResponse = func(res *Response) error {
b := E.NewBuilder("errors in middleware")
errs := E.NewBuilder("modify response errors")
for _, mr := range modResps {
b.Add(E.From(mr.modifyResponse(res)).Subject(mr.name))
if err := mr.modifyResponse(res); err != nil {
errs.Add(E.From(err).Subject(mr.name))
}
}
return b.Build().Error()
return errs.Error()
}
}

View File

@@ -13,10 +13,10 @@ import (
var testMiddlewareCompose []byte
func TestBuild(t *testing.T) {
middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose)
ExpectNoError(t, err.Error())
_, err = E.Check(json.MarshalIndent(middlewares, "", " "))
ExpectNoError(t, err.Error())
errs := E.NewBuilder("")
middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs)
ExpectNoError(t, errs.Error())
E.Must(json.MarshalIndent(middlewares, "", " "))
// t.Log(string(data))
// TODO: test
}

View File

@@ -6,30 +6,39 @@ import (
"path"
"strings"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
var middlewares map[string]*Middleware
var allMiddlewares map[string]*Middleware
func Get(name string) (middleware *Middleware, ok bool) {
middleware, ok = middlewares[U.ToLowerNoSnake(name)]
return
var (
ErrUnknownMiddleware = E.New("unknown middleware")
ErrDuplicatedMiddleware = E.New("duplicated middleware")
)
func Get(name string) (*Middleware, Error) {
middleware, ok := allMiddlewares[strutils.ToLowerNoSnake(name)]
if !ok {
return nil, ErrUnknownMiddleware.
Subject(name).
Withf(strutils.DoYouMean(utils.NearestField(name, allMiddlewares)))
}
return middleware, nil
}
func All() map[string]*Middleware {
return middlewares
return allMiddlewares
}
// initialize middleware names and label parsers
// initialize middleware names and label parsers.
func init() {
middlewares = map[string]*Middleware{
allMiddlewares = map[string]*Middleware{
"setxforwarded": SetXForwarded,
"hidexforwarded": HideXForwarded,
"redirecthttp": RedirectHTTP,
"forwardauth": ForwardAuth.m,
"modifyresponse": ModifyResponse.m,
"modifyrequest": ModifyRequest.m,
"errorpage": CustomErrorPage,
@@ -37,9 +46,13 @@ func init() {
"realip": RealIP.m,
"cloudflarerealip": CloudflareRealIP.m,
"cidrwhitelist": CIDRWhiteList.m,
// !experimental
"forwardauth": ForwardAuth.m,
// "oauth2": OAuth2.m,
}
names := make(map[*Middleware][]string)
for name, m := range middlewares {
for name, m := range allMiddlewares {
names[m] = append(names[m], http.CanonicalHeaderKey(name))
}
for m, names := range names {
@@ -52,27 +65,30 @@ func init() {
}
func LoadComposeFiles() {
b := E.NewBuilder("failed to load middlewares")
middlewareDefs, err := U.ListFiles(common.MiddlewareComposeBasePath, 0)
errs := E.NewBuilder("middleware compile errors")
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
if err != nil {
logrus.Errorf("failed to list middleware definitions: %s", err)
logger.Err(err).Msg("failed to list middleware definitions")
return
}
for _, defFile := range middlewareDefs {
mws, err := BuildMiddlewaresFromComposeFile(defFile)
mws := BuildMiddlewaresFromComposeFile(defFile, errs)
if len(mws) == 0 {
continue
}
for name, m := range mws {
if _, ok := middlewares[name]; ok {
b.Add(E.Duplicated("middleware", name))
if _, ok := allMiddlewares[name]; ok {
errs.Add(ErrDuplicatedMiddleware.Subject(name))
continue
}
middlewares[U.ToLowerNoSnake(name)] = m
logger.Infof("middleware %s loaded from %s", name, path.Base(defFile))
allMiddlewares[strutils.ToLowerNoSnake(name)] = m
logger.Info().
Str("name", name).
Str("src", path.Base(defFile)).
Msg("middleware loaded")
}
b.Add(err.Subject(path.Base(defFile)))
}
if b.HasError() {
logger.Error(b.Build())
if errs.HasError() {
E.LogError(errs.About(), errs.Error(), &logger)
}
}
var logger = logrus.WithField("module", "middlewares")

View File

@@ -12,9 +12,9 @@ type (
}
// order: set_headers -> add_headers -> hide_headers
modifyRequestOpts struct {
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
SetHeaders map[string]string `json:"setHeaders"`
AddHeaders map[string]string `json:"addHeaders"`
HideHeaders []string `json:"hideHeaders"`
}
)
@@ -22,7 +22,7 @@ var ModifyRequest = &modifyRequest{
m: &Middleware{withOptions: NewModifyRequest},
}
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
mr := new(modifyRequest)
var mrFunc RewriteFunc
if common.IsDebug {

Some files were not shown because too many files have changed in this diff Show More