mirror of
https://github.com/yusing/godoxy.git
synced 2026-01-11 22:30:47 +01:00
Compare commits
49 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bf520541b | ||
|
|
a531896bd6 | ||
|
|
e005b42d18 | ||
|
|
1f6573b6da | ||
|
|
73af381c4c | ||
|
|
625bf4dfdc | ||
|
|
46b4090629 | ||
|
|
91e012987e | ||
|
|
a86d316d07 | ||
|
|
76454df5e6 | ||
|
|
67b6e40f85 | ||
|
|
9889b5a8d3 | ||
|
|
00fc75b61b | ||
|
|
4ee93a1351 | ||
|
|
669d13b89a | ||
|
|
5fa86b5eb7 | ||
|
|
369cdf8c4f | ||
|
|
0397f69853 | ||
|
|
81177926ff | ||
|
|
e5bbb18414 | ||
|
|
cfa74d69ae | ||
|
|
bee26f43d4 | ||
|
|
a3ab32e9ab | ||
|
|
c847fe4747 | ||
|
|
a278711421 | ||
|
|
01ffe0d97c | ||
|
|
bd732dfa0a | ||
|
|
8b8e1773e8 | ||
|
|
b296fb2965 | ||
|
|
53557e38b6 | ||
|
|
c0c61709ca | ||
|
|
56b778f19c | ||
|
|
f4d532598c | ||
|
|
53fa28ae77 | ||
|
|
f38b3abdbc | ||
|
|
99207ae606 | ||
|
|
d3b8cb8cba | ||
|
|
51c6eb4597 | ||
|
|
d47b672aa5 | ||
|
|
64e30f59e8 | ||
|
|
cef7b3d396 | ||
|
|
7184c9cfe9 | ||
|
|
da04a0dff4 | ||
|
|
d91b66ae87 | ||
|
|
5c40f4aa84 | ||
|
|
1797896fa6 | ||
|
|
d1c9e18c97 | ||
|
|
ef83ed0596 | ||
|
|
d89155a6ee |
22
.env.example
Normal file
22
.env.example
Normal file
@@ -0,0 +1,22 @@
|
||||
# set timezone to get correct log timestamp
|
||||
TZ=ETC/UTC
|
||||
|
||||
# generate secret with `openssl rand -base64 32`
|
||||
GOPROXY_API_JWT_SECRET=
|
||||
|
||||
# the JWT token time-to-live
|
||||
GOPROXY_API_JWT_TOKEN_TTL=1h
|
||||
|
||||
# API/WebUI login credentials
|
||||
GOPROXY_API_USER=admin
|
||||
GOPROXY_API_PASSWORD=password
|
||||
|
||||
# Proxy listening address
|
||||
GOPROXY_HTTP_ADDR=:80
|
||||
GOPROXY_HTTPS_ADDR=:443
|
||||
|
||||
# API listening address
|
||||
GOPROXY_API_ADDR=127.0.0.1:8888
|
||||
|
||||
# Debug mode
|
||||
GOPROXY_DEBUG=false
|
||||
4
.github/workflows/docker-image.yml
vendored
4
.github/workflows/docker-image.yml
vendored
@@ -11,7 +11,7 @@ env:
|
||||
jobs:
|
||||
build:
|
||||
name: Build multi-platform Docker image
|
||||
runs-on: self-hosted
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -85,7 +85,7 @@ jobs:
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
merge:
|
||||
runs-on: self-hosted
|
||||
runs-on: ubuntu-22.04
|
||||
needs:
|
||||
- build
|
||||
permissions:
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,6 +1,8 @@
|
||||
compose.yml
|
||||
*.compose.yml
|
||||
|
||||
config
|
||||
certs
|
||||
config*/
|
||||
certs*/
|
||||
bin/
|
||||
@@ -22,3 +24,4 @@ todo.md
|
||||
.aider*
|
||||
mtrace.json
|
||||
.env
|
||||
test.Dockerfile
|
||||
|
||||
137
.golangci.yml
Normal file
137
.golangci.yml
Normal file
@@ -0,0 +1,137 @@
|
||||
run:
|
||||
timeout: 10m
|
||||
|
||||
linters-settings:
|
||||
govet:
|
||||
enable-all: true
|
||||
disable:
|
||||
- shadow
|
||||
- fieldalignment
|
||||
gocyclo:
|
||||
min-complexity: 14
|
||||
goconst:
|
||||
min-len: 3
|
||||
min-occurrences: 4
|
||||
misspell:
|
||||
locale: US
|
||||
funlen:
|
||||
lines: -1
|
||||
statements: 120
|
||||
forbidigo:
|
||||
forbid:
|
||||
- ^print(ln)?$
|
||||
godox:
|
||||
keywords:
|
||||
- FIXME
|
||||
tagalign:
|
||||
align: false
|
||||
sort: true
|
||||
order:
|
||||
- description
|
||||
- json
|
||||
- toml
|
||||
- yaml
|
||||
- yml
|
||||
- label
|
||||
- label-slice-as-struct
|
||||
- file
|
||||
- kv
|
||||
- export
|
||||
stylecheck:
|
||||
dot-import-whitelist:
|
||||
- github.com/yusing/go-proxy/internal/utils/testing # go tests only
|
||||
- github.com/yusing/go-proxy/internal/api/v1/utils # api only
|
||||
revive:
|
||||
rules:
|
||||
- name: struct-tag
|
||||
- name: blank-imports
|
||||
- name: context-as-argument
|
||||
- name: context-keys-type
|
||||
- name: error-return
|
||||
- name: error-strings
|
||||
- name: error-naming
|
||||
- name: exported
|
||||
disabled: true
|
||||
- name: if-return
|
||||
- name: increment-decrement
|
||||
- name: var-naming
|
||||
- name: var-declaration
|
||||
- name: package-comments
|
||||
disabled: true
|
||||
- name: range
|
||||
- name: receiver-naming
|
||||
- name: time-naming
|
||||
- name: unexported-return
|
||||
- name: indent-error-flow
|
||||
- name: errorf
|
||||
- name: empty-block
|
||||
- name: superfluous-else
|
||||
- name: unused-parameter
|
||||
disabled: true
|
||||
- name: unreachable-code
|
||||
- name: redefines-builtin-id
|
||||
gomoddirectives:
|
||||
replace-allow-list:
|
||||
- github.com/abbot/go-http-auth
|
||||
- github.com/gorilla/mux
|
||||
- github.com/mailgun/minheap
|
||||
- github.com/mailgun/multibuf
|
||||
- github.com/jaguilar/vt100
|
||||
- github.com/cucumber/godog
|
||||
- github.com/http-wasm/http-wasm-host-go
|
||||
testifylint:
|
||||
disable:
|
||||
- suite-dont-use-pkg
|
||||
- require-error
|
||||
- go-require
|
||||
staticcheck:
|
||||
checks:
|
||||
- all
|
||||
- -SA1019
|
||||
errcheck:
|
||||
exclude-functions:
|
||||
- fmt.Fprintln
|
||||
linters:
|
||||
enable-all: true
|
||||
disable:
|
||||
- execinquery # deprecated
|
||||
- gomnd # deprecated
|
||||
- sqlclosecheck # not relevant (SQL)
|
||||
- rowserrcheck # not relevant (SQL)
|
||||
- cyclop # duplicate of gocyclo
|
||||
- depguard # Not relevant
|
||||
- nakedret # Too strict
|
||||
- lll # Not relevant
|
||||
- gocyclo # FIXME must be fixed
|
||||
- gocognit # Too strict
|
||||
- nestif # Too many false-positive.
|
||||
- prealloc # Too many false-positive.
|
||||
- makezero # Not relevant
|
||||
- dupl # Too strict
|
||||
- gci # I don't care
|
||||
- gosec # Too strict
|
||||
- gochecknoinits
|
||||
- gochecknoglobals
|
||||
- wsl # Too strict
|
||||
- nlreturn # Not relevant
|
||||
- mnd # Too strict
|
||||
- testpackage # Too strict
|
||||
- tparallel # Not relevant
|
||||
- paralleltest # Not relevant
|
||||
- exhaustive # Not relevant
|
||||
- exhaustruct # Not relevant
|
||||
- err113 # Too strict
|
||||
- wrapcheck # Too strict
|
||||
- noctx # Too strict
|
||||
- bodyclose # too many false-positive
|
||||
- forcetypeassert # Too strict
|
||||
- tagliatelle # Too strict
|
||||
- varnamelen # Not relevant
|
||||
- nilnil # Not relevant
|
||||
- ireturn # Not relevant
|
||||
- contextcheck # too many false-positive
|
||||
- containedctx # too many false-positive
|
||||
- maintidx # kind of duplicate of gocyclo
|
||||
- nonamedreturns # Too strict
|
||||
- gosmopolitan # not relevant
|
||||
- exportloopref # Not relevant since go1.22
|
||||
9
.trunk/.gitignore
vendored
Normal file
9
.trunk/.gitignore
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
*out
|
||||
*logs
|
||||
*actions
|
||||
*notifications
|
||||
*tools
|
||||
plugins
|
||||
user_trunk.yaml
|
||||
user.yaml
|
||||
tmp
|
||||
41
.trunk/trunk.yaml
Normal file
41
.trunk/trunk.yaml
Normal file
@@ -0,0 +1,41 @@
|
||||
# This file controls the behavior of Trunk: https://docs.trunk.io/cli
|
||||
# To learn more about the format of this file, see https://docs.trunk.io/reference/trunk-yaml
|
||||
version: 0.1
|
||||
cli:
|
||||
version: 1.22.6
|
||||
# Trunk provides extensibility via plugins. (https://docs.trunk.io/plugins)
|
||||
plugins:
|
||||
sources:
|
||||
- id: trunk
|
||||
ref: v1.6.3
|
||||
uri: https://github.com/trunk-io/plugins
|
||||
# Many linters and tools depend on runtimes - configure them here. (https://docs.trunk.io/runtimes)
|
||||
runtimes:
|
||||
enabled:
|
||||
- node@18.12.1
|
||||
- python@3.10.8
|
||||
- go@1.23.2
|
||||
# This is the section where you manage your linters. (https://docs.trunk.io/check/configuration)
|
||||
lint:
|
||||
enabled:
|
||||
- hadolint@2.12.0
|
||||
- actionlint@1.7.3
|
||||
- checkov@3.2.257
|
||||
- git-diff-check
|
||||
- gofmt@1.20.4
|
||||
- golangci-lint@1.61.0
|
||||
- markdownlint@0.42.0
|
||||
- osv-scanner@1.9.0
|
||||
- oxipng@9.1.2
|
||||
- prettier@3.3.3
|
||||
- shellcheck@0.10.0
|
||||
- shfmt@3.6.0
|
||||
- trufflehog@3.82.7
|
||||
- yamllint@1.35.1
|
||||
actions:
|
||||
disabled:
|
||||
- trunk-announce
|
||||
- trunk-check-pre-push
|
||||
- trunk-fmt-pre-commit
|
||||
enabled:
|
||||
- trunk-upgrade-available
|
||||
21
Makefile
21
Makefile
@@ -28,14 +28,20 @@ get:
|
||||
go get -u ./cmd && go mod tidy
|
||||
|
||||
debug:
|
||||
make build && sudo GOPROXY_DEBUG=1 bin/go-proxy
|
||||
make build
|
||||
sudo GOPROXY_DEBUG=1 bin/go-proxy
|
||||
|
||||
debug-trace:
|
||||
make build
|
||||
sudo GOPROXY_DEBUG=1 GOPROXY_TRACE=1 bin/go-proxy
|
||||
|
||||
profile:
|
||||
GODEBUG=gctrace=1 make build
|
||||
sudo GOPROXY_DEBUG=1 bin/go-proxy
|
||||
|
||||
mtrace:
|
||||
bin/go-proxy debug-ls-mtrace > mtrace.json
|
||||
|
||||
run-test:
|
||||
make build && sudo GOPROXY_TEST=1 bin/go-proxy
|
||||
|
||||
run:
|
||||
make build && sudo bin/go-proxy
|
||||
|
||||
@@ -49,7 +55,7 @@ repush:
|
||||
git push gitlab dev --force
|
||||
|
||||
rapid-crash:
|
||||
sudo docker run --restart=always --name test_crash debian:bookworm-slim /bin/cat &&\
|
||||
sudo docker run --restart=always --name test_crash -p 80 debian:bookworm-slim /bin/cat &&\
|
||||
sleep 3 &&\
|
||||
sudo docker rm -f test_crash
|
||||
|
||||
@@ -58,4 +64,7 @@ debug-list-containers:
|
||||
|
||||
ci-test:
|
||||
mkdir -p /tmp/artifacts
|
||||
act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)"
|
||||
act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)"
|
||||
|
||||
cloc:
|
||||
cloc --not-match-f '_test.go$$' cmd internal pkg
|
||||
43
README.md
43
README.md
@@ -24,6 +24,8 @@ _Join our [Discord](https://discord.gg/umReR62nRd) for help and discussions_
|
||||
- [Key Features](#key-features)
|
||||
- [Getting Started](#getting-started)
|
||||
- [Setup](#setup)
|
||||
- [Manual Setup](#manual-setup)
|
||||
- [Folder structrue](#folder-structrue)
|
||||
- [Use JSON Schema in VSCode](#use-json-schema-in-vscode)
|
||||
- [Screenshots](#screenshots)
|
||||
- [idlesleeper](#idlesleeper)
|
||||
@@ -59,10 +61,12 @@ _Join our [Discord](https://discord.gg/umReR62nRd) for help and discussions_
|
||||
docker pull ghcr.io/yusing/go-proxy:latest
|
||||
```
|
||||
|
||||
2. Create new directory, `cd` into it, then run setup
|
||||
2. Create new directory, `cd` into it, then run setup, or [set up manually](#manual-setup)
|
||||
|
||||
```shell
|
||||
docker run --rm -v .:/setup ghcr.io/yusing/go-proxy /app/go-proxy setup
|
||||
# Then set the JWT secret
|
||||
sed -i "s|GOPROXY_API_JWT_SECRET=.*|GOPROXY_API_JWT_SECRET=$(openssl rand -base64 32)|g" .env
|
||||
```
|
||||
|
||||
3. Setup DNS Records point to machine which runs `go-proxy`, e.g.
|
||||
@@ -83,6 +87,43 @@ _Join our [Discord](https://discord.gg/umReR62nRd) for help and discussions_
|
||||
|
||||
[🔼Back to top](#table-of-content)
|
||||
|
||||
### Manual Setup
|
||||
|
||||
1. Make `config` directory then grab `config.example.yml` into `config/config.yml`
|
||||
|
||||
`mkdir -p config && wget https://raw.githubusercontent.com/yusing/go-proxy/v0.7/config.example.yml -O config/config.yml`
|
||||
|
||||
2. Grab `.env.example` into `.env`
|
||||
|
||||
`wget https://raw.githubusercontent.com/yusing/go-proxy/v0.7/.env.example -O .env`
|
||||
|
||||
3. Grab `compose.example.yml` into `compose.yml`
|
||||
|
||||
`wget https://raw.githubusercontent.com/yusing/go-proxy/v0.7/compose.example.yml -O compose.yml`
|
||||
|
||||
4. Set the JWT secret
|
||||
|
||||
`sed -i "s|GOPROXY_API_JWT_SECRET=.*|GOPROXY_API_JWT_SECRET=$(openssl rand -base64 32)|g" .env`
|
||||
|
||||
5. Start the container `docker compose up -d`
|
||||
|
||||
### Folder structrue
|
||||
|
||||
```shell
|
||||
├── certs
|
||||
│ ├── cert.crt
|
||||
│ └── priv.key
|
||||
├── compose.yml
|
||||
├── config
|
||||
│ ├── config.yml
|
||||
│ ├── middlewares
|
||||
│ │ ├── middleware1.yml
|
||||
│ │ ├── middleware2.yml
|
||||
│ ├── provider1.yml
|
||||
│ └── provider2.yml
|
||||
└── .env
|
||||
```
|
||||
|
||||
### Use JSON Schema in VSCode
|
||||
|
||||
Copy [`.vscode/settings.example.json`](.vscode/settings.example.json) to `.vscode/settings.json` and modify it to fit your needs
|
||||
|
||||
183
cmd/main.go
183
cmd/main.go
@@ -1,76 +1,79 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal"
|
||||
"github.com/yusing/go-proxy/internal/api"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/query"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/config"
|
||||
"github.com/yusing/go-proxy/internal/docker"
|
||||
"github.com/yusing/go-proxy/internal/docker/idlewatcher"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
R "github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/server"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/pkg"
|
||||
)
|
||||
|
||||
func main() {
|
||||
args := common.GetArgs()
|
||||
|
||||
if args.Command == common.CommandSetup {
|
||||
switch args.Command {
|
||||
case common.CommandSetup:
|
||||
internal.Setup()
|
||||
return
|
||||
}
|
||||
|
||||
l := logrus.WithField("module", "main")
|
||||
onShutdown := F.NewSlice[func()]()
|
||||
|
||||
if common.IsDebug {
|
||||
logrus.SetLevel(logrus.DebugLevel)
|
||||
}
|
||||
|
||||
if args.Command != common.CommandStart {
|
||||
logrus.SetOutput(io.Discard)
|
||||
} else {
|
||||
logrus.SetFormatter(&logrus.TextFormatter{
|
||||
DisableSorting: true,
|
||||
FullTimestamp: true,
|
||||
ForceColors: true,
|
||||
TimestampFormat: "01-02 15:04:05",
|
||||
})
|
||||
logrus.Infof("go-proxy version %s", pkg.GetVersion())
|
||||
}
|
||||
|
||||
if args.Command == common.CommandReload {
|
||||
case common.CommandReload:
|
||||
if err := query.ReloadServer(); err != nil {
|
||||
E.LogFatal("server reload error", err)
|
||||
}
|
||||
logging.Info().Msg("ok")
|
||||
return
|
||||
case common.CommandListIcons:
|
||||
icons, err := internal.ListAvailableIcons()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
log.Print("ok")
|
||||
printJSON(icons)
|
||||
return
|
||||
case common.CommandListRoutes:
|
||||
routes, err := query.ListRoutes()
|
||||
if err != nil {
|
||||
log.Printf("failed to connect to api server: %s", err)
|
||||
log.Printf("falling back to config file")
|
||||
printJSON(config.RoutesByAlias())
|
||||
} else {
|
||||
printJSON(routes)
|
||||
}
|
||||
return
|
||||
case common.CommandDebugListMTrace:
|
||||
trace, err := query.ListMiddlewareTraces()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
printJSON(trace)
|
||||
return
|
||||
}
|
||||
|
||||
// exit if only validate config
|
||||
if args.Command == common.CommandStart {
|
||||
logging.Info().Msgf("go-proxy version %s", pkg.GetVersion())
|
||||
logging.Trace().Msg("trace enabled")
|
||||
// logging.AddHook(notif.GetDispatcher())
|
||||
} else {
|
||||
logging.DiscardLogger()
|
||||
}
|
||||
|
||||
if args.Command == common.CommandValidate {
|
||||
data, err := os.ReadFile(common.ConfigPath)
|
||||
if err == nil {
|
||||
err = config.Validate(data).Error()
|
||||
err = config.Validate(data)
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatal("config error: ", err)
|
||||
@@ -85,68 +88,39 @@ func main() {
|
||||
|
||||
middleware.LoadComposeFiles()
|
||||
|
||||
if err := config.Load(); err != nil {
|
||||
logrus.Warn(err)
|
||||
var cfg *config.Config
|
||||
var err E.Error
|
||||
if cfg, err = config.Load(); err != nil {
|
||||
E.LogWarn("errors in config", err)
|
||||
}
|
||||
cfg := config.GetInstance()
|
||||
|
||||
switch args.Command {
|
||||
case common.CommandListConfigs:
|
||||
printJSON(cfg.Value())
|
||||
return
|
||||
case common.CommandListRoutes:
|
||||
routes, err := query.ListRoutes()
|
||||
if err != nil {
|
||||
log.Printf("failed to connect to api server: %s", err)
|
||||
log.Printf("falling back to config file")
|
||||
printJSON(cfg.RoutesByAlias())
|
||||
} else {
|
||||
printJSON(routes)
|
||||
}
|
||||
return
|
||||
case common.CommandListIcons:
|
||||
icons, err := internal.ListAvailableIcons()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
printJSON(icons)
|
||||
printJSON(config.Value())
|
||||
return
|
||||
case common.CommandDebugListEntries:
|
||||
printJSON(cfg.DumpEntries())
|
||||
printJSON(config.DumpEntries())
|
||||
return
|
||||
case common.CommandDebugListProviders:
|
||||
printJSON(cfg.DumpProviders())
|
||||
printJSON(config.DumpProviders())
|
||||
return
|
||||
case common.CommandDebugListMTrace:
|
||||
trace, err := query.ListMiddlewareTraces()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
printJSON(trace)
|
||||
}
|
||||
|
||||
cfg.StartProxyProviders()
|
||||
cfg.WatchChanges()
|
||||
|
||||
onShutdown.Add(docker.CloseAllClients)
|
||||
onShutdown.Add(cfg.Dispose)
|
||||
config.WatchChanges()
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGINT)
|
||||
signal.Notify(sig, syscall.SIGTERM)
|
||||
signal.Notify(sig, syscall.SIGHUP)
|
||||
|
||||
autocert := cfg.GetAutoCertProvider()
|
||||
|
||||
autocert := config.GetAutoCertProvider()
|
||||
if autocert != nil {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
if err := autocert.Setup(ctx); err != nil {
|
||||
l.Fatal(err)
|
||||
} else {
|
||||
onShutdown.Add(cancel)
|
||||
if err := autocert.Setup(); err != nil {
|
||||
E.LogFatal("autocert setup error", err)
|
||||
}
|
||||
} else {
|
||||
l.Info("autocert not configured")
|
||||
logging.Info().Msg("autocert not configured")
|
||||
}
|
||||
|
||||
proxyServer := server.InitProxyServer(server.Options{
|
||||
@@ -155,75 +129,40 @@ func main() {
|
||||
HTTPAddr: common.ProxyHTTPAddr,
|
||||
HTTPSAddr: common.ProxyHTTPSAddr,
|
||||
Handler: http.HandlerFunc(R.ProxyHandler),
|
||||
RedirectToHTTPS: cfg.Value().RedirectToHTTPS,
|
||||
RedirectToHTTPS: config.Value().RedirectToHTTPS,
|
||||
})
|
||||
apiServer := server.InitAPIServer(server.Options{
|
||||
Name: "api",
|
||||
CertProvider: autocert,
|
||||
HTTPAddr: common.APIHTTPAddr,
|
||||
Handler: api.NewHandler(cfg),
|
||||
RedirectToHTTPS: cfg.Value().RedirectToHTTPS,
|
||||
Handler: api.NewHandler(),
|
||||
RedirectToHTTPS: config.Value().RedirectToHTTPS,
|
||||
})
|
||||
|
||||
proxyServer.Start()
|
||||
apiServer.Start()
|
||||
onShutdown.Add(proxyServer.Stop)
|
||||
onShutdown.Add(apiServer.Stop)
|
||||
|
||||
go idlewatcher.Start()
|
||||
onShutdown.Add(idlewatcher.Stop)
|
||||
|
||||
// wait for signal
|
||||
<-sig
|
||||
|
||||
// grafully shutdown
|
||||
logrus.Info("shutting down")
|
||||
done := make(chan struct{}, 1)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(onShutdown.Size())
|
||||
onShutdown.ForEach(func(f func()) {
|
||||
go func() {
|
||||
l.Debugf("waiting for %s to complete...", funcName(f))
|
||||
f()
|
||||
l.Debugf("%s done", funcName(f))
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
timeout := time.After(time.Duration(cfg.Value().TimeoutShutdown) * time.Second)
|
||||
select {
|
||||
case <-done:
|
||||
logrus.Info("shutdown complete")
|
||||
case <-timeout:
|
||||
logrus.Info("timeout waiting for shutdown")
|
||||
onShutdown.ForEach(func(f func()) {
|
||||
l.Warnf("%s() is still running", funcName(f))
|
||||
})
|
||||
}
|
||||
logging.Info().Msg("shutting down")
|
||||
task.CancelGlobalContext()
|
||||
task.GlobalContextWait(time.Second * time.Duration(config.Value().TimeoutShutdown))
|
||||
}
|
||||
|
||||
func prepareDirectory(dir string) {
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
if err = os.MkdirAll(dir, 0755); err != nil {
|
||||
logrus.Fatalf("failed to create directory %s: %v", dir, err)
|
||||
if err = os.MkdirAll(dir, 0o755); err != nil {
|
||||
logging.Fatal().Msgf("failed to create directory %s: %v", dir, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func funcName(f func()) string {
|
||||
parts := strings.Split(runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name(), "/go-proxy/")
|
||||
return parts[len(parts)-1]
|
||||
}
|
||||
|
||||
func printJSON(obj any) {
|
||||
j, err := E.Check(json.MarshalIndent(obj, "", " "))
|
||||
j, err := json.MarshalIndent(obj, "", " ")
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
logging.Fatal().Err(err).Send()
|
||||
}
|
||||
rawLogger := log.New(os.Stdout, "", 0)
|
||||
rawLogger.Printf("%s", j) // raw output for convenience using "jq"
|
||||
|
||||
@@ -4,31 +4,26 @@ services:
|
||||
container_name: go-proxy-frontend
|
||||
restart: unless-stopped
|
||||
network_mode: host
|
||||
env_file: .env
|
||||
depends_on:
|
||||
- app
|
||||
# if you also want to proxy the WebUI and access it via gp.y.z
|
||||
# labels:
|
||||
# - proxy.aliases=gp
|
||||
# - proxy.gp.port=3000
|
||||
|
||||
# Make sure the value is same as `GOPROXY_API_ADDR` below (if you have changed it)
|
||||
#
|
||||
# environment:
|
||||
# GOPROXY_API_ADDR: 127.0.0.1:8888
|
||||
# modify below to fit your needs
|
||||
labels:
|
||||
proxy.aliases: gp
|
||||
proxy.#1.port: 3000
|
||||
proxy.#1.middlewares.cidr_whitelist.status_code: 403
|
||||
proxy.#1.middlewares.cidr_whitelist.message: IP not allowed
|
||||
proxy.#1.middlewares.cidr_whitelist.allow: |
|
||||
- 127.0.0.1
|
||||
- 10.0.0.0/8
|
||||
- 192.168.0.0/16
|
||||
- 172.16.0.0/12
|
||||
app:
|
||||
image: ghcr.io/yusing/go-proxy:latest
|
||||
container_name: go-proxy
|
||||
restart: always
|
||||
network_mode: host
|
||||
environment:
|
||||
# (Optional) change this to your timezone to get correct log timestamp
|
||||
TZ: ETC/UTC
|
||||
|
||||
# Change these if you need
|
||||
#
|
||||
# GOPROXY_HTTP_ADDR: :80
|
||||
# GOPROXY_HTTPS_ADDR: :443
|
||||
# GOPROXY_API_ADDR: 127.0.0.1:8888
|
||||
env_file: .env
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- ./config:/app/config
|
||||
|
||||
19
go.mod
19
go.mod
@@ -8,9 +8,11 @@ require (
|
||||
github.com/docker/docker v27.3.1+incompatible
|
||||
github.com/fsnotify/fsnotify v1.7.0
|
||||
github.com/go-acme/lego/v4 v4.19.2
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||
github.com/gotify/server/v2 v2.5.0
|
||||
github.com/puzpuzpuz/xsync/v3 v3.4.0
|
||||
github.com/rs/zerolog v1.33.0
|
||||
github.com/santhosh-tekuri/jsonschema v1.2.4
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
golang.org/x/net v0.30.0
|
||||
golang.org/x/text v0.19.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
@@ -19,7 +21,7 @@ require (
|
||||
require (
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/cloudflare/cloudflare-go v0.106.0 // indirect
|
||||
github.com/cloudflare/cloudflare-go v0.108.0 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/go-connections v0.5.0 // indirect
|
||||
@@ -32,6 +34,8 @@ require (
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/miekg/dns v1.1.62 // indirect
|
||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||
github.com/moby/term v0.5.0 // indirect
|
||||
@@ -40,13 +44,14 @@ require (
|
||||
github.com/opencontainers/image-spec v1.1.0 // indirect
|
||||
github.com/ovh/go-ovh v1.6.0 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/rogpeppe/go-internal v1.12.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0 // indirect
|
||||
go.opentelemetry.io/otel v1.30.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.13.1 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect
|
||||
go.opentelemetry.io/otel v1.31.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.30.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.31.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.30.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.30.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.31.0 // indirect
|
||||
golang.org/x/crypto v0.28.0 // indirect
|
||||
golang.org/x/mod v0.21.0 // indirect
|
||||
golang.org/x/oauth2 v0.23.0 // indirect
|
||||
|
||||
42
go.sum
42
go.sum
@@ -4,12 +4,13 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/cloudflare/cloudflare-go v0.106.0 h1:q41gC5Wc1nfi0D1ZhSHokWcd9mGMbqC7RE7qiP+qE00=
|
||||
github.com/cloudflare/cloudflare-go v0.106.0/go.mod h1:pfUQ4PIG4ISI0/Mmc21Bp86UnFU0ktmPf3iTgbSL+cM=
|
||||
github.com/cloudflare/cloudflare-go v0.108.0 h1:C4Skfjd8I8X3uEOGmQUT4/iGyZcWdkIU7HwvMoLkEE0=
|
||||
github.com/cloudflare/cloudflare-go v0.108.0/go.mod h1:m492eNahT/9MsN7Ppnoge8AaI7QhVFtEgVm3I9HJFeU=
|
||||
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
|
||||
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -40,8 +41,11 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
|
||||
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
@@ -49,6 +53,8 @@ github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD
|
||||
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gotify/server/v2 v2.5.0 h1:tJd+a5bb17X52f0EV2KxqLuyjQFKmVK1+t/iNUkP16Y=
|
||||
github.com/gotify/server/v2 v2.5.0/go.mod h1:DKPMQI/FZ69iKbZvrOL6VWwRaoB9O+HDvJWVd/kiGbc=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 h1:asbCHRVmodnJTuQ3qamDwqVOIjwqUPTYmYuemVOx+Ys=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0/go.mod h1:ggCgvZ2r7uOoQjOyu2Y1NhHmEPPzzuhWgcza5M1Ji1I=
|
||||
github.com/jarcoal/httpmock v1.3.0 h1:2RJ8GP0IIaWwcC9Fp2BmVi8Kog3v2Hn7VXM3fTd+nuc=
|
||||
@@ -59,6 +65,12 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g=
|
||||
github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM=
|
||||
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
|
||||
@@ -84,8 +96,11 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH
|
||||
github.com/puzpuzpuz/xsync/v3 v3.4.0 h1:DuVBAdXuGFHv8adVXjWWZ63pJq+NRXOWVXlKDBZ+mJ4=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
|
||||
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
|
||||
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
||||
github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis=
|
||||
github.com/santhosh-tekuri/jsonschema v1.2.4/go.mod h1:TEAUOeZSmIxTTuHatJzrvARHiuO9LYd+cIxzgEHCQI4=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
@@ -96,20 +111,20 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0 h1:ZIg3ZT/aQ7AfKqdwp7ECpOK6vHqquXXuyTjIO8ZdmPs=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0/go.mod h1:DQAwmETtZV00skUwgD6+0U89g80NKsJE3DCKeLLPQMI=
|
||||
go.opentelemetry.io/otel v1.30.0 h1:F2t8sK4qf1fAmY9ua4ohFS/K+FUuOPemHUIXHtktrts=
|
||||
go.opentelemetry.io/otel v1.30.0/go.mod h1:tFw4Br9b7fOS+uEao81PJjVMjW/5fvNCbpsDIXqP0pc=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 h1:UP6IpuHFkUgOQL9FFQFrZ+5LiwhhYRbi7VZSIx6Nj5s=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0/go.mod h1:qxuZLtbq5QDtdeSHsS7bcf6EH6uO6jUAgk764zd3rhM=
|
||||
go.opentelemetry.io/otel v1.31.0 h1:NsJcKPIW0D0H3NgzPDHmo0WW6SptzPdqg/L1zsIm2hY=
|
||||
go.opentelemetry.io/otel v1.31.0/go.mod h1:O0C14Yl9FgkjqcCZAsE053C13OaddMYr/hz6clDkEJE=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.30.0 h1:lsInsfvhVIfOI6qHVyysXMNDnjO9Npvl7tlDPJFBVd4=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.30.0/go.mod h1:KQsVNh4OjgjTG0G6EiNi1jVpnaeeKsKMRwbLN+f1+8M=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 h1:umZgi92IyxfXd/l4kaDhnKgY8rnN/cZcF1LKc6I8OQ8=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0/go.mod h1:4lVs6obhSVRb1EW5FhOuBTyiQhtRtAnnva9vD3yRfq8=
|
||||
go.opentelemetry.io/otel/metric v1.30.0 h1:4xNulvn9gjzo4hjg+wzIKG7iNFEaBMX00Qd4QIZs7+w=
|
||||
go.opentelemetry.io/otel/metric v1.30.0/go.mod h1:aXTfST94tswhWEb+5QjlSqG+cZlmyXy/u8jFpor3WqQ=
|
||||
go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozRPcF2fE=
|
||||
go.opentelemetry.io/otel/metric v1.31.0/go.mod h1:C3dEloVbLuYoX41KpmAhOqNriGbA+qqH6PQ5E5mUfnY=
|
||||
go.opentelemetry.io/otel/sdk v1.30.0 h1:cHdik6irO49R5IysVhdn8oaiR9m8XluDaJAs4DfOrYE=
|
||||
go.opentelemetry.io/otel/sdk v1.30.0/go.mod h1:p14X4Ok8S+sygzblytT1nqG98QG2KYKv++HE0LY/mhg=
|
||||
go.opentelemetry.io/otel/trace v1.30.0 h1:7UBkkYzeg3C7kQX8VAidWh2biiQbtAKjyIML8dQ9wmc=
|
||||
go.opentelemetry.io/otel/trace v1.30.0/go.mod h1:5EyKqTzzmyqB9bwtCCq6pDLktPK6fmGf/Dph+8VI02o=
|
||||
go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys=
|
||||
go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A=
|
||||
go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
|
||||
go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
@@ -138,6 +153,9 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
||||
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
|
||||
@@ -2,13 +2,13 @@ package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
v1 "github.com/yusing/go-proxy/internal/api/v1"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/error_page"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/auth"
|
||||
. "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/config"
|
||||
)
|
||||
|
||||
type ServeMux struct{ *http.ServeMux }
|
||||
@@ -21,43 +21,38 @@ func (mux ServeMux) HandleFunc(method, endpoint string, handler http.HandlerFunc
|
||||
mux.ServeMux.HandleFunc(fmt.Sprintf("%s %s", method, endpoint), checkHost(handler))
|
||||
}
|
||||
|
||||
func NewHandler(cfg *config.Config) http.Handler {
|
||||
func NewHandler() http.Handler {
|
||||
mux := NewServeMux()
|
||||
mux.HandleFunc("GET", "/v1", v1.Index)
|
||||
mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
|
||||
mux.HandleFunc("GET", "/v1/checkhealth", wrap(cfg, v1.CheckHealth))
|
||||
mux.HandleFunc("HEAD", "/v1/checkhealth", wrap(cfg, v1.CheckHealth))
|
||||
mux.HandleFunc("POST", "/v1/reload", wrap(cfg, v1.Reload))
|
||||
mux.HandleFunc("GET", "/v1/list", wrap(cfg, v1.List))
|
||||
mux.HandleFunc("GET", "/v1/list/{what}", wrap(cfg, v1.List))
|
||||
mux.HandleFunc("GET", "/v1/file", v1.GetFileContent)
|
||||
mux.HandleFunc("GET", "/v1/file/{filename...}", v1.GetFileContent)
|
||||
mux.HandleFunc("POST", "/v1/file/{filename...}", v1.SetFileContent)
|
||||
mux.HandleFunc("PUT", "/v1/file/{filename...}", v1.SetFileContent)
|
||||
mux.HandleFunc("GET", "/v1/stats", wrap(cfg, v1.Stats))
|
||||
mux.HandleFunc("GET", "/v1/stats/ws", wrap(cfg, v1.StatsWS))
|
||||
mux.HandleFunc("GET", "/v1/error_page", error_page.GetHandleFunc())
|
||||
mux.HandleFunc("POST", "/v1/login", auth.LoginHandler)
|
||||
mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler)
|
||||
mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler)
|
||||
mux.HandleFunc("POST", "/v1/reload", v1.Reload)
|
||||
mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(v1.List))
|
||||
mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(v1.List))
|
||||
mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(v1.List))
|
||||
mux.HandleFunc("GET", "/v1/file", auth.RequireAuth(v1.GetFileContent))
|
||||
mux.HandleFunc("GET", "/v1/file/{filename...}", auth.RequireAuth(v1.GetFileContent))
|
||||
mux.HandleFunc("POST", "/v1/file/{filename...}", auth.RequireAuth(v1.SetFileContent))
|
||||
mux.HandleFunc("PUT", "/v1/file/{filename...}", auth.RequireAuth(v1.SetFileContent))
|
||||
mux.HandleFunc("GET", "/v1/stats", v1.Stats)
|
||||
mux.HandleFunc("GET", "/v1/stats/ws", v1.StatsWS)
|
||||
return mux
|
||||
}
|
||||
|
||||
// allow only requests to API server with host matching common.APIHTTPAddr
|
||||
// allow only requests to API server with localhost.
|
||||
func checkHost(f http.HandlerFunc) http.HandlerFunc {
|
||||
if common.IsDebug {
|
||||
return f
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Host != common.APIHTTPAddr {
|
||||
Logger.Warnf("invalid request to API server with host: %s, expect %s", r.Host, common.APIHTTPAddr)
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte("invalid request"))
|
||||
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
if host != "127.0.0.1" && host != "localhost" && host != "[::1]" {
|
||||
LogWarn(r).Msgf("blocked API request from %s", host)
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
f(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func wrap(cfg *config.Config, f func(cfg *config.Config, w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
f(cfg, w, r)
|
||||
}
|
||||
}
|
||||
|
||||
135
internal/api/v1/auth/auth.go
Normal file
135
internal/api/v1/auth/auth.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type (
|
||||
Credentials struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
Claims struct {
|
||||
Username string `json:"username"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidUsername = E.New("invalid username")
|
||||
ErrInvalidPassword = E.New("invalid password")
|
||||
)
|
||||
|
||||
func validatePassword(cred *Credentials) error {
|
||||
if cred.Username != common.APIUser {
|
||||
return ErrInvalidUsername.Subject(cred.Username)
|
||||
}
|
||||
if !bytes.Equal(common.HashPassword(cred.Password), common.APIPasswordHash) {
|
||||
return ErrInvalidPassword.Subject(cred.Password)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
var creds Credentials
|
||||
err := json.NewDecoder(r.Body).Decode(&creds)
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err := validatePassword(&creds); err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(common.APIJWTTokenTTL)
|
||||
claim := &Claims{
|
||||
Username: creds.Username,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS512, claim)
|
||||
tokenStr, err := token.SignedString(common.APIJWTSecret)
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, err)
|
||||
return
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "token",
|
||||
Value: tokenStr,
|
||||
Expires: expiresAt,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
Path: "/",
|
||||
})
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "token",
|
||||
Value: "",
|
||||
Expires: time.Unix(0, 0),
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
Path: "/",
|
||||
})
|
||||
w.Header().Set("location", "/login")
|
||||
w.WriteHeader(http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
if common.IsDebugSkipAuth {
|
||||
return next
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if checkToken(w, r) {
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkToken(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
tokenCookie, err := r.Cookie("token")
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, E.PrependSubject("token", err), http.StatusUnauthorized)
|
||||
return false
|
||||
}
|
||||
var claims Claims
|
||||
token, err := jwt.ParseWithClaims(tokenCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return common.APIJWTSecret, nil
|
||||
})
|
||||
|
||||
switch {
|
||||
case err != nil:
|
||||
break
|
||||
case !token.Valid:
|
||||
err = E.New("invalid token")
|
||||
case claims.Username != common.APIUser:
|
||||
err = E.New("username mismatch").Subject(claims.Username)
|
||||
case claims.ExpiresAt.Before(time.Now()):
|
||||
err = E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusForbidden)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/config"
|
||||
R "github.com/yusing/go-proxy/internal/route"
|
||||
)
|
||||
|
||||
func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
target := r.FormValue("target")
|
||||
if target == "" {
|
||||
U.HandleErr(w, r, U.ErrMissingKey("target"), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var ok bool
|
||||
route := cfg.FindRoute(target)
|
||||
|
||||
switch {
|
||||
case route == nil:
|
||||
U.HandleErr(w, r, U.ErrNotFound("target", target), http.StatusNotFound)
|
||||
return
|
||||
case route.Type() == R.RouteTypeReverseProxy:
|
||||
ok = IsSiteHealthy(route.URL().String())
|
||||
case route.Type() == R.RouteTypeStream:
|
||||
entry := route.Entry()
|
||||
ok = IsStreamHealthy(
|
||||
strings.Split(entry.Scheme, ":")[1], // target scheme
|
||||
fmt.Sprintf("%s:%v", entry.Host, strings.Split(entry.Port, ":")[1]),
|
||||
)
|
||||
}
|
||||
|
||||
if ok {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusRequestTimeout)
|
||||
}
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package error_page
|
||||
|
||||
import "net/http"
|
||||
|
||||
func GetHandleFunc() http.HandlerFunc {
|
||||
setup()
|
||||
return serveHTTP
|
||||
}
|
||||
|
||||
func serveHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/" {
|
||||
http.Error(w, "invalid path", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
content, ok := fileContentMap.Load(r.URL.Path)
|
||||
if !ok {
|
||||
http.Error(w, "404 not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
w.Write(content)
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/config"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/proxy/provider"
|
||||
"github.com/yusing/go-proxy/internal/route/provider"
|
||||
)
|
||||
|
||||
func GetFileContent(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -24,7 +24,7 @@ func GetFileContent(w http.ResponseWriter, r *http.Request) {
|
||||
U.HandleErr(w, r, err)
|
||||
return
|
||||
}
|
||||
w.Write(content)
|
||||
U.WriteBody(w, content)
|
||||
}
|
||||
|
||||
func SetFileContent(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -39,19 +39,20 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var validateErr E.NestedError
|
||||
var valErr E.Error
|
||||
if filename == common.ConfigFileName {
|
||||
validateErr = config.Validate(content)
|
||||
valErr = config.Validate(content)
|
||||
} else if !strings.HasPrefix(filename, path.Base(common.MiddlewareComposeBasePath)) {
|
||||
validateErr = provider.Validate(content)
|
||||
valErr = provider.Validate(content)
|
||||
}
|
||||
// no validation for include files
|
||||
|
||||
if validateErr != nil {
|
||||
U.RespondJson(w, validateErr.JSONObject(), http.StatusBadRequest)
|
||||
if valErr != nil {
|
||||
U.RespondJSON(w, r, valErr, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
err = os.WriteFile(path.Join(common.ConfigBasePath, filename), content, 0644)
|
||||
err = os.WriteFile(path.Join(common.ConfigBasePath, filename), content, 0o644)
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, err)
|
||||
return
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
)
|
||||
|
||||
func IsSiteHealthy(url string) bool {
|
||||
// try HEAD first
|
||||
// if HEAD is not allowed, try GET
|
||||
resp, err := U.Head(url)
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
if err != nil && resp != nil && resp.StatusCode == http.StatusMethodNotAllowed {
|
||||
_, err = U.Get(url)
|
||||
}
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func IsStreamHealthy(scheme, address string) bool {
|
||||
conn, err := net.DialTimeout(scheme, address, common.DialTimeout)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
conn.Close()
|
||||
return true
|
||||
}
|
||||
@@ -1,7 +1,11 @@
|
||||
package v1
|
||||
|
||||
import "net/http"
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
)
|
||||
|
||||
func Index(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("API ready"))
|
||||
WriteBody(w, []byte("API ready"))
|
||||
}
|
||||
|
||||
@@ -8,54 +8,69 @@ import (
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/config"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
"github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
ListRoutes = "routes"
|
||||
ListConfigFiles = "config_files"
|
||||
ListMiddlewares = "middlewares"
|
||||
ListMiddlewareTrace = "middleware_trace"
|
||||
ListMatchDomains = "match_domains"
|
||||
ListHomepageConfig = "homepage_config"
|
||||
ListRoute = "route"
|
||||
ListRoutes = "routes"
|
||||
ListConfigFiles = "config_files"
|
||||
ListMiddlewares = "middlewares"
|
||||
ListMiddlewareTraces = "middleware_trace"
|
||||
ListMatchDomains = "match_domains"
|
||||
ListHomepageConfig = "homepage_config"
|
||||
ListTasks = "tasks"
|
||||
)
|
||||
|
||||
func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
func List(w http.ResponseWriter, r *http.Request) {
|
||||
what := r.PathValue("what")
|
||||
if what == "" {
|
||||
what = ListRoutes
|
||||
}
|
||||
which := r.PathValue("which")
|
||||
|
||||
switch what {
|
||||
case ListRoute:
|
||||
if route := listRoute(which); route == nil {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
} else {
|
||||
U.RespondJSON(w, r, route)
|
||||
}
|
||||
case ListRoutes:
|
||||
listRoutes(cfg, w, r)
|
||||
U.RespondJSON(w, r, config.RoutesByAlias(route.RouteType(r.FormValue("type"))))
|
||||
case ListConfigFiles:
|
||||
listConfigFiles(w, r)
|
||||
case ListMiddlewares:
|
||||
listMiddlewares(w, r)
|
||||
case ListMiddlewareTrace:
|
||||
listMiddlewareTrace(w, r)
|
||||
U.RespondJSON(w, r, middleware.All())
|
||||
case ListMiddlewareTraces:
|
||||
U.RespondJSON(w, r, middleware.GetAllTrace())
|
||||
case ListMatchDomains:
|
||||
listMatchDomains(cfg, w, r)
|
||||
U.RespondJSON(w, r, config.Value().MatchDomains)
|
||||
case ListHomepageConfig:
|
||||
listHomepageConfig(cfg, w, r)
|
||||
U.RespondJSON(w, r, config.HomepageConfig())
|
||||
case ListTasks:
|
||||
U.RespondJSON(w, r, task.DebugTaskMap())
|
||||
default:
|
||||
U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func listRoutes(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
routes := cfg.RoutesByAlias()
|
||||
typeFilter := r.FormValue("type")
|
||||
if typeFilter != "" {
|
||||
for k, v := range routes {
|
||||
if v["type"] != typeFilter {
|
||||
delete(routes, k)
|
||||
}
|
||||
}
|
||||
func listRoute(which string) any {
|
||||
if which == "" {
|
||||
which = "all"
|
||||
}
|
||||
|
||||
U.HandleErr(w, r, U.RespondJson(w, routes))
|
||||
if which == "all" {
|
||||
return config.RoutesByAlias()
|
||||
}
|
||||
routes := config.RoutesByAlias()
|
||||
route, ok := routes[which]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return route
|
||||
}
|
||||
|
||||
func listConfigFiles(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -67,21 +82,5 @@ func listConfigFiles(w http.ResponseWriter, r *http.Request) {
|
||||
for i := range files {
|
||||
files[i] = strings.TrimPrefix(files[i], common.ConfigBasePath+"/")
|
||||
}
|
||||
U.HandleErr(w, r, U.RespondJson(w, files))
|
||||
}
|
||||
|
||||
func listMiddlewareTrace(w http.ResponseWriter, r *http.Request) {
|
||||
U.HandleErr(w, r, U.RespondJson(w, middleware.GetAllTrace()))
|
||||
}
|
||||
|
||||
func listMiddlewares(w http.ResponseWriter, r *http.Request) {
|
||||
U.HandleErr(w, r, U.RespondJson(w, middleware.All()))
|
||||
}
|
||||
|
||||
func listMatchDomains(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
U.HandleErr(w, r, U.RespondJson(w, cfg.Value().MatchDomains))
|
||||
}
|
||||
|
||||
func listHomepageConfig(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
U.HandleErr(w, r, U.RespondJson(w, cfg.HomepageConfig()))
|
||||
U.RespondJSON(w, r, files)
|
||||
}
|
||||
|
||||
@@ -13,57 +13,52 @@ import (
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
)
|
||||
|
||||
func ReloadServer() E.NestedError {
|
||||
resp, err := U.Post(fmt.Sprintf("%s/v1/reload", common.APIHTTPURL), "", nil)
|
||||
func ReloadServer() E.Error {
|
||||
resp, err := U.Post(common.APIHTTPURL+"/v1/reload", "", nil)
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
failure := E.Failure("server reload").Extraf("status code: %v", resp.StatusCode)
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
failure := E.Errorf("server reload status %v", resp.StatusCode)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return failure.Extraf("unable to read response body: %s", err)
|
||||
return failure.With(err)
|
||||
}
|
||||
reloadErr, ok := E.FromJSON(b)
|
||||
if ok {
|
||||
return E.Join("reload success, but server returned error", reloadErr)
|
||||
}
|
||||
return failure.Extraf("unable to read response body")
|
||||
reloadErr := string(body)
|
||||
return failure.Withf(reloadErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ListRoutes() (map[string]map[string]any, E.NestedError) {
|
||||
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListRoutes))
|
||||
func List[T any](what string) (_ T, outErr E.Error) {
|
||||
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, what))
|
||||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
outErr = E.From(err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, E.Failure("list routes").Extraf("status code: %v", resp.StatusCode)
|
||||
outErr = E.Errorf("list %s: failed, status %v", what, resp.StatusCode)
|
||||
return
|
||||
}
|
||||
var routes map[string]map[string]any
|
||||
err = json.NewDecoder(resp.Body).Decode(&routes)
|
||||
var res T
|
||||
err = json.NewDecoder(resp.Body).Decode(&res)
|
||||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
outErr = E.From(err)
|
||||
return
|
||||
}
|
||||
return routes, nil
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func ListMiddlewareTraces() (middleware.Traces, E.NestedError) {
|
||||
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListMiddlewareTrace))
|
||||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, E.Failure("list middleware trace").Extraf("status code: %v", resp.StatusCode)
|
||||
}
|
||||
var traces middleware.Traces
|
||||
err = json.NewDecoder(resp.Body).Decode(&traces)
|
||||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
}
|
||||
return traces, nil
|
||||
func ListRoutes() (map[string]map[string]any, E.Error) {
|
||||
return List[map[string]map[string]any](v1.ListRoutes)
|
||||
}
|
||||
|
||||
func ListMiddlewareTraces() (middleware.Traces, E.Error) {
|
||||
return List[middleware.Traces](v1.ListMiddlewareTraces)
|
||||
}
|
||||
|
||||
func DebugListTasks() (map[string]any, E.Error) {
|
||||
return List[map[string]any](v1.ListTasks)
|
||||
}
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"github.com/yusing/go-proxy/internal/config"
|
||||
)
|
||||
|
||||
func Reload(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
if err := cfg.Reload(); err != nil {
|
||||
U.RespondJson(w, err.JSONObject(), http.StatusInternalServerError)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
func Reload(w http.ResponseWriter, r *http.Request) {
|
||||
if err := config.Reload(); err != nil {
|
||||
U.HandleErr(w, r, err)
|
||||
return
|
||||
}
|
||||
U.WriteBody(w, []byte("OK"))
|
||||
}
|
||||
|
||||
@@ -5,33 +5,35 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/coder/websocket/wsjson"
|
||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/config"
|
||||
"github.com/yusing/go-proxy/internal/server"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/coder/websocket/wsjson"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
func Stats(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
U.HandleErr(w, r, U.RespondJson(w, getStats(cfg)))
|
||||
func Stats(w http.ResponseWriter, r *http.Request) {
|
||||
U.RespondJSON(w, r, getStats())
|
||||
}
|
||||
|
||||
func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
|
||||
originPats := make([]string, len(cfg.Value().MatchDomains)+len(localAddresses))
|
||||
func StatsWS(w http.ResponseWriter, r *http.Request) {
|
||||
var originPats []string
|
||||
|
||||
if len(originPats) == 0 {
|
||||
U.Logger.Warnf("no match domains configured, accepting websocket request from all origins")
|
||||
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
|
||||
|
||||
if len(config.Value().MatchDomains) == 0 {
|
||||
U.LogWarn(r).Msg("no match domains configured, accepting websocket API request from all origins")
|
||||
originPats = []string{"*"}
|
||||
} else {
|
||||
for i, domain := range cfg.Value().MatchDomains {
|
||||
originPats[i] = "*." + domain
|
||||
originPats = make([]string, len(config.Value().MatchDomains))
|
||||
for i, domain := range config.Value().MatchDomains {
|
||||
originPats[i] = "*" + domain
|
||||
}
|
||||
originPats = append(originPats, localAddresses...)
|
||||
}
|
||||
U.LogInfo(r).Msgf("websocket API request from origins: %s", originPats)
|
||||
if common.IsDebug {
|
||||
originPats = []string{"*"}
|
||||
}
|
||||
@@ -39,9 +41,10 @@ func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
OriginPatterns: originPats,
|
||||
})
|
||||
if err != nil {
|
||||
U.Logger.Errorf("/stats/ws failed to upgrade websocket: %s", err)
|
||||
U.LogError(r).Err(err).Msg("failed to upgrade websocket")
|
||||
return
|
||||
}
|
||||
/* trunk-ignore(golangci-lint/errcheck) */
|
||||
defer conn.CloseNow()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -51,17 +54,17 @@ func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
stats := getStats(cfg)
|
||||
stats := getStats()
|
||||
if err := wsjson.Write(ctx, conn, stats); err != nil {
|
||||
U.Logger.Errorf("/stats/ws failed to write JSON: %s", err)
|
||||
U.LogError(r).Msg("failed to write JSON")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getStats(cfg *config.Config) map[string]any {
|
||||
func getStats() map[string]any {
|
||||
return map[string]any{
|
||||
"proxies": cfg.Statistics(),
|
||||
"uptime": utils.FormatDuration(server.GetProxyServer().Uptime()),
|
||||
"proxies": config.Statistics(),
|
||||
"uptime": strutils.FormatDuration(server.GetProxyServer().Uptime()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,37 +1,36 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
var Logger = logrus.WithField("module", "api")
|
||||
|
||||
// HandleErr logs the error and returns an HTTP error response to the client.
|
||||
// If code is specified, it will be used as the HTTP status code; otherwise,
|
||||
// http.StatusInternalServerError is used.
|
||||
//
|
||||
// The error is only logged but not returned to the client.
|
||||
func HandleErr(w http.ResponseWriter, r *http.Request, origErr error, code ...int) {
|
||||
if origErr == nil {
|
||||
return
|
||||
}
|
||||
err := E.From(origErr).Subjectf("%s %s", r.Method, r.URL)
|
||||
Logger.Error(err)
|
||||
LogError(r).Msg(origErr.Error())
|
||||
statusCode := http.StatusInternalServerError
|
||||
if len(code) > 0 {
|
||||
http.Error(w, err.String(), code[0])
|
||||
return
|
||||
statusCode = code[0]
|
||||
}
|
||||
http.Error(w, err.String(), http.StatusInternalServerError)
|
||||
http.Error(w, http.StatusText(statusCode), statusCode)
|
||||
}
|
||||
|
||||
func ErrMissingKey(k string) error {
|
||||
return errors.New("missing key '" + k + "' in query or request body")
|
||||
return E.New("missing key '" + k + "' in query or request body")
|
||||
}
|
||||
|
||||
func ErrInvalidKey(k string) error {
|
||||
return errors.New("invalid key '" + k + "' in query or request body")
|
||||
return E.New("invalid key '" + k + "' in query or request body")
|
||||
}
|
||||
|
||||
func ErrNotFound(k, v string) error {
|
||||
return fmt.Errorf("key %q with value %q not found", k, v)
|
||||
return E.Errorf("key %q with value %q not found", k, v)
|
||||
}
|
||||
|
||||
@@ -8,20 +8,21 @@ import (
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
)
|
||||
|
||||
var HTTPClient = &http.Client{
|
||||
Timeout: common.ConnectionTimeout,
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DisableKeepAlives: true,
|
||||
ForceAttemptHTTP2: true,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: common.DialTimeout,
|
||||
KeepAlive: common.KeepAlive, // this is different from DisableKeepAlives
|
||||
}).DialContext,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
var (
|
||||
httpClient = &http.Client{
|
||||
Timeout: common.ConnectionTimeout,
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
ForceAttemptHTTP2: false,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: common.DialTimeout,
|
||||
KeepAlive: common.KeepAlive, // this is different from DisableKeepAlives
|
||||
}).DialContext,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
|
||||
var Get = HTTPClient.Get
|
||||
var Post = HTTPClient.Post
|
||||
var Head = HTTPClient.Head
|
||||
Get = httpClient.Get
|
||||
Post = httpClient.Post
|
||||
Head = httpClient.Head
|
||||
)
|
||||
|
||||
18
internal/api/v1/utils/logging.go
Normal file
18
internal/api/v1/utils/logging.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
)
|
||||
|
||||
func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
|
||||
return logging.WithLevel(level).Str("module", "api").
|
||||
Str("method", r.Method).
|
||||
Str("path", r.RequestURI)
|
||||
}
|
||||
|
||||
func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) }
|
||||
func LogWarn(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.WarnLevel) }
|
||||
func LogInfo(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.InfoLevel) }
|
||||
@@ -2,19 +2,42 @@ package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
)
|
||||
|
||||
func RespondJson(w http.ResponseWriter, data any, code ...int) error {
|
||||
func WriteBody(w http.ResponseWriter, body []byte) {
|
||||
if _, err := w.Write(body); err != nil {
|
||||
HandleErr(w, nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) (canProceed bool) {
|
||||
if len(code) > 0 {
|
||||
w.WriteHeader(code[0])
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
j, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
} else {
|
||||
w.Write(j)
|
||||
var j []byte
|
||||
var err error
|
||||
|
||||
switch data := data.(type) {
|
||||
case string:
|
||||
j = []byte(fmt.Sprintf("%q", data))
|
||||
case []byte:
|
||||
j = data
|
||||
default:
|
||||
j, err = json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
logging.Panic().Err(err).Msg("failed to marshal json")
|
||||
return false
|
||||
}
|
||||
}
|
||||
return nil
|
||||
_, err = w.Write(j)
|
||||
if err != nil {
|
||||
HandleErr(w, r, err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -3,9 +3,10 @@ package v1
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/pkg"
|
||||
)
|
||||
|
||||
func GetVersion(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(pkg.GetVersion()))
|
||||
WriteBody(w, []byte(pkg.GetVersion()))
|
||||
}
|
||||
|
||||
@@ -8,12 +8,21 @@ import (
|
||||
"github.com/go-acme/lego/v4/certcrypto"
|
||||
"github.com/go-acme/lego/v4/lego"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/config/types"
|
||||
)
|
||||
|
||||
type Config types.AutoCertConfig
|
||||
|
||||
var (
|
||||
ErrMissingDomain = E.New("missing field 'domains'")
|
||||
ErrMissingEmail = E.New("missing field 'email'")
|
||||
ErrMissingProvider = E.New("missing field 'provider'")
|
||||
ErrUnknownProvider = E.New("unknown provider")
|
||||
)
|
||||
|
||||
func NewConfig(cfg *types.AutoCertConfig) *Config {
|
||||
if cfg.CertPath == "" {
|
||||
cfg.CertPath = CertFileDefault
|
||||
@@ -27,35 +36,36 @@ func NewConfig(cfg *types.AutoCertConfig) *Config {
|
||||
return (*Config)(cfg)
|
||||
}
|
||||
|
||||
func (cfg *Config) GetProvider() (provider *Provider, res E.NestedError) {
|
||||
b := E.NewBuilder("unable to initialize autocert")
|
||||
defer b.To(&res)
|
||||
func (cfg *Config) GetProvider() (*Provider, E.Error) {
|
||||
b := E.NewBuilder("autocert errors")
|
||||
|
||||
if cfg.Provider != ProviderLocal {
|
||||
if len(cfg.Domains) == 0 {
|
||||
b.Addf("%s", "no domains specified")
|
||||
b.Add(ErrMissingDomain)
|
||||
}
|
||||
if cfg.Provider == "" {
|
||||
b.Addf("%s", "no provider specified")
|
||||
b.Add(ErrMissingProvider)
|
||||
}
|
||||
if cfg.Email == "" {
|
||||
b.Addf("%s", "no email specified")
|
||||
b.Add(ErrMissingEmail)
|
||||
}
|
||||
// check if provider is implemented
|
||||
_, ok := providersGenMap[cfg.Provider]
|
||||
if !ok {
|
||||
b.Addf("unknown provider: %q", cfg.Provider)
|
||||
b.Add(ErrUnknownProvider.
|
||||
Subject(cfg.Provider).
|
||||
Withf(strutils.DoYouMean(utils.NearestField(cfg.Provider, providersGenMap))))
|
||||
}
|
||||
}
|
||||
|
||||
if b.HasError() {
|
||||
return
|
||||
return nil, b.Error()
|
||||
}
|
||||
|
||||
privKey, err := E.Check(ecdsa.GenerateKey(elliptic.P256(), rand.Reader))
|
||||
if err.HasError() {
|
||||
b.Add(E.FailWith("generate private key", err))
|
||||
return
|
||||
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
b.Addf("generate private key: %w", err)
|
||||
return nil, b.Error()
|
||||
}
|
||||
|
||||
user := &User{
|
||||
@@ -66,11 +76,9 @@ func (cfg *Config) GetProvider() (provider *Provider, res E.NestedError) {
|
||||
legoCfg := lego.NewConfig(user)
|
||||
legoCfg.Certificate.KeyType = certcrypto.RSA2048
|
||||
|
||||
provider = &Provider{
|
||||
return &Provider{
|
||||
cfg: cfg,
|
||||
user: user,
|
||||
legoCfg: legoCfg,
|
||||
}
|
||||
|
||||
return
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
package autocert
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/go-acme/lego/v4/providers/dns/clouddns"
|
||||
"github.com/go-acme/lego/v4/providers/dns/cloudflare"
|
||||
"github.com/go-acme/lego/v4/providers/dns/duckdns"
|
||||
"github.com/go-acme/lego/v4/providers/dns/ovh"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -32,9 +29,3 @@ var providersGenMap = map[string]ProviderGenerator{
|
||||
ProviderDuckdns: providerGenerator(duckdns.NewDefaultConfig, duckdns.NewDNSProviderConfig),
|
||||
ProviderOVH: providerGenerator(ovh.NewDefaultConfig, ovh.NewDNSProviderConfig),
|
||||
}
|
||||
|
||||
var (
|
||||
ErrGetCertFailure = errors.New("get certificate failed")
|
||||
)
|
||||
|
||||
var logger = logrus.WithField("module", "autocert")
|
||||
|
||||
5
internal/autocert/logger.go
Normal file
5
internal/autocert/logger.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package autocert
|
||||
|
||||
import "github.com/yusing/go-proxy/internal/logging"
|
||||
|
||||
var logger = logging.With().Str("module", "autocert").Logger()
|
||||
@@ -1,9 +1,9 @@
|
||||
package autocert
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
@@ -14,24 +14,29 @@ import (
|
||||
"github.com/go-acme/lego/v4/challenge"
|
||||
"github.com/go-acme/lego/v4/lego"
|
||||
"github.com/go-acme/lego/v4/registration"
|
||||
"github.com/yusing/go-proxy/internal/config/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
cfg *Config
|
||||
user *User
|
||||
legoCfg *lego.Config
|
||||
client *lego.Client
|
||||
type (
|
||||
Provider struct {
|
||||
cfg *Config
|
||||
user *User
|
||||
legoCfg *lego.Config
|
||||
client *lego.Client
|
||||
|
||||
tlsCert *tls.Certificate
|
||||
certExpiries CertExpiries
|
||||
}
|
||||
tlsCert *tls.Certificate
|
||||
certExpiries CertExpiries
|
||||
}
|
||||
ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.Error)
|
||||
|
||||
type ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.NestedError)
|
||||
type CertExpiries map[string]time.Time
|
||||
CertExpiries map[string]time.Time
|
||||
)
|
||||
|
||||
var ErrGetCertFailure = errors.New("get certificate failed")
|
||||
|
||||
func (p *Provider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if p.tlsCert == nil {
|
||||
@@ -56,25 +61,20 @@ func (p *Provider) GetExpiries() CertExpiries {
|
||||
return p.certExpiries
|
||||
}
|
||||
|
||||
func (p *Provider) ObtainCert() (res E.NestedError) {
|
||||
b := E.NewBuilder("failed to obtain certificate")
|
||||
defer b.To(&res)
|
||||
|
||||
func (p *Provider) ObtainCert() E.Error {
|
||||
if p.cfg.Provider == ProviderLocal {
|
||||
return nil
|
||||
}
|
||||
|
||||
if p.client == nil {
|
||||
if err := p.initClient(); err.HasError() {
|
||||
b.Add(E.FailWith("init autocert client", err))
|
||||
return
|
||||
if err := p.initClient(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if p.user.Registration == nil {
|
||||
if err := p.registerACME(); err.HasError() {
|
||||
b.Add(E.FailWith("register ACME", err))
|
||||
return
|
||||
if err := p.registerACME(); err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,27 +83,23 @@ func (p *Provider) ObtainCert() (res E.NestedError) {
|
||||
Domains: p.cfg.Domains,
|
||||
Bundle: true,
|
||||
}
|
||||
cert, err := E.Check(client.Certificate.Obtain(req))
|
||||
if err.HasError() {
|
||||
b.Add(err)
|
||||
return
|
||||
cert, err := client.Certificate.Obtain(req)
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
|
||||
if err = p.saveCert(cert); err.HasError() {
|
||||
b.Add(E.FailWith("save certificate", err))
|
||||
return
|
||||
if err = p.saveCert(cert); err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
|
||||
tlsCert, err := E.Check(tls.X509KeyPair(cert.Certificate, cert.PrivateKey))
|
||||
if err.HasError() {
|
||||
b.Add(E.FailWith("parse obtained certificate", err))
|
||||
return
|
||||
tlsCert, err := tls.X509KeyPair(cert.Certificate, cert.PrivateKey)
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
|
||||
expiries, err := getCertExpiries(&tlsCert)
|
||||
if err.HasError() {
|
||||
b.Add(E.FailWith("get certificate expiry", err))
|
||||
return
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
p.tlsCert = &tlsCert
|
||||
p.certExpiries = expiries
|
||||
@@ -111,22 +107,23 @@ func (p *Provider) ObtainCert() (res E.NestedError) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) LoadCert() E.NestedError {
|
||||
cert, err := E.Check(tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath))
|
||||
if err.HasError() {
|
||||
return err
|
||||
func (p *Provider) LoadCert() E.Error {
|
||||
cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)
|
||||
if err != nil {
|
||||
return E.Errorf("load SSL certificate: %w", err)
|
||||
}
|
||||
expiries, err := getCertExpiries(&cert)
|
||||
if err.HasError() {
|
||||
return err
|
||||
if err != nil {
|
||||
return E.Errorf("parse SSL certificate: %w", err)
|
||||
}
|
||||
p.tlsCert = &cert
|
||||
p.certExpiries = expiries
|
||||
|
||||
logger.Infof("next renewal in %v", U.FormatDuration(time.Until(p.ShouldRenewOn())))
|
||||
logger.Info().Msgf("next renewal in %v", strutils.FormatDuration(time.Until(p.ShouldRenewOn())))
|
||||
return p.renewIfNeeded()
|
||||
}
|
||||
|
||||
// ShouldRenewOn returns the time at which the certificate should be renewed.
|
||||
func (p *Provider) ShouldRenewOn() time.Time {
|
||||
for _, expiry := range p.certExpiries {
|
||||
return expiry.AddDate(0, -1, 0) // 1 month before
|
||||
@@ -135,55 +132,55 @@ func (p *Provider) ShouldRenewOn() time.Time {
|
||||
panic("no certificate available")
|
||||
}
|
||||
|
||||
func (p *Provider) ScheduleRenewal(ctx context.Context) {
|
||||
func (p *Provider) ScheduleRenewal() {
|
||||
if p.GetName() == ProviderLocal {
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug("started renewal scheduler")
|
||||
defer logger.Debug("renewal scheduler stopped")
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C: // check every 5 seconds
|
||||
if err := p.renewIfNeeded(); err.HasError() {
|
||||
logger.Warn(err)
|
||||
go func() {
|
||||
task := task.GlobalTask("cert renew scheduler")
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
defer task.Finish("cert renew scheduler stopped")
|
||||
for {
|
||||
select {
|
||||
case <-task.Context().Done():
|
||||
return
|
||||
case <-ticker.C: // check every 5 seconds
|
||||
if err := p.renewIfNeeded(); err != nil {
|
||||
E.LogWarn("cert renew failed", err, &logger)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (p *Provider) initClient() E.NestedError {
|
||||
legoClient, err := E.Check(lego.NewClient(p.legoCfg))
|
||||
if err.HasError() {
|
||||
return E.FailWith("create lego client", err)
|
||||
func (p *Provider) initClient() E.Error {
|
||||
legoClient, err := lego.NewClient(p.legoCfg)
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
|
||||
legoProvider, err := providersGenMap[p.cfg.Provider](p.cfg.Options)
|
||||
if err.HasError() {
|
||||
return E.FailWith("create lego provider", err)
|
||||
generator := providersGenMap[p.cfg.Provider]
|
||||
legoProvider, pErr := generator(p.cfg.Options)
|
||||
if pErr != nil {
|
||||
return pErr
|
||||
}
|
||||
|
||||
err = E.From(legoClient.Challenge.SetDNS01Provider(legoProvider))
|
||||
if err.HasError() {
|
||||
return E.FailWith("set challenge provider", err)
|
||||
err = legoClient.Challenge.SetDNS01Provider(legoProvider)
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
|
||||
p.client = legoClient
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) registerACME() E.NestedError {
|
||||
func (p *Provider) registerACME() error {
|
||||
if p.user.Registration != nil {
|
||||
return nil
|
||||
}
|
||||
reg, err := E.Check(p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}))
|
||||
if err.HasError() {
|
||||
reg, err := p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.user.Registration = reg
|
||||
@@ -191,26 +188,27 @@ func (p *Provider) registerACME() E.NestedError {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError {
|
||||
//* This should have been done in setup
|
||||
//* but double check is always a good choice
|
||||
func (p *Provider) saveCert(cert *certificate.Resource) error {
|
||||
/* This should have been done in setup
|
||||
but double check is always a good choice.*/
|
||||
_, err := os.Stat(path.Dir(p.cfg.CertPath))
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
if err = os.MkdirAll(path.Dir(p.cfg.CertPath), 0o755); err != nil {
|
||||
return E.FailWith("create cert directory", err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return E.FailWith("stat cert directory", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw-------
|
||||
if err != nil {
|
||||
return E.FailWith("write key file", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0o644) // -rw-r--r--
|
||||
if err != nil {
|
||||
return E.FailWith("write cert file", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -232,39 +230,36 @@ func (p *Provider) certState() CertState {
|
||||
sort.Strings(certDomains)
|
||||
|
||||
if !reflect.DeepEqual(certDomains, wantedDomains) {
|
||||
logger.Debugf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains)
|
||||
logger.Info().Msgf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains)
|
||||
return CertStateMismatch
|
||||
}
|
||||
|
||||
return CertStateValid
|
||||
}
|
||||
|
||||
func (p *Provider) renewIfNeeded() E.NestedError {
|
||||
func (p *Provider) renewIfNeeded() E.Error {
|
||||
if p.cfg.Provider == ProviderLocal {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch p.certState() {
|
||||
case CertStateExpired:
|
||||
logger.Info("certs expired, renewing")
|
||||
logger.Info().Msg("certs expired, renewing")
|
||||
case CertStateMismatch:
|
||||
logger.Info("cert domains mismatch with config, renewing")
|
||||
logger.Info().Msg("cert domains mismatch with config, renewing")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := p.ObtainCert(); err.HasError() {
|
||||
return E.FailWith("renew certificate", err)
|
||||
}
|
||||
return nil
|
||||
return p.ObtainCert()
|
||||
}
|
||||
|
||||
func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) {
|
||||
func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) {
|
||||
r := make(CertExpiries, len(cert.Certificate))
|
||||
for _, cert := range cert.Certificate {
|
||||
x509Cert, err := E.Check(x509.ParseCertificate(cert))
|
||||
if err.HasError() {
|
||||
return nil, E.FailWith("parse certificate", err)
|
||||
x509Cert, err := x509.ParseCertificate(cert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if x509Cert.IsCA {
|
||||
continue
|
||||
@@ -281,16 +276,13 @@ func providerGenerator[CT any, PT challenge.Provider](
|
||||
defaultCfg func() *CT,
|
||||
newProvider func(*CT) (PT, error),
|
||||
) ProviderGenerator {
|
||||
return func(opt types.AutocertProviderOpt) (challenge.Provider, E.NestedError) {
|
||||
return func(opt types.AutocertProviderOpt) (challenge.Provider, E.Error) {
|
||||
cfg := defaultCfg()
|
||||
err := U.Deserialize(opt, cfg)
|
||||
if err.HasError() {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p, err := E.Check(newProvider(cfg))
|
||||
if err.HasError() {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
p, pErr := newProvider(cfg)
|
||||
return p, E.From(pErr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,6 +45,6 @@ oauth2_config:
|
||||
testYaml = testYaml[1:] // remove first \n
|
||||
opt := make(map[string]any)
|
||||
ExpectNoError(t, yaml.Unmarshal([]byte(testYaml), opt))
|
||||
ExpectNoError(t, U.Deserialize(opt, cfg).Error())
|
||||
ExpectNoError(t, U.Deserialize(opt, cfg))
|
||||
ExpectDeepEqual(t, cfg, cfgExpected)
|
||||
}
|
||||
|
||||
@@ -1,27 +1,26 @@
|
||||
package autocert
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
func (p *Provider) Setup(ctx context.Context) (err E.NestedError) {
|
||||
func (p *Provider) Setup() (err E.Error) {
|
||||
if err = p.LoadCert(); err != nil {
|
||||
if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist
|
||||
return err
|
||||
}
|
||||
logger.Debug("obtaining cert due to error loading cert")
|
||||
logger.Debug().Msg("obtaining cert due to error loading cert")
|
||||
if err = p.ObtainCert(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
go p.ScheduleRenewal(ctx)
|
||||
p.ScheduleRenewal()
|
||||
|
||||
for _, expiry := range p.GetExpiries() {
|
||||
logger.Infof("certificate expire on %s", expiry)
|
||||
logger.Info().Msg("certificate expire on " + expiry.String())
|
||||
break
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package autocert
|
||||
|
||||
import (
|
||||
"github.com/go-acme/lego/v4/registration"
|
||||
"crypto"
|
||||
|
||||
"github.com/go-acme/lego/v4/registration"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
@@ -19,4 +20,4 @@ func (u *User) GetRegistration() *registration.Resource {
|
||||
}
|
||||
func (u *User) GetPrivateKey() crypto.PrivateKey {
|
||||
return u.key
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,8 +3,7 @@ package common
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"log"
|
||||
)
|
||||
|
||||
type Args struct {
|
||||
@@ -42,7 +41,7 @@ func GetArgs() Args {
|
||||
flag.Parse()
|
||||
args.Command = flag.Arg(0)
|
||||
if err := validateArg(args.Command); err != nil {
|
||||
logrus.Fatal(err)
|
||||
log.Fatalf("invalid command: %s", err)
|
||||
}
|
||||
return args
|
||||
}
|
||||
@@ -53,5 +52,5 @@ func validateArg(arg string) error {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("invalid command: %s", arg)
|
||||
return fmt.Errorf("invalid command %q", arg)
|
||||
}
|
||||
|
||||
@@ -13,43 +13,44 @@ const (
|
||||
// file, folder structure
|
||||
|
||||
const (
|
||||
DotEnvPath = ".env"
|
||||
DotEnvExamplePath = ".env.example"
|
||||
|
||||
ConfigBasePath = "config"
|
||||
ConfigFileName = "config.yml"
|
||||
ConfigExampleFileName = "config.example.yml"
|
||||
ConfigPath = ConfigBasePath + "/" + ConfigFileName
|
||||
|
||||
MiddlewareComposeBasePath = ConfigBasePath + "/middlewares"
|
||||
)
|
||||
JWTKeyPath = ConfigBasePath + "/jwt.key"
|
||||
|
||||
MiddlewareComposeBasePath = ConfigBasePath + "/middlewares"
|
||||
|
||||
const (
|
||||
SchemaBasePath = "schema"
|
||||
ConfigSchemaPath = SchemaBasePath + "/config.schema.json"
|
||||
FileProviderSchemaPath = SchemaBasePath + "/providers.schema.json"
|
||||
)
|
||||
|
||||
const (
|
||||
ComposeFileName = "compose.yml"
|
||||
ComposeExampleFileName = "compose.example.yml"
|
||||
)
|
||||
|
||||
const (
|
||||
ErrorPagesBasePath = "error_pages"
|
||||
)
|
||||
|
||||
var (
|
||||
RequiredDirectories = []string{
|
||||
ConfigBasePath,
|
||||
SchemaBasePath,
|
||||
ErrorPagesBasePath,
|
||||
MiddlewareComposeBasePath,
|
||||
}
|
||||
)
|
||||
var RequiredDirectories = []string{
|
||||
ConfigBasePath,
|
||||
SchemaBasePath,
|
||||
ErrorPagesBasePath,
|
||||
MiddlewareComposeBasePath,
|
||||
}
|
||||
|
||||
const DockerHostFromEnv = "$DOCKER_HOST"
|
||||
|
||||
const (
|
||||
IdleTimeoutDefault = "0"
|
||||
HealthCheckIntervalDefault = 5 * time.Second
|
||||
HealthCheckTimeoutDefault = 5 * time.Second
|
||||
|
||||
WakeTimeoutDefault = "30s"
|
||||
StopTimeoutDefault = "10s"
|
||||
StopMethodDefault = "stop"
|
||||
)
|
||||
|
||||
const HeaderCheckRedirect = "X-Goproxy-Check-Redirect"
|
||||
|
||||
31
internal/common/crypto.go
Normal file
31
internal/common/crypto.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func HashPassword(pwd string) []byte {
|
||||
h := sha512.New()
|
||||
h.Write([]byte(pwd))
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func generateJWTKey(size int) string {
|
||||
bytes := make([]byte, size)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
log.Panic().Err(err).Msg("failed to generate jwt key")
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
func decodeJWTKey(key string) []byte {
|
||||
bytes, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
log.Panic().Err(err).Msg("failed to decode jwt key")
|
||||
}
|
||||
return bytes
|
||||
}
|
||||
@@ -2,19 +2,21 @@ package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var (
|
||||
NoSchemaValidation = GetEnvBool("GOPROXY_NO_SCHEMA_VALIDATION", true)
|
||||
IsTest = GetEnvBool("GOPROXY_TEST", false) || strings.HasSuffix(os.Args[0], ".test")
|
||||
IsDebug = GetEnvBool("GOPROXY_DEBUG", IsTest)
|
||||
IsDebugSkipAuth = GetEnvBool("GOPROXY_DEBUG_SKIP_AUTH", false)
|
||||
IsTrace = GetEnvBool("GOPROXY_TRACE", false) && IsDebug
|
||||
|
||||
ProxyHTTPAddr,
|
||||
ProxyHTTPHost,
|
||||
@@ -30,6 +32,11 @@ var (
|
||||
APIHTTPHost,
|
||||
APIHTTPPort,
|
||||
APIHTTPURL = GetAddrEnv("GOPROXY_API_ADDR", "127.0.0.1:8888", "http")
|
||||
|
||||
APIJWTSecret = decodeJWTKey(GetEnv("GOPROXY_API_JWT_SECRET", generateJWTKey(32)))
|
||||
APIJWTTokenTTL = GetDurationEnv("GOPROXY_API_JWT_TOKEN_TTL", time.Hour)
|
||||
APIUser = GetEnv("GOPROXY_API_USER", "admin")
|
||||
APIPasswordHash = HashPassword(GetEnv("GOPROXY_API_PASSWORD", "password"))
|
||||
)
|
||||
|
||||
func GetEnvBool(key string, defaultValue bool) bool {
|
||||
@@ -39,7 +46,7 @@ func GetEnvBool(key string, defaultValue bool) bool {
|
||||
}
|
||||
b, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid boolean value: %s", value)
|
||||
log.Fatal().Msgf("env %s: invalid boolean value: %s", key, value)
|
||||
}
|
||||
return b
|
||||
}
|
||||
@@ -56,7 +63,7 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL str
|
||||
addr = GetEnv(key, defaultValue)
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
logrus.Fatalf("Invalid address: %s", addr)
|
||||
log.Fatal().Msgf("env %s: invalid address: %s", key, addr)
|
||||
}
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
@@ -64,3 +71,15 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL str
|
||||
fullURL = fmt.Sprintf("%s://%s:%s", scheme, host, port)
|
||||
return
|
||||
}
|
||||
|
||||
func GetDurationEnv(key string, defaultValue time.Duration) time.Duration {
|
||||
value, ok := os.LookupEnv(key)
|
||||
if !ok || value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
d, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
log.Fatal().Msgf("env %s: invalid duration value: %s", key, value)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
@@ -1,236 +1,249 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/autocert"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/config/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
|
||||
PR "github.com/yusing/go-proxy/internal/proxy/provider"
|
||||
R "github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/notif"
|
||||
"github.com/yusing/go-proxy/internal/route"
|
||||
proxy "github.com/yusing/go-proxy/internal/route/provider"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
W "github.com/yusing/go-proxy/internal/watcher"
|
||||
"github.com/yusing/go-proxy/internal/watcher"
|
||||
"github.com/yusing/go-proxy/internal/watcher/events"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
value *types.Config
|
||||
proxyProviders F.Map[string, *PR.Provider]
|
||||
providers F.Map[string, *proxy.Provider]
|
||||
autocertProvider *autocert.Provider
|
||||
|
||||
l logrus.FieldLogger
|
||||
|
||||
watcher W.Watcher
|
||||
watcherCtx context.Context
|
||||
watcherCancel context.CancelFunc
|
||||
reloadReq chan struct{}
|
||||
task task.Task
|
||||
}
|
||||
|
||||
var instance *Config
|
||||
var (
|
||||
instance *Config
|
||||
cfgWatcher watcher.Watcher
|
||||
logger = logging.With().Str("module", "config").Logger()
|
||||
reloadMu sync.Mutex
|
||||
)
|
||||
|
||||
const configEventFlushInterval = 500 * time.Millisecond
|
||||
|
||||
const (
|
||||
cfgRenameWarn = `Config file renamed, not reloading.
|
||||
Make sure you rename it back before next time you start.`
|
||||
cfgDeleteWarn = `Config file deleted, not reloading.
|
||||
You may run "ls-config" to show or dump the current config.`
|
||||
)
|
||||
|
||||
func GetInstance() *Config {
|
||||
return instance
|
||||
}
|
||||
|
||||
func Load() E.NestedError {
|
||||
if instance != nil {
|
||||
return nil
|
||||
func newConfig() *Config {
|
||||
return &Config{
|
||||
value: types.DefaultConfig(),
|
||||
providers: F.NewMapOf[string, *proxy.Provider](),
|
||||
task: task.GlobalTask("config"),
|
||||
}
|
||||
instance = &Config{
|
||||
value: types.DefaultConfig(),
|
||||
proxyProviders: F.NewMapOf[string, *PR.Provider](),
|
||||
l: logrus.WithField("module", "config"),
|
||||
watcher: W.NewConfigFileWatcher(common.ConfigFileName),
|
||||
reloadReq: make(chan struct{}, 1),
|
||||
}
|
||||
return instance.load()
|
||||
}
|
||||
|
||||
func Validate(data []byte) E.NestedError {
|
||||
func Load() (*Config, E.Error) {
|
||||
if instance != nil {
|
||||
return instance, nil
|
||||
}
|
||||
instance = newConfig()
|
||||
cfgWatcher = watcher.NewConfigFileWatcher(common.ConfigFileName)
|
||||
return instance, instance.load()
|
||||
}
|
||||
|
||||
func Validate(data []byte) E.Error {
|
||||
return U.ValidateYaml(U.GetSchema(common.ConfigSchemaPath), data)
|
||||
}
|
||||
|
||||
func MatchDomains() []string {
|
||||
if instance == nil {
|
||||
logrus.Panic("config has not been loaded, please check if there is any errors")
|
||||
}
|
||||
return instance.value.MatchDomains
|
||||
}
|
||||
|
||||
func (cfg *Config) Value() types.Config {
|
||||
if cfg == nil {
|
||||
logrus.Panic("config has not been loaded, please check if there is any errors")
|
||||
}
|
||||
return *cfg.value
|
||||
func WatchChanges() {
|
||||
task := task.GlobalTask("Config watcher")
|
||||
eventQueue := events.NewEventQueue(
|
||||
task,
|
||||
configEventFlushInterval,
|
||||
OnConfigChange,
|
||||
func(err E.Error) {
|
||||
E.LogError("config reload error", err, &logger)
|
||||
},
|
||||
)
|
||||
eventQueue.Start(cfgWatcher.Events(task.Context()))
|
||||
}
|
||||
|
||||
func (cfg *Config) GetAutoCertProvider() *autocert.Provider {
|
||||
if instance == nil {
|
||||
logrus.Panic("config has not been loaded, please check if there is any errors")
|
||||
func OnConfigChange(flushTask task.Task, ev []events.Event) {
|
||||
defer flushTask.Finish("config reload complete")
|
||||
|
||||
// no matter how many events during the interval
|
||||
// just reload once and check the last event
|
||||
switch ev[len(ev)-1].Action {
|
||||
case events.ActionFileRenamed:
|
||||
logger.Warn().Msg(cfgRenameWarn)
|
||||
return
|
||||
case events.ActionFileDeleted:
|
||||
logger.Warn().Msg(cfgDeleteWarn)
|
||||
return
|
||||
}
|
||||
|
||||
if err := Reload(); err != nil {
|
||||
// recovered in event queue
|
||||
panic(err)
|
||||
}
|
||||
return cfg.autocertProvider
|
||||
}
|
||||
|
||||
func (cfg *Config) Dispose() {
|
||||
if cfg.watcherCancel != nil {
|
||||
cfg.watcherCancel()
|
||||
cfg.l.Debug("stopped watcher")
|
||||
func Reload() E.Error {
|
||||
// avoid race between config change and API reload request
|
||||
reloadMu.Lock()
|
||||
defer reloadMu.Unlock()
|
||||
|
||||
newCfg := newConfig()
|
||||
err := newCfg.load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.stopProviders()
|
||||
|
||||
// cancel all current subtasks -> wait
|
||||
// -> replace config -> start new subtasks
|
||||
instance.task.Finish("config changed")
|
||||
instance.task.Wait()
|
||||
*instance = *newCfg
|
||||
instance.StartProxyProviders()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *Config) Reload() (err E.NestedError) {
|
||||
cfg.stopProviders()
|
||||
err = cfg.load()
|
||||
cfg.StartProxyProviders()
|
||||
return
|
||||
func Value() types.Config {
|
||||
return *instance.value
|
||||
}
|
||||
|
||||
func GetAutoCertProvider() *autocert.Provider {
|
||||
return instance.autocertProvider
|
||||
}
|
||||
|
||||
func (cfg *Config) Task() task.Task {
|
||||
return cfg.task
|
||||
}
|
||||
|
||||
func (cfg *Config) StartProxyProviders() {
|
||||
cfg.controlProviders("start", (*PR.Provider).StartAllRoutes)
|
||||
}
|
||||
|
||||
func (cfg *Config) WatchChanges() {
|
||||
cfg.watcherCtx, cfg.watcherCancel = context.WithCancel(context.Background())
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-cfg.watcherCtx.Done():
|
||||
return
|
||||
case <-cfg.reloadReq:
|
||||
if err := cfg.Reload(); err.HasError() {
|
||||
cfg.l.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
eventCh, errCh := cfg.watcher.Events(cfg.watcherCtx)
|
||||
for {
|
||||
select {
|
||||
case <-cfg.watcherCtx.Done():
|
||||
return
|
||||
case event := <-eventCh:
|
||||
if event.Action == events.ActionFileDeleted || event.Action == events.ActionFileRenamed {
|
||||
cfg.l.Error("config file deleted or renamed, ignoring...")
|
||||
continue
|
||||
} else {
|
||||
cfg.reloadReq <- struct{}{}
|
||||
}
|
||||
case err := <-errCh:
|
||||
cfg.l.Error(err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (cfg *Config) forEachRoute(do func(alias string, r R.Route, p *PR.Provider)) {
|
||||
cfg.proxyProviders.RangeAll(func(_ string, p *PR.Provider) {
|
||||
p.RangeRoutes(func(a string, r R.Route) {
|
||||
do(a, r, p)
|
||||
errs := cfg.providers.CollectErrorsParallel(
|
||||
func(_ string, p *proxy.Provider) error {
|
||||
subtask := cfg.task.Subtask(p.String())
|
||||
return p.Start(subtask)
|
||||
})
|
||||
})
|
||||
|
||||
if err := E.Join(errs...); err != nil {
|
||||
E.LogError("route provider errors", err, &logger)
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *Config) load() (res E.NestedError) {
|
||||
b := E.NewBuilder("errors loading config")
|
||||
defer b.To(&res)
|
||||
func (cfg *Config) load() E.Error {
|
||||
const errMsg = "config load error"
|
||||
|
||||
cfg.l.Debug("loading config")
|
||||
defer cfg.l.Debug("loaded config")
|
||||
|
||||
data, err := E.Check(os.ReadFile(common.ConfigPath))
|
||||
if err.HasError() {
|
||||
b.Add(E.FailWith("read config", err))
|
||||
logrus.Fatal(b.Build())
|
||||
data, err := os.ReadFile(common.ConfigPath)
|
||||
if err != nil {
|
||||
E.LogFatal(errMsg, err, &logger)
|
||||
}
|
||||
|
||||
if !common.NoSchemaValidation {
|
||||
if err = Validate(data); err.HasError() {
|
||||
b.Add(E.FailWith("schema validation", err))
|
||||
logrus.Fatal(b.Build())
|
||||
if err := Validate(data); err != nil {
|
||||
E.LogFatal(errMsg, err, &logger)
|
||||
}
|
||||
}
|
||||
|
||||
model := types.DefaultConfig()
|
||||
if err := E.From(yaml.Unmarshal(data, model)); err.HasError() {
|
||||
b.Add(E.FailWith("parse config", err))
|
||||
logrus.Fatal(b.Build())
|
||||
if err := E.From(yaml.Unmarshal(data, model)); err != nil {
|
||||
E.LogFatal(errMsg, err, &logger)
|
||||
}
|
||||
|
||||
// errors are non fatal below
|
||||
b.Add(cfg.initAutoCert(&model.AutoCert))
|
||||
b.Add(cfg.loadProviders(&model.Providers))
|
||||
errs := E.NewBuilder(errMsg)
|
||||
errs.Add(cfg.initNotification(model.Providers.Notification))
|
||||
errs.Add(cfg.initAutoCert(&model.AutoCert))
|
||||
errs.Add(cfg.loadRouteProviders(&model.Providers))
|
||||
|
||||
cfg.value = model
|
||||
R.SetFindMuxDomains(model.MatchDomains)
|
||||
return
|
||||
for i, domain := range model.MatchDomains {
|
||||
if !strings.HasPrefix(domain, ".") {
|
||||
model.MatchDomains[i] = "." + domain
|
||||
}
|
||||
}
|
||||
route.SetFindMuxDomains(model.MatchDomains)
|
||||
return errs.Error()
|
||||
}
|
||||
|
||||
func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.NestedError) {
|
||||
func (cfg *Config) initNotification(notifCfgMap types.NotificationConfigMap) (err E.Error) {
|
||||
if len(notifCfgMap) == 0 {
|
||||
return
|
||||
}
|
||||
errs := E.NewBuilder("notification providers load errors")
|
||||
for name, notifCfg := range notifCfgMap {
|
||||
_, err := notif.RegisterProvider(cfg.task.Subtask(name), notifCfg)
|
||||
errs.Add(err)
|
||||
}
|
||||
return errs.Error()
|
||||
}
|
||||
|
||||
func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) {
|
||||
if cfg.autocertProvider != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cfg.l.Debug("initializing autocert")
|
||||
defer cfg.l.Debug("initialized autocert")
|
||||
|
||||
cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider()
|
||||
if err.HasError() {
|
||||
err = E.FailWith("autocert provider", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (cfg *Config) loadProviders(providers *types.ProxyProviders) (res E.NestedError) {
|
||||
cfg.l.Debug("loading providers")
|
||||
defer cfg.l.Debug("loaded providers")
|
||||
func (cfg *Config) loadRouteProviders(providers *types.Providers) E.Error {
|
||||
subtask := cfg.task.Subtask("load route providers")
|
||||
defer subtask.Finish("done")
|
||||
|
||||
b := E.NewBuilder("errors loading providers")
|
||||
defer b.To(&res)
|
||||
errs := E.NewBuilder("route provider errors")
|
||||
results := E.NewBuilder("loaded route providers")
|
||||
|
||||
lenLongestName := 0
|
||||
for _, filename := range providers.Files {
|
||||
p, err := PR.NewFileProvider(filename)
|
||||
p, err := proxy.NewFileProvider(filename)
|
||||
if err != nil {
|
||||
b.Add(err.Subject(filename))
|
||||
errs.Add(E.PrependSubject(filename, err))
|
||||
continue
|
||||
}
|
||||
cfg.proxyProviders.Store(p.GetName(), p)
|
||||
b.Add(p.LoadRoutes().Subject(filename))
|
||||
cfg.providers.Store(p.GetName(), p)
|
||||
if len(p.GetName()) > lenLongestName {
|
||||
lenLongestName = len(p.GetName())
|
||||
}
|
||||
}
|
||||
for name, dockerHost := range providers.Docker {
|
||||
p, err := PR.NewDockerProvider(name, dockerHost)
|
||||
p, err := proxy.NewDockerProvider(name, dockerHost)
|
||||
if err != nil {
|
||||
b.Add(err.Subjectf("%s (%s)", name, dockerHost))
|
||||
errs.Add(E.PrependSubject(name, err))
|
||||
continue
|
||||
}
|
||||
cfg.proxyProviders.Store(p.GetName(), p)
|
||||
b.Add(p.LoadRoutes().Subject(dockerHost))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) {
|
||||
errors := E.NewBuilder("errors in %s these providers", action)
|
||||
|
||||
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) {
|
||||
if err := do(p); err.HasError() {
|
||||
errors.Add(err.Subject(p))
|
||||
cfg.providers.Store(p.GetName(), p)
|
||||
if len(p.GetName()) > lenLongestName {
|
||||
lenLongestName = len(p.GetName())
|
||||
}
|
||||
})
|
||||
|
||||
if err := errors.Build(); err.HasError() {
|
||||
cfg.l.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *Config) stopProviders() {
|
||||
cfg.controlProviders("stop routes", (*PR.Provider).StopAllRoutes)
|
||||
cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) {
|
||||
if err := p.LoadRoutes(); err != nil {
|
||||
errs.Add(err.Subject(p.String()))
|
||||
}
|
||||
results.Addf("%-"+strconv.Itoa(lenLongestName)+"s %d routes", p.GetName(), p.NumRoutes())
|
||||
})
|
||||
logger.Info().Msg(results.String())
|
||||
return errs.Error()
|
||||
}
|
||||
|
||||
@@ -5,34 +5,36 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
H "github.com/yusing/go-proxy/internal/homepage"
|
||||
PR "github.com/yusing/go-proxy/internal/proxy/provider"
|
||||
R "github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/homepage"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
"github.com/yusing/go-proxy/internal/route"
|
||||
proxy "github.com/yusing/go-proxy/internal/route/provider"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
func (cfg *Config) DumpEntries() map[string]*types.RawEntry {
|
||||
entries := make(map[string]*types.RawEntry)
|
||||
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
|
||||
entries[alias] = r.Entry()
|
||||
func DumpEntries() map[string]*entry.RawEntry {
|
||||
entries := make(map[string]*entry.RawEntry)
|
||||
instance.providers.RangeAll(func(_ string, p *proxy.Provider) {
|
||||
p.RangeRoutes(func(alias string, r *route.Route) {
|
||||
entries[alias] = r.Entry
|
||||
})
|
||||
})
|
||||
return entries
|
||||
}
|
||||
|
||||
func (cfg *Config) DumpProviders() map[string]*PR.Provider {
|
||||
entries := make(map[string]*PR.Provider)
|
||||
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) {
|
||||
func DumpProviders() map[string]*proxy.Provider {
|
||||
entries := make(map[string]*proxy.Provider)
|
||||
instance.providers.RangeAll(func(name string, p *proxy.Provider) {
|
||||
entries[name] = p
|
||||
})
|
||||
return entries
|
||||
}
|
||||
|
||||
func (cfg *Config) HomepageConfig() H.HomePageConfig {
|
||||
func HomepageConfig() homepage.Config {
|
||||
var proto, port string
|
||||
domains := cfg.value.MatchDomains
|
||||
cert, _ := cfg.autocertProvider.GetCert(nil)
|
||||
domains := instance.value.MatchDomains
|
||||
cert, _ := instance.autocertProvider.GetCert(nil)
|
||||
if cert != nil {
|
||||
proto = "https"
|
||||
port = common.ProxyHTTPSPort
|
||||
@@ -41,31 +43,25 @@ func (cfg *Config) HomepageConfig() H.HomePageConfig {
|
||||
port = common.ProxyHTTPPort
|
||||
}
|
||||
|
||||
hpCfg := H.NewHomePageConfig()
|
||||
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
|
||||
if !r.Started() {
|
||||
return
|
||||
}
|
||||
|
||||
entry := r.Entry()
|
||||
if entry.Homepage == nil {
|
||||
entry.Homepage = &H.HomePageItem{
|
||||
Show: r.Entry().IsExplicit || !p.IsExplicitOnly(),
|
||||
}
|
||||
}
|
||||
|
||||
item := entry.Homepage
|
||||
|
||||
if !item.Show && !item.IsEmpty() {
|
||||
hpCfg := homepage.NewHomePageConfig()
|
||||
route.GetReverseProxies().RangeAll(func(alias string, r *route.HTTPRoute) {
|
||||
en := r.Raw
|
||||
item := en.Homepage
|
||||
if item == nil {
|
||||
item = new(homepage.Item)
|
||||
item.Show = true
|
||||
}
|
||||
|
||||
if !item.Show || r.Type() != R.RouteTypeReverseProxy {
|
||||
if !item.IsEmpty() {
|
||||
item.Show = true
|
||||
}
|
||||
|
||||
if !item.Show {
|
||||
return
|
||||
}
|
||||
|
||||
if item.Name == "" {
|
||||
item.Name = U.Title(
|
||||
item.Name = strutils.Title(
|
||||
strings.ReplaceAll(
|
||||
strings.ReplaceAll(alias, "-", " "),
|
||||
"_", " ",
|
||||
@@ -73,16 +69,22 @@ func (cfg *Config) HomepageConfig() H.HomePageConfig {
|
||||
)
|
||||
}
|
||||
|
||||
if p.GetType() == PR.ProviderTypeDocker {
|
||||
switch {
|
||||
case entry.IsDocker(r):
|
||||
if item.Category == "" {
|
||||
item.Category = "Docker"
|
||||
}
|
||||
item.SourceType = string(PR.ProviderTypeDocker)
|
||||
} else if p.GetType() == PR.ProviderTypeFile {
|
||||
item.SourceType = string(proxy.ProviderTypeDocker)
|
||||
case entry.UseLoadBalance(r):
|
||||
if item.Category == "" {
|
||||
item.Category = "Load-balanced"
|
||||
}
|
||||
item.SourceType = "loadbalancer"
|
||||
default:
|
||||
if item.Category == "" {
|
||||
item.Category = "Others"
|
||||
}
|
||||
item.SourceType = string(PR.ProviderTypeFile)
|
||||
item.SourceType = string(proxy.ProviderTypeFile)
|
||||
}
|
||||
|
||||
if item.URL == "" {
|
||||
@@ -90,39 +92,39 @@ func (cfg *Config) HomepageConfig() H.HomePageConfig {
|
||||
item.URL = fmt.Sprintf("%s://%s.%s:%s", proto, strings.ToLower(alias), domains[0], port)
|
||||
}
|
||||
}
|
||||
item.AltURL = r.URL().String()
|
||||
item.AltURL = r.TargetURL().String()
|
||||
|
||||
hpCfg.Add(item)
|
||||
})
|
||||
return hpCfg
|
||||
}
|
||||
|
||||
func (cfg *Config) RoutesByAlias() map[string]U.SerializedObject {
|
||||
routes := make(map[string]U.SerializedObject)
|
||||
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
|
||||
if !r.Started() {
|
||||
return
|
||||
func RoutesByAlias(typeFilter ...route.RouteType) map[string]any {
|
||||
routes := make(map[string]any)
|
||||
if len(typeFilter) == 0 || typeFilter[0] == "" {
|
||||
typeFilter = []route.RouteType{route.RouteTypeReverseProxy, route.RouteTypeStream}
|
||||
}
|
||||
for _, t := range typeFilter {
|
||||
switch t {
|
||||
case route.RouteTypeReverseProxy:
|
||||
route.GetReverseProxies().RangeAll(func(alias string, r *route.HTTPRoute) {
|
||||
routes[alias] = r
|
||||
})
|
||||
case route.RouteTypeStream:
|
||||
route.GetStreamProxies().RangeAll(func(alias string, r *route.StreamRoute) {
|
||||
routes[alias] = r
|
||||
})
|
||||
}
|
||||
obj, err := U.Serialize(r)
|
||||
if err.HasError() {
|
||||
cfg.l.Error(err)
|
||||
return
|
||||
}
|
||||
obj["provider"] = p.GetName()
|
||||
obj["type"] = string(r.Type())
|
||||
obj["started"] = r.Started()
|
||||
obj["raw"] = r.Entry()
|
||||
routes[alias] = obj
|
||||
})
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
func (cfg *Config) Statistics() map[string]any {
|
||||
func Statistics() map[string]any {
|
||||
nTotalStreams := 0
|
||||
nTotalRPs := 0
|
||||
providerStats := make(map[string]PR.ProviderStats)
|
||||
providerStats := make(map[string]proxy.ProviderStats)
|
||||
|
||||
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) {
|
||||
instance.providers.RangeAll(func(name string, p *proxy.Provider) {
|
||||
providerStats[name] = p.Statistics()
|
||||
})
|
||||
|
||||
@@ -138,9 +140,9 @@ func (cfg *Config) Statistics() map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *Config) FindRoute(alias string) R.Route {
|
||||
return F.MapFind(cfg.proxyProviders,
|
||||
func(p *PR.Provider) (R.Route, bool) {
|
||||
func FindRoute(alias string) *route.Route {
|
||||
return F.MapFind(instance.providers,
|
||||
func(p *proxy.Provider) (*route.Route, bool) {
|
||||
if route, ok := p.GetRoute(alias); ok {
|
||||
return route, true
|
||||
}
|
||||
|
||||
13
internal/config/types/autocert_config.go
Normal file
13
internal/config/types/autocert_config.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package types
|
||||
|
||||
type (
|
||||
AutoCertConfig struct {
|
||||
Email string `json:"email,omitempty" yaml:"email"`
|
||||
Domains []string `json:"domains,omitempty" yaml:",flow"`
|
||||
CertPath string `json:"cert_path,omitempty" yaml:"cert_path"`
|
||||
KeyPath string `json:"key_path,omitempty" yaml:"key_path"`
|
||||
Provider string `json:"provider,omitempty" yaml:"provider"`
|
||||
Options AutocertProviderOpt `json:"options,omitempty" yaml:",flow"`
|
||||
}
|
||||
AutocertProviderOpt map[string]any
|
||||
)
|
||||
25
internal/config/types/config.go
Normal file
25
internal/config/types/config.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package types
|
||||
|
||||
type (
|
||||
Config struct {
|
||||
Providers Providers `json:"providers" yaml:",flow"`
|
||||
AutoCert AutoCertConfig `json:"autocert" yaml:",flow"`
|
||||
ExplicitOnly bool `json:"explicit_only" yaml:"explicit_only"`
|
||||
MatchDomains []string `json:"match_domains" yaml:"match_domains"`
|
||||
TimeoutShutdown int `json:"timeout_shutdown" yaml:"timeout_shutdown"`
|
||||
RedirectToHTTPS bool `json:"redirect_to_https" yaml:"redirect_to_https"`
|
||||
}
|
||||
Providers struct {
|
||||
Files []string `json:"include" yaml:"include"`
|
||||
Docker map[string]string `json:"docker" yaml:"docker"`
|
||||
Notification NotificationConfigMap `json:"notification" yaml:"notification"`
|
||||
}
|
||||
)
|
||||
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Providers: Providers{},
|
||||
TimeoutShutdown: 3,
|
||||
RedirectToHTTPS: false,
|
||||
}
|
||||
}
|
||||
5
internal/config/types/notif_config.go
Normal file
5
internal/config/types/notif_config.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package types
|
||||
|
||||
import "github.com/yusing/go-proxy/internal/notif"
|
||||
|
||||
type NotificationConfigMap map[string]notif.ProviderConfig
|
||||
@@ -1,65 +1,61 @@
|
||||
package docker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/docker/cli/cli/connhelper"
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
key string
|
||||
refCount *atomic.Int32
|
||||
*client.Client
|
||||
type (
|
||||
Client = *SharedClient
|
||||
SharedClient struct {
|
||||
*client.Client
|
||||
|
||||
l logrus.FieldLogger
|
||||
key string
|
||||
refCount *U.RefCount
|
||||
|
||||
l zerolog.Logger
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
clientMap F.Map[string, Client] = F.NewMapOf[string, Client]()
|
||||
clientMapMu sync.Mutex
|
||||
|
||||
clientOptEnvHost = []client.Opt{
|
||||
client.WithHostFromEnv(),
|
||||
client.WithAPIVersionNegotiation(),
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
task.GlobalTask("close docker clients").OnFinished("", func() {
|
||||
clientMap.RangeAllParallel(func(_ string, c Client) {
|
||||
if c.Connected() {
|
||||
c.Client.Close()
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func ParseDockerHostname(host string) (string, E.NestedError) {
|
||||
switch host {
|
||||
case common.DockerHostFromEnv, "":
|
||||
return "localhost", nil
|
||||
}
|
||||
url, err := E.Check(client.ParseHostURL(host))
|
||||
if err != nil {
|
||||
return "", E.Invalid("host", host).With(err)
|
||||
}
|
||||
return url.Hostname(), nil
|
||||
func (c *SharedClient) Connected() bool {
|
||||
return c != nil && c.Client != nil
|
||||
}
|
||||
|
||||
func (c Client) DaemonHostname() string {
|
||||
// DaemonHost should always return a valid host
|
||||
hostname, _ := ParseDockerHostname(c.DaemonHost())
|
||||
return hostname
|
||||
}
|
||||
|
||||
func (c Client) Connected() bool {
|
||||
return c.Client != nil
|
||||
}
|
||||
|
||||
// if the client is still referenced, this is no-op
|
||||
func (c *Client) Close() error {
|
||||
if c.refCount.Add(-1) > 0 {
|
||||
return nil
|
||||
// if the client is still referenced, this is no-op.
|
||||
func (c *SharedClient) Close() {
|
||||
if c.Connected() {
|
||||
c.refCount.Sub()
|
||||
}
|
||||
|
||||
clientMap.Delete(c.key)
|
||||
|
||||
client := c.Client
|
||||
c.Client = nil
|
||||
|
||||
c.l.Debugf("client closed")
|
||||
|
||||
if client != nil {
|
||||
return client.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnectClient creates a new Docker client connection to the specified host.
|
||||
@@ -72,13 +68,13 @@ func (c *Client) Close() error {
|
||||
// Returns:
|
||||
// - Client: the Docker client connection.
|
||||
// - error: an error if the connection failed.
|
||||
func ConnectClient(host string) (Client, E.NestedError) {
|
||||
func ConnectClient(host string) (Client, error) {
|
||||
clientMapMu.Lock()
|
||||
defer clientMapMu.Unlock()
|
||||
|
||||
// check if client exists
|
||||
if client, ok := clientMap.Load(host); ok {
|
||||
client.refCount.Add(1)
|
||||
client.refCount.Add()
|
||||
return client, nil
|
||||
}
|
||||
|
||||
@@ -86,12 +82,14 @@ func ConnectClient(host string) (Client, E.NestedError) {
|
||||
var opt []client.Opt
|
||||
|
||||
switch host {
|
||||
case "":
|
||||
return nil, errors.New("empty docker host")
|
||||
case common.DockerHostFromEnv:
|
||||
opt = clientOptEnvHost
|
||||
default:
|
||||
helper, err := E.Check(connhelper.GetConnectionHelper(host))
|
||||
if err.HasError() {
|
||||
return Client{}, E.UnexpectedError(err.Error())
|
||||
helper, err := connhelper.GetConnectionHelper(host)
|
||||
if err != nil {
|
||||
logging.Panic().Err(err).Msg("failed to get connection helper")
|
||||
}
|
||||
if helper != nil {
|
||||
httpClient := &http.Client{
|
||||
@@ -113,41 +111,29 @@ func ConnectClient(host string) (Client, E.NestedError) {
|
||||
}
|
||||
}
|
||||
|
||||
client, err := E.Check(client.NewClientWithOpts(opt...))
|
||||
|
||||
if err.HasError() {
|
||||
return Client{}, err
|
||||
client, err := client.NewClientWithOpts(opt...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := Client{
|
||||
c := &SharedClient{
|
||||
Client: client,
|
||||
key: host,
|
||||
refCount: &atomic.Int32{},
|
||||
l: logger.WithField("docker_client", client.DaemonHost()),
|
||||
refCount: U.NewRefCounter(),
|
||||
l: logger.With().Str("address", client.DaemonHost()).Logger(),
|
||||
}
|
||||
c.refCount.Add(1)
|
||||
c.l.Debugf("client connected")
|
||||
c.l.Trace().Msg("client connected")
|
||||
|
||||
clientMap.Store(host, c)
|
||||
|
||||
go func() {
|
||||
<-c.refCount.Zero()
|
||||
clientMap.Delete(c.key)
|
||||
|
||||
if c.Connected() {
|
||||
c.Client.Close()
|
||||
c.l.Trace().Msg("client closed")
|
||||
}
|
||||
}()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func CloseAllClients() {
|
||||
clientMap.RangeAll(func(_ string, c Client) {
|
||||
c.Client.Close()
|
||||
})
|
||||
clientMap.Clear()
|
||||
logger.Debug("closed all clients")
|
||||
}
|
||||
|
||||
var (
|
||||
clientMap F.Map[string, Client] = F.NewMapOf[string, Client]()
|
||||
clientMapMu sync.Mutex
|
||||
|
||||
clientOptEnvHost = []client.Opt{
|
||||
client.WithHostFromEnv(),
|
||||
client.WithAPIVersionNegotiation(),
|
||||
}
|
||||
|
||||
logger = logrus.WithField("module", "docker")
|
||||
)
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
package docker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/client"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type ClientInfo struct {
|
||||
Client Client
|
||||
Containers []types.Container
|
||||
}
|
||||
|
||||
var listOptions = container.ListOptions{
|
||||
// Filters: filters.NewArgs(
|
||||
// filters.Arg("health", "healthy"),
|
||||
// filters.Arg("health", "none"),
|
||||
// filters.Arg("health", "starting"),
|
||||
// ),
|
||||
All: true,
|
||||
}
|
||||
|
||||
func GetClientInfo(clientHost string, getContainer bool) (*ClientInfo, E.NestedError) {
|
||||
dockerClient, err := ConnectClient(clientHost)
|
||||
if err.HasError() {
|
||||
return nil, E.FailWith("connect to docker", err)
|
||||
}
|
||||
defer dockerClient.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var containers []types.Container
|
||||
if getContainer {
|
||||
containers, err = E.Check(dockerClient.ContainerList(ctx, listOptions))
|
||||
if err.HasError() {
|
||||
return nil, E.FailWith("list containers", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &ClientInfo{
|
||||
Client: dockerClient,
|
||||
Containers: containers,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func IsErrConnectionFailed(err error) bool {
|
||||
return client.IsErrConnectionFailed(err)
|
||||
}
|
||||
@@ -1,143 +1,144 @@
|
||||
package docker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type Container struct {
|
||||
*types.Container
|
||||
*ProxyProperties
|
||||
}
|
||||
type (
|
||||
PortMapping = map[string]types.Port
|
||||
Container struct {
|
||||
_ U.NoCopy
|
||||
|
||||
func FromDocker(c *types.Container, dockerHost string) (res Container) {
|
||||
res.Container = c
|
||||
isExplicit := c.Labels[LabelAliases] != ""
|
||||
res.ProxyProperties = &ProxyProperties{
|
||||
DockerHost: dockerHost,
|
||||
ContainerName: res.getName(),
|
||||
ContainerID: c.ID,
|
||||
ImageName: res.getImageName(),
|
||||
PublicPortMapping: res.getPublicPortMapping(),
|
||||
PrivatePortMapping: res.getPrivatePortMapping(),
|
||||
NetworkMode: c.HostConfig.NetworkMode,
|
||||
Aliases: res.getAliases(),
|
||||
IsExcluded: U.ParseBool(res.getDeleteLabel(LabelExclude)),
|
||||
IsExplicit: isExplicit,
|
||||
IsDatabase: res.isDatabase(),
|
||||
IdleTimeout: res.getDeleteLabel(LabelIdleTimeout),
|
||||
WakeTimeout: res.getDeleteLabel(LabelWakeTimeout),
|
||||
StopMethod: res.getDeleteLabel(LabelStopMethod),
|
||||
StopTimeout: res.getDeleteLabel(LabelStopTimeout),
|
||||
StopSignal: res.getDeleteLabel(LabelStopSignal),
|
||||
Running: c.Status == "running" || c.State == "running",
|
||||
DockerHost string `json:"docker_host" yaml:"-"`
|
||||
ContainerName string `json:"container_name" yaml:"-"`
|
||||
ContainerID string `json:"container_id" yaml:"-"`
|
||||
ImageName string `json:"image_name" yaml:"-"`
|
||||
|
||||
Labels map[string]string `json:"labels" yaml:"-"`
|
||||
|
||||
PublicPortMapping PortMapping `json:"public_ports" yaml:"-"` // non-zero publicPort:types.Port
|
||||
PrivatePortMapping PortMapping `json:"private_ports" yaml:"-"` // privatePort:types.Port
|
||||
PublicIP string `json:"public_ip" yaml:"-"`
|
||||
PrivateIP string `json:"private_ip" yaml:"-"`
|
||||
NetworkMode string `json:"network_mode" yaml:"-"`
|
||||
|
||||
Aliases []string `json:"aliases" yaml:"-"`
|
||||
IsExcluded bool `json:"is_excluded" yaml:"-"`
|
||||
IsExplicit bool `json:"is_explicit" yaml:"-"`
|
||||
IsDatabase bool `json:"is_database" yaml:"-"`
|
||||
IdleTimeout string `json:"idle_timeout,omitempty" yaml:"-"`
|
||||
WakeTimeout string `json:"wake_timeout,omitempty" yaml:"-"`
|
||||
StopMethod string `json:"stop_method,omitempty" yaml:"-"`
|
||||
StopTimeout string `json:"stop_timeout,omitempty" yaml:"-"` // stop_method = "stop" only
|
||||
StopSignal string `json:"stop_signal,omitempty" yaml:"-"` // stop_method = "stop" | "kill" only
|
||||
Running bool `json:"running" yaml:"-"`
|
||||
}
|
||||
)
|
||||
|
||||
var DummyContainer = new(Container)
|
||||
|
||||
func FromDocker(c *types.Container, dockerHost string) (res *Container) {
|
||||
isExplicit := c.Labels[LabelAliases] != ""
|
||||
helper := containerHelper{c}
|
||||
res = &Container{
|
||||
DockerHost: dockerHost,
|
||||
ContainerName: helper.getName(),
|
||||
ContainerID: c.ID,
|
||||
ImageName: helper.getImageName(),
|
||||
|
||||
Labels: c.Labels,
|
||||
|
||||
PublicPortMapping: helper.getPublicPortMapping(),
|
||||
PrivatePortMapping: helper.getPrivatePortMapping(),
|
||||
NetworkMode: c.HostConfig.NetworkMode,
|
||||
|
||||
Aliases: helper.getAliases(),
|
||||
IsExcluded: strutils.ParseBool(helper.getDeleteLabel(LabelExclude)),
|
||||
IsExplicit: isExplicit,
|
||||
IsDatabase: helper.isDatabase(),
|
||||
IdleTimeout: helper.getDeleteLabel(LabelIdleTimeout),
|
||||
WakeTimeout: helper.getDeleteLabel(LabelWakeTimeout),
|
||||
StopMethod: helper.getDeleteLabel(LabelStopMethod),
|
||||
StopTimeout: helper.getDeleteLabel(LabelStopTimeout),
|
||||
StopSignal: helper.getDeleteLabel(LabelStopSignal),
|
||||
Running: c.Status == "running" || c.State == "running",
|
||||
}
|
||||
res.setPrivateIP(helper)
|
||||
res.setPublicIP()
|
||||
return
|
||||
}
|
||||
|
||||
func FromJson(json types.ContainerJSON, dockerHost string) Container {
|
||||
func FromJSON(json types.ContainerJSON, dockerHost string) *Container {
|
||||
ports := make([]types.Port, 0)
|
||||
for k, bindings := range json.NetworkSettings.Ports {
|
||||
privPortStr, proto := k.Port(), k.Proto()
|
||||
privPort, _ := strconv.ParseUint(privPortStr, 10, 16)
|
||||
ports = append(ports, types.Port{
|
||||
PrivatePort: uint16(privPort),
|
||||
Type: proto,
|
||||
})
|
||||
for _, v := range bindings {
|
||||
pubPort, _ := strconv.ParseUint(v.HostPort, 10, 16)
|
||||
privPort, _ := strconv.ParseUint(k.Port(), 10, 16)
|
||||
ports = append(ports, types.Port{
|
||||
IP: v.HostIP,
|
||||
PublicPort: uint16(pubPort),
|
||||
PrivatePort: uint16(privPort),
|
||||
Type: proto,
|
||||
})
|
||||
}
|
||||
}
|
||||
cont := FromDocker(&types.Container{
|
||||
ID: json.ID,
|
||||
Names: []string{json.Name},
|
||||
Names: []string{strings.TrimPrefix(json.Name, "/")},
|
||||
Image: json.Image,
|
||||
Ports: ports,
|
||||
Labels: json.Config.Labels,
|
||||
State: json.State.Status,
|
||||
Status: json.State.Status,
|
||||
Mounts: json.Mounts,
|
||||
NetworkSettings: &types.SummaryNetworkSettings{
|
||||
Networks: json.NetworkSettings.Networks,
|
||||
},
|
||||
}, dockerHost)
|
||||
cont.NetworkMode = string(json.HostConfig.NetworkMode)
|
||||
return cont
|
||||
}
|
||||
|
||||
func (c Container) getDeleteLabel(label string) string {
|
||||
if l, ok := c.Labels[label]; ok {
|
||||
delete(c.Labels, label)
|
||||
return l
|
||||
func (c *Container) setPublicIP() {
|
||||
if !c.Running {
|
||||
return
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c Container) getAliases() []string {
|
||||
if l := c.getDeleteLabel(LabelAliases); l != "" {
|
||||
return U.CommaSeperatedList(l)
|
||||
} else {
|
||||
return []string{c.getName()}
|
||||
if strings.HasPrefix(c.DockerHost, "unix://") {
|
||||
c.PublicIP = "127.0.0.1"
|
||||
return
|
||||
}
|
||||
url, err := url.Parse(c.DockerHost)
|
||||
if err != nil {
|
||||
logger.Err(err).Msgf("invalid docker host %q, falling back to 127.0.0.1", c.DockerHost)
|
||||
c.PublicIP = "127.0.0.1"
|
||||
return
|
||||
}
|
||||
c.PublicIP = url.Hostname()
|
||||
}
|
||||
|
||||
func (c Container) getName() string {
|
||||
return strings.TrimPrefix(c.Names[0], "/")
|
||||
}
|
||||
|
||||
func (c Container) getImageName() string {
|
||||
colonSep := strings.Split(c.Image, ":")
|
||||
slashSep := strings.Split(colonSep[0], "/")
|
||||
return slashSep[len(slashSep)-1]
|
||||
}
|
||||
|
||||
func (c Container) getPublicPortMapping() PortMapping {
|
||||
res := make(PortMapping)
|
||||
for _, v := range c.Ports {
|
||||
if v.PublicPort == 0 {
|
||||
func (c *Container) setPrivateIP(helper containerHelper) {
|
||||
if !strings.HasPrefix(c.DockerHost, "unix://") {
|
||||
return
|
||||
}
|
||||
if helper.NetworkSettings == nil {
|
||||
return
|
||||
}
|
||||
for _, v := range helper.NetworkSettings.Networks {
|
||||
if v.IPAddress == "" {
|
||||
continue
|
||||
}
|
||||
res[fmt.Sprint(v.PublicPort)] = v
|
||||
c.PrivateIP = v.IPAddress
|
||||
return
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (c Container) getPrivatePortMapping() PortMapping {
|
||||
res := make(PortMapping)
|
||||
for _, v := range c.Ports {
|
||||
res[fmt.Sprint(v.PrivatePort)] = v
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
var databaseMPs = map[string]struct{}{
|
||||
"/var/lib/postgresql/data": {},
|
||||
"/var/lib/mysql": {},
|
||||
"/var/lib/mongodb": {},
|
||||
"/var/lib/mariadb": {},
|
||||
"/var/lib/memcached": {},
|
||||
"/var/lib/rabbitmq": {},
|
||||
}
|
||||
|
||||
var databasePrivPorts = map[uint16]struct{}{
|
||||
5432: {}, // postgres
|
||||
3306: {}, // mysql, mariadb
|
||||
6379: {}, // redis
|
||||
11211: {}, // memcached
|
||||
27017: {}, // mongodb
|
||||
}
|
||||
|
||||
func (c Container) isDatabase() bool {
|
||||
for _, m := range c.Container.Mounts {
|
||||
if _, ok := databaseMPs[m.Destination]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range c.Ports {
|
||||
if _, ok := databasePrivPorts[v.PrivatePort]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
90
internal/docker/container_helper.go
Normal file
90
internal/docker/container_helper.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package docker
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type containerHelper struct {
|
||||
*types.Container
|
||||
}
|
||||
|
||||
// getDeleteLabel gets the value of a label and then deletes it from the container.
|
||||
// If the label does not exist, an empty string is returned.
|
||||
func (c containerHelper) getDeleteLabel(label string) string {
|
||||
if l, ok := c.Labels[label]; ok {
|
||||
delete(c.Labels, label)
|
||||
return l
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c containerHelper) getAliases() []string {
|
||||
if l := c.getDeleteLabel(LabelAliases); l != "" {
|
||||
return strutils.CommaSeperatedList(l)
|
||||
}
|
||||
return []string{c.getName()}
|
||||
}
|
||||
|
||||
func (c containerHelper) getName() string {
|
||||
return strings.TrimPrefix(c.Names[0], "/")
|
||||
}
|
||||
|
||||
func (c containerHelper) getImageName() string {
|
||||
colonSep := strings.Split(c.Image, ":")
|
||||
slashSep := strings.Split(colonSep[0], "/")
|
||||
return slashSep[len(slashSep)-1]
|
||||
}
|
||||
|
||||
func (c containerHelper) getPublicPortMapping() PortMapping {
|
||||
res := make(PortMapping)
|
||||
for _, v := range c.Ports {
|
||||
if v.PublicPort == 0 {
|
||||
continue
|
||||
}
|
||||
res[strutils.PortString(v.PublicPort)] = v
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (c containerHelper) getPrivatePortMapping() PortMapping {
|
||||
res := make(PortMapping)
|
||||
for _, v := range c.Ports {
|
||||
res[strutils.PortString(v.PrivatePort)] = v
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
var databaseMPs = map[string]struct{}{
|
||||
"/var/lib/postgresql/data": {},
|
||||
"/var/lib/mysql": {},
|
||||
"/var/lib/mongodb": {},
|
||||
"/var/lib/mariadb": {},
|
||||
"/var/lib/memcached": {},
|
||||
"/var/lib/rabbitmq": {},
|
||||
}
|
||||
|
||||
var databasePrivPorts = map[uint16]struct{}{
|
||||
5432: {}, // postgres
|
||||
3306: {}, // mysql, mariadb
|
||||
6379: {}, // redis
|
||||
11211: {}, // memcached
|
||||
27017: {}, // mongodb
|
||||
}
|
||||
|
||||
func (c containerHelper) isDatabase() bool {
|
||||
for _, m := range c.Mounts {
|
||||
if _, ok := databaseMPs[m.Destination]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range c.Ports {
|
||||
if _, ok := databasePrivPorts[v.PrivatePort]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -3,9 +3,10 @@ package idlewatcher
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
)
|
||||
|
||||
type templateData struct {
|
||||
@@ -18,18 +19,15 @@ type templateData struct {
|
||||
var loadingPage []byte
|
||||
var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(loadingPage)))
|
||||
|
||||
const headerCheckRedirect = "X-GoProxy-Check-Redirect"
|
||||
|
||||
func (w *watcher) makeRespBody(format string, args ...any) []byte {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
func (w *Watcher) makeLoadingPageBody() []byte {
|
||||
msg := w.ContainerName + " is starting..."
|
||||
|
||||
data := new(templateData)
|
||||
data.CheckRedirectHeader = headerCheckRedirect
|
||||
data.CheckRedirectHeader = common.HeaderCheckRedirect
|
||||
data.Title = w.ContainerName
|
||||
data.Message = strings.ReplaceAll(msg, "\n", "<br>")
|
||||
data.Message = strings.ReplaceAll(data.Message, " ", " ")
|
||||
data.Message = strings.ReplaceAll(msg, " ", " ")
|
||||
|
||||
buf := bytes.NewBuffer(make([]byte, 128)) // more than enough
|
||||
buf := bytes.NewBuffer(make([]byte, len(loadingPage)+len(data.Title)+len(data.Message)+len(common.HeaderCheckRedirect)))
|
||||
err := loadingPageTmpl.Execute(buf, data)
|
||||
if err != nil { // should never happen in production
|
||||
panic(err)
|
||||
103
internal/docker/idlewatcher/types/config.go
Normal file
103
internal/docker/idlewatcher/types/config.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/docker"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type (
|
||||
Config struct {
|
||||
IdleTimeout time.Duration `json:"idle_timeout,omitempty"`
|
||||
WakeTimeout time.Duration `json:"wake_timeout,omitempty"`
|
||||
StopTimeout int `json:"stop_timeout,omitempty"` // docker api takes integer seconds for timeout argument
|
||||
StopMethod StopMethod `json:"stop_method,omitempty"`
|
||||
StopSignal Signal `json:"stop_signal,omitempty"`
|
||||
|
||||
DockerHost string `json:"docker_host,omitempty"`
|
||||
ContainerName string `json:"container_name,omitempty"`
|
||||
ContainerID string `json:"container_id,omitempty"`
|
||||
ContainerRunning bool `json:"container_running,omitempty"`
|
||||
}
|
||||
StopMethod string
|
||||
Signal string
|
||||
)
|
||||
|
||||
const (
|
||||
StopMethodPause StopMethod = "pause"
|
||||
StopMethodStop StopMethod = "stop"
|
||||
StopMethodKill StopMethod = "kill"
|
||||
)
|
||||
|
||||
func ValidateConfig(cont *docker.Container) (*Config, E.Error) {
|
||||
if cont == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if cont.IdleTimeout == "" {
|
||||
return &Config{
|
||||
DockerHost: cont.DockerHost,
|
||||
ContainerName: cont.ContainerName,
|
||||
ContainerID: cont.ContainerID,
|
||||
ContainerRunning: cont.Running,
|
||||
}, nil
|
||||
}
|
||||
|
||||
errs := E.NewBuilder("invalid idlewatcher config")
|
||||
|
||||
idleTimeout := E.Collect(errs, validateDurationPostitive, cont.IdleTimeout)
|
||||
wakeTimeout := E.Collect(errs, validateDurationPostitive, cont.WakeTimeout)
|
||||
stopTimeout := E.Collect(errs, validateDurationPostitive, cont.StopTimeout)
|
||||
stopMethod := E.Collect(errs, validateStopMethod, cont.StopMethod)
|
||||
signal := E.Collect(errs, validateSignal, cont.StopSignal)
|
||||
|
||||
if errs.HasError() {
|
||||
return nil, errs.Error()
|
||||
}
|
||||
|
||||
return &Config{
|
||||
IdleTimeout: idleTimeout,
|
||||
WakeTimeout: wakeTimeout,
|
||||
StopTimeout: int(stopTimeout.Seconds()),
|
||||
StopMethod: stopMethod,
|
||||
StopSignal: signal,
|
||||
|
||||
DockerHost: cont.DockerHost,
|
||||
ContainerName: cont.ContainerName,
|
||||
ContainerID: cont.ContainerID,
|
||||
ContainerRunning: cont.Running,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func validateDurationPostitive(value string) (time.Duration, error) {
|
||||
d, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if d < 0 {
|
||||
return 0, errors.New("duration must be positive")
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func validateSignal(s string) (Signal, error) {
|
||||
switch s {
|
||||
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
|
||||
"INT", "TERM", "HUP", "QUIT":
|
||||
return Signal(s), nil
|
||||
}
|
||||
|
||||
return "", errors.New("invalid signal " + s)
|
||||
}
|
||||
|
||||
func validateStopMethod(s string) (StopMethod, error) {
|
||||
sm := StopMethod(s)
|
||||
switch sm {
|
||||
case StopMethodPause, StopMethodStop, StopMethodKill:
|
||||
return sm, nil
|
||||
default:
|
||||
return "", errors.New("invalid stop method " + s)
|
||||
}
|
||||
}
|
||||
14
internal/docker/idlewatcher/types/waker.go
Normal file
14
internal/docker/idlewatcher/types/waker.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
type Waker interface {
|
||||
health.HealthMonitor
|
||||
http.Handler
|
||||
net.Stream
|
||||
}
|
||||
@@ -1,101 +1,134 @@
|
||||
package idlewatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
type Waker struct {
|
||||
*watcher
|
||||
type waker struct {
|
||||
_ U.NoCopy
|
||||
|
||||
client *http.Client
|
||||
rp *gphttp.ReverseProxy
|
||||
stream net.Stream
|
||||
hc health.HealthChecker
|
||||
|
||||
ready atomic.Bool
|
||||
}
|
||||
|
||||
func NewWaker(w *watcher, rp *gphttp.ReverseProxy) *Waker {
|
||||
tr := &http.Transport{}
|
||||
if w.NoTLSVerify {
|
||||
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
const (
|
||||
idleWakerCheckInterval = 100 * time.Millisecond
|
||||
idleWakerCheckTimeout = time.Second
|
||||
)
|
||||
|
||||
// TODO: support stream
|
||||
|
||||
func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) {
|
||||
hcCfg := entry.HealthCheckConfig()
|
||||
hcCfg.Timeout = idleWakerCheckTimeout
|
||||
|
||||
waker := &waker{
|
||||
rp: rp,
|
||||
stream: stream,
|
||||
}
|
||||
return &Waker{
|
||||
watcher: w,
|
||||
client: &http.Client{
|
||||
Timeout: 1 * time.Second,
|
||||
Transport: tr,
|
||||
},
|
||||
rp: rp,
|
||||
|
||||
watcher, err := registerWatcher(providerSubTask, entry, waker)
|
||||
if err != nil {
|
||||
return nil, E.Errorf("register watcher: %w", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case rp != nil:
|
||||
waker.hc = health.NewHTTPHealthChecker(entry.TargetURL(), hcCfg, rp.Transport)
|
||||
case stream != nil:
|
||||
waker.hc = health.NewRawHealthChecker(entry.TargetURL(), hcCfg)
|
||||
default:
|
||||
panic("both nil")
|
||||
}
|
||||
return watcher, nil
|
||||
}
|
||||
|
||||
// lifetime should follow route provider.
|
||||
func NewHTTPWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) {
|
||||
return newWaker(providerSubTask, entry, rp, nil)
|
||||
}
|
||||
|
||||
func NewStreamWaker(providerSubTask task.Task, entry entry.Entry, stream net.Stream) (Waker, E.Error) {
|
||||
return newWaker(providerSubTask, entry, nil, stream)
|
||||
}
|
||||
|
||||
// Start implements health.HealthMonitor.
|
||||
func (w *Watcher) Start(routeSubTask task.Task) E.Error {
|
||||
routeSubTask.Finish("ignored")
|
||||
w.task.OnCancel("stop route", func() {
|
||||
routeSubTask.Parent().Finish(w.task.FinishCause())
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Finish implements health.HealthMonitor.
|
||||
func (w *Watcher) Finish(reason any) {
|
||||
if w.stream != nil {
|
||||
w.stream.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Waker) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
w.wake(w.rp.ServeHTTP, rw, r)
|
||||
// Name implements health.HealthMonitor.
|
||||
func (w *Watcher) Name() string {
|
||||
return w.String()
|
||||
}
|
||||
|
||||
func (w *Waker) wake(next http.HandlerFunc, rw http.ResponseWriter, r *http.Request) {
|
||||
// pass through if container is ready
|
||||
// String implements health.HealthMonitor.
|
||||
func (w *Watcher) String() string {
|
||||
return w.ContainerName
|
||||
}
|
||||
|
||||
// Uptime implements health.HealthMonitor.
|
||||
func (w *Watcher) Uptime() time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Status implements health.HealthMonitor.
|
||||
func (w *Watcher) Status() health.Status {
|
||||
if !w.ContainerRunning {
|
||||
return health.StatusNapping
|
||||
}
|
||||
|
||||
if w.ready.Load() {
|
||||
next(rw, r)
|
||||
return
|
||||
return health.StatusHealthy
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), w.WakeTimeout)
|
||||
defer cancel()
|
||||
|
||||
if r.Header.Get(headerCheckRedirect) == "" {
|
||||
// Send a loading response to the client
|
||||
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
rw.Write(w.makeRespBody("%s waking up...", w.ContainerName))
|
||||
return
|
||||
}
|
||||
|
||||
// wake the container and reset idle timer
|
||||
// also wait for another wake request
|
||||
w.wakeCh <- struct{}{}
|
||||
|
||||
if <-w.wakeDone != nil {
|
||||
http.Error(rw, "Error sending wake request", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// maybe another request came in while we were waiting for the wake
|
||||
if w.ready.Load() {
|
||||
next(rw, r)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
wakeReq, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodHead,
|
||||
w.URL.String(),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
w.l.Errorf("new request err to %s: %s", r.URL, err)
|
||||
http.Error(rw, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// we don't care about the response
|
||||
_, err = w.client.Do(wakeReq)
|
||||
if err == nil {
|
||||
w.ready.Store(true)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// retry until the container is ready or timeout
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
healthy, _, err := w.hc.CheckHealth()
|
||||
switch {
|
||||
case err != nil:
|
||||
w.ready.Store(false)
|
||||
return health.StatusError
|
||||
case healthy:
|
||||
w.ready.Store(true)
|
||||
return health.StatusHealthy
|
||||
default:
|
||||
return health.StatusStarting
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON implements health.HealthMonitor.
|
||||
func (w *Watcher) MarshalJSON() ([]byte, error) {
|
||||
var url net.URL
|
||||
if w.hc.URL().Port() != "0" {
|
||||
url = w.hc.URL()
|
||||
}
|
||||
return (&health.JSONRepresentation{
|
||||
Name: w.Name(),
|
||||
Status: w.Status(),
|
||||
Config: w.hc.Config(),
|
||||
URL: url,
|
||||
}).MarshalJSON()
|
||||
}
|
||||
|
||||
108
internal/docker/idlewatcher/waker_http.go
Normal file
108
internal/docker/idlewatcher/waker_http.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package idlewatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
// ServeHTTP implements http.Handler.
|
||||
func (w *Watcher) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
shouldNext := w.wakeFromHTTP(rw, r)
|
||||
if !shouldNext {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
default:
|
||||
w.rp.ServeHTTP(rw, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldNext bool) {
|
||||
w.resetIdleTimer()
|
||||
|
||||
// pass through if container is already ready
|
||||
if w.ready.Load() {
|
||||
return true
|
||||
}
|
||||
|
||||
if r.Body != nil {
|
||||
defer r.Body.Close()
|
||||
}
|
||||
|
||||
accept := gphttp.GetAccept(r.Header)
|
||||
acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty())
|
||||
|
||||
isCheckRedirect := r.Header.Get(common.HeaderCheckRedirect) != ""
|
||||
if !isCheckRedirect && acceptHTML {
|
||||
// Send a loading response to the client
|
||||
body := w.makeLoadingPageBody()
|
||||
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
rw.Header().Set("Content-Length", strconv.Itoa(len(body)))
|
||||
rw.Header().Add("Cache-Control", "no-cache")
|
||||
rw.Header().Add("Cache-Control", "no-store")
|
||||
rw.Header().Add("Cache-Control", "must-revalidate")
|
||||
rw.Header().Add("Connection", "close")
|
||||
if _, err := rw.Write(body); err != nil {
|
||||
w.Err(err).Msg("error writing http response")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeoutCause(r.Context(), w.WakeTimeout, errors.New("wake timeout"))
|
||||
defer cancel()
|
||||
|
||||
checkCanceled := func() (canceled bool) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
w.WakeDebug().Str("cause", context.Cause(ctx).Error()).Msg("canceled")
|
||||
return true
|
||||
case <-w.task.Context().Done():
|
||||
w.WakeDebug().Str("cause", w.task.FinishCause().Error()).Msg("canceled")
|
||||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if checkCanceled() {
|
||||
return false
|
||||
}
|
||||
|
||||
w.WakeTrace().Msg("signal received")
|
||||
err := w.wakeIfStopped()
|
||||
if err != nil {
|
||||
w.WakeError(err)
|
||||
http.Error(rw, "Error waking container", http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
|
||||
for {
|
||||
if checkCanceled() {
|
||||
return false
|
||||
}
|
||||
|
||||
if w.Status() == health.StatusHealthy {
|
||||
w.resetIdleTimer()
|
||||
if isCheckRedirect {
|
||||
w.Debug().Msgf("redirecting to %s ...", w.hc.URL())
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return false
|
||||
}
|
||||
w.Debug().Msgf("passing through to %s ...", w.hc.URL())
|
||||
return true
|
||||
}
|
||||
|
||||
// retry until the container is ready or timeout
|
||||
time.Sleep(idleWakerCheckInterval)
|
||||
}
|
||||
}
|
||||
90
internal/docker/idlewatcher/waker_stream.go
Normal file
90
internal/docker/idlewatcher/waker_stream.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package idlewatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
// Setup implements types.Stream.
|
||||
func (w *Watcher) Addr() net.Addr {
|
||||
return w.stream.Addr()
|
||||
}
|
||||
|
||||
// Setup implements types.Stream.
|
||||
func (w *Watcher) Setup() error {
|
||||
return w.stream.Setup()
|
||||
}
|
||||
|
||||
// Accept implements types.Stream.
|
||||
func (w *Watcher) Accept() (conn types.StreamConn, err error) {
|
||||
conn, err = w.stream.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if wakeErr := w.wakeFromStream(); wakeErr != nil {
|
||||
w.WakeError(wakeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Handle implements types.Stream.
|
||||
func (w *Watcher) Handle(conn types.StreamConn) error {
|
||||
if err := w.wakeFromStream(); err != nil {
|
||||
return err
|
||||
}
|
||||
return w.stream.Handle(conn)
|
||||
}
|
||||
|
||||
// Close implements types.Stream.
|
||||
func (w *Watcher) Close() error {
|
||||
return w.stream.Close()
|
||||
}
|
||||
|
||||
func (w *Watcher) wakeFromStream() error {
|
||||
w.resetIdleTimer()
|
||||
|
||||
// pass through if container is already ready
|
||||
if w.ready.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
w.WakeDebug().Msg("wake signal received")
|
||||
wakeErr := w.wakeIfStopped()
|
||||
if wakeErr != nil {
|
||||
wakeErr = fmt.Errorf("%s failed: %w", w.String(), wakeErr)
|
||||
w.WakeError(wakeErr)
|
||||
return wakeErr
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeoutCause(w.task.Context(), w.WakeTimeout, errors.New("wake timeout"))
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.task.Context().Done():
|
||||
cause := w.task.FinishCause()
|
||||
w.WakeDebug().Str("cause", cause.Error()).Msg("canceled")
|
||||
return cause
|
||||
case <-ctx.Done():
|
||||
cause := context.Cause(ctx)
|
||||
w.WakeDebug().Str("cause", cause.Error()).Msg("timeout")
|
||||
return cause
|
||||
default:
|
||||
}
|
||||
|
||||
if w.Status() == health.StatusHealthy {
|
||||
w.resetIdleTimer()
|
||||
w.Debug().Msg("container is ready, passing through to " + w.hc.URL().String())
|
||||
return nil
|
||||
}
|
||||
|
||||
// retry until the container is ready or timeout
|
||||
time.Sleep(idleWakerCheckInterval)
|
||||
}
|
||||
}
|
||||
@@ -2,269 +2,292 @@ package idlewatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/rs/zerolog"
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
P "github.com/yusing/go-proxy/internal/proxy"
|
||||
PT "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
W "github.com/yusing/go-proxy/internal/watcher"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/watcher"
|
||||
"github.com/yusing/go-proxy/internal/watcher/events"
|
||||
)
|
||||
|
||||
type (
|
||||
watcher struct {
|
||||
*P.ReverseProxyEntry
|
||||
Watcher struct {
|
||||
_ U.NoCopy
|
||||
|
||||
client D.Client
|
||||
zerolog.Logger
|
||||
|
||||
ready atomic.Bool // whether the site is ready to accept connection
|
||||
*idlewatcher.Config
|
||||
*waker
|
||||
|
||||
client D.Client
|
||||
stopByMethod StopCallback // send a docker command w.r.t. `stop_method`
|
||||
|
||||
wakeCh chan struct{}
|
||||
wakeDone chan E.NestedError
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
refCount *sync.WaitGroup
|
||||
|
||||
l logrus.FieldLogger
|
||||
ticker *time.Ticker
|
||||
task task.Task
|
||||
}
|
||||
|
||||
WakeDone <-chan error
|
||||
WakeFunc func() WakeDone
|
||||
StopCallback func() E.NestedError
|
||||
StopCallback func() error
|
||||
)
|
||||
|
||||
var (
|
||||
mainLoopCtx context.Context
|
||||
mainLoopCancel context.CancelFunc
|
||||
mainLoopWg sync.WaitGroup
|
||||
|
||||
watcherMap = make(map[string]*watcher)
|
||||
watcherMap = F.NewMapOf[string, *Watcher]()
|
||||
watcherMapMu sync.Mutex
|
||||
|
||||
newWatcherCh = make(chan *watcher)
|
||||
|
||||
logger = logrus.WithField("module", "idle_watcher")
|
||||
logger = logging.With().Str("module", "idle_watcher").Logger()
|
||||
)
|
||||
|
||||
func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) {
|
||||
failure := E.Failure("idle_watcher register")
|
||||
const dockerReqTimeout = 3 * time.Second
|
||||
|
||||
if entry.IdleTimeout == 0 {
|
||||
return nil, failure.With(E.Invalid("idle_timeout", 0))
|
||||
func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, error) {
|
||||
cfg := entry.IdlewatcherConfig()
|
||||
|
||||
if cfg.IdleTimeout == 0 {
|
||||
panic("should not reach here")
|
||||
}
|
||||
|
||||
watcherMapMu.Lock()
|
||||
defer watcherMapMu.Unlock()
|
||||
|
||||
key := entry.ContainerID
|
||||
key := cfg.ContainerID
|
||||
|
||||
if w, ok := watcherMap[key]; ok {
|
||||
w.refCount.Add(1)
|
||||
w.ReverseProxyEntry = entry
|
||||
if w, ok := watcherMap.Load(key); ok {
|
||||
w.Config = cfg
|
||||
w.waker = waker
|
||||
w.resetIdleTimer()
|
||||
providerSubtask.Finish("used existing watcher")
|
||||
return w, nil
|
||||
}
|
||||
|
||||
client, err := D.ConnectClient(entry.DockerHost)
|
||||
if err.HasError() {
|
||||
return nil, failure.With(err)
|
||||
client, err := D.ConnectClient(cfg.DockerHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
w := &watcher{
|
||||
ReverseProxyEntry: entry,
|
||||
client: client,
|
||||
refCount: &sync.WaitGroup{},
|
||||
wakeCh: make(chan struct{}),
|
||||
wakeDone: make(chan E.NestedError),
|
||||
l: logger.WithField("container", entry.ContainerName),
|
||||
w := &Watcher{
|
||||
Logger: logger.With().Str("name", cfg.ContainerName).Logger(),
|
||||
Config: cfg,
|
||||
waker: waker,
|
||||
client: client,
|
||||
task: providerSubtask,
|
||||
ticker: time.NewTicker(cfg.IdleTimeout),
|
||||
}
|
||||
w.refCount.Add(1)
|
||||
w.stopByMethod = w.getStopCallback()
|
||||
|
||||
watcherMap[key] = w
|
||||
watcherMap.Store(key, w)
|
||||
|
||||
go func() {
|
||||
newWatcherCh <- w
|
||||
cause := w.watchUntilDestroy()
|
||||
|
||||
watcherMap.Delete(w.ContainerID)
|
||||
|
||||
w.ticker.Stop()
|
||||
w.client.Close()
|
||||
w.task.Finish(cause)
|
||||
}()
|
||||
|
||||
return w, nil
|
||||
}
|
||||
|
||||
func (w *watcher) Unregister() {
|
||||
w.refCount.Add(-1)
|
||||
// WakeDebug logs a debug message related to waking the container.
|
||||
func (w *Watcher) WakeDebug() *zerolog.Event {
|
||||
return w.Debug().Str("action", "wake")
|
||||
}
|
||||
|
||||
func Start() {
|
||||
logger.Debug("started")
|
||||
defer logger.Debug("stopped")
|
||||
|
||||
mainLoopCtx, mainLoopCancel = context.WithCancel(context.Background())
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-mainLoopCtx.Done():
|
||||
return
|
||||
case w := <-newWatcherCh:
|
||||
w.l.Debug("registered")
|
||||
mainLoopWg.Add(1)
|
||||
go func() {
|
||||
w.watchUntilCancel()
|
||||
w.refCount.Wait() // wait for 0 ref count
|
||||
|
||||
w.client.Close()
|
||||
delete(watcherMap, w.ContainerID)
|
||||
w.l.Debug("unregistered")
|
||||
mainLoopWg.Done()
|
||||
}()
|
||||
}
|
||||
}
|
||||
func (w *Watcher) WakeTrace() *zerolog.Event {
|
||||
return w.Trace().Str("action", "wake")
|
||||
}
|
||||
|
||||
func Stop() {
|
||||
mainLoopCancel()
|
||||
mainLoopWg.Wait()
|
||||
func (w *Watcher) WakeError(err error) {
|
||||
w.Err(err).Str("action", "wake").Msg("error")
|
||||
}
|
||||
|
||||
func (w *watcher) containerStop() error {
|
||||
return w.client.ContainerStop(w.ctx, w.ContainerID, container.StopOptions{
|
||||
func (w *Watcher) LogReason(action, reason string) {
|
||||
w.Info().Str("reason", reason).Msg(action)
|
||||
}
|
||||
|
||||
func (w *Watcher) containerStop(ctx context.Context) error {
|
||||
return w.client.ContainerStop(ctx, w.ContainerID, container.StopOptions{
|
||||
Signal: string(w.StopSignal),
|
||||
Timeout: &w.StopTimeout})
|
||||
Timeout: &w.StopTimeout,
|
||||
})
|
||||
}
|
||||
|
||||
func (w *watcher) containerPause() error {
|
||||
return w.client.ContainerPause(w.ctx, w.ContainerID)
|
||||
func (w *Watcher) containerPause(ctx context.Context) error {
|
||||
return w.client.ContainerPause(ctx, w.ContainerID)
|
||||
}
|
||||
|
||||
func (w *watcher) containerKill() error {
|
||||
return w.client.ContainerKill(w.ctx, w.ContainerID, string(w.StopSignal))
|
||||
func (w *Watcher) containerKill(ctx context.Context) error {
|
||||
return w.client.ContainerKill(ctx, w.ContainerID, string(w.StopSignal))
|
||||
}
|
||||
|
||||
func (w *watcher) containerUnpause() error {
|
||||
return w.client.ContainerUnpause(w.ctx, w.ContainerID)
|
||||
func (w *Watcher) containerUnpause(ctx context.Context) error {
|
||||
return w.client.ContainerUnpause(ctx, w.ContainerID)
|
||||
}
|
||||
|
||||
func (w *watcher) containerStart() error {
|
||||
return w.client.ContainerStart(w.ctx, w.ContainerID, container.StartOptions{})
|
||||
func (w *Watcher) containerStart(ctx context.Context) error {
|
||||
return w.client.ContainerStart(ctx, w.ContainerID, container.StartOptions{})
|
||||
}
|
||||
|
||||
func (w *watcher) containerStatus() (string, E.NestedError) {
|
||||
json, err := w.client.ContainerInspect(w.ctx, w.ContainerID)
|
||||
func (w *Watcher) containerStatus() (string, error) {
|
||||
if !w.client.Connected() {
|
||||
return "", errors.New("docker client not connected")
|
||||
}
|
||||
ctx, cancel := context.WithTimeoutCause(w.task.Context(), dockerReqTimeout, errors.New("docker request timeout"))
|
||||
defer cancel()
|
||||
json, err := w.client.ContainerInspect(ctx, w.ContainerID)
|
||||
if err != nil {
|
||||
return "", E.FailWith("inspect container", err)
|
||||
return "", err
|
||||
}
|
||||
return json.State.Status, nil
|
||||
}
|
||||
|
||||
func (w *watcher) wakeIfStopped() E.NestedError {
|
||||
if w.ready.Load() || w.ContainerRunning {
|
||||
func (w *Watcher) wakeIfStopped() error {
|
||||
if w.ContainerRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
status, err := w.containerStatus()
|
||||
|
||||
if err.HasError() {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// "created", "running", "paused", "restarting", "removing", "exited", or "dead"
|
||||
|
||||
ctx, cancel := context.WithTimeout(w.task.Context(), w.WakeTimeout)
|
||||
defer cancel()
|
||||
|
||||
// !Hard coded here since theres no constants from Docker API
|
||||
switch status {
|
||||
case "exited", "dead":
|
||||
return E.From(w.containerStart())
|
||||
return w.containerStart(ctx)
|
||||
case "paused":
|
||||
return E.From(w.containerUnpause())
|
||||
return w.containerUnpause(ctx)
|
||||
case "running":
|
||||
return nil
|
||||
default:
|
||||
return E.Unexpected("container state", status)
|
||||
panic("should not reach here")
|
||||
}
|
||||
}
|
||||
|
||||
func (w *watcher) getStopCallback() StopCallback {
|
||||
var cb func() error
|
||||
func (w *Watcher) getStopCallback() StopCallback {
|
||||
var cb func(context.Context) error
|
||||
switch w.StopMethod {
|
||||
case PT.StopMethodPause:
|
||||
case idlewatcher.StopMethodPause:
|
||||
cb = w.containerPause
|
||||
case PT.StopMethodStop:
|
||||
case idlewatcher.StopMethodStop:
|
||||
cb = w.containerStop
|
||||
case PT.StopMethodKill:
|
||||
case idlewatcher.StopMethodKill:
|
||||
cb = w.containerKill
|
||||
default:
|
||||
panic("should not reach here")
|
||||
}
|
||||
return func() E.NestedError {
|
||||
status, err := w.containerStatus()
|
||||
if err.HasError() {
|
||||
return err
|
||||
}
|
||||
if status != "running" {
|
||||
return nil
|
||||
}
|
||||
return E.From(cb())
|
||||
return func() error {
|
||||
ctx, cancel := context.WithTimeout(w.task.Context(), time.Duration(w.StopTimeout)*time.Second)
|
||||
defer cancel()
|
||||
return cb(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *watcher) watchUntilCancel() {
|
||||
defer close(w.wakeCh)
|
||||
func (w *Watcher) resetIdleTimer() {
|
||||
w.Trace().Msg("reset idle timer")
|
||||
w.ticker.Reset(w.IdleTimeout)
|
||||
}
|
||||
|
||||
w.ctx, w.cancel = context.WithCancel(context.Background())
|
||||
|
||||
dockerWatcher := W.NewDockerWatcherWithClient(w.client)
|
||||
dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.ctx, W.DockerListOptions{
|
||||
Filters: W.NewDockerFilter(
|
||||
W.DockerFilterContainer,
|
||||
W.DockerrFilterContainer(w.ContainerID),
|
||||
W.DockerFilterStart,
|
||||
W.DockerFilterStop,
|
||||
W.DockerFilterDie,
|
||||
W.DockerFilterKill,
|
||||
W.DockerFilterPause,
|
||||
W.DockerFilterUnpause,
|
||||
func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask task.Task, eventCh <-chan events.Event, errCh <-chan E.Error) {
|
||||
eventTask = w.task.Subtask("docker event watcher")
|
||||
eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), watcher.DockerListOptions{
|
||||
Filters: watcher.NewDockerFilter(
|
||||
watcher.DockerFilterContainer,
|
||||
watcher.DockerFilterContainerNameID(w.ContainerID),
|
||||
watcher.DockerFilterStart,
|
||||
watcher.DockerFilterStop,
|
||||
watcher.DockerFilterDie,
|
||||
watcher.DockerFilterKill,
|
||||
watcher.DockerFilterDestroy,
|
||||
watcher.DockerFilterPause,
|
||||
watcher.DockerFilterUnpause,
|
||||
),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(w.IdleTimeout)
|
||||
defer ticker.Stop()
|
||||
// watchUntilDestroy waits for the container to be created, started, or unpaused,
|
||||
// and then reset the idle timer.
|
||||
//
|
||||
// When the container is stopped, paused,
|
||||
// or killed, the idle timer is stopped and the ContainerRunning flag is set to false.
|
||||
//
|
||||
// When the idle timer fires, the container is stopped according to the
|
||||
// stop method.
|
||||
//
|
||||
// it exits only if the context is canceled, the container is destroyed,
|
||||
// errors occurred on docker client, or route provider died (mainly caused by config reload).
|
||||
func (w *Watcher) watchUntilDestroy() (returnCause error) {
|
||||
dockerWatcher := watcher.NewDockerWatcherWithClient(w.client)
|
||||
eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher)
|
||||
defer eventTask.Finish("stopped")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-mainLoopCtx.Done():
|
||||
w.cancel()
|
||||
case <-w.ctx.Done():
|
||||
w.l.Debug("stopped")
|
||||
return
|
||||
case <-w.task.Context().Done():
|
||||
return w.task.FinishCause()
|
||||
case err := <-dockerEventErrCh:
|
||||
if err != nil && err.IsNot(context.Canceled) {
|
||||
w.l.Error(E.FailWith("docker watcher", err))
|
||||
if !err.Is(context.Canceled) {
|
||||
E.LogError("idlewatcher error", err, &w.Logger)
|
||||
}
|
||||
return err
|
||||
case e := <-dockerEventCh:
|
||||
switch {
|
||||
case e.Action == events.ActionContainerDestroy:
|
||||
w.ContainerRunning = false
|
||||
w.ready.Store(false)
|
||||
w.LogReason("watcher stopped", "container destroyed")
|
||||
return errors.New("container destroyed")
|
||||
// create / start / unpause
|
||||
case e.Action.IsContainerWake():
|
||||
ticker.Reset(w.IdleTimeout)
|
||||
w.l.Info(e)
|
||||
default: // stop / pause / kill
|
||||
ticker.Stop()
|
||||
w.ContainerRunning = true
|
||||
w.resetIdleTimer()
|
||||
w.Info().Msg("awaken")
|
||||
case e.Action.IsContainerSleep(): // stop / pause / kil
|
||||
w.ContainerRunning = false
|
||||
w.ready.Store(false)
|
||||
w.l.Info(e)
|
||||
w.ticker.Stop()
|
||||
default:
|
||||
w.Error().Msg("unexpected docker event: " + e.String())
|
||||
}
|
||||
case <-ticker.C:
|
||||
w.l.Debug("idle timeout")
|
||||
ticker.Stop()
|
||||
if err := w.stopByMethod(); err != nil && err.IsNot(context.Canceled) {
|
||||
w.l.Error(E.FailWith("stop", err).Extraf("stop method: %s", w.StopMethod))
|
||||
// container name changed should also change the container id
|
||||
if w.ContainerName != e.ActorName {
|
||||
w.Debug().Msgf("renamed %s -> %s", w.ContainerName, e.ActorName)
|
||||
w.ContainerName = e.ActorName
|
||||
}
|
||||
case <-w.wakeCh:
|
||||
w.l.Debug("wake signal received")
|
||||
ticker.Reset(w.IdleTimeout)
|
||||
err := w.wakeIfStopped()
|
||||
if err != nil {
|
||||
w.l.Error(E.FailWith("wake", err))
|
||||
if w.ContainerID != e.ActorID {
|
||||
w.Debug().Msgf("id changed %s -> %s", w.ContainerID, e.ActorID)
|
||||
w.ContainerID = e.ActorID
|
||||
// recreate event stream
|
||||
eventTask.Finish("recreate event stream")
|
||||
eventTask, dockerEventCh, dockerEventErrCh = w.getEventCh(dockerWatcher)
|
||||
}
|
||||
case <-w.ticker.C:
|
||||
w.ticker.Stop()
|
||||
if w.ContainerRunning {
|
||||
err := w.stopByMethod()
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled):
|
||||
continue
|
||||
case err != nil:
|
||||
w.Err(err).Msgf("container stop with method %q failed", w.StopMethod)
|
||||
default:
|
||||
w.LogReason("container stopped", "idle timeout")
|
||||
}
|
||||
}
|
||||
w.wakeDone <- err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,18 +2,28 @@ package docker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
func (c Client) Inspect(containerID string) (Container, E.NestedError) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
func Inspect(dockerHost string, containerID string) (*Container, error) {
|
||||
client, err := ConnectClient(dockerHost)
|
||||
defer client.Close()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client.Inspect(containerID)
|
||||
}
|
||||
|
||||
func (c Client) Inspect(containerID string) (*Container, error) {
|
||||
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker container inspect timeout"))
|
||||
defer cancel()
|
||||
|
||||
json, err := c.ContainerInspect(ctx, containerID)
|
||||
if err != nil {
|
||||
return Container{}, E.From(err)
|
||||
return nil, err
|
||||
}
|
||||
return FromJson(json, c.key), nil
|
||||
return FromJSON(json, c.key), nil
|
||||
}
|
||||
|
||||
@@ -24,6 +24,11 @@ type (
|
||||
NestedLabelMap map[string]U.SerializedObject
|
||||
)
|
||||
|
||||
var (
|
||||
ErrApplyToNil = E.New("label value is nil")
|
||||
ErrFieldNotExist = E.New("field does not exist")
|
||||
)
|
||||
|
||||
func (l *Label) String() string {
|
||||
if l.Attribute == "" {
|
||||
return l.Namespace + "." + l.Target
|
||||
@@ -39,22 +44,22 @@ func (l *Label) String() string {
|
||||
//
|
||||
// Returns:
|
||||
// - error: an error if the field does not exist.
|
||||
func ApplyLabel[T any](obj *T, l *Label) E.NestedError {
|
||||
func ApplyLabel[T any](obj *T, l *Label) E.Error {
|
||||
if obj == nil {
|
||||
return E.Invalid("nil object", l)
|
||||
return ErrApplyToNil.Subject(l.String())
|
||||
}
|
||||
switch nestedLabel := l.Value.(type) {
|
||||
case *Label:
|
||||
var field reflect.Value
|
||||
objType := reflect.TypeFor[T]()
|
||||
for i := 0; i < reflect.TypeFor[T]().NumField(); i++ {
|
||||
for i := range reflect.TypeFor[T]().NumField() {
|
||||
if objType.Field(i).Tag.Get("yaml") == l.Attribute {
|
||||
field = reflect.ValueOf(obj).Elem().Field(i)
|
||||
break
|
||||
}
|
||||
}
|
||||
if !field.IsValid() {
|
||||
return E.NotExist("field", l.Attribute)
|
||||
return ErrFieldNotExist.Subject(l.Attribute).Subject(l.String())
|
||||
}
|
||||
dst, ok := field.Interface().(NestedLabelMap)
|
||||
if !ok {
|
||||
@@ -65,7 +70,11 @@ func ApplyLabel[T any](obj *T, l *Label) E.NestedError {
|
||||
} else {
|
||||
field = field.Addr()
|
||||
}
|
||||
return U.Deserialize(U.SerializedObject{nestedLabel.Namespace: nestedLabel.Value}, field.Interface())
|
||||
err := U.Deserialize(U.SerializedObject{nestedLabel.Namespace: nestedLabel.Value}, field.Interface())
|
||||
if err != nil {
|
||||
return err.Subject(l.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if dst == nil {
|
||||
field.Set(reflect.MakeMap(reflect.TypeFor[NestedLabelMap]()))
|
||||
@@ -77,18 +86,22 @@ func ApplyLabel[T any](obj *T, l *Label) E.NestedError {
|
||||
dst[nestedLabel.Namespace][nestedLabel.Attribute] = nestedLabel.Value
|
||||
return nil
|
||||
default:
|
||||
return U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj)
|
||||
err := U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj)
|
||||
if err != nil {
|
||||
return err.Subject(l.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func ParseLabel(label string, value string) (*Label, E.NestedError) {
|
||||
func ParseLabel(label string, value string) *Label {
|
||||
parts := strings.Split(label, ".")
|
||||
|
||||
if len(parts) < 2 {
|
||||
return &Label{
|
||||
Namespace: label,
|
||||
Value: value,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
l := &Label{
|
||||
@@ -104,12 +117,9 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) {
|
||||
l.Attribute = parts[2]
|
||||
default:
|
||||
l.Attribute = parts[2]
|
||||
nestedLabel, err := ParseLabel(strings.Join(parts[3:], "."), value)
|
||||
if err.HasError() {
|
||||
return nil, err
|
||||
}
|
||||
nestedLabel := ParseLabel(strings.Join(parts[3:], "."), value)
|
||||
l.Value = nestedLabel
|
||||
}
|
||||
|
||||
return l, nil
|
||||
return l
|
||||
}
|
||||
|
||||
@@ -8,17 +8,20 @@ import (
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
const (
|
||||
mName = "middleware1"
|
||||
mAttr = "prop1"
|
||||
v = "value1"
|
||||
)
|
||||
|
||||
func makeLabel(ns, name, attr string) string {
|
||||
return fmt.Sprintf("%s.%s.%s", ns, name, attr)
|
||||
}
|
||||
|
||||
func TestNestedLabel(t *testing.T) {
|
||||
mName := "middleware1"
|
||||
mAttr := "prop1"
|
||||
v := "value1"
|
||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
|
||||
ExpectNoError(t, err.Error())
|
||||
sGot := ExpectType[*Label](t, pl.Value)
|
||||
lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
|
||||
sGot := ExpectType[*Label](t, lbl.Value)
|
||||
ExpectFalse(t, sGot == nil)
|
||||
ExpectEqual(t, sGot.Namespace, mName)
|
||||
ExpectEqual(t, sGot.Attribute, mAttr)
|
||||
@@ -28,13 +31,9 @@ func TestApplyNestedLabel(t *testing.T) {
|
||||
entry := new(struct {
|
||||
Middlewares NestedLabelMap `yaml:"middlewares"`
|
||||
})
|
||||
mName := "middleware1"
|
||||
mAttr := "prop1"
|
||||
v := "value1"
|
||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
|
||||
ExpectNoError(t, err.Error())
|
||||
err = ApplyLabel(entry, pl)
|
||||
ExpectNoError(t, err.Error())
|
||||
lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
|
||||
err := ApplyLabel(entry, lbl)
|
||||
ExpectNoError(t, err)
|
||||
middleware1, ok := entry.Middlewares[mName]
|
||||
ExpectTrue(t, ok)
|
||||
got := ExpectType[string](t, middleware1[mAttr])
|
||||
@@ -42,10 +41,6 @@ func TestApplyNestedLabel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestApplyNestedLabelExisting(t *testing.T) {
|
||||
mName := "middleware1"
|
||||
mAttr := "prop1"
|
||||
v := "value1"
|
||||
|
||||
checkAttr := "prop2"
|
||||
checkV := "value2"
|
||||
entry := new(struct {
|
||||
@@ -55,10 +50,9 @@ func TestApplyNestedLabelExisting(t *testing.T) {
|
||||
entry.Middlewares[mName] = make(U.SerializedObject)
|
||||
entry.Middlewares[mName][checkAttr] = checkV
|
||||
|
||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
|
||||
ExpectNoError(t, err.Error())
|
||||
err = ApplyLabel(entry, pl)
|
||||
ExpectNoError(t, err.Error())
|
||||
lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
|
||||
err := ApplyLabel(entry, lbl)
|
||||
ExpectNoError(t, err)
|
||||
middleware1, ok := entry.Middlewares[mName]
|
||||
ExpectTrue(t, ok)
|
||||
got := ExpectType[string](t, middleware1[mAttr])
|
||||
@@ -71,19 +65,15 @@ func TestApplyNestedLabelExisting(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestApplyNestedLabelNoAttr(t *testing.T) {
|
||||
mName := "middleware1"
|
||||
v := "value1"
|
||||
|
||||
entry := new(struct {
|
||||
Middlewares NestedLabelMap `yaml:"middlewares"`
|
||||
})
|
||||
entry.Middlewares = make(NestedLabelMap)
|
||||
entry.Middlewares[mName] = make(U.SerializedObject)
|
||||
|
||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", "middlewares", mName)), v)
|
||||
ExpectNoError(t, err.Error())
|
||||
err = ApplyLabel(entry, pl)
|
||||
ExpectNoError(t, err.Error())
|
||||
lbl := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", "middlewares", mName)), v)
|
||||
err := ApplyLabel(entry, lbl)
|
||||
ExpectNoError(t, err)
|
||||
_, ok := entry.Middlewares[mName]
|
||||
ExpectTrue(t, ok)
|
||||
}
|
||||
|
||||
44
internal/docker/list_containers.go
Normal file
44
internal/docker/list_containers.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package docker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/client"
|
||||
)
|
||||
|
||||
var listOptions = container.ListOptions{
|
||||
// created|restarting|running|removing|paused|exited|dead
|
||||
// Filters: filters.NewArgs(
|
||||
// filters.Arg("status", "created"),
|
||||
// filters.Arg("status", "restarting"),
|
||||
// filters.Arg("status", "running"),
|
||||
// filters.Arg("status", "paused"),
|
||||
// filters.Arg("status", "exited"),
|
||||
// ),
|
||||
All: true,
|
||||
}
|
||||
|
||||
func ListContainers(clientHost string) ([]types.Container, error) {
|
||||
dockerClient, err := ConnectClient(clientHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer dockerClient.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("list containers timeout"))
|
||||
defer cancel()
|
||||
|
||||
containers, err := dockerClient.ContainerList(ctx, listOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return containers, nil
|
||||
}
|
||||
|
||||
func IsErrConnectionFailed(err error) bool {
|
||||
return client.IsErrConnectionFailed(err)
|
||||
}
|
||||
7
internal/docker/logger.go
Normal file
7
internal/docker/logger.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package docker
|
||||
|
||||
import (
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
)
|
||||
|
||||
var logger = logging.With().Str("module", "docker").Logger()
|
||||
@@ -1,25 +0,0 @@
|
||||
package docker
|
||||
|
||||
import "github.com/docker/docker/api/types"
|
||||
|
||||
type PortMapping = map[string]types.Port
|
||||
type ProxyProperties struct {
|
||||
DockerHost string `yaml:"-" json:"docker_host"`
|
||||
ContainerName string `yaml:"-" json:"container_name"`
|
||||
ContainerID string `yaml:"-" json:"container_id"`
|
||||
ImageName string `yaml:"-" json:"image_name"`
|
||||
PublicPortMapping PortMapping `yaml:"-" json:"public_port_mapping"` // non-zero publicPort:types.Port
|
||||
PrivatePortMapping PortMapping `yaml:"-" json:"private_port_mapping"` // privatePort:types.Port
|
||||
NetworkMode string `yaml:"-" json:"network_mode"`
|
||||
|
||||
Aliases []string `yaml:"-" json:"aliases"`
|
||||
IsExcluded bool `yaml:"-" json:"is_excluded"`
|
||||
IsExplicit bool `yaml:"-" json:"is_explicit"`
|
||||
IsDatabase bool `yaml:"-" json:"is_database"`
|
||||
IdleTimeout string `yaml:"-" json:"idle_timeout"`
|
||||
WakeTimeout string `yaml:"-" json:"wake_timeout"`
|
||||
StopMethod string `yaml:"-" json:"stop_method"`
|
||||
StopTimeout string `yaml:"-" json:"stop_timeout"` // stop_method = "stop" only
|
||||
StopSignal string `yaml:"-" json:"stop_signal"` // stop_method = "stop" | "kill" only
|
||||
Running bool `yaml:"-" json:"running"`
|
||||
}
|
||||
46
internal/error/base.go
Normal file
46
internal/error/base.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package error
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// baseError is an immutable wrapper around an error.
|
||||
type baseError struct {
|
||||
Err error `json:"err"`
|
||||
}
|
||||
|
||||
func (err *baseError) Unwrap() error {
|
||||
return err.Err
|
||||
}
|
||||
|
||||
func (err *baseError) Is(other error) bool {
|
||||
if other, ok := other.(*baseError); ok {
|
||||
return errors.Is(err.Err, other.Err)
|
||||
}
|
||||
return errors.Is(err.Err, other)
|
||||
}
|
||||
|
||||
func (err baseError) Subject(subject string) Error {
|
||||
err.Err = PrependSubject(subject, err.Err)
|
||||
return &err
|
||||
}
|
||||
|
||||
func (err *baseError) Subjectf(format string, args ...any) Error {
|
||||
if len(args) > 0 {
|
||||
return err.Subject(fmt.Sprintf(format, args...))
|
||||
}
|
||||
return err.Subject(format)
|
||||
}
|
||||
|
||||
func (err baseError) With(extra error) Error {
|
||||
return &nestedError{&err, []error{extra}}
|
||||
}
|
||||
|
||||
func (err baseError) Withf(format string, args ...any) Error {
|
||||
return &nestedError{&err, []error{fmt.Errorf(format, args...)}}
|
||||
}
|
||||
|
||||
func (err *baseError) Error() string {
|
||||
return err.Err.Error()
|
||||
}
|
||||
@@ -2,69 +2,123 @@ package error
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Builder struct {
|
||||
*builder
|
||||
}
|
||||
|
||||
type builder struct {
|
||||
message string
|
||||
errors []NestedError
|
||||
about string
|
||||
errs []error
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func NewBuilder(format string, args ...any) Builder {
|
||||
return Builder{&builder{message: fmt.Sprintf(format, args...)}}
|
||||
func NewBuilder(about string) *Builder {
|
||||
return &Builder{about: about}
|
||||
}
|
||||
|
||||
// adding nil / nil is no-op,
|
||||
// you may safely pass expressions returning error to it
|
||||
func (b Builder) Add(err NestedError) Builder {
|
||||
if err != nil {
|
||||
func (b *Builder) About() string {
|
||||
if !b.HasError() {
|
||||
return ""
|
||||
}
|
||||
return b.about
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func (b *Builder) HasError() bool {
|
||||
return len(b.errs) > 0
|
||||
}
|
||||
|
||||
func (b *Builder) error() Error {
|
||||
if !b.HasError() {
|
||||
return nil
|
||||
}
|
||||
return &nestedError{Err: New(b.about), Extras: b.errs}
|
||||
}
|
||||
|
||||
func (b *Builder) Error() Error {
|
||||
if len(b.errs) == 1 {
|
||||
return From(b.errs[0])
|
||||
}
|
||||
return b.error()
|
||||
}
|
||||
|
||||
func (b *Builder) String() string {
|
||||
err := b.error()
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
// Add adds an error to the Builder.
|
||||
//
|
||||
// adding nil is no-op.
|
||||
func (b *Builder) Add(err error) *Builder {
|
||||
if err == nil {
|
||||
return b
|
||||
}
|
||||
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
switch err := err.(type) {
|
||||
case *baseError:
|
||||
b.errs = append(b.errs, err.Err)
|
||||
case *nestedError:
|
||||
if err.Err == nil {
|
||||
b.errs = append(b.errs, err.Extras...)
|
||||
} else {
|
||||
b.errs = append(b.errs, err)
|
||||
}
|
||||
default:
|
||||
b.errs = append(b.errs, err)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Builder) Adds(err string) *Builder {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
b.errs = append(b.errs, newError(err))
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Builder) Addf(format string, args ...any) *Builder {
|
||||
if len(args) > 0 {
|
||||
b.Lock()
|
||||
b.errors = append(b.errors, err)
|
||||
b.Unlock()
|
||||
defer b.Unlock()
|
||||
b.errs = append(b.errs, fmt.Errorf(format, args...))
|
||||
} else {
|
||||
b.Adds(format)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Builder) AddFrom(other *Builder, flatten bool) *Builder {
|
||||
if other == nil || !other.HasError() {
|
||||
return b
|
||||
}
|
||||
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
if flatten {
|
||||
b.errs = append(b.errs, other.errs...)
|
||||
} else {
|
||||
b.errs = append(b.errs, other.error())
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (b Builder) AddE(err error) Builder {
|
||||
return b.Add(From(err))
|
||||
}
|
||||
func (b *Builder) AddRange(errs ...error) *Builder {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
func (b Builder) Addf(format string, args ...any) Builder {
|
||||
return b.Add(errorf(format, args...))
|
||||
}
|
||||
|
||||
// Build builds a NestedError based on the errors collected in the Builder.
|
||||
//
|
||||
// If there are no errors in the Builder, it returns a Nil() NestedError.
|
||||
// Otherwise, it returns a NestedError with the message and the errors collected.
|
||||
//
|
||||
// Returns:
|
||||
// - NestedError: the built NestedError.
|
||||
func (b Builder) Build() NestedError {
|
||||
if len(b.errors) == 0 {
|
||||
return nil
|
||||
} else if len(b.errors) == 1 && !strings.ContainsRune(b.message, ' ') {
|
||||
return b.errors[0].Subject(b.message)
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
b.errs = append(b.errs, err)
|
||||
}
|
||||
}
|
||||
return Join(b.message, b.errors...)
|
||||
}
|
||||
|
||||
func (b Builder) To(ptr *NestedError) {
|
||||
if ptr == nil {
|
||||
return
|
||||
} else if *ptr == nil {
|
||||
*ptr = b.Build()
|
||||
} else {
|
||||
(*ptr).With(b.Build())
|
||||
}
|
||||
}
|
||||
|
||||
func (b Builder) HasError() bool {
|
||||
return len(b.errors) > 0
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -1,53 +1,55 @@
|
||||
package error
|
||||
package error_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/error"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestBuilderEmpty(t *testing.T) {
|
||||
eb := NewBuilder("qwer")
|
||||
ExpectTrue(t, eb.Build() == nil)
|
||||
ExpectTrue(t, eb.Build().NoError())
|
||||
eb := NewBuilder("foo")
|
||||
ExpectTrue(t, errors.Is(eb.Error(), nil))
|
||||
ExpectFalse(t, eb.HasError())
|
||||
}
|
||||
|
||||
func TestBuilderAddNil(t *testing.T) {
|
||||
eb := NewBuilder("asdf")
|
||||
var err NestedError
|
||||
eb := NewBuilder("foo")
|
||||
var err Error
|
||||
for range 3 {
|
||||
eb.Add(nil)
|
||||
}
|
||||
for range 3 {
|
||||
eb.Add(err)
|
||||
}
|
||||
ExpectTrue(t, eb.Build() == nil)
|
||||
ExpectTrue(t, eb.Build().NoError())
|
||||
eb.AddRange(nil, nil, err)
|
||||
ExpectFalse(t, eb.HasError())
|
||||
ExpectTrue(t, eb.Error() == nil)
|
||||
}
|
||||
|
||||
func TestBuilderIs(t *testing.T) {
|
||||
eb := NewBuilder("foo")
|
||||
eb.Add(context.Canceled)
|
||||
eb.Add(io.ErrShortBuffer)
|
||||
ExpectTrue(t, eb.HasError())
|
||||
ExpectError(t, io.ErrShortBuffer, eb.Error())
|
||||
ExpectError(t, context.Canceled, eb.Error())
|
||||
}
|
||||
|
||||
func TestBuilderNested(t *testing.T) {
|
||||
eb := NewBuilder("error occurred")
|
||||
eb.Add(Failure("Action 1").With(Invalid("Inner", "1")).With(Invalid("Inner", "2")))
|
||||
eb.Add(Failure("Action 2").With(Invalid("Inner", "3")))
|
||||
eb := NewBuilder("action failed")
|
||||
eb.Add(New("Action 1").Withf("Inner: 1").Withf("Inner: 2"))
|
||||
eb.Add(New("Action 2").Withf("Inner: 3"))
|
||||
|
||||
got := eb.Build().String()
|
||||
expected1 :=
|
||||
(`error occurred:
|
||||
- Action 1 failed:
|
||||
- invalid Inner: 1
|
||||
- invalid Inner: 2
|
||||
- Action 2 failed:
|
||||
- invalid Inner: 3`)
|
||||
expected2 :=
|
||||
(`error occurred:
|
||||
- Action 1 failed:
|
||||
- invalid Inner: "1"
|
||||
- invalid Inner: "2"
|
||||
- Action 2 failed:
|
||||
- invalid Inner: "3"`)
|
||||
if got != expected1 && got != expected2 {
|
||||
t.Errorf("expected \n%s, got \n%s", expected1, got)
|
||||
}
|
||||
got := eb.String()
|
||||
expected := `action failed
|
||||
• Action 1
|
||||
• Inner: 1
|
||||
• Inner: 2
|
||||
• Action 2
|
||||
• Inner: 3`
|
||||
ExpectEqual(t, got, expected)
|
||||
}
|
||||
|
||||
@@ -1,296 +1,31 @@
|
||||
package error
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
type Error interface {
|
||||
error
|
||||
|
||||
type (
|
||||
NestedError = *nestedError
|
||||
nestedError struct {
|
||||
subject string
|
||||
err error
|
||||
extras []nestedError
|
||||
}
|
||||
jsonNestedError struct {
|
||||
Subject string
|
||||
Err string
|
||||
Extras []jsonNestedError
|
||||
}
|
||||
)
|
||||
|
||||
func From(err error) NestedError {
|
||||
if IsNil(err) {
|
||||
return nil
|
||||
}
|
||||
return &nestedError{err: err}
|
||||
// Is is a wrapper for errors.Is when there is no sub-error.
|
||||
//
|
||||
// When there are sub-errors, they will also be checked.
|
||||
Is(other error) bool
|
||||
// With appends a sub-error to the error.
|
||||
With(extra error) Error
|
||||
// Withf is a wrapper for With(fmt.Errorf(format, args...)).
|
||||
Withf(format string, args ...any) Error
|
||||
// Subject prepends the given subject with a colon and space to the error message.
|
||||
//
|
||||
// If there is already a subject in the error message, the subject will be
|
||||
// prepended to the existing subject with " > ".
|
||||
//
|
||||
// Subject empty string is ignored.
|
||||
Subject(subject string) Error
|
||||
// Subjectf is a wrapper for Subject(fmt.Sprintf(format, args...)).
|
||||
Subjectf(format string, args ...any) Error
|
||||
}
|
||||
|
||||
func FromJSON(data []byte) (NestedError, bool) {
|
||||
var j jsonNestedError
|
||||
if err := json.Unmarshal(data, &j); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
if j.Err == "" {
|
||||
return nil, false
|
||||
}
|
||||
extras := make([]nestedError, len(j.Extras))
|
||||
for i, e := range j.Extras {
|
||||
extra, ok := fromJSONObject(e)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
extras[i] = *extra
|
||||
}
|
||||
return &nestedError{
|
||||
subject: j.Subject,
|
||||
err: errors.New(j.Err),
|
||||
extras: extras,
|
||||
}, true
|
||||
}
|
||||
|
||||
// Check is a helper function that
|
||||
// convert (T, error) to (T, NestedError).
|
||||
func Check[T any](obj T, err error) (T, NestedError) {
|
||||
return obj, From(err)
|
||||
}
|
||||
|
||||
func Join(message string, err ...NestedError) NestedError {
|
||||
extras := make([]nestedError, len(err))
|
||||
nErr := 0
|
||||
for i, e := range err {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
extras[i] = *e
|
||||
nErr += 1
|
||||
}
|
||||
if nErr == 0 {
|
||||
return nil
|
||||
}
|
||||
return &nestedError{
|
||||
err: errors.New(message),
|
||||
extras: extras,
|
||||
}
|
||||
}
|
||||
|
||||
func JoinE(message string, err ...error) NestedError {
|
||||
b := NewBuilder(message)
|
||||
for _, e := range err {
|
||||
b.AddE(e)
|
||||
}
|
||||
return b.Build()
|
||||
}
|
||||
|
||||
func IsNil(err error) bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func IsNotNil(err error) bool {
|
||||
return err != nil
|
||||
}
|
||||
|
||||
func (ne NestedError) String() string {
|
||||
var buf strings.Builder
|
||||
ne.writeToSB(&buf, 0, "")
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (ne NestedError) Is(err error) bool {
|
||||
if ne == nil {
|
||||
return err == nil
|
||||
}
|
||||
// return errors.Is(ne.err, err)
|
||||
if errors.Is(ne.err, err) {
|
||||
return true
|
||||
}
|
||||
for _, e := range ne.extras {
|
||||
if e.Is(err) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ne NestedError) IsNot(err error) bool {
|
||||
return !ne.Is(err)
|
||||
}
|
||||
|
||||
func (ne NestedError) Error() error {
|
||||
if ne == nil {
|
||||
return nil
|
||||
}
|
||||
return ne.buildError(0, "")
|
||||
}
|
||||
|
||||
func (ne NestedError) With(s any) NestedError {
|
||||
if ne == nil {
|
||||
return ne
|
||||
}
|
||||
var msg string
|
||||
switch ss := s.(type) {
|
||||
case nil:
|
||||
return ne
|
||||
case NestedError:
|
||||
return ne.withError(ss)
|
||||
case error:
|
||||
return ne.withError(From(ss))
|
||||
case string:
|
||||
msg = ss
|
||||
case fmt.Stringer:
|
||||
return ne.appendMsg(ss.String())
|
||||
default:
|
||||
return ne.appendMsg(fmt.Sprint(s))
|
||||
}
|
||||
return ne.withError(From(errors.New(msg)))
|
||||
}
|
||||
|
||||
func (ne NestedError) Extraf(format string, args ...any) NestedError {
|
||||
return ne.With(errorf(format, args...))
|
||||
}
|
||||
|
||||
func (ne NestedError) Subject(s any) NestedError {
|
||||
if ne == nil {
|
||||
return ne
|
||||
}
|
||||
var subject string
|
||||
switch ss := s.(type) {
|
||||
case string:
|
||||
subject = ss
|
||||
case fmt.Stringer:
|
||||
subject = ss.String()
|
||||
default:
|
||||
subject = fmt.Sprint(s)
|
||||
}
|
||||
if ne.subject == "" {
|
||||
ne.subject = subject
|
||||
} else if !strings.ContainsRune(subject, ' ') || strings.ContainsRune(ne.subject, '.') {
|
||||
ne.subject = fmt.Sprintf("%s.%s", subject, ne.subject)
|
||||
} else {
|
||||
ne.subject = fmt.Sprintf("%s > %s", subject, ne.subject)
|
||||
}
|
||||
return ne
|
||||
}
|
||||
|
||||
func (ne NestedError) Subjectf(format string, args ...any) NestedError {
|
||||
if ne == nil {
|
||||
return ne
|
||||
}
|
||||
if strings.Contains(format, "%q") {
|
||||
panic("Subjectf format should not contain %q")
|
||||
}
|
||||
if strings.Contains(format, "%w") {
|
||||
panic("Subjectf format should not contain %w")
|
||||
}
|
||||
return ne.Subject(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (ne NestedError) JSONObject() jsonNestedError {
|
||||
extras := make([]jsonNestedError, len(ne.extras))
|
||||
for i, e := range ne.extras {
|
||||
extras[i] = e.JSONObject()
|
||||
}
|
||||
return jsonNestedError{
|
||||
Subject: ne.subject,
|
||||
Err: ne.err.Error(),
|
||||
Extras: extras,
|
||||
}
|
||||
}
|
||||
|
||||
func (ne NestedError) JSON() []byte {
|
||||
b, _ := json.MarshalIndent(ne.JSONObject(), "", " ")
|
||||
return b
|
||||
}
|
||||
|
||||
func (ne NestedError) NoError() bool {
|
||||
return ne == nil
|
||||
}
|
||||
|
||||
func (ne NestedError) HasError() bool {
|
||||
return ne != nil
|
||||
}
|
||||
|
||||
func errorf(format string, args ...any) NestedError {
|
||||
return From(fmt.Errorf(format, args...))
|
||||
}
|
||||
|
||||
func fromJSONObject(obj jsonNestedError) (NestedError, bool) {
|
||||
data, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return FromJSON(data)
|
||||
}
|
||||
|
||||
func (ne NestedError) withError(err NestedError) NestedError {
|
||||
if ne != nil && err != nil {
|
||||
ne.extras = append(ne.extras, *err)
|
||||
}
|
||||
return ne
|
||||
}
|
||||
|
||||
func (ne NestedError) appendMsg(msg string) NestedError {
|
||||
if ne == nil {
|
||||
return nil
|
||||
}
|
||||
ne.err = fmt.Errorf("%w %s", ne.err, msg)
|
||||
return ne
|
||||
}
|
||||
|
||||
func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) {
|
||||
for i := 0; i < level; i++ {
|
||||
sb.WriteString(" ")
|
||||
}
|
||||
sb.WriteString(prefix)
|
||||
|
||||
if ne.NoError() {
|
||||
sb.WriteString("nil")
|
||||
return
|
||||
}
|
||||
|
||||
sb.WriteString(ne.err.Error())
|
||||
if ne.subject != "" {
|
||||
sb.WriteString(fmt.Sprintf(" for %q", ne.subject))
|
||||
}
|
||||
if len(ne.extras) > 0 {
|
||||
sb.WriteRune(':')
|
||||
for _, extra := range ne.extras {
|
||||
sb.WriteRune('\n')
|
||||
extra.writeToSB(sb, level+1, "- ")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ne NestedError) buildError(level int, prefix string) error {
|
||||
var res error
|
||||
var sb strings.Builder
|
||||
|
||||
for i := 0; i < level; i++ {
|
||||
sb.WriteString(" ")
|
||||
}
|
||||
sb.WriteString(prefix)
|
||||
|
||||
if ne.NoError() {
|
||||
sb.WriteString("nil")
|
||||
return errors.New(sb.String())
|
||||
}
|
||||
|
||||
res = fmt.Errorf("%s%w", sb.String(), ne.err)
|
||||
sb.Reset()
|
||||
|
||||
if ne.subject != "" {
|
||||
sb.WriteString(fmt.Sprintf(" for %q", ne.subject))
|
||||
}
|
||||
if len(ne.extras) > 0 {
|
||||
sb.WriteRune(':')
|
||||
res = fmt.Errorf("%w%s", res, sb.String())
|
||||
for _, extra := range ne.extras {
|
||||
res = errors.Join(res, extra.buildError(level+1, "- "))
|
||||
}
|
||||
} else {
|
||||
res = fmt.Errorf("%w%s", res, sb.String())
|
||||
}
|
||||
return res
|
||||
// this makes JSON marshaling work,
|
||||
// as the builtin one doesn't.
|
||||
type errStr string
|
||||
|
||||
func (err errStr) Error() string {
|
||||
return string(err)
|
||||
}
|
||||
|
||||
@@ -1,109 +1,157 @@
|
||||
package error_test
|
||||
package error
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/error"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestBaseString(t *testing.T) {
|
||||
ExpectEqual(t, New("error").Error(), "error")
|
||||
}
|
||||
|
||||
func TestBaseWithSubject(t *testing.T) {
|
||||
err := New("error")
|
||||
withSubject := err.Subject("foo")
|
||||
withSubjectf := err.Subjectf("%s %s", "foo", "bar")
|
||||
|
||||
ExpectError(t, err, withSubject)
|
||||
ExpectStrEqual(t, withSubject.Error(), "foo: error")
|
||||
ExpectTrue(t, withSubject.Is(err))
|
||||
|
||||
ExpectError(t, err, withSubjectf)
|
||||
ExpectStrEqual(t, withSubjectf.Error(), "foo bar: error")
|
||||
ExpectTrue(t, withSubjectf.Is(err))
|
||||
}
|
||||
|
||||
func TestBaseWithExtra(t *testing.T) {
|
||||
err := New("error")
|
||||
extra := New("bar").Subject("baz")
|
||||
withExtra := err.With(extra)
|
||||
|
||||
ExpectTrue(t, withExtra.Is(extra))
|
||||
ExpectTrue(t, withExtra.Is(err))
|
||||
|
||||
ExpectTrue(t, errors.Is(withExtra, extra))
|
||||
ExpectTrue(t, errors.Is(withExtra, err))
|
||||
|
||||
ExpectTrue(t, strings.Contains(withExtra.Error(), err.Error()))
|
||||
ExpectTrue(t, strings.Contains(withExtra.Error(), extra.Error()))
|
||||
ExpectTrue(t, strings.Contains(withExtra.Error(), "baz"))
|
||||
}
|
||||
|
||||
func TestBaseUnwrap(t *testing.T) {
|
||||
err := errors.New("err")
|
||||
wrapped := From(err)
|
||||
|
||||
ExpectError(t, err, errors.Unwrap(wrapped))
|
||||
}
|
||||
|
||||
func TestNestedUnwrap(t *testing.T) {
|
||||
err := errors.New("err")
|
||||
err2 := New("err2")
|
||||
wrapped := From(err).Subject("foo").With(err2.Subject("bar"))
|
||||
|
||||
unwrapper, ok := wrapped.(interface{ Unwrap() []error })
|
||||
ExpectTrue(t, ok)
|
||||
|
||||
ExpectError(t, err, wrapped)
|
||||
ExpectError(t, err2, wrapped)
|
||||
ExpectEqual(t, len(unwrapper.Unwrap()), 2)
|
||||
}
|
||||
|
||||
func TestErrorIs(t *testing.T) {
|
||||
ExpectTrue(t, Failure("foo").Is(ErrFailure))
|
||||
ExpectTrue(t, Failure("foo").With("bar").Is(ErrFailure))
|
||||
ExpectFalse(t, Failure("foo").With("bar").Is(ErrInvalid))
|
||||
ExpectFalse(t, Failure("foo").With("bar").With("baz").Is(ErrInvalid))
|
||||
from := errors.New("error")
|
||||
err := From(from)
|
||||
ExpectError(t, from, err)
|
||||
|
||||
ExpectTrue(t, Invalid("foo", "bar").Is(ErrInvalid))
|
||||
ExpectFalse(t, Invalid("foo", "bar").Is(ErrFailure))
|
||||
ExpectTrue(t, err.Is(from))
|
||||
ExpectFalse(t, err.Is(New("error")))
|
||||
|
||||
ExpectFalse(t, Invalid("foo", "bar").Is(nil))
|
||||
|
||||
ExpectTrue(t, errors.Is(Failure("foo").Error(), ErrFailure))
|
||||
ExpectTrue(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrInvalid))
|
||||
ExpectTrue(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrFailure))
|
||||
ExpectFalse(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrNotExists))
|
||||
ExpectTrue(t, errors.Is(err.Subject("foo"), from))
|
||||
ExpectTrue(t, errors.Is(err.Withf("foo"), from))
|
||||
ExpectTrue(t, errors.Is(err.Subject("foo").Withf("bar"), from))
|
||||
}
|
||||
|
||||
func TestErrorNestedIs(t *testing.T) {
|
||||
var err NestedError
|
||||
ExpectTrue(t, err.Is(nil))
|
||||
func TestErrorImmutability(t *testing.T) {
|
||||
err := New("err")
|
||||
err2 := New("err2")
|
||||
|
||||
err = Failure("some reason")
|
||||
ExpectTrue(t, err.Is(ErrFailure))
|
||||
ExpectFalse(t, err.Is(ErrDuplicated))
|
||||
for range 3 {
|
||||
// t.Logf("%d: %v %T %s", i, errors.Unwrap(err), err, err)
|
||||
err.Subject("foo")
|
||||
ExpectFalse(t, strings.Contains(err.Error(), "foo"))
|
||||
|
||||
err.With(Duplicated("something", ""))
|
||||
ExpectTrue(t, err.Is(ErrFailure))
|
||||
ExpectTrue(t, err.Is(ErrDuplicated))
|
||||
ExpectFalse(t, err.Is(ErrInvalid))
|
||||
}
|
||||
err.With(err2)
|
||||
ExpectFalse(t, strings.Contains(err.Error(), "extra"))
|
||||
ExpectFalse(t, err.Is(err2))
|
||||
|
||||
func TestIsNil(t *testing.T) {
|
||||
var err NestedError
|
||||
ExpectTrue(t, err.Is(nil))
|
||||
ExpectFalse(t, err.HasError())
|
||||
ExpectTrue(t, err == nil)
|
||||
ExpectTrue(t, err.NoError())
|
||||
|
||||
eb := NewBuilder("")
|
||||
returnNil := func() error {
|
||||
return eb.Build().Error()
|
||||
err = err.Subject("bar").Withf("baz")
|
||||
ExpectTrue(t, err != nil)
|
||||
}
|
||||
ExpectTrue(t, IsNil(returnNil()))
|
||||
ExpectTrue(t, returnNil() == nil)
|
||||
|
||||
ExpectTrue(t, (err.
|
||||
Subject("any").
|
||||
With("something").
|
||||
Extraf("foo %s", "bar")) == nil)
|
||||
}
|
||||
|
||||
func TestErrorSimple(t *testing.T) {
|
||||
ne := Failure("foo bar")
|
||||
ExpectEqual(t, ne.String(), "foo bar failed")
|
||||
ne = ne.Subject("baz")
|
||||
ExpectEqual(t, ne.String(), "foo bar failed for \"baz\"")
|
||||
}
|
||||
|
||||
func TestErrorWith(t *testing.T) {
|
||||
ne := Failure("foo").With("bar").With("baz")
|
||||
ExpectEqual(t, ne.String(), "foo failed:\n - bar\n - baz")
|
||||
err1 := New("err1")
|
||||
err2 := New("err2")
|
||||
|
||||
err3 := err1.With(err2)
|
||||
|
||||
ExpectTrue(t, err3.Is(err1))
|
||||
ExpectTrue(t, err3.Is(err2))
|
||||
|
||||
err2.Subject("foo")
|
||||
|
||||
ExpectTrue(t, err3.Is(err1))
|
||||
ExpectTrue(t, err3.Is(err2))
|
||||
|
||||
// check if err3 is affected by err2.Subject
|
||||
ExpectFalse(t, strings.Contains(err3.Error(), "foo"))
|
||||
}
|
||||
|
||||
func TestErrorNested(t *testing.T) {
|
||||
inner := Failure("inner").
|
||||
With("1").
|
||||
With("1")
|
||||
inner2 := Failure("inner2").
|
||||
func TestErrorStringSimple(t *testing.T) {
|
||||
errFailure := New("generic failure")
|
||||
ne := errFailure.Subject("foo bar")
|
||||
ExpectStrEqual(t, ne.Error(), "foo bar: generic failure")
|
||||
ne = ne.Subject("baz")
|
||||
ExpectStrEqual(t, ne.Error(), "baz > foo bar: generic failure")
|
||||
}
|
||||
|
||||
func TestErrorStringNested(t *testing.T) {
|
||||
errFailure := New("generic failure")
|
||||
inner := errFailure.Subject("inner").
|
||||
Withf("1").
|
||||
Withf("1")
|
||||
inner2 := errFailure.Subject("inner2").
|
||||
Subject("action 2").
|
||||
With("2").
|
||||
With("2")
|
||||
inner3 := Failure("inner3").
|
||||
Withf("2").
|
||||
Withf("2")
|
||||
inner3 := errFailure.Subject("inner3").
|
||||
Subject("action 3").
|
||||
With("3").
|
||||
With("3")
|
||||
ne := Failure("foo").
|
||||
With("bar").
|
||||
With("baz").
|
||||
Withf("3").
|
||||
Withf("3")
|
||||
ne := errFailure.
|
||||
Subject("foo").
|
||||
Withf("bar").
|
||||
Withf("baz").
|
||||
With(inner).
|
||||
With(inner.With(inner2.With(inner3)))
|
||||
want :=
|
||||
`foo failed:
|
||||
- bar
|
||||
- baz
|
||||
- inner failed:
|
||||
- 1
|
||||
- 1
|
||||
- inner failed:
|
||||
- 1
|
||||
- 1
|
||||
- inner2 failed for "action 2":
|
||||
- 2
|
||||
- 2
|
||||
- inner3 failed for "action 3":
|
||||
- 3
|
||||
- 3`
|
||||
ExpectEqual(t, ne.String(), want)
|
||||
ExpectEqual(t, ne.Error().Error(), want)
|
||||
want := `foo: generic failure
|
||||
• bar
|
||||
• baz
|
||||
• inner: generic failure
|
||||
• 1
|
||||
• 1
|
||||
• inner: generic failure
|
||||
• 1
|
||||
• 1
|
||||
• action 2 > inner2: generic failure
|
||||
• 2
|
||||
• 2
|
||||
• action 3 > inner3: generic failure
|
||||
• 3
|
||||
• 3`
|
||||
ExpectStrEqual(t, ne.Error(), want)
|
||||
}
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
package error
|
||||
|
||||
import (
|
||||
stderrors "errors"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrFailure = stderrors.New("failed")
|
||||
ErrInvalid = stderrors.New("invalid")
|
||||
ErrUnsupported = stderrors.New("unsupported")
|
||||
ErrUnexpected = stderrors.New("unexpected")
|
||||
ErrNotExists = stderrors.New("does not exist")
|
||||
ErrMissing = stderrors.New("missing")
|
||||
ErrDuplicated = stderrors.New("duplicated")
|
||||
ErrOutOfRange = stderrors.New("out of range")
|
||||
ErrTypeError = stderrors.New("type error")
|
||||
ErrTypeMismatch = stderrors.New("type mismatch")
|
||||
)
|
||||
|
||||
const fmtSubjectWhat = "%w %v: %q"
|
||||
|
||||
func Failure(what string) NestedError {
|
||||
return errorf("%s %w", what, ErrFailure)
|
||||
}
|
||||
|
||||
func FailedWhy(what string, why string) NestedError {
|
||||
return Failure(what).With(why)
|
||||
}
|
||||
|
||||
func FailWith(what string, err any) NestedError {
|
||||
return Failure(what).With(err)
|
||||
}
|
||||
|
||||
func Invalid(subject, what any) NestedError {
|
||||
return errorf(fmtSubjectWhat, ErrInvalid, subject, what)
|
||||
}
|
||||
|
||||
func Unsupported(subject, what any) NestedError {
|
||||
return errorf(fmtSubjectWhat, ErrUnsupported, subject, what)
|
||||
}
|
||||
|
||||
func Unexpected(subject, what any) NestedError {
|
||||
return errorf(fmtSubjectWhat, ErrUnexpected, subject, what)
|
||||
}
|
||||
|
||||
func UnexpectedError(err error) NestedError {
|
||||
return errorf("%w error: %w", ErrUnexpected, err)
|
||||
}
|
||||
|
||||
func NotExist(subject, what any) NestedError {
|
||||
return errorf("%v %w: %v", subject, ErrNotExists, what)
|
||||
}
|
||||
|
||||
func Missing(subject any) NestedError {
|
||||
return errorf("%w %v", ErrMissing, subject)
|
||||
}
|
||||
|
||||
func Duplicated(subject, what any) NestedError {
|
||||
return errorf("%w %v: %v", ErrDuplicated, subject, what)
|
||||
}
|
||||
|
||||
func OutOfRange(subject any, value any) NestedError {
|
||||
return errorf("%v %w: %v", subject, ErrOutOfRange, value)
|
||||
}
|
||||
|
||||
func TypeError(subject any, from, to reflect.Type) NestedError {
|
||||
return errorf("%v %w: %s -> %s\n", subject, ErrTypeError, from, to)
|
||||
}
|
||||
|
||||
func TypeError2(subject any, from, to reflect.Value) NestedError {
|
||||
return TypeError(subject, from.Type(), to.Type())
|
||||
}
|
||||
|
||||
func TypeMismatch[Expect any](value any) NestedError {
|
||||
return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value)
|
||||
}
|
||||
43
internal/error/log.go
Normal file
43
internal/error/log.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package error
|
||||
|
||||
import (
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
)
|
||||
|
||||
func getLogger(logger ...*zerolog.Logger) *zerolog.Logger {
|
||||
if len(logger) > 0 {
|
||||
return logger[0]
|
||||
}
|
||||
return logging.GetLogger()
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func LogFatal(msg string, err error, logger ...*zerolog.Logger) {
|
||||
getLogger(logger...).Fatal().Msg(err.Error())
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func LogError(msg string, err error, logger ...*zerolog.Logger) {
|
||||
getLogger(logger...).Error().Msg(err.Error())
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func LogWarn(msg string, err error, logger ...*zerolog.Logger) {
|
||||
getLogger(logger...).Warn().Msg(err.Error())
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func LogPanic(msg string, err error, logger ...*zerolog.Logger) {
|
||||
getLogger(logger...).Panic().Msg(err.Error())
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func LogInfo(msg string, err error, logger ...*zerolog.Logger) {
|
||||
getLogger(logger...).Info().Msg(err.Error())
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func LogDebug(msg string, err error, logger ...*zerolog.Logger) {
|
||||
getLogger(logger...).Debug().Msg(err.Error())
|
||||
}
|
||||
120
internal/error/nested_error.go
Normal file
120
internal/error/nested_error.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package error
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type nestedError struct {
|
||||
Err error `json:"err"`
|
||||
Extras []error `json:"extras"`
|
||||
}
|
||||
|
||||
func (err nestedError) Subject(subject string) Error {
|
||||
if err.Err == nil {
|
||||
err.Err = newError(subject)
|
||||
} else {
|
||||
err.Err = PrependSubject(subject, err.Err)
|
||||
}
|
||||
return &err
|
||||
}
|
||||
|
||||
func (err *nestedError) Subjectf(format string, args ...any) Error {
|
||||
if len(args) > 0 {
|
||||
return err.Subject(fmt.Sprintf(format, args...))
|
||||
}
|
||||
return err.Subject(format)
|
||||
}
|
||||
|
||||
func (err nestedError) With(extra error) Error {
|
||||
if extra != nil {
|
||||
err.Extras = append(err.Extras, extra)
|
||||
}
|
||||
return &err
|
||||
}
|
||||
|
||||
func (err nestedError) Withf(format string, args ...any) Error {
|
||||
if len(args) > 0 {
|
||||
err.Extras = append(err.Extras, fmt.Errorf(format, args...))
|
||||
} else {
|
||||
err.Extras = append(err.Extras, newError(format))
|
||||
}
|
||||
return &err
|
||||
}
|
||||
|
||||
func (err *nestedError) Unwrap() []error {
|
||||
if err.Err == nil {
|
||||
if len(err.Extras) == 0 {
|
||||
return nil
|
||||
}
|
||||
return err.Extras
|
||||
}
|
||||
return append([]error{err.Err}, err.Extras...)
|
||||
}
|
||||
|
||||
func (err *nestedError) Is(other error) bool {
|
||||
if errors.Is(err.Err, other) {
|
||||
return true
|
||||
}
|
||||
for _, e := range err.Extras {
|
||||
if errors.Is(e, other) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (err *nestedError) Error() string {
|
||||
return buildError(err, 0)
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func makeLine(err string, level int) string {
|
||||
const bulletPrefix = "• "
|
||||
const spaces = " "
|
||||
|
||||
if level == 0 {
|
||||
return err
|
||||
}
|
||||
return spaces[:2*level] + bulletPrefix + err
|
||||
}
|
||||
|
||||
func makeLines(errs []error, level int) []string {
|
||||
if len(errs) == 0 {
|
||||
return nil
|
||||
}
|
||||
lines := make([]string, 0, len(errs))
|
||||
for _, err := range errs {
|
||||
switch err := err.(type) {
|
||||
case *nestedError:
|
||||
if err.Err != nil {
|
||||
lines = append(lines, makeLine(err.Err.Error(), level))
|
||||
}
|
||||
if extras := makeLines(err.Extras, level+1); len(extras) > 0 {
|
||||
lines = append(lines, extras...)
|
||||
}
|
||||
default:
|
||||
lines = append(lines, makeLine(err.Error(), level))
|
||||
}
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
func buildError(err error, level int) string {
|
||||
switch err := err.(type) {
|
||||
case nil:
|
||||
return makeLine("<nil>", level)
|
||||
case *nestedError:
|
||||
lines := make([]string, 0, 1+len(err.Extras))
|
||||
if err.Err != nil {
|
||||
lines = append(lines, makeLine(err.Err.Error(), level))
|
||||
}
|
||||
if extras := makeLines(err.Extras, level+1); len(extras) > 0 {
|
||||
lines = append(lines, extras...)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
default:
|
||||
return makeLine(err.Error(), level)
|
||||
}
|
||||
}
|
||||
52
internal/error/subject.go
Normal file
52
internal/error/subject.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package error
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
|
||||
)
|
||||
|
||||
type withSubject struct {
|
||||
Subject string `json:"subject"`
|
||||
Err error `json:"err"`
|
||||
}
|
||||
|
||||
const subjectSep = " > "
|
||||
|
||||
func highlight(subject string) string {
|
||||
return ansi.HighlightRed + subject + ansi.Reset
|
||||
}
|
||||
|
||||
func PrependSubject(subject string, err error) error {
|
||||
switch err := err.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case *withSubject:
|
||||
return err.Prepend(subject)
|
||||
case Error:
|
||||
return err.Subject(subject)
|
||||
default:
|
||||
return &withSubject{subject, err}
|
||||
}
|
||||
}
|
||||
|
||||
func (err withSubject) Prepend(subject string) *withSubject {
|
||||
if subject != "" {
|
||||
err.Subject = subject + subjectSep + err.Subject
|
||||
}
|
||||
return &err
|
||||
}
|
||||
|
||||
func (err *withSubject) Is(other error) bool {
|
||||
return err.Err == other
|
||||
}
|
||||
|
||||
func (err *withSubject) Unwrap() error {
|
||||
return err.Err
|
||||
}
|
||||
|
||||
func (err *withSubject) Error() string {
|
||||
subjects := strings.Split(err.Subject, subjectSep)
|
||||
subjects[len(subjects)-1] = highlight(subjects[len(subjects)-1])
|
||||
return strings.Join(subjects, subjectSep) + ": " + err.Err.Error()
|
||||
}
|
||||
68
internal/error/utils.go
Normal file
68
internal/error/utils.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package error
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func newError(message string) error {
|
||||
return errStr(message)
|
||||
}
|
||||
|
||||
func New(message string) Error {
|
||||
if message == "" {
|
||||
return nil
|
||||
}
|
||||
return &baseError{newError(message)}
|
||||
}
|
||||
|
||||
func Errorf(format string, args ...any) Error {
|
||||
return &baseError{fmt.Errorf(format, args...)}
|
||||
}
|
||||
|
||||
func From(err error) Error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if err, ok := err.(Error); ok {
|
||||
return err
|
||||
}
|
||||
return &baseError{err}
|
||||
}
|
||||
|
||||
func Must[T any](v T, err error) T {
|
||||
if err != nil {
|
||||
LogPanic("must failed", err)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func Join(errors ...error) Error {
|
||||
n := 0
|
||||
for _, err := range errors {
|
||||
if err != nil {
|
||||
n++
|
||||
}
|
||||
}
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
errs := make([]error, 0, n)
|
||||
for _, err := range errors {
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
return &nestedError{Extras: errs}
|
||||
}
|
||||
|
||||
func Collect[T any, Err error, Arg any, Func func(Arg) (T, Err)](eb *Builder, fn Func, arg Arg) T {
|
||||
result, err := fn(arg)
|
||||
eb.Add(err)
|
||||
return result
|
||||
}
|
||||
|
||||
func Collect2[T any, Err error, Arg1 any, Arg2 any, Func func(Arg1, Arg2) (T, Err)](eb *Builder, fn Func, arg1 Arg1, arg2 Arg2) T {
|
||||
result, err := fn(arg1, arg2)
|
||||
eb.Add(err)
|
||||
return result
|
||||
}
|
||||
@@ -1,24 +1,24 @@
|
||||
package homepage
|
||||
|
||||
type (
|
||||
HomePageConfig map[string]HomePageCategory
|
||||
HomePageCategory []*HomePageItem
|
||||
Config map[string]Category
|
||||
Category []*Item
|
||||
|
||||
HomePageItem struct {
|
||||
Show bool `yaml:"show" json:"show"`
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Icon string `yaml:"icon" json:"icon"`
|
||||
URL string `yaml:"url" json:"url"` // alias + domain
|
||||
Category string `yaml:"category" json:"category"`
|
||||
Description string `yaml:"description" json:"description"`
|
||||
WidgetConfig map[string]any `yaml:",flow" json:"widget_config"`
|
||||
Item struct {
|
||||
Show bool `json:"show" yaml:"show"`
|
||||
Name string `json:"name" yaml:"name"`
|
||||
Icon string `json:"icon" yaml:"icon"`
|
||||
URL string `json:"url" yaml:"url"` // alias + domain
|
||||
Category string `json:"category" yaml:"category"`
|
||||
Description string `json:"description" yaml:"description"`
|
||||
WidgetConfig map[string]any `json:"widget_config" yaml:",flow"`
|
||||
|
||||
SourceType string `yaml:"-" json:"source_type"`
|
||||
AltURL string `yaml:"-" json:"alt_url"` // original proxy target
|
||||
SourceType string `json:"source_type" yaml:"-"`
|
||||
AltURL string `json:"alt_url" yaml:"-"` // original proxy target
|
||||
}
|
||||
)
|
||||
|
||||
func (item *HomePageItem) IsEmpty() bool {
|
||||
func (item *Item) IsEmpty() bool {
|
||||
return item == nil || (item.Name == "" &&
|
||||
item.Icon == "" &&
|
||||
item.URL == "" &&
|
||||
@@ -27,17 +27,17 @@ func (item *HomePageItem) IsEmpty() bool {
|
||||
len(item.WidgetConfig) == 0)
|
||||
}
|
||||
|
||||
func NewHomePageConfig() HomePageConfig {
|
||||
return HomePageConfig(make(map[string]HomePageCategory))
|
||||
func NewHomePageConfig() Config {
|
||||
return Config(make(map[string]Category))
|
||||
}
|
||||
|
||||
func (c *HomePageConfig) Clear() {
|
||||
*c = make(HomePageConfig)
|
||||
func (c *Config) Clear() {
|
||||
*c = make(Config)
|
||||
}
|
||||
|
||||
func (c HomePageConfig) Add(item *HomePageItem) {
|
||||
func (c Config) Add(item *Item) {
|
||||
if c[item.Category] == nil {
|
||||
c[item.Category] = make(HomePageCategory, 0)
|
||||
c[item.Category] = make(Category, 0)
|
||||
}
|
||||
c[item.Category] = append(c[item.Category], item)
|
||||
}
|
||||
|
||||
@@ -4,12 +4,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"log"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
@@ -21,8 +20,10 @@ type GitHubContents struct { //! keep this, may reuse in future
|
||||
Size int `json:"size"`
|
||||
}
|
||||
|
||||
const iconsCachePath = "/tmp/icons_cache.json"
|
||||
const updateInterval = 1 * time.Hour
|
||||
const (
|
||||
iconsCachePath = "/tmp/icons_cache.json"
|
||||
updateInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
func ListAvailableIcons() ([]string, error) {
|
||||
owner := "walkxcode"
|
||||
@@ -30,13 +31,14 @@ func ListAvailableIcons() ([]string, error) {
|
||||
ref := "main"
|
||||
|
||||
var lastUpdate time.Time
|
||||
var icons = make([]string, 0)
|
||||
|
||||
icons := make([]string, 0)
|
||||
info, err := os.Stat(iconsCachePath)
|
||||
if err == nil {
|
||||
lastUpdate = info.ModTime().Local()
|
||||
}
|
||||
if time.Since(lastUpdate) < updateInterval {
|
||||
err := utils.LoadJson(iconsCachePath, &icons)
|
||||
err := utils.LoadJSON(iconsCachePath, &icons)
|
||||
if err == nil {
|
||||
return icons, nil
|
||||
}
|
||||
@@ -51,7 +53,7 @@ func ListAvailableIcons() ([]string, error) {
|
||||
icons = append(icons, content.Path)
|
||||
}
|
||||
}
|
||||
err = utils.SaveJson(iconsCachePath, &icons, 0o644).Error()
|
||||
err = utils.SaveJSON(iconsCachePath, &icons, 0o644)
|
||||
if err != nil {
|
||||
log.Print("error saving cache", err)
|
||||
}
|
||||
@@ -59,7 +61,7 @@ func ListAvailableIcons() ([]string, error) {
|
||||
}
|
||||
|
||||
func getRepoContents(client *http.Client, owner string, repo string, ref string, path string) ([]GitHubContents, error) {
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("https://api.github.com/repos/%s/%s/contents/%s?ref=%s", owner, repo, path, ref), nil)
|
||||
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://api.github.com/repos/%s/%s/contents/%s?ref=%s", owner, repo, path, ref), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
70
internal/logging/logging.go
Normal file
70
internal/logging/logging.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
)
|
||||
|
||||
var logger zerolog.Logger
|
||||
|
||||
func init() {
|
||||
var timeFmt string
|
||||
var level zerolog.Level
|
||||
var exclude []string
|
||||
|
||||
switch {
|
||||
case common.IsTrace:
|
||||
timeFmt = "04:05"
|
||||
level = zerolog.TraceLevel
|
||||
case common.IsDebug:
|
||||
timeFmt = "01-02 15:04"
|
||||
level = zerolog.DebugLevel
|
||||
default:
|
||||
timeFmt = "01-02 15:04"
|
||||
level = zerolog.InfoLevel
|
||||
exclude = []string{"module"}
|
||||
}
|
||||
|
||||
prefixLength := len(timeFmt) + 5 // level takes 3 + 2 spaces
|
||||
prefix := strings.Repeat(" ", prefixLength)
|
||||
|
||||
logger = zerolog.New(
|
||||
zerolog.ConsoleWriter{
|
||||
Out: os.Stderr,
|
||||
TimeFormat: timeFmt,
|
||||
FieldsExclude: exclude,
|
||||
FormatMessage: func(msgI interface{}) string { // pad spaces for each line
|
||||
msg := msgI.(string)
|
||||
lines := strings.Split(msg, "\n")
|
||||
if len(lines) == 1 {
|
||||
return msg
|
||||
}
|
||||
for i := 1; i < len(lines); i++ {
|
||||
lines[i] = prefix + lines[i]
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
},
|
||||
},
|
||||
).Level(level).With().Timestamp().Logger()
|
||||
}
|
||||
|
||||
func DiscardLogger() { logger = zerolog.Nop() }
|
||||
|
||||
func AddHook(h zerolog.Hook) { logger = logger.Hook(h) }
|
||||
|
||||
func GetLogger() *zerolog.Logger { return &logger }
|
||||
func With() zerolog.Context { return logger.With() }
|
||||
|
||||
func WithLevel(level zerolog.Level) *zerolog.Event { return logger.WithLevel(level) }
|
||||
|
||||
func Info() *zerolog.Event { return logger.Info() }
|
||||
func Warn() *zerolog.Event { return logger.Warn() }
|
||||
func Error() *zerolog.Event { return logger.Error() }
|
||||
func Err(err error) *zerolog.Event { return logger.Err(err) }
|
||||
func Debug() *zerolog.Event { return logger.Debug() }
|
||||
func Fatal() *zerolog.Event { return logger.Fatal() }
|
||||
func Panic() *zerolog.Event { return logger.Panic() }
|
||||
func Trace() *zerolog.Event { return logger.Trace() }
|
||||
@@ -16,13 +16,12 @@ var (
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: defaultDialer.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
MaxIdleConnsPerHost: 100,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
DefaultTransportNoTLS = func() *http.Transport {
|
||||
var clone = DefaultTransport.Clone()
|
||||
clone := DefaultTransport.Clone()
|
||||
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
return clone
|
||||
}()
|
||||
|
||||
@@ -5,7 +5,10 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ContentType string
|
||||
type (
|
||||
ContentType string
|
||||
AcceptContentType []ContentType
|
||||
)
|
||||
|
||||
func GetContentType(h http.Header) ContentType {
|
||||
ct := h.Get("Content-Type")
|
||||
@@ -19,6 +22,18 @@ func GetContentType(h http.Header) ContentType {
|
||||
return ContentType(ct)
|
||||
}
|
||||
|
||||
func GetAccept(h http.Header) AcceptContentType {
|
||||
var accepts []ContentType
|
||||
for _, v := range h["Accept"] {
|
||||
ct, _, err := mime.ParseMediaType(v)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
accepts = append(accepts, ContentType(ct))
|
||||
}
|
||||
return accepts
|
||||
}
|
||||
|
||||
func (ct ContentType) IsHTML() bool {
|
||||
return ct == "text/html" || ct == "application/xhtml+xml"
|
||||
}
|
||||
@@ -30,3 +45,34 @@ func (ct ContentType) IsJSON() bool {
|
||||
func (ct ContentType) IsPlainText() bool {
|
||||
return ct == "text/plain"
|
||||
}
|
||||
|
||||
func (act AcceptContentType) IsEmpty() bool {
|
||||
return len(act) == 0
|
||||
}
|
||||
|
||||
func (act AcceptContentType) AcceptHTML() bool {
|
||||
for _, v := range act {
|
||||
if v.IsHTML() || v == "text/*" || v == "*/*" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (act AcceptContentType) AcceptJSON() bool {
|
||||
for _, v := range act {
|
||||
if v.IsJSON() || v == "*/*" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (act AcceptContentType) AcceptPlainText() bool {
|
||||
for _, v := range act {
|
||||
if v.IsPlainText() || v == "text/*" || v == "*/*" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
41
internal/net/http/content_type_test.go
Normal file
41
internal/net/http/content_type_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestContentTypes(t *testing.T) {
|
||||
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsHTML())
|
||||
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/html; charset=utf-8"}}).IsHTML())
|
||||
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/xhtml+xml"}}).IsHTML())
|
||||
ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/plain"}}).IsHTML())
|
||||
|
||||
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/json"}}).IsJSON())
|
||||
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/json; charset=utf-8"}}).IsJSON())
|
||||
ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsJSON())
|
||||
|
||||
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/plain"}}).IsPlainText())
|
||||
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/plain; charset=utf-8"}}).IsPlainText())
|
||||
ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsPlainText())
|
||||
}
|
||||
|
||||
func TestAcceptContentTypes(t *testing.T) {
|
||||
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain"}}).AcceptPlainText())
|
||||
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain; charset=utf-8"}}).AcceptPlainText())
|
||||
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain"}}).AcceptHTML())
|
||||
ExpectTrue(t, GetAccept(http.Header{"Accept": {"application/json"}}).AcceptJSON())
|
||||
ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptPlainText())
|
||||
ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptHTML())
|
||||
ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptJSON())
|
||||
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptPlainText())
|
||||
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptHTML())
|
||||
|
||||
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/plain"}}).AcceptHTML())
|
||||
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/plain; charset=utf-8"}}).AcceptHTML())
|
||||
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/html"}}).AcceptPlainText())
|
||||
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/html"}}).AcceptJSON())
|
||||
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptJSON())
|
||||
}
|
||||
15
internal/net/http/dummy_response_writer.go
Normal file
15
internal/net/http/dummy_response_writer.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package http
|
||||
|
||||
import "net/http"
|
||||
|
||||
type DummyResponseWriter struct{}
|
||||
|
||||
func (w DummyResponseWriter) Header() http.Header {
|
||||
return make(http.Header)
|
||||
}
|
||||
|
||||
func (w DummyResponseWriter) Write([]byte) (_ int, _ error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (w DummyResponseWriter) WriteHeader(int) {}
|
||||
91
internal/net/http/loadbalancer/ip_hash.go
Normal file
91
internal/net/http/loadbalancer/ip_hash.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
)
|
||||
|
||||
type ipHash struct {
|
||||
*LoadBalancer
|
||||
|
||||
realIP *middleware.Middleware
|
||||
pool servers
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) newIPHash() impl {
|
||||
impl := &ipHash{LoadBalancer: lb}
|
||||
if len(lb.Options) == 0 {
|
||||
return impl
|
||||
}
|
||||
var err E.Error
|
||||
impl.realIP, err = middleware.NewRealIP(lb.Options)
|
||||
if err != nil {
|
||||
E.LogError("invalid real_ip options, ignoring", err, &impl.Logger)
|
||||
}
|
||||
return impl
|
||||
}
|
||||
|
||||
func (impl *ipHash) OnAddServer(srv *Server) {
|
||||
impl.mu.Lock()
|
||||
defer impl.mu.Unlock()
|
||||
|
||||
for i, s := range impl.pool {
|
||||
if s == srv {
|
||||
return
|
||||
}
|
||||
if s == nil {
|
||||
impl.pool[i] = srv
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
impl.pool = append(impl.pool, srv)
|
||||
}
|
||||
|
||||
func (impl *ipHash) OnRemoveServer(srv *Server) {
|
||||
impl.mu.Lock()
|
||||
defer impl.mu.Unlock()
|
||||
|
||||
for i, s := range impl.pool {
|
||||
if s == srv {
|
||||
impl.pool[i] = nil
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (impl *ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) {
|
||||
if impl.realIP != nil {
|
||||
impl.realIP.ModifyRequest(impl.serveHTTP, rw, r)
|
||||
} else {
|
||||
impl.serveHTTP(rw, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (impl *ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
impl.Err(err).Msg("invalid remote address " + r.RemoteAddr)
|
||||
return
|
||||
}
|
||||
idx := hashIP(ip) % uint32(len(impl.pool))
|
||||
|
||||
srv := impl.pool[idx]
|
||||
if srv == nil || srv.Status().Bad() {
|
||||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||
}
|
||||
srv.ServeHTTP(rw, r)
|
||||
}
|
||||
|
||||
func hashIP(ip string) uint32 {
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(ip))
|
||||
return h.Sum32()
|
||||
}
|
||||
53
internal/net/http/loadbalancer/least_conn.go
Normal file
53
internal/net/http/loadbalancer/least_conn.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
type leastConn struct {
|
||||
*LoadBalancer
|
||||
nConn F.Map[*Server, *atomic.Int64]
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) newLeastConn() impl {
|
||||
return &leastConn{
|
||||
LoadBalancer: lb,
|
||||
nConn: F.NewMapOf[*Server, *atomic.Int64](),
|
||||
}
|
||||
}
|
||||
|
||||
func (impl *leastConn) OnAddServer(srv *Server) {
|
||||
impl.nConn.Store(srv, new(atomic.Int64))
|
||||
}
|
||||
|
||||
func (impl *leastConn) OnRemoveServer(srv *Server) {
|
||||
impl.nConn.Delete(srv)
|
||||
}
|
||||
|
||||
func (impl *leastConn) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
|
||||
srv := srvs[0]
|
||||
minConn, ok := impl.nConn.Load(srv)
|
||||
if !ok {
|
||||
impl.Error().Msgf("[BUG] server %s not found", srv.Name)
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
for i := 1; i < len(srvs); i++ {
|
||||
nConn, ok := impl.nConn.Load(srvs[i])
|
||||
if !ok {
|
||||
impl.Error().Msgf("[BUG] server %s not found", srv.Name)
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
}
|
||||
if nConn.Load() < minConn.Load() {
|
||||
minConn = nConn
|
||||
srv = srvs[i]
|
||||
}
|
||||
}
|
||||
|
||||
minConn.Add(1)
|
||||
srv.ServeHTTP(rw, r)
|
||||
minConn.Add(-1)
|
||||
}
|
||||
297
internal/net/http/loadbalancer/loadbalancer.go
Normal file
297
internal/net/http/loadbalancer/loadbalancer.go
Normal file
@@ -0,0 +1,297 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
// TODO: stats of each server.
|
||||
// TODO: support weighted mode.
|
||||
type (
|
||||
impl interface {
|
||||
ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request)
|
||||
OnAddServer(srv *Server)
|
||||
OnRemoveServer(srv *Server)
|
||||
}
|
||||
Config struct {
|
||||
Link string `json:"link" yaml:"link"`
|
||||
Mode Mode `json:"mode" yaml:"mode"`
|
||||
Weight weightType `json:"weight" yaml:"weight"`
|
||||
Options middleware.OptionsRaw `json:"options,omitempty" yaml:"options,omitempty"`
|
||||
}
|
||||
LoadBalancer struct {
|
||||
zerolog.Logger
|
||||
|
||||
impl
|
||||
*Config
|
||||
|
||||
task task.Task
|
||||
|
||||
pool Pool
|
||||
poolMu sync.Mutex
|
||||
|
||||
sumWeight weightType
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
weightType uint16
|
||||
)
|
||||
|
||||
const maxWeight weightType = 100
|
||||
|
||||
func New(cfg *Config) *LoadBalancer {
|
||||
lb := &LoadBalancer{
|
||||
Logger: logger.With().Str("name", cfg.Link).Logger(),
|
||||
Config: new(Config),
|
||||
pool: newPool(),
|
||||
}
|
||||
lb.UpdateConfigIfNeeded(cfg)
|
||||
return lb
|
||||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (lb *LoadBalancer) Start(routeSubtask task.Task) E.Error {
|
||||
lb.startTime = time.Now()
|
||||
lb.task = routeSubtask
|
||||
lb.task.OnFinished("loadbalancer cleanup", func() {
|
||||
if lb.impl != nil {
|
||||
lb.pool.RangeAll(func(k string, v *Server) {
|
||||
lb.impl.OnRemoveServer(v)
|
||||
})
|
||||
}
|
||||
lb.pool.Clear()
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Finish implements task.TaskFinisher.
|
||||
func (lb *LoadBalancer) Finish(reason any) {
|
||||
lb.task.Finish(reason)
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) updateImpl() {
|
||||
switch lb.Mode {
|
||||
case Unset, RoundRobin:
|
||||
lb.impl = lb.newRoundRobin()
|
||||
case LeastConn:
|
||||
lb.impl = lb.newLeastConn()
|
||||
case IPHash:
|
||||
lb.impl = lb.newIPHash()
|
||||
default: // should happen in test only
|
||||
lb.impl = lb.newRoundRobin()
|
||||
}
|
||||
lb.pool.RangeAll(func(_ string, srv *Server) {
|
||||
lb.impl.OnAddServer(srv)
|
||||
})
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
|
||||
if cfg != nil {
|
||||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
|
||||
lb.Link = cfg.Link
|
||||
|
||||
if lb.Mode == Unset && cfg.Mode != Unset {
|
||||
lb.Mode = cfg.Mode
|
||||
if !lb.Mode.ValidateUpdate() {
|
||||
lb.Error().Msgf("invalid mode %q, fallback to %q", cfg.Mode, lb.Mode)
|
||||
}
|
||||
lb.updateImpl()
|
||||
}
|
||||
|
||||
if len(lb.Options) == 0 && len(cfg.Options) > 0 {
|
||||
lb.Options = cfg.Options
|
||||
}
|
||||
}
|
||||
|
||||
if lb.impl == nil {
|
||||
lb.updateImpl()
|
||||
}
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) AddServer(srv *Server) {
|
||||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
|
||||
if lb.pool.Has(srv.Name) {
|
||||
old, _ := lb.pool.Load(srv.Name)
|
||||
lb.sumWeight -= old.Weight
|
||||
lb.impl.OnRemoveServer(old)
|
||||
}
|
||||
lb.pool.Store(srv.Name, srv)
|
||||
lb.sumWeight += srv.Weight
|
||||
|
||||
lb.rebalance()
|
||||
lb.impl.OnAddServer(srv)
|
||||
|
||||
lb.Debug().
|
||||
Str("action", "add").
|
||||
Str("server", srv.Name).
|
||||
Msgf("%d servers available", lb.pool.Size())
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) RemoveServer(srv *Server) {
|
||||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
|
||||
if !lb.pool.Has(srv.Name) {
|
||||
return
|
||||
}
|
||||
|
||||
lb.pool.Delete(srv.Name)
|
||||
|
||||
lb.sumWeight -= srv.Weight
|
||||
lb.rebalance()
|
||||
lb.impl.OnRemoveServer(srv)
|
||||
|
||||
lb.Debug().
|
||||
Str("action", "remove").
|
||||
Str("server", srv.Name).
|
||||
Msgf("%d servers left", lb.pool.Size())
|
||||
|
||||
if lb.pool.Size() == 0 {
|
||||
lb.task.Finish("no server left")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) rebalance() {
|
||||
if lb.sumWeight == maxWeight {
|
||||
return
|
||||
}
|
||||
if lb.pool.Size() == 0 {
|
||||
return
|
||||
}
|
||||
if lb.sumWeight == 0 { // distribute evenly
|
||||
weightEach := maxWeight / weightType(lb.pool.Size())
|
||||
remainder := maxWeight % weightType(lb.pool.Size())
|
||||
lb.pool.RangeAll(func(_ string, s *Server) {
|
||||
s.Weight = weightEach
|
||||
lb.sumWeight += weightEach
|
||||
if remainder > 0 {
|
||||
s.Weight++
|
||||
remainder--
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// scale evenly
|
||||
scaleFactor := float64(maxWeight) / float64(lb.sumWeight)
|
||||
lb.sumWeight = 0
|
||||
|
||||
lb.pool.RangeAll(func(_ string, s *Server) {
|
||||
s.Weight = weightType(float64(s.Weight) * scaleFactor)
|
||||
lb.sumWeight += s.Weight
|
||||
})
|
||||
|
||||
delta := maxWeight - lb.sumWeight
|
||||
if delta == 0 {
|
||||
return
|
||||
}
|
||||
lb.pool.Range(func(_ string, s *Server) bool {
|
||||
if delta == 0 {
|
||||
return false
|
||||
}
|
||||
if delta > 0 {
|
||||
s.Weight++
|
||||
lb.sumWeight++
|
||||
delta--
|
||||
} else {
|
||||
s.Weight--
|
||||
lb.sumWeight--
|
||||
delta++
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
srvs := lb.availServers()
|
||||
if len(srvs) == 0 {
|
||||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
if r.Header.Get(common.HeaderCheckRedirect) != "" {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 1*time.Second)
|
||||
defer cancel()
|
||||
// send dummy request to wake all servers
|
||||
var dummyRW gphttp.DummyResponseWriter
|
||||
for _, srv := range srvs {
|
||||
// wake only if server implements Waker
|
||||
_, ok := srv.handler.(idlewatcher.Waker)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
wakeReq := r.Clone(ctx)
|
||||
srv.ServeHTTP(dummyRW, wakeReq)
|
||||
}
|
||||
}
|
||||
lb.impl.ServeHTTP(srvs, rw, r)
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) Uptime() time.Duration {
|
||||
return time.Since(lb.startTime)
|
||||
}
|
||||
|
||||
// MarshalJSON implements health.HealthMonitor.
|
||||
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
|
||||
extra := make(map[string]any)
|
||||
lb.pool.RangeAll(func(k string, v *Server) {
|
||||
extra[v.Name] = v.healthMon
|
||||
})
|
||||
|
||||
return (&health.JSONRepresentation{
|
||||
Name: lb.Name(),
|
||||
Status: lb.Status(),
|
||||
Started: lb.startTime,
|
||||
Uptime: lb.Uptime(),
|
||||
Extra: map[string]any{
|
||||
"config": lb.Config,
|
||||
"pool": extra,
|
||||
},
|
||||
}).MarshalJSON()
|
||||
}
|
||||
|
||||
// Name implements health.HealthMonitor.
|
||||
func (lb *LoadBalancer) Name() string {
|
||||
return lb.Link
|
||||
}
|
||||
|
||||
// Status implements health.HealthMonitor.
|
||||
func (lb *LoadBalancer) Status() health.Status {
|
||||
if lb.pool.Size() == 0 {
|
||||
return health.StatusUnknown
|
||||
}
|
||||
if len(lb.availServers()) == 0 {
|
||||
return health.StatusUnhealthy
|
||||
}
|
||||
return health.StatusHealthy
|
||||
}
|
||||
|
||||
// String implements health.HealthMonitor.
|
||||
func (lb *LoadBalancer) String() string {
|
||||
return lb.Name()
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) availServers() []*Server {
|
||||
avail := make([]*Server, 0, lb.pool.Size())
|
||||
lb.pool.RangeAll(func(_ string, srv *Server) {
|
||||
if srv.Status().Good() {
|
||||
avail = append(avail, srv)
|
||||
}
|
||||
})
|
||||
return avail
|
||||
}
|
||||
43
internal/net/http/loadbalancer/loadbalancer_test.go
Normal file
43
internal/net/http/loadbalancer/loadbalancer_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestRebalance(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("zero", func(t *testing.T) {
|
||||
lb := New(new(Config))
|
||||
for range 10 {
|
||||
lb.AddServer(&Server{})
|
||||
}
|
||||
lb.rebalance()
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
t.Run("less", func(t *testing.T) {
|
||||
lb := New(new(Config))
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||
lb.rebalance()
|
||||
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
t.Run("more", func(t *testing.T) {
|
||||
lb := New(new(Config))
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .4)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||
lb.rebalance()
|
||||
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
}
|
||||
5
internal/net/http/loadbalancer/logger.go
Normal file
5
internal/net/http/loadbalancer/logger.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package loadbalancer
|
||||
|
||||
import "github.com/yusing/go-proxy/internal/logging"
|
||||
|
||||
var logger = logging.With().Str("module", "load_balancer").Logger()
|
||||
32
internal/net/http/loadbalancer/mode.go
Normal file
32
internal/net/http/loadbalancer/mode.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type Mode string
|
||||
|
||||
const (
|
||||
Unset Mode = ""
|
||||
RoundRobin Mode = "roundrobin"
|
||||
LeastConn Mode = "leastconn"
|
||||
IPHash Mode = "iphash"
|
||||
)
|
||||
|
||||
func (mode *Mode) ValidateUpdate() bool {
|
||||
switch strutils.ToLowerNoSnake(string(*mode)) {
|
||||
case "":
|
||||
return true
|
||||
case string(RoundRobin):
|
||||
*mode = RoundRobin
|
||||
return true
|
||||
case string(LeastConn):
|
||||
*mode = LeastConn
|
||||
return true
|
||||
case string(IPHash):
|
||||
*mode = IPHash
|
||||
return true
|
||||
}
|
||||
*mode = RoundRobin
|
||||
return false
|
||||
}
|
||||
22
internal/net/http/loadbalancer/round_robin.go
Normal file
22
internal/net/http/loadbalancer/round_robin.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type roundRobin struct {
|
||||
index atomic.Uint32
|
||||
}
|
||||
|
||||
func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} }
|
||||
func (lb *roundRobin) OnAddServer(srv *Server) {}
|
||||
func (lb *roundRobin) OnRemoveServer(srv *Server) {}
|
||||
|
||||
func (lb *roundRobin) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
|
||||
index := lb.index.Add(1) % uint32(len(srvs))
|
||||
srvs[index].ServeHTTP(rw, r)
|
||||
if lb.index.Load() >= 2*uint32(len(srvs)) {
|
||||
lb.index.Store(0)
|
||||
}
|
||||
}
|
||||
55
internal/net/http/loadbalancer/server.go
Normal file
55
internal/net/http/loadbalancer/server.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
type (
|
||||
Server struct {
|
||||
_ U.NoCopy
|
||||
|
||||
Name string
|
||||
URL types.URL
|
||||
Weight weightType
|
||||
|
||||
handler http.Handler
|
||||
healthMon health.HealthMonitor
|
||||
}
|
||||
servers = []*Server
|
||||
Pool = F.Map[string, *Server]
|
||||
)
|
||||
|
||||
var newPool = F.NewMap[Pool]
|
||||
|
||||
func NewServer(name string, url types.URL, weight weightType, handler http.Handler, healthMon health.HealthMonitor) *Server {
|
||||
srv := &Server{
|
||||
Name: name,
|
||||
URL: url,
|
||||
Weight: weight,
|
||||
handler: handler,
|
||||
healthMon: healthMon,
|
||||
}
|
||||
return srv
|
||||
}
|
||||
|
||||
func (srv *Server) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
srv.handler.ServeHTTP(rw, r)
|
||||
}
|
||||
|
||||
func (srv *Server) String() string {
|
||||
return srv.Name
|
||||
}
|
||||
|
||||
func (srv *Server) Status() health.Status {
|
||||
return srv.healthMon.Status()
|
||||
}
|
||||
|
||||
func (srv *Server) Uptime() time.Duration {
|
||||
return srv.healthMon.Uptime()
|
||||
}
|
||||
5
internal/net/http/logger.go
Normal file
5
internal/net/http/logger.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package http
|
||||
|
||||
import "github.com/yusing/go-proxy/internal/logging"
|
||||
|
||||
var logger = logging.With().Str("module", "http").Logger()
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
@@ -15,9 +15,9 @@ type cidrWhitelist struct {
|
||||
}
|
||||
|
||||
type cidrWhitelistOpts struct {
|
||||
Allow []*types.CIDR
|
||||
StatusCode int
|
||||
Message string
|
||||
Allow []*types.CIDR `json:"allow"`
|
||||
StatusCode int `json:"statusCode"`
|
||||
Message string `json:"message"`
|
||||
|
||||
cachedAddr F.Map[string, bool] // cache for trusted IPs
|
||||
}
|
||||
@@ -35,7 +35,7 @@ var cidrWhitelistDefaults = func() *cidrWhitelistOpts {
|
||||
}
|
||||
}
|
||||
|
||||
func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.NestedError) {
|
||||
func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
|
||||
wl := new(cidrWhitelist)
|
||||
wl.m = &Middleware{
|
||||
impl: wl,
|
||||
@@ -47,7 +47,7 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.NestedError) {
|
||||
return nil, err
|
||||
}
|
||||
if len(wl.cidrWhitelistOpts.Allow) == 0 {
|
||||
return nil, E.Missing("allow range")
|
||||
return nil, E.New("no allowed CIDRs")
|
||||
}
|
||||
return wl.m, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
@@ -13,10 +14,9 @@ var testCIDRWhitelistCompose []byte
|
||||
var deny, accept *Middleware
|
||||
|
||||
func TestCIDRWhitelist(t *testing.T) {
|
||||
mids, err := BuildMiddlewaresFromYAML(testCIDRWhitelistCompose)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
errs := E.NewBuilder("")
|
||||
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
|
||||
ExpectNoError(t, errs.Error())
|
||||
deny = mids["deny@file"]
|
||||
accept = mids["accept@file"]
|
||||
if deny == nil || accept == nil {
|
||||
@@ -26,7 +26,7 @@ func TestCIDRWhitelist(t *testing.T) {
|
||||
t.Run("deny", func(t *testing.T) {
|
||||
for range 10 {
|
||||
result, err := newMiddlewareTest(deny, nil)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults().StatusCode)
|
||||
ExpectEqual(t, string(result.Data), cidrWhitelistDefaults().Message)
|
||||
}
|
||||
@@ -35,7 +35,7 @@ func TestCIDRWhitelist(t *testing.T) {
|
||||
t.Run("accept", func(t *testing.T) {
|
||||
for range 10 {
|
||||
result, err := newMiddlewareTest(accept, nil)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -10,10 +10,10 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -26,14 +26,14 @@ const (
|
||||
var (
|
||||
cfCIDRsLastUpdate time.Time
|
||||
cfCIDRsMu sync.Mutex
|
||||
cfCIDRsLogger = logrus.WithField("middleware", "CloudflareRealIP")
|
||||
cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger()
|
||||
)
|
||||
|
||||
var CloudflareRealIP = &realIP{
|
||||
m: &Middleware{withOptions: NewCloudflareRealIP},
|
||||
}
|
||||
|
||||
func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.NestedError) {
|
||||
func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) {
|
||||
cri := new(realIP)
|
||||
cri.m = &Middleware{
|
||||
impl: cri,
|
||||
@@ -80,13 +80,13 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
||||
)
|
||||
if err != nil {
|
||||
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)
|
||||
cfCIDRsLogger.Errorf("failed to update cloudflare range: %s, retry in %s", err, cfCIDRsUpdateRetryInterval)
|
||||
cfCIDRsLogger.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
cfCIDRsLastUpdate = time.Now()
|
||||
cfCIDRsLogger.Debugf("cloudflare CIDR range updated")
|
||||
cfCIDRsLogger.Info().Msg("cloudflare CIDR range updated")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -109,9 +109,9 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error {
|
||||
_, cidr, err := net.ParseCIDR(line)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
|
||||
} else {
|
||||
cfCIDRs = append(cfCIDRs, (*types.CIDR)(cidr))
|
||||
}
|
||||
|
||||
cfCIDRs = append(cfCIDRs, (*types.CIDR)(cidr))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -2,43 +2,52 @@ package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/error_page"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
|
||||
)
|
||||
|
||||
var CustomErrorPage = &Middleware{
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
if !ServeStaticErrorPageFile(w, r) {
|
||||
next(w, r)
|
||||
}
|
||||
},
|
||||
modifyResponse: func(resp *Response) error {
|
||||
var CustomErrorPage *Middleware
|
||||
|
||||
func init() {
|
||||
CustomErrorPage = customErrorPage()
|
||||
}
|
||||
|
||||
func customErrorPage() *Middleware {
|
||||
m := &Middleware{
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
if !ServeStaticErrorPageFile(w, r) {
|
||||
next(w, r)
|
||||
}
|
||||
},
|
||||
}
|
||||
m.modifyResponse = func(resp *Response) error {
|
||||
// only handles non-success status code and html/plain content type
|
||||
contentType := gphttp.GetContentType(resp.Header)
|
||||
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
|
||||
errorPage, ok := error_page.GetErrorPageByStatus(resp.StatusCode)
|
||||
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
|
||||
if ok {
|
||||
errPageLogger.Debugf("error page for status %d loaded", resp.StatusCode)
|
||||
CustomErrorPage.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
|
||||
/* trunk-ignore(golangci-lint/errcheck) */
|
||||
io.Copy(io.Discard, resp.Body) // drain the original body
|
||||
resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
|
||||
resp.ContentLength = int64(len(errorPage))
|
||||
resp.Header.Set("Content-Length", fmt.Sprint(len(errorPage)))
|
||||
resp.Header.Set("Content-Length", strconv.Itoa(len(errorPage)))
|
||||
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
|
||||
} else {
|
||||
errPageLogger.Errorf("unable to load error page for status %d", resp.StatusCode)
|
||||
CustomErrorPage.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
|
||||
@@ -48,27 +57,27 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
|
||||
}
|
||||
if strings.HasPrefix(path, gphttp.StaticFilePathPrefix) {
|
||||
filename := path[len(gphttp.StaticFilePathPrefix):]
|
||||
file, ok := error_page.GetStaticFile(filename)
|
||||
file, ok := errorpage.GetStaticFile(filename)
|
||||
if !ok {
|
||||
errPageLogger.Errorf("unable to load resource %s", filename)
|
||||
logger.Error().Msg("unable to load resource " + filename)
|
||||
return false
|
||||
} else {
|
||||
ext := filepath.Ext(filename)
|
||||
switch ext {
|
||||
case ".html":
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
case ".js":
|
||||
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
|
||||
case ".css":
|
||||
w.Header().Set("Content-Type", "text/css; charset=utf-8")
|
||||
default:
|
||||
errPageLogger.Errorf("unexpected file type %q for %s", ext, filename)
|
||||
}
|
||||
w.Write(file)
|
||||
return true
|
||||
}
|
||||
ext := filepath.Ext(filename)
|
||||
switch ext {
|
||||
case ".html":
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
case ".js":
|
||||
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
|
||||
case ".css":
|
||||
w.Header().Set("Content-Type", "text/css; charset=utf-8")
|
||||
default:
|
||||
logger.Error().Msgf("unexpected file type %q for %s", ext, filename)
|
||||
}
|
||||
if _, err := w.Write(file); err != nil {
|
||||
logger.Err(err).Msg("unable to write resource " + filename)
|
||||
http.Error(w, "Error page failure", http.StatusInternalServerError)
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var errPageLogger = logrus.WithField("middleware", "error_page")
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
package error_page
|
||||
package errorpage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"sync"
|
||||
|
||||
api "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
W "github.com/yusing/go-proxy/internal/watcher"
|
||||
@@ -17,19 +18,34 @@ import (
|
||||
|
||||
const errPagesBasePath = common.ErrorPagesBasePath
|
||||
|
||||
var setup = sync.OnceFunc(func() {
|
||||
dirWatcher = W.NewDirectoryWatcher(context.Background(), errPagesBasePath)
|
||||
var (
|
||||
setupMu sync.Mutex
|
||||
dirWatcher W.Watcher
|
||||
fileContentMap = F.NewMapOf[string, []byte]()
|
||||
)
|
||||
|
||||
func setup() {
|
||||
setupMu.Lock()
|
||||
defer setupMu.Unlock()
|
||||
|
||||
if dirWatcher != nil {
|
||||
return
|
||||
}
|
||||
|
||||
task := task.GlobalTask("error page")
|
||||
dirWatcher = W.NewDirectoryWatcher(task.Subtask("dir watcher"), errPagesBasePath)
|
||||
loadContent()
|
||||
go watchDir()
|
||||
})
|
||||
go watchDir(task)
|
||||
}
|
||||
|
||||
func GetStaticFile(filename string) ([]byte, bool) {
|
||||
setup()
|
||||
return fileContentMap.Load(filename)
|
||||
}
|
||||
|
||||
// try <statusCode>.html -> 404.html -> not ok
|
||||
// try <statusCode>.html -> 404.html -> not ok.
|
||||
func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
|
||||
content, ok = fileContentMap.Load(fmt.Sprintf("%d.html", statusCode))
|
||||
content, ok = GetStaticFile(fmt.Sprintf("%d.html", statusCode))
|
||||
if !ok && statusCode != 404 {
|
||||
return fileContentMap.Load("404.html")
|
||||
}
|
||||
@@ -39,7 +55,7 @@ func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
|
||||
func loadContent() {
|
||||
files, err := U.ListFiles(errPagesBasePath, 0)
|
||||
if err != nil {
|
||||
api.Logger.Error(err)
|
||||
logger.Err(err).Msg("failed to list error page resources")
|
||||
return
|
||||
}
|
||||
for _, file := range files {
|
||||
@@ -48,19 +64,21 @@ func loadContent() {
|
||||
}
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
api.Logger.Errorf("failed to read error page resource %s: %s", file, err)
|
||||
logger.Warn().Err(err).Msgf("failed to read error page resource %s", file)
|
||||
continue
|
||||
}
|
||||
file = path.Base(file)
|
||||
api.Logger.Infof("error page resource %s loaded", file)
|
||||
logging.Info().Msgf("error page resource %s loaded", file)
|
||||
fileContentMap.Store(file, content)
|
||||
}
|
||||
}
|
||||
|
||||
func watchDir() {
|
||||
eventCh, errCh := dirWatcher.Events(context.Background())
|
||||
func watchDir(task task.Task) {
|
||||
eventCh, errCh := dirWatcher.Events(task.Context())
|
||||
for {
|
||||
select {
|
||||
case <-task.Context().Done():
|
||||
return
|
||||
case event, ok := <-eventCh:
|
||||
if !ok {
|
||||
return
|
||||
@@ -72,17 +90,14 @@ func watchDir() {
|
||||
loadContent()
|
||||
case events.ActionFileDeleted:
|
||||
fileContentMap.Delete(filename)
|
||||
api.Logger.Infof("error page resource %s deleted", filename)
|
||||
logger.Warn().Msgf("error page resource %s deleted", filename)
|
||||
case events.ActionFileRenamed:
|
||||
api.Logger.Infof("error page resource %s deleted", filename)
|
||||
logger.Warn().Msgf("error page resource %s deleted", filename)
|
||||
fileContentMap.Delete(filename)
|
||||
loadContent()
|
||||
}
|
||||
case err := <-errCh:
|
||||
api.Logger.Errorf("error watching error page directory: %s", err)
|
||||
E.LogError("error watching error page directory", err, &logger)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var dirWatcher W.Watcher
|
||||
var fileContentMap = F.NewMapOf[string, []byte]()
|
||||
5
internal/net/http/middleware/errorpage/logger.go
Normal file
5
internal/net/http/middleware/errorpage/logger.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package errorpage
|
||||
|
||||
import "github.com/yusing/go-proxy/internal/logging"
|
||||
|
||||
var logger = logging.With().Str("module", "errorpage").Logger()
|
||||
@@ -24,11 +24,12 @@ type (
|
||||
client http.Client
|
||||
}
|
||||
forwardAuthOpts struct {
|
||||
Address string
|
||||
TrustForwardHeader bool
|
||||
AuthResponseHeaders []string
|
||||
AddAuthCookiesToResponse []string
|
||||
transport http.RoundTripper
|
||||
Address string `json:"address"`
|
||||
TrustForwardHeader bool `json:"trustForwardHeader"`
|
||||
AuthResponseHeaders []string `json:"authResponseHeaders"`
|
||||
AddAuthCookiesToResponse []string `json:"addAuthCookiesToResponse"`
|
||||
|
||||
transport http.RoundTripper
|
||||
}
|
||||
)
|
||||
|
||||
@@ -36,16 +37,14 @@ var ForwardAuth = &forwardAuth{
|
||||
m: &Middleware{withOptions: NewForwardAuthfunc},
|
||||
}
|
||||
|
||||
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
fa := new(forwardAuth)
|
||||
fa.forwardAuthOpts = new(forwardAuthOpts)
|
||||
err := Deserialize(optsRaw, fa.forwardAuthOpts)
|
||||
if err != nil {
|
||||
if err := Deserialize(optsRaw, fa.forwardAuthOpts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = E.Check(url.Parse(fa.Address))
|
||||
if err != nil {
|
||||
return nil, E.Invalid("address", fa.Address)
|
||||
if _, err := url.Parse(fa.Address); err != nil {
|
||||
return nil, E.From(err)
|
||||
}
|
||||
|
||||
fa.m = &Middleware{
|
||||
|
||||
5
internal/net/http/middleware/logger.go
Normal file
5
internal/net/http/middleware/logger.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package middleware
|
||||
|
||||
import "github.com/yusing/go-proxy/internal/logging"
|
||||
|
||||
var logger = logging.With().Str("module", "middleware").Logger()
|
||||
@@ -5,13 +5,14 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type (
|
||||
Error = E.NestedError
|
||||
Error = E.Error
|
||||
|
||||
ReverseProxy = gphttp.ReverseProxy
|
||||
ProxyRequest = gphttp.ProxyRequest
|
||||
@@ -24,12 +25,16 @@ type (
|
||||
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
|
||||
RewriteFunc func(req *Request)
|
||||
ModifyResponseFunc func(resp *Response) error
|
||||
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.NestedError)
|
||||
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error)
|
||||
|
||||
OptionsRaw = map[string]any
|
||||
Options any
|
||||
|
||||
Middleware struct {
|
||||
_ U.NoCopy
|
||||
|
||||
zerolog.Logger
|
||||
|
||||
name string
|
||||
|
||||
before BeforeFunc // runs before ReverseProxy.ServeHTTP
|
||||
@@ -75,48 +80,57 @@ func (m *Middleware) MarshalJSON() ([]byte, error) {
|
||||
}, "", " ")
|
||||
}
|
||||
|
||||
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||
if len(optsRaw) != 0 && m.withOptions != nil {
|
||||
if mWithOpt, err := m.withOptions(optsRaw); err != nil {
|
||||
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
if m.withOptions != nil {
|
||||
m, err := m.withOptions(optsRaw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return mWithOpt, nil
|
||||
}
|
||||
m.Logger = logger.With().Str("name", m.name).Logger()
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// WithOptionsClone is called only once
|
||||
// set withOptions and labelParser will not be used after that
|
||||
return &Middleware{
|
||||
m.name,
|
||||
m.before,
|
||||
m.modifyResponse,
|
||||
nil,
|
||||
m.impl,
|
||||
m.parent,
|
||||
m.children,
|
||||
false,
|
||||
Logger: logger.With().Str("name", m.name).Logger(),
|
||||
name: m.name,
|
||||
before: m.before,
|
||||
modifyResponse: m.modifyResponse,
|
||||
impl: m.impl,
|
||||
parent: m.parent,
|
||||
children: m.children,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TODO: check conflict or duplicates
|
||||
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (res E.NestedError) {
|
||||
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
if m.before != nil {
|
||||
m.before(next, w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) ModifyResponse(resp *Response) error {
|
||||
if m.modifyResponse != nil {
|
||||
return m.modifyResponse(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: check conflict or duplicates.
|
||||
func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
|
||||
middlewares := make([]*Middleware, 0, len(middlewaresMap))
|
||||
|
||||
invalidM := E.NewBuilder("invalid middlewares")
|
||||
invalidOpts := E.NewBuilder("invalid options")
|
||||
defer func() {
|
||||
invalidM.Add(invalidOpts.Build())
|
||||
invalidM.To(&res)
|
||||
}()
|
||||
errs := E.NewBuilder("middlewares compile error")
|
||||
invalidOpts := E.NewBuilder("options compile error")
|
||||
|
||||
for name, opts := range middlewaresMap {
|
||||
m, ok := Get(name)
|
||||
if !ok {
|
||||
invalidM.Add(E.NotExist("middleware", name))
|
||||
m, err := Get(name)
|
||||
if err != nil {
|
||||
errs.Add(err)
|
||||
continue
|
||||
}
|
||||
|
||||
m, err := m.WithOptionsClone(opts)
|
||||
m, err = m.WithOptionsClone(opts)
|
||||
if err != nil {
|
||||
invalidOpts.Add(err.Subject(name))
|
||||
continue
|
||||
@@ -124,10 +138,18 @@ func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[strin
|
||||
middlewares = append(middlewares, m)
|
||||
}
|
||||
|
||||
if invalidM.HasError() {
|
||||
if invalidOpts.HasError() {
|
||||
errs.Add(invalidOpts.Error())
|
||||
}
|
||||
return middlewares, errs.Error()
|
||||
}
|
||||
|
||||
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {
|
||||
var middlewares []*Middleware
|
||||
middlewares, err = createMiddlewares(middlewaresMap)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
patchReverseProxy(rpName, rp, middlewares)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,67 +4,63 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.NestedError) {
|
||||
func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]*Middleware {
|
||||
fileContent, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, E.FailWith("read middleware compose file", err)
|
||||
eb.Add(err)
|
||||
return nil
|
||||
}
|
||||
return BuildMiddlewaresFromYAML(fileContent)
|
||||
return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
|
||||
}
|
||||
|
||||
func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.NestedError) {
|
||||
b := E.NewBuilder("middlewares compile errors")
|
||||
defer b.To(&outErr)
|
||||
|
||||
func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[string]*Middleware {
|
||||
var rawMap map[string][]map[string]any
|
||||
err := yaml.Unmarshal(data, &rawMap)
|
||||
if err != nil {
|
||||
b.Add(E.FailWith("yaml unmarshal", err))
|
||||
return
|
||||
eb.Add(err)
|
||||
return nil
|
||||
}
|
||||
middlewares = make(map[string]*Middleware)
|
||||
middlewares := make(map[string]*Middleware)
|
||||
for name, defs := range rawMap {
|
||||
chainErr := E.NewBuilder(name)
|
||||
chainErr := E.NewBuilder("")
|
||||
chain := make([]*Middleware, 0, len(defs))
|
||||
for i, def := range defs {
|
||||
if def["use"] == nil || def["use"] == "" {
|
||||
chainErr.Add(E.Missing("use").Subjectf(".%d", i))
|
||||
chainErr.Addf("item %d: missing field 'use'", i)
|
||||
continue
|
||||
}
|
||||
baseName := def["use"].(string)
|
||||
base, ok := Get(baseName)
|
||||
if !ok {
|
||||
base, ok = middlewares[baseName]
|
||||
if !ok {
|
||||
chainErr.Add(E.NotExist("middleware", baseName).Subjectf(".%d", i))
|
||||
continue
|
||||
}
|
||||
base, err := Get(baseName)
|
||||
if err != nil {
|
||||
chainErr.Add(err.Subjectf("%s[%d]", name, i))
|
||||
continue
|
||||
}
|
||||
delete(def, "use")
|
||||
m, err := base.WithOptionsClone(def)
|
||||
m.name = fmt.Sprintf("%s[%d]", name, i)
|
||||
if err != nil {
|
||||
chainErr.Add(err.Subjectf("item%d", i))
|
||||
chainErr.Add(err.Subjectf("%s[%d]", name, i))
|
||||
continue
|
||||
}
|
||||
m.name = fmt.Sprintf("%s[%d]", name, i)
|
||||
chain = append(chain, m)
|
||||
}
|
||||
if chainErr.HasError() {
|
||||
b.Add(chainErr.Build())
|
||||
eb.Add(chainErr.Error().Subject(source))
|
||||
} else {
|
||||
middlewares[name+"@file"] = BuildMiddlewareFromChain(name, chain)
|
||||
}
|
||||
}
|
||||
return
|
||||
return middlewares
|
||||
}
|
||||
|
||||
// TODO: check conflict or duplicates
|
||||
// TODO: check conflict or duplicates.
|
||||
func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
|
||||
m := &Middleware{name: name, children: chain}
|
||||
|
||||
@@ -86,11 +82,13 @@ func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
|
||||
}
|
||||
if len(modResps) > 0 {
|
||||
m.modifyResponse = func(res *Response) error {
|
||||
b := E.NewBuilder("errors in middleware")
|
||||
errs := E.NewBuilder("modify response errors")
|
||||
for _, mr := range modResps {
|
||||
b.Add(E.From(mr.modifyResponse(res)).Subject(mr.name))
|
||||
if err := mr.modifyResponse(res); err != nil {
|
||||
errs.Add(E.From(err).Subject(mr.name))
|
||||
}
|
||||
}
|
||||
return b.Build().Error()
|
||||
return errs.Error()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,10 +13,10 @@ import (
|
||||
var testMiddlewareCompose []byte
|
||||
|
||||
func TestBuild(t *testing.T) {
|
||||
middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose)
|
||||
ExpectNoError(t, err.Error())
|
||||
_, err = E.Check(json.MarshalIndent(middlewares, "", " "))
|
||||
ExpectNoError(t, err.Error())
|
||||
errs := E.NewBuilder("")
|
||||
middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs)
|
||||
ExpectNoError(t, errs.Error())
|
||||
E.Must(json.MarshalIndent(middlewares, "", " "))
|
||||
// t.Log(string(data))
|
||||
// TODO: test
|
||||
}
|
||||
|
||||
@@ -6,30 +6,39 @@ import (
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
var middlewares map[string]*Middleware
|
||||
var allMiddlewares map[string]*Middleware
|
||||
|
||||
func Get(name string) (middleware *Middleware, ok bool) {
|
||||
middleware, ok = middlewares[U.ToLowerNoSnake(name)]
|
||||
return
|
||||
var (
|
||||
ErrUnknownMiddleware = E.New("unknown middleware")
|
||||
ErrDuplicatedMiddleware = E.New("duplicated middleware")
|
||||
)
|
||||
|
||||
func Get(name string) (*Middleware, Error) {
|
||||
middleware, ok := allMiddlewares[strutils.ToLowerNoSnake(name)]
|
||||
if !ok {
|
||||
return nil, ErrUnknownMiddleware.
|
||||
Subject(name).
|
||||
Withf(strutils.DoYouMean(utils.NearestField(name, allMiddlewares)))
|
||||
}
|
||||
return middleware, nil
|
||||
}
|
||||
|
||||
func All() map[string]*Middleware {
|
||||
return middlewares
|
||||
return allMiddlewares
|
||||
}
|
||||
|
||||
// initialize middleware names and label parsers
|
||||
// initialize middleware names and label parsers.
|
||||
func init() {
|
||||
middlewares = map[string]*Middleware{
|
||||
allMiddlewares = map[string]*Middleware{
|
||||
"setxforwarded": SetXForwarded,
|
||||
"hidexforwarded": HideXForwarded,
|
||||
"redirecthttp": RedirectHTTP,
|
||||
"forwardauth": ForwardAuth.m,
|
||||
"modifyresponse": ModifyResponse.m,
|
||||
"modifyrequest": ModifyRequest.m,
|
||||
"errorpage": CustomErrorPage,
|
||||
@@ -37,9 +46,13 @@ func init() {
|
||||
"realip": RealIP.m,
|
||||
"cloudflarerealip": CloudflareRealIP.m,
|
||||
"cidrwhitelist": CIDRWhiteList.m,
|
||||
|
||||
// !experimental
|
||||
"forwardauth": ForwardAuth.m,
|
||||
// "oauth2": OAuth2.m,
|
||||
}
|
||||
names := make(map[*Middleware][]string)
|
||||
for name, m := range middlewares {
|
||||
for name, m := range allMiddlewares {
|
||||
names[m] = append(names[m], http.CanonicalHeaderKey(name))
|
||||
}
|
||||
for m, names := range names {
|
||||
@@ -52,27 +65,30 @@ func init() {
|
||||
}
|
||||
|
||||
func LoadComposeFiles() {
|
||||
b := E.NewBuilder("failed to load middlewares")
|
||||
middlewareDefs, err := U.ListFiles(common.MiddlewareComposeBasePath, 0)
|
||||
errs := E.NewBuilder("middleware compile errors")
|
||||
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
|
||||
if err != nil {
|
||||
logrus.Errorf("failed to list middleware definitions: %s", err)
|
||||
logger.Err(err).Msg("failed to list middleware definitions")
|
||||
return
|
||||
}
|
||||
for _, defFile := range middlewareDefs {
|
||||
mws, err := BuildMiddlewaresFromComposeFile(defFile)
|
||||
mws := BuildMiddlewaresFromComposeFile(defFile, errs)
|
||||
if len(mws) == 0 {
|
||||
continue
|
||||
}
|
||||
for name, m := range mws {
|
||||
if _, ok := middlewares[name]; ok {
|
||||
b.Add(E.Duplicated("middleware", name))
|
||||
if _, ok := allMiddlewares[name]; ok {
|
||||
errs.Add(ErrDuplicatedMiddleware.Subject(name))
|
||||
continue
|
||||
}
|
||||
middlewares[U.ToLowerNoSnake(name)] = m
|
||||
logger.Infof("middleware %s loaded from %s", name, path.Base(defFile))
|
||||
allMiddlewares[strutils.ToLowerNoSnake(name)] = m
|
||||
logger.Info().
|
||||
Str("name", name).
|
||||
Str("src", path.Base(defFile)).
|
||||
Msg("middleware loaded")
|
||||
}
|
||||
b.Add(err.Subject(path.Base(defFile)))
|
||||
}
|
||||
if b.HasError() {
|
||||
logger.Error(b.Build())
|
||||
if errs.HasError() {
|
||||
E.LogError(errs.About(), errs.Error(), &logger)
|
||||
}
|
||||
}
|
||||
|
||||
var logger = logrus.WithField("module", "middlewares")
|
||||
|
||||
@@ -12,9 +12,9 @@ type (
|
||||
}
|
||||
// order: set_headers -> add_headers -> hide_headers
|
||||
modifyRequestOpts struct {
|
||||
SetHeaders map[string]string
|
||||
AddHeaders map[string]string
|
||||
HideHeaders []string
|
||||
SetHeaders map[string]string `json:"setHeaders"`
|
||||
AddHeaders map[string]string `json:"addHeaders"`
|
||||
HideHeaders []string `json:"hideHeaders"`
|
||||
}
|
||||
)
|
||||
|
||||
@@ -22,7 +22,7 @@ var ModifyRequest = &modifyRequest{
|
||||
m: &Middleware{withOptions: NewModifyRequest},
|
||||
}
|
||||
|
||||
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
mr := new(modifyRequest)
|
||||
var mrFunc RewriteFunc
|
||||
if common.IsDebug {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user