mirror of
https://github.com/yusing/godoxy.git
synced 2026-02-10 04:27:42 +01:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bf520541b | ||
|
|
a531896bd6 | ||
|
|
e005b42d18 | ||
|
|
1f6573b6da | ||
|
|
73af381c4c | ||
|
|
625bf4dfdc | ||
|
|
46b4090629 | ||
|
|
91e012987e | ||
|
|
a86d316d07 | ||
|
|
76454df5e6 | ||
|
|
67b6e40f85 | ||
|
|
9889b5a8d3 |
22
.env.example
Normal file
22
.env.example
Normal file
@@ -0,0 +1,22 @@
|
||||
# set timezone to get correct log timestamp
|
||||
TZ=ETC/UTC
|
||||
|
||||
# generate secret with `openssl rand -base64 32`
|
||||
GOPROXY_API_JWT_SECRET=
|
||||
|
||||
# the JWT token time-to-live
|
||||
GOPROXY_API_JWT_TOKEN_TTL=1h
|
||||
|
||||
# API/WebUI login credentials
|
||||
GOPROXY_API_USER=admin
|
||||
GOPROXY_API_PASSWORD=password
|
||||
|
||||
# Proxy listening address
|
||||
GOPROXY_HTTP_ADDR=:80
|
||||
GOPROXY_HTTPS_ADDR=:443
|
||||
|
||||
# API listening address
|
||||
GOPROXY_API_ADDR=127.0.0.1:8888
|
||||
|
||||
# Debug mode
|
||||
GOPROXY_DEBUG=false
|
||||
@@ -108,6 +108,7 @@ linters:
|
||||
- prealloc # Too many false-positive.
|
||||
- makezero # Not relevant
|
||||
- dupl # Too strict
|
||||
- gci # I don't care
|
||||
- gosec # Too strict
|
||||
- gochecknoinits
|
||||
- gochecknoglobals
|
||||
|
||||
43
README.md
43
README.md
@@ -24,6 +24,8 @@ _Join our [Discord](https://discord.gg/umReR62nRd) for help and discussions_
|
||||
- [Key Features](#key-features)
|
||||
- [Getting Started](#getting-started)
|
||||
- [Setup](#setup)
|
||||
- [Manual Setup](#manual-setup)
|
||||
- [Folder structrue](#folder-structrue)
|
||||
- [Use JSON Schema in VSCode](#use-json-schema-in-vscode)
|
||||
- [Screenshots](#screenshots)
|
||||
- [idlesleeper](#idlesleeper)
|
||||
@@ -59,10 +61,12 @@ _Join our [Discord](https://discord.gg/umReR62nRd) for help and discussions_
|
||||
docker pull ghcr.io/yusing/go-proxy:latest
|
||||
```
|
||||
|
||||
2. Create new directory, `cd` into it, then run setup
|
||||
2. Create new directory, `cd` into it, then run setup, or [set up manually](#manual-setup)
|
||||
|
||||
```shell
|
||||
docker run --rm -v .:/setup ghcr.io/yusing/go-proxy /app/go-proxy setup
|
||||
# Then set the JWT secret
|
||||
sed -i "s|GOPROXY_API_JWT_SECRET=.*|GOPROXY_API_JWT_SECRET=$(openssl rand -base64 32)|g" .env
|
||||
```
|
||||
|
||||
3. Setup DNS Records point to machine which runs `go-proxy`, e.g.
|
||||
@@ -83,6 +87,43 @@ _Join our [Discord](https://discord.gg/umReR62nRd) for help and discussions_
|
||||
|
||||
[🔼Back to top](#table-of-content)
|
||||
|
||||
### Manual Setup
|
||||
|
||||
1. Make `config` directory then grab `config.example.yml` into `config/config.yml`
|
||||
|
||||
`mkdir -p config && wget https://raw.githubusercontent.com/yusing/go-proxy/v0.7/config.example.yml -O config/config.yml`
|
||||
|
||||
2. Grab `.env.example` into `.env`
|
||||
|
||||
`wget https://raw.githubusercontent.com/yusing/go-proxy/v0.7/.env.example -O .env`
|
||||
|
||||
3. Grab `compose.example.yml` into `compose.yml`
|
||||
|
||||
`wget https://raw.githubusercontent.com/yusing/go-proxy/v0.7/compose.example.yml -O compose.yml`
|
||||
|
||||
4. Set the JWT secret
|
||||
|
||||
`sed -i "s|GOPROXY_API_JWT_SECRET=.*|GOPROXY_API_JWT_SECRET=$(openssl rand -base64 32)|g" .env`
|
||||
|
||||
5. Start the container `docker compose up -d`
|
||||
|
||||
### Folder structrue
|
||||
|
||||
```shell
|
||||
├── certs
|
||||
│ ├── cert.crt
|
||||
│ └── priv.key
|
||||
├── compose.yml
|
||||
├── config
|
||||
│ ├── config.yml
|
||||
│ ├── middlewares
|
||||
│ │ ├── middleware1.yml
|
||||
│ │ ├── middleware2.yml
|
||||
│ ├── provider1.yml
|
||||
│ └── provider2.yml
|
||||
└── .env
|
||||
```
|
||||
|
||||
### Use JSON Schema in VSCode
|
||||
|
||||
Copy [`.vscode/settings.example.json`](.vscode/settings.example.json) to `.vscode/settings.json` and modify it to fit your needs
|
||||
|
||||
62
cmd/main.go
62
cmd/main.go
@@ -26,17 +26,40 @@ import (
|
||||
func main() {
|
||||
args := common.GetArgs()
|
||||
|
||||
if args.Command == common.CommandSetup {
|
||||
switch args.Command {
|
||||
case common.CommandSetup:
|
||||
internal.Setup()
|
||||
return
|
||||
}
|
||||
|
||||
if args.Command == common.CommandReload {
|
||||
case common.CommandReload:
|
||||
if err := query.ReloadServer(); err != nil {
|
||||
E.LogFatal("server reload error", err)
|
||||
}
|
||||
logging.Info().Msg("ok")
|
||||
return
|
||||
case common.CommandListIcons:
|
||||
icons, err := internal.ListAvailableIcons()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
printJSON(icons)
|
||||
return
|
||||
case common.CommandListRoutes:
|
||||
routes, err := query.ListRoutes()
|
||||
if err != nil {
|
||||
log.Printf("failed to connect to api server: %s", err)
|
||||
log.Printf("falling back to config file")
|
||||
printJSON(config.RoutesByAlias())
|
||||
} else {
|
||||
printJSON(routes)
|
||||
}
|
||||
return
|
||||
case common.CommandDebugListMTrace:
|
||||
trace, err := query.ListMiddlewareTraces()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
printJSON(trace)
|
||||
return
|
||||
}
|
||||
|
||||
if args.Command == common.CommandStart {
|
||||
@@ -75,43 +98,12 @@ func main() {
|
||||
case common.CommandListConfigs:
|
||||
printJSON(config.Value())
|
||||
return
|
||||
case common.CommandListRoutes:
|
||||
routes, err := query.ListRoutes()
|
||||
if err != nil {
|
||||
log.Printf("failed to connect to api server: %s", err)
|
||||
log.Printf("falling back to config file")
|
||||
printJSON(config.RoutesByAlias())
|
||||
} else {
|
||||
printJSON(routes)
|
||||
}
|
||||
return
|
||||
case common.CommandListIcons:
|
||||
icons, err := internal.ListAvailableIcons()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
printJSON(icons)
|
||||
return
|
||||
case common.CommandDebugListEntries:
|
||||
printJSON(config.DumpEntries())
|
||||
return
|
||||
case common.CommandDebugListProviders:
|
||||
printJSON(config.DumpProviders())
|
||||
return
|
||||
case common.CommandDebugListMTrace:
|
||||
trace, err := query.ListMiddlewareTraces()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
printJSON(trace)
|
||||
return
|
||||
case common.CommandDebugListTasks:
|
||||
tasks, err := query.DebugListTasks()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
printJSON(tasks)
|
||||
return
|
||||
}
|
||||
|
||||
cfg.StartProxyProviders()
|
||||
|
||||
@@ -4,31 +4,26 @@ services:
|
||||
container_name: go-proxy-frontend
|
||||
restart: unless-stopped
|
||||
network_mode: host
|
||||
env_file: .env
|
||||
depends_on:
|
||||
- app
|
||||
# if you also want to proxy the WebUI and access it via gp.y.z
|
||||
# labels:
|
||||
# - proxy.aliases=gp
|
||||
# - proxy.gp.port=3000
|
||||
|
||||
# Make sure the value is same as `GOPROXY_API_ADDR` below (if you have changed it)
|
||||
#
|
||||
# environment:
|
||||
# GOPROXY_API_ADDR: 127.0.0.1:8888
|
||||
# modify below to fit your needs
|
||||
labels:
|
||||
proxy.aliases: gp
|
||||
proxy.#1.port: 3000
|
||||
proxy.#1.middlewares.cidr_whitelist.status_code: 403
|
||||
proxy.#1.middlewares.cidr_whitelist.message: IP not allowed
|
||||
proxy.#1.middlewares.cidr_whitelist.allow: |
|
||||
- 127.0.0.1
|
||||
- 10.0.0.0/8
|
||||
- 192.168.0.0/16
|
||||
- 172.16.0.0/12
|
||||
app:
|
||||
image: ghcr.io/yusing/go-proxy:latest
|
||||
container_name: go-proxy
|
||||
restart: always
|
||||
network_mode: host
|
||||
environment:
|
||||
# (Optional) change this to your timezone to get correct log timestamp
|
||||
TZ: ETC/UTC
|
||||
|
||||
# Change these if you need
|
||||
#
|
||||
# GOPROXY_HTTP_ADDR: :80
|
||||
# GOPROXY_HTTPS_ADDR: :443
|
||||
# GOPROXY_API_ADDR: 127.0.0.1:8888
|
||||
env_file: .env
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- ./config:/app/config
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/yusing/go-proxy/internal/api/v1/auth"
|
||||
. "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/config"
|
||||
)
|
||||
|
||||
type ServeMux struct{ *http.ServeMux }
|
||||
@@ -26,8 +25,6 @@ func NewHandler() http.Handler {
|
||||
mux := NewServeMux()
|
||||
mux.HandleFunc("GET", "/v1", v1.Index)
|
||||
mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
|
||||
// mux.HandleFunc("GET", "/v1/checkhealth", v1.CheckHealth)
|
||||
// mux.HandleFunc("HEAD", "/v1/checkhealth", v1.CheckHealth)
|
||||
mux.HandleFunc("POST", "/v1/login", auth.LoginHandler)
|
||||
mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler)
|
||||
mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler)
|
||||
@@ -59,9 +56,3 @@ func checkHost(f http.HandlerFunc) http.HandlerFunc {
|
||||
f(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func wrap(cfg *config.Config, f func(cfg *config.Config, w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
f(cfg, w, r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,10 +30,6 @@ var (
|
||||
ErrInvalidPassword = E.New("invalid password")
|
||||
)
|
||||
|
||||
const tokenExpiration = 24 * time.Hour
|
||||
|
||||
const jwtClaimKeyUsername = "username"
|
||||
|
||||
func validatePassword(cred *Credentials) error {
|
||||
if cred.Username != common.APIUser {
|
||||
return ErrInvalidUsername.Subject(cred.Username)
|
||||
@@ -56,7 +52,7 @@ func LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(tokenExpiration)
|
||||
expiresAt := time.Now().Add(common.APIJWTTokenTTL)
|
||||
claim := &Claims{
|
||||
Username: creds.Username,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
@@ -114,7 +110,7 @@ func checkToken(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
var claims Claims
|
||||
token, err := jwt.ParseWithClaims(tokenCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("Unexpected signing method: %v", t.Header["alg"])
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return common.APIJWTSecret, nil
|
||||
})
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
func CheckHealth(w http.ResponseWriter, r *http.Request) {
|
||||
target := r.FormValue("target")
|
||||
if target == "" {
|
||||
HandleErr(w, r, ErrMissingKey("target"), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
result, ok := health.Inspect(target)
|
||||
if !ok {
|
||||
HandleErr(w, r, ErrNotFound("target", target), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
json, err := result.MarshalJSON()
|
||||
if err != nil {
|
||||
HandleErr(w, r, err)
|
||||
return
|
||||
}
|
||||
RespondJSON(w, r, json)
|
||||
}
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
func ReloadServer() E.Error {
|
||||
resp, err := U.Post(fmt.Sprintf("%s/v1/reload", common.APIHTTPURL), "", nil)
|
||||
resp, err := U.Post(common.APIHTTPURL+"/v1/reload", "", nil)
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
|
||||
@@ -19,18 +19,21 @@ func Stats(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func StatsWS(w http.ResponseWriter, r *http.Request) {
|
||||
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
|
||||
originPats := make([]string, len(config.Value().MatchDomains)+len(localAddresses))
|
||||
var originPats []string
|
||||
|
||||
if len(originPats) == 0 {
|
||||
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
|
||||
|
||||
if len(config.Value().MatchDomains) == 0 {
|
||||
U.LogWarn(r).Msg("no match domains configured, accepting websocket API request from all origins")
|
||||
originPats = []string{"*"}
|
||||
} else {
|
||||
originPats = make([]string, len(config.Value().MatchDomains))
|
||||
for i, domain := range config.Value().MatchDomains {
|
||||
originPats[i] = "*." + domain
|
||||
originPats[i] = "*" + domain
|
||||
}
|
||||
originPats = append(originPats, localAddresses...)
|
||||
}
|
||||
U.LogInfo(r).Msgf("websocket API request from origins: %s", originPats)
|
||||
if common.IsDebug {
|
||||
originPats = []string{"*"}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
@@ -23,7 +24,7 @@ func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int)
|
||||
|
||||
switch data := data.(type) {
|
||||
case string:
|
||||
j = []byte(`"` + data + `"`)
|
||||
j = []byte(fmt.Sprintf("%q", data))
|
||||
case []byte:
|
||||
j = data
|
||||
default:
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package autocert
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/go-acme/lego/v4/providers/dns/clouddns"
|
||||
"github.com/go-acme/lego/v4/providers/dns/cloudflare"
|
||||
"github.com/go-acme/lego/v4/providers/dns/duckdns"
|
||||
@@ -31,7 +29,3 @@ var providersGenMap = map[string]ProviderGenerator{
|
||||
ProviderDuckdns: providerGenerator(duckdns.NewDefaultConfig, duckdns.NewDNSProviderConfig),
|
||||
ProviderOVH: providerGenerator(ovh.NewDefaultConfig, ovh.NewDNSProviderConfig),
|
||||
}
|
||||
|
||||
var (
|
||||
ErrGetCertFailure = errors.New("get certificate failed")
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ package autocert
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
@@ -35,6 +36,8 @@ type (
|
||||
CertExpiries map[string]time.Time
|
||||
)
|
||||
|
||||
var ErrGetCertFailure = errors.New("get certificate failed")
|
||||
|
||||
func (p *Provider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if p.tlsCert == nil {
|
||||
return nil, ErrGetCertFailure
|
||||
@@ -248,10 +251,7 @@ func (p *Provider) renewIfNeeded() E.Error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := p.ObtainCert(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return p.ObtainCert()
|
||||
}
|
||||
|
||||
func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) {
|
||||
|
||||
@@ -21,7 +21,6 @@ const (
|
||||
CommandDebugListEntries = "debug-ls-entries"
|
||||
CommandDebugListProviders = "debug-ls-providers"
|
||||
CommandDebugListMTrace = "debug-ls-mtrace"
|
||||
CommandDebugListTasks = "debug-ls-tasks"
|
||||
)
|
||||
|
||||
var ValidCommands = []string{
|
||||
@@ -35,7 +34,6 @@ var ValidCommands = []string{
|
||||
CommandDebugListEntries,
|
||||
CommandDebugListProviders,
|
||||
CommandDebugListMTrace,
|
||||
CommandDebugListTasks,
|
||||
}
|
||||
|
||||
func GetArgs() Args {
|
||||
|
||||
@@ -13,7 +13,8 @@ const (
|
||||
// file, folder structure
|
||||
|
||||
const (
|
||||
DotEnvPath = ".env"
|
||||
DotEnvPath = ".env"
|
||||
DotEnvExamplePath = ".env.example"
|
||||
|
||||
ConfigBasePath = "config"
|
||||
ConfigFileName = "config.yml"
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -33,6 +34,7 @@ var (
|
||||
APIHTTPURL = GetAddrEnv("GOPROXY_API_ADDR", "127.0.0.1:8888", "http")
|
||||
|
||||
APIJWTSecret = decodeJWTKey(GetEnv("GOPROXY_API_JWT_SECRET", generateJWTKey(32)))
|
||||
APIJWTTokenTTL = GetDurationEnv("GOPROXY_API_JWT_TOKEN_TTL", time.Hour)
|
||||
APIUser = GetEnv("GOPROXY_API_USER", "admin")
|
||||
APIPasswordHash = HashPassword(GetEnv("GOPROXY_API_PASSWORD", "password"))
|
||||
)
|
||||
@@ -69,3 +71,15 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL str
|
||||
fullURL = fmt.Sprintf("%s://%s:%s", scheme, host, port)
|
||||
return
|
||||
}
|
||||
|
||||
func GetDurationEnv(key string, defaultValue time.Duration) time.Duration {
|
||||
value, ok := os.LookupEnv(key)
|
||||
if !ok || value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
d, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
log.Fatal().Msgf("env %s: invalid duration value: %s", key, value)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package config
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -177,6 +178,11 @@ func (cfg *Config) load() E.Error {
|
||||
errs.Add(cfg.loadRouteProviders(&model.Providers))
|
||||
|
||||
cfg.value = model
|
||||
for i, domain := range model.MatchDomains {
|
||||
if !strings.HasPrefix(domain, ".") {
|
||||
model.MatchDomains[i] = "." + domain
|
||||
}
|
||||
}
|
||||
route.SetFindMuxDomains(model.MatchDomains)
|
||||
return errs.Error()
|
||||
}
|
||||
|
||||
@@ -69,17 +69,18 @@ func HomepageConfig() homepage.Config {
|
||||
)
|
||||
}
|
||||
|
||||
if entry.IsDocker(r) {
|
||||
switch {
|
||||
case entry.IsDocker(r):
|
||||
if item.Category == "" {
|
||||
item.Category = "Docker"
|
||||
}
|
||||
item.SourceType = string(proxy.ProviderTypeDocker)
|
||||
} else if entry.UseLoadBalance(r) {
|
||||
case entry.UseLoadBalance(r):
|
||||
if item.Category == "" {
|
||||
item.Category = "Load-balanced"
|
||||
}
|
||||
item.SourceType = "loadbalancer"
|
||||
} else {
|
||||
default:
|
||||
if item.Category == "" {
|
||||
item.Category = "Others"
|
||||
}
|
||||
|
||||
@@ -52,13 +52,10 @@ func (c *SharedClient) Connected() bool {
|
||||
}
|
||||
|
||||
// if the client is still referenced, this is no-op.
|
||||
func (c *SharedClient) Close() error {
|
||||
if !c.Connected() {
|
||||
return nil
|
||||
func (c *SharedClient) Close() {
|
||||
if c.Connected() {
|
||||
c.refCount.Sub()
|
||||
}
|
||||
|
||||
c.refCount.Sub()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnectClient creates a new Docker client connection to the specified host.
|
||||
@@ -115,7 +112,6 @@ func ConnectClient(host string) (Client, error) {
|
||||
}
|
||||
|
||||
client, err := client.NewClientWithOpts(opt...)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package idlewatcher
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
@@ -21,7 +20,7 @@ var loadingPage []byte
|
||||
var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(loadingPage)))
|
||||
|
||||
func (w *Watcher) makeLoadingPageBody() []byte {
|
||||
msg := fmt.Sprintf("%s is starting...", w.ContainerName)
|
||||
msg := w.ContainerName + " is starting..."
|
||||
|
||||
data := new(templateData)
|
||||
data.CheckRedirectHeader = common.HeaderCheckRedirect
|
||||
|
||||
@@ -45,17 +45,18 @@ func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReversePr
|
||||
return nil, E.Errorf("register watcher: %w", err)
|
||||
}
|
||||
|
||||
if rp != nil {
|
||||
switch {
|
||||
case rp != nil:
|
||||
waker.hc = health.NewHTTPHealthChecker(entry.TargetURL(), hcCfg, rp.Transport)
|
||||
} else if stream != nil {
|
||||
case stream != nil:
|
||||
waker.hc = health.NewRawHealthChecker(entry.TargetURL(), hcCfg)
|
||||
} else {
|
||||
default:
|
||||
panic("both nil")
|
||||
}
|
||||
return watcher, nil
|
||||
}
|
||||
|
||||
// lifetime should follow route provider
|
||||
// lifetime should follow route provider.
|
||||
func NewHTTPWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) {
|
||||
return newWaker(providerSubTask, entry, rp, nil)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
// ServeHTTP implements http.Handler
|
||||
// ServeHTTP implements http.Handler.
|
||||
func (w *Watcher) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
shouldNext := w.wakeFromHTTP(rw, r)
|
||||
if !shouldNext {
|
||||
@@ -81,7 +81,7 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
|
||||
w.WakeTrace().Msg("signal received")
|
||||
err := w.wakeIfStopped()
|
||||
if err != nil {
|
||||
w.WakeError(err).Send()
|
||||
w.WakeError(err)
|
||||
http.Error(rw, "Error waking container", http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ func (w *Watcher) Accept() (conn types.StreamConn, err error) {
|
||||
return
|
||||
}
|
||||
if wakeErr := w.wakeFromStream(); wakeErr != nil {
|
||||
w.WakeError(wakeErr).Msg("error waking from stream")
|
||||
w.WakeError(wakeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -58,7 +58,7 @@ func (w *Watcher) wakeFromStream() error {
|
||||
wakeErr := w.wakeIfStopped()
|
||||
if wakeErr != nil {
|
||||
wakeErr = fmt.Errorf("%s failed: %w", w.String(), wakeErr)
|
||||
w.WakeError(wakeErr).Msg("wake failed")
|
||||
w.WakeError(wakeErr)
|
||||
return wakeErr
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/watcher"
|
||||
W "github.com/yusing/go-proxy/internal/watcher"
|
||||
"github.com/yusing/go-proxy/internal/watcher/events"
|
||||
)
|
||||
|
||||
@@ -108,8 +107,8 @@ func (w *Watcher) WakeTrace() *zerolog.Event {
|
||||
return w.Trace().Str("action", "wake")
|
||||
}
|
||||
|
||||
func (w *Watcher) WakeError(err error) *zerolog.Event {
|
||||
return w.Err(err).Str("action", "wake")
|
||||
func (w *Watcher) WakeError(err error) {
|
||||
w.Err(err).Str("action", "wake").Msg("error")
|
||||
}
|
||||
|
||||
func (w *Watcher) LogReason(action, reason string) {
|
||||
@@ -204,17 +203,17 @@ func (w *Watcher) resetIdleTimer() {
|
||||
|
||||
func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask task.Task, eventCh <-chan events.Event, errCh <-chan E.Error) {
|
||||
eventTask = w.task.Subtask("docker event watcher")
|
||||
eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), W.DockerListOptions{
|
||||
Filters: W.NewDockerFilter(
|
||||
W.DockerFilterContainer,
|
||||
W.DockerFilterContainerNameID(w.ContainerID),
|
||||
W.DockerFilterStart,
|
||||
W.DockerFilterStop,
|
||||
W.DockerFilterDie,
|
||||
W.DockerFilterKill,
|
||||
W.DockerFilterDestroy,
|
||||
W.DockerFilterPause,
|
||||
W.DockerFilterUnpause,
|
||||
eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), watcher.DockerListOptions{
|
||||
Filters: watcher.NewDockerFilter(
|
||||
watcher.DockerFilterContainer,
|
||||
watcher.DockerFilterContainerNameID(w.ContainerID),
|
||||
watcher.DockerFilterStart,
|
||||
watcher.DockerFilterStop,
|
||||
watcher.DockerFilterDie,
|
||||
watcher.DockerFilterKill,
|
||||
watcher.DockerFilterDestroy,
|
||||
watcher.DockerFilterPause,
|
||||
watcher.DockerFilterUnpause,
|
||||
),
|
||||
})
|
||||
return
|
||||
@@ -230,9 +229,9 @@ func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask tas
|
||||
// stop method.
|
||||
//
|
||||
// it exits only if the context is canceled, the container is destroyed,
|
||||
// errors occured on docker client, or route provider died (mainly caused by config reload).
|
||||
// errors occurred on docker client, or route provider died (mainly caused by config reload).
|
||||
func (w *Watcher) watchUntilDestroy() (returnCause error) {
|
||||
dockerWatcher := W.NewDockerWatcherWithClient(w.client)
|
||||
dockerWatcher := watcher.NewDockerWatcherWithClient(w.client)
|
||||
eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher)
|
||||
defer eventTask.Finish("stopped")
|
||||
|
||||
@@ -279,9 +278,13 @@ func (w *Watcher) watchUntilDestroy() (returnCause error) {
|
||||
case <-w.ticker.C:
|
||||
w.ticker.Stop()
|
||||
if w.ContainerRunning {
|
||||
if err := w.stopByMethod(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
err := w.stopByMethod()
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled):
|
||||
continue
|
||||
case err != nil:
|
||||
w.Err(err).Msgf("container stop with method %q failed", w.StopMethod)
|
||||
} else {
|
||||
default:
|
||||
w.LogReason("container stopped", "idle timeout")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,26 +27,31 @@ func (b *Builder) HasError() bool {
|
||||
return len(b.errs) > 0
|
||||
}
|
||||
|
||||
func (b *Builder) Error() Error {
|
||||
func (b *Builder) error() Error {
|
||||
if !b.HasError() {
|
||||
return nil
|
||||
}
|
||||
if len(b.errs) == 1 {
|
||||
return From(b.errs[0])
|
||||
}
|
||||
return &nestedError{Err: New(b.about), Extras: b.errs}
|
||||
}
|
||||
|
||||
func (b *Builder) Error() Error {
|
||||
if len(b.errs) == 1 {
|
||||
return From(b.errs[0])
|
||||
}
|
||||
return b.error()
|
||||
}
|
||||
|
||||
func (b *Builder) String() string {
|
||||
if !b.HasError() {
|
||||
err := b.error()
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return (&nestedError{Err: New(b.about), Extras: b.errs}).Error()
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
// Add adds an error to the Builder.
|
||||
//
|
||||
// adding nil is no-op,
|
||||
// adding nil is no-op.
|
||||
func (b *Builder) Add(err error) *Builder {
|
||||
if err == nil {
|
||||
return b
|
||||
@@ -90,6 +95,21 @@ func (b *Builder) Addf(format string, args ...any) *Builder {
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Builder) AddFrom(other *Builder, flatten bool) *Builder {
|
||||
if other == nil || !other.HasError() {
|
||||
return b
|
||||
}
|
||||
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
if flatten {
|
||||
b.errs = append(b.errs, other.errs...)
|
||||
} else {
|
||||
b.errs = append(b.errs, other.error())
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Builder) AddRange(errs ...error) *Builder {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
@@ -22,7 +22,7 @@ type Error interface {
|
||||
Subjectf(format string, args ...any) Error
|
||||
}
|
||||
|
||||
// this makes JSON marshalling work,
|
||||
// this makes JSON marshaling work,
|
||||
// as the builtin one doesn't.
|
||||
type errStr string
|
||||
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
package error
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var ErrInvalidErrorJson = errors.New("invalid error json")
|
||||
|
||||
func newError(message string) error {
|
||||
return errStr(message)
|
||||
}
|
||||
|
||||
@@ -15,13 +15,14 @@ func init() {
|
||||
var level zerolog.Level
|
||||
var exclude []string
|
||||
|
||||
if common.IsTrace {
|
||||
switch {
|
||||
case common.IsTrace:
|
||||
timeFmt = "04:05"
|
||||
level = zerolog.TraceLevel
|
||||
} else if common.IsDebug {
|
||||
case common.IsDebug:
|
||||
timeFmt = "01-02 15:04"
|
||||
level = zerolog.DebugLevel
|
||||
} else {
|
||||
default:
|
||||
timeFmt = "01-02 15:04"
|
||||
level = zerolog.InfoLevel
|
||||
exclude = []string{"module"}
|
||||
|
||||
@@ -16,13 +16,12 @@ var (
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: defaultDialer.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
MaxIdleConnsPerHost: 100,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
DefaultTransportNoTLS = func() *http.Transport {
|
||||
var clone = DefaultTransport.Clone()
|
||||
clone := DefaultTransport.Clone()
|
||||
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
return clone
|
||||
}()
|
||||
|
||||
@@ -5,8 +5,10 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ContentType string
|
||||
type AcceptContentType []ContentType
|
||||
type (
|
||||
ContentType string
|
||||
AcceptContentType []ContentType
|
||||
)
|
||||
|
||||
func GetContentType(h http.Header) ContentType {
|
||||
ct := h.Get("Content-Type")
|
||||
|
||||
@@ -55,7 +55,6 @@ func New(cfg *Config) *LoadBalancer {
|
||||
Logger: logger.With().Str("name", cfg.Link).Logger(),
|
||||
Config: new(Config),
|
||||
pool: newPool(),
|
||||
task: task.DummyTask(),
|
||||
}
|
||||
lb.UpdateConfigIfNeeded(cfg)
|
||||
return lb
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type Mode string
|
||||
@@ -14,7 +14,7 @@ const (
|
||||
)
|
||||
|
||||
func (mode *Mode) ValidateUpdate() bool {
|
||||
switch U.ToLowerNoSnake(string(*mode)) {
|
||||
switch strutils.ToLowerNoSnake(string(*mode)) {
|
||||
case "":
|
||||
return true
|
||||
case string(RoundRobin):
|
||||
|
||||
@@ -109,9 +109,9 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error {
|
||||
_, cidr, err := net.ParseCIDR(line)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
|
||||
} else {
|
||||
cfCIDRs = append(cfCIDRs, (*types.CIDR)(cidr))
|
||||
}
|
||||
|
||||
cfCIDRs = append(cfCIDRs, (*types.CIDR)(cidr))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -19,24 +19,33 @@ import (
|
||||
const errPagesBasePath = common.ErrorPagesBasePath
|
||||
|
||||
var (
|
||||
setupMu sync.Mutex
|
||||
dirWatcher W.Watcher
|
||||
fileContentMap = F.NewMapOf[string, []byte]()
|
||||
)
|
||||
|
||||
var setup = sync.OnceFunc(func() {
|
||||
func setup() {
|
||||
setupMu.Lock()
|
||||
defer setupMu.Unlock()
|
||||
|
||||
if dirWatcher != nil {
|
||||
return
|
||||
}
|
||||
|
||||
task := task.GlobalTask("error page")
|
||||
dirWatcher = W.NewDirectoryWatcher(task.Subtask("dir watcher"), errPagesBasePath)
|
||||
loadContent()
|
||||
go watchDir(task)
|
||||
})
|
||||
}
|
||||
|
||||
func GetStaticFile(filename string) ([]byte, bool) {
|
||||
setup()
|
||||
return fileContentMap.Load(filename)
|
||||
}
|
||||
|
||||
// try <statusCode>.html -> 404.html -> not ok.
|
||||
func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
|
||||
content, ok = fileContentMap.Load(fmt.Sprintf("%d.html", statusCode))
|
||||
content, ok = GetStaticFile(fmt.Sprintf("%d.html", statusCode))
|
||||
if !ok && statusCode != 404 {
|
||||
return fileContentMap.Load("404.html")
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
@@ -21,7 +20,7 @@ var (
|
||||
)
|
||||
|
||||
func Get(name string) (*Middleware, Error) {
|
||||
middleware, ok := allMiddlewares[U.ToLowerNoSnake(name)]
|
||||
middleware, ok := allMiddlewares[strutils.ToLowerNoSnake(name)]
|
||||
if !ok {
|
||||
return nil, ErrUnknownMiddleware.
|
||||
Subject(name).
|
||||
@@ -34,7 +33,7 @@ func All() map[string]*Middleware {
|
||||
return allMiddlewares
|
||||
}
|
||||
|
||||
// initialize middleware names and label parsers
|
||||
// initialize middleware names and label parsers.
|
||||
func init() {
|
||||
allMiddlewares = map[string]*Middleware{
|
||||
"setxforwarded": SetXForwarded,
|
||||
@@ -67,7 +66,7 @@ func init() {
|
||||
|
||||
func LoadComposeFiles() {
|
||||
errs := E.NewBuilder("middleware compile errors")
|
||||
middlewareDefs, err := U.ListFiles(common.MiddlewareComposeBasePath, 0)
|
||||
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
|
||||
if err != nil {
|
||||
logger.Err(err).Msg("failed to list middleware definitions")
|
||||
return
|
||||
@@ -82,7 +81,7 @@ func LoadComposeFiles() {
|
||||
errs.Add(ErrDuplicatedMiddleware.Subject(name))
|
||||
continue
|
||||
}
|
||||
allMiddlewares[U.ToLowerNoSnake(name)] = m
|
||||
allMiddlewares[strutils.ToLowerNoSnake(name)] = m
|
||||
logger.Info().
|
||||
Str("name", name).
|
||||
Str("src", path.Base(defFile)).
|
||||
|
||||
@@ -22,8 +22,10 @@ type Trace struct {
|
||||
|
||||
type Traces []*Trace
|
||||
|
||||
var traces = Traces{}
|
||||
var tracesMu sync.Mutex
|
||||
var (
|
||||
traces = make(Traces, 0)
|
||||
tracesMu sync.Mutex
|
||||
)
|
||||
|
||||
const MaxTraceNum = 100
|
||||
|
||||
|
||||
@@ -10,18 +10,20 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ModifyResponseFunc func(*http.Response) error
|
||||
type ModifyResponseWriter struct {
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
type (
|
||||
ModifyResponseFunc func(*http.Response) error
|
||||
ModifyResponseWriter struct {
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
|
||||
headerSent bool
|
||||
code int
|
||||
headerSent bool
|
||||
code int
|
||||
|
||||
modifier ModifyResponseFunc
|
||||
modified bool
|
||||
modifierErr error
|
||||
}
|
||||
modifier ModifyResponseFunc
|
||||
modified bool
|
||||
modifierErr error
|
||||
}
|
||||
)
|
||||
|
||||
func NewModifyResponseWriter(w http.ResponseWriter, r *http.Request, f ModifyResponseFunc) *ModifyResponseWriter {
|
||||
return &ModifyResponseWriter{
|
||||
|
||||
@@ -404,7 +404,7 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
rw.WriteHeader(res.StatusCode)
|
||||
|
||||
err = U.Copy2(req.Context(), rw, res.Body)
|
||||
_, err = io.Copy(rw, res.Body)
|
||||
if err != nil {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
p.errorHandler(rw, req, err, true)
|
||||
|
||||
@@ -53,15 +53,15 @@ func RegisterProvider(configSubTask task.Task, cfg ProviderConfig) (Provider, er
|
||||
Subject(name).
|
||||
Withf(strutils.DoYouMean(utils.NearestField(name, Providers)))
|
||||
}
|
||||
if provider, err := createFunc(cfg); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
|
||||
provider, err := createFunc(cfg)
|
||||
if err == nil {
|
||||
dispatcher.providers.Add(provider)
|
||||
configSubTask.OnCancel("remove provider", func() {
|
||||
dispatcher.providers.Remove(provider)
|
||||
})
|
||||
return provider, nil
|
||||
}
|
||||
return provider, err
|
||||
}
|
||||
|
||||
func (disp *Dispatcher) start() {
|
||||
|
||||
@@ -68,7 +68,7 @@ func validateRPEntry(m *RawEntry, s fields.Scheme, errs *E.Builder) *ReverseProx
|
||||
port := E.Collect(errs, fields.ValidatePort, m.Port)
|
||||
pathPats := E.Collect(errs, fields.ValidatePathPatterns, m.PathPatterns)
|
||||
url := E.Collect(errs, url.Parse, fmt.Sprintf("%s://%s:%d", s, host, port))
|
||||
iwCfg := E.Collect(errs, idlewatcher.ValidateConfig, m.Container)
|
||||
iwCfg := E.Collect(errs, idlewatcher.ValidateConfig, cont)
|
||||
|
||||
if errs.HasError() {
|
||||
return nil
|
||||
|
||||
@@ -61,7 +61,7 @@ func validateStreamEntry(m *RawEntry, errs *E.Builder) *StreamEntry {
|
||||
port := E.Collect(errs, fields.ValidateStreamPort, m.Port)
|
||||
scheme := E.Collect(errs, fields.ValidateStreamScheme, m.Scheme)
|
||||
url := E.Collect(errs, net.ParseURL, fmt.Sprintf("%s://%s:%d", scheme.ListeningScheme, host, port.ListeningPort))
|
||||
idleWatcherCfg := E.Collect(errs, idlewatcher.ValidateConfig, m.Container)
|
||||
idleWatcherCfg := E.Collect(errs, idlewatcher.ValidateConfig, cont)
|
||||
|
||||
if errs.HasError() {
|
||||
return nil
|
||||
|
||||
@@ -89,7 +89,6 @@ func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.Error) {
|
||||
r := &HTTPRoute{
|
||||
ReverseProxyEntry: entry,
|
||||
rp: rp,
|
||||
task: task.DummyTask(),
|
||||
l: logger.With().
|
||||
Str("type", string(entry.Scheme)).
|
||||
Str("name", string(entry.Alias)).
|
||||
@@ -211,40 +210,51 @@ func (r *HTTPRoute) addToLoadBalancer() {
|
||||
|
||||
func ProxyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
mux, err := findMuxFunc(r.Host)
|
||||
if err == nil {
|
||||
mux.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
// Why use StatusNotFound instead of StatusBadRequest or StatusBadGateway?
|
||||
// On nginx, when route for domain does not exist, it returns StatusBadGateway.
|
||||
// Then scraper / scanners will know the subdomain is invalid.
|
||||
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
|
||||
if err != nil {
|
||||
if !middleware.ServeStaticErrorPageFile(w, r) {
|
||||
logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request")
|
||||
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
|
||||
if ok {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if _, err := w.Write(errorPage); err != nil {
|
||||
logger.Err(err).Msg("failed to write error page")
|
||||
}
|
||||
} else {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
if !middleware.ServeStaticErrorPageFile(w, r) {
|
||||
logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request")
|
||||
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
|
||||
if ok {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if _, err := w.Write(errorPage); err != nil {
|
||||
logger.Err(err).Msg("failed to write error page")
|
||||
}
|
||||
} else {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
}
|
||||
return
|
||||
}
|
||||
mux.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func findMuxAnyDomain(host string) (http.Handler, error) {
|
||||
hostSplit := strings.Split(host, ".")
|
||||
n := len(hostSplit)
|
||||
if n <= 2 {
|
||||
switch {
|
||||
case n == 3:
|
||||
host = hostSplit[0]
|
||||
case n > 3:
|
||||
var builder strings.Builder
|
||||
builder.Grow(2*n - 3)
|
||||
builder.WriteString(hostSplit[0])
|
||||
for _, part := range hostSplit[:n-2] {
|
||||
builder.WriteRune('.')
|
||||
builder.WriteString(part)
|
||||
}
|
||||
host = builder.String()
|
||||
default:
|
||||
return nil, errors.New("missing subdomain in url")
|
||||
}
|
||||
sd := strings.Join(hostSplit[:n-2], ".")
|
||||
if r, ok := httpRoutes.Load(sd); ok {
|
||||
if r, ok := httpRoutes.Load(host); ok {
|
||||
return r.handler, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no such route: %s", sd)
|
||||
return nil, fmt.Errorf("no such route: %s", host)
|
||||
}
|
||||
|
||||
func findMuxByDomains(domains []string) func(host string) (http.Handler, error) {
|
||||
@@ -252,20 +262,18 @@ func findMuxByDomains(domains []string) func(host string) (http.Handler, error)
|
||||
var subdomain string
|
||||
|
||||
for _, domain := range domains {
|
||||
if !strings.HasPrefix(domain, ".") {
|
||||
domain = "." + domain
|
||||
}
|
||||
subdomain = strings.TrimSuffix(host, domain)
|
||||
if len(subdomain) < len(host) {
|
||||
if strings.HasSuffix(host, domain) {
|
||||
subdomain = strings.TrimSuffix(host, domain)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(subdomain) == len(host) { // not matched
|
||||
return nil, fmt.Errorf("%s does not match any base domain", host)
|
||||
|
||||
if subdomain != "" { // matched
|
||||
if r, ok := httpRoutes.Load(subdomain); ok {
|
||||
return r.handler, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no such route: %s", subdomain)
|
||||
}
|
||||
if r, ok := httpRoutes.Load(subdomain); ok {
|
||||
return r.handler, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no such route: %s", subdomain)
|
||||
return nil, fmt.Errorf("%s does not match any base domain", host)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,9 +18,9 @@ type EventHandler struct {
|
||||
updated *E.Builder
|
||||
}
|
||||
|
||||
func (provider *Provider) newEventHandler() *EventHandler {
|
||||
func (p *Provider) newEventHandler() *EventHandler {
|
||||
return &EventHandler{
|
||||
provider: provider,
|
||||
provider: p,
|
||||
errs: E.NewBuilder("event errors"),
|
||||
added: E.NewBuilder("added"),
|
||||
removed: E.NewBuilder("removed"),
|
||||
@@ -60,11 +60,12 @@ func (handler *EventHandler) Handle(parent task.Task, events []watcher.Event) {
|
||||
|
||||
oldRoutes.RangeAll(func(k string, oldr *route.Route) {
|
||||
newr, ok := newRoutes.Load(k)
|
||||
if !ok {
|
||||
switch {
|
||||
case !ok:
|
||||
handler.Remove(oldr)
|
||||
} else if handler.matchAny(events, newr) {
|
||||
case handler.matchAny(events, newr):
|
||||
handler.Update(parent, oldr, newr)
|
||||
} else if entry.ShouldNotServe(newr) {
|
||||
case entry.ShouldNotServe(newr):
|
||||
handler.Remove(oldr)
|
||||
}
|
||||
})
|
||||
@@ -122,11 +123,11 @@ func (handler *EventHandler) Update(parent task.Task, oldRoute *route.Route, new
|
||||
}
|
||||
|
||||
func (handler *EventHandler) Log() {
|
||||
results := E.NewBuilder("event occured")
|
||||
results.Add(handler.added.Error())
|
||||
results.Add(handler.removed.Error())
|
||||
results.Add(handler.updated.Error())
|
||||
results.Add(handler.errs.Error())
|
||||
results := E.NewBuilder("event occurred")
|
||||
results.AddFrom(handler.added, false)
|
||||
results.AddFrom(handler.removed, false)
|
||||
results.AddFrom(handler.updated, false)
|
||||
results.AddFrom(handler.errs, false)
|
||||
if result := results.String(); result != "" {
|
||||
handler.provider.Logger().Info().Msg(result)
|
||||
}
|
||||
|
||||
@@ -45,9 +45,7 @@ const (
|
||||
providerEventFlushInterval = 300 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyProviderName = errors.New("empty provider name")
|
||||
)
|
||||
var ErrEmptyProviderName = errors.New("empty provider name")
|
||||
|
||||
func newProvider(name string, t ProviderType) *Provider {
|
||||
return &Provider{
|
||||
@@ -109,12 +107,11 @@ func (p *Provider) startRoute(parent task.Task, r *R.Route) E.Error {
|
||||
p.routes.Delete(r.Entry.Alias)
|
||||
subtask.Finish(err) // just to ensure
|
||||
return err.Subject(r.Entry.Alias)
|
||||
} else {
|
||||
p.routes.Store(r.Entry.Alias, r)
|
||||
subtask.OnFinished("del from provider", func() {
|
||||
p.routes.Delete(r.Entry.Alias)
|
||||
})
|
||||
}
|
||||
p.routes.Store(r.Entry.Alias, r)
|
||||
subtask.OnFinished("del from provider", func() {
|
||||
p.routes.Delete(r.Entry.Alias)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -80,11 +80,12 @@ func FromEntries(entries entry.RawEntries) (Routes, E.Error) {
|
||||
entries.RangeAllParallel(func(alias string, en *entry.RawEntry) {
|
||||
en.Alias = alias
|
||||
r, err := NewRoute(en)
|
||||
if err != nil {
|
||||
switch {
|
||||
case err != nil:
|
||||
b.Add(err.Subject(alias))
|
||||
} else if entry.ShouldNotServe(r) {
|
||||
case entry.ShouldNotServe(r):
|
||||
return
|
||||
} else {
|
||||
default:
|
||||
routes.Store(alias, r)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
type StreamRoute struct {
|
||||
*entry.StreamEntry
|
||||
|
||||
stream net.Stream `json:"-"`
|
||||
stream net.Stream
|
||||
|
||||
HealthMon health.HealthMonitor `json:"health"`
|
||||
|
||||
@@ -44,7 +44,6 @@ func NewStreamRoute(entry *entry.StreamEntry) (impl, E.Error) {
|
||||
}
|
||||
return &StreamRoute{
|
||||
StreamEntry: entry,
|
||||
task: task.DummyTask(),
|
||||
l: logger.With().
|
||||
Str("type", string(entry.Scheme.ListeningScheme)).
|
||||
Str("name", entry.TargetName()).
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ var (
|
||||
baseURL = "https://github.com/yusing/go-proxy/raw/" + branch
|
||||
requiredConfigs = []Config{
|
||||
{common.ConfigBasePath, true, false, ""},
|
||||
{common.DotEnvPath, false, true, common.DotEnvExamplePath},
|
||||
{common.ComposeFileName, false, true, common.ComposeExampleFileName},
|
||||
{path.Join(common.ConfigBasePath, common.ConfigFileName), false, true, common.ConfigExampleFileName},
|
||||
}
|
||||
@@ -40,7 +41,7 @@ func Setup() {
|
||||
config.setup()
|
||||
}
|
||||
|
||||
log.Println("done")
|
||||
log.Println("setup finished")
|
||||
}
|
||||
|
||||
func (c *Config) setup() {
|
||||
@@ -96,7 +97,7 @@ func fetch(remoteFilename string, outFileName string) {
|
||||
log.Printf("%q already exists, downloading to %q\n", outFileName, remoteFilename)
|
||||
outFileName = remoteFilename
|
||||
}
|
||||
log.Printf("downloading %q\n", remoteFilename)
|
||||
log.Printf("downloading %q to %q\n", remoteFilename, outFileName)
|
||||
|
||||
url, err := url.JoinPath(baseURL, remoteFilename)
|
||||
if err != nil {
|
||||
@@ -120,7 +121,7 @@ func fetch(remoteFilename string, outFileName string) {
|
||||
log.Fatalf("failed to write to file: %s\n", err)
|
||||
}
|
||||
|
||||
log.Printf("downloaded to %q\n", outFileName)
|
||||
log.Print("done")
|
||||
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
package task
|
||||
|
||||
import "context"
|
||||
|
||||
type dummyTask struct{}
|
||||
|
||||
func DummyTask() (_ Task) {
|
||||
return
|
||||
}
|
||||
|
||||
// Context implements Task.
|
||||
func (d dummyTask) Context() context.Context {
|
||||
panic("call of dummyTask.Context")
|
||||
}
|
||||
|
||||
// Finish implements Task.
|
||||
func (d dummyTask) Finish() {}
|
||||
|
||||
// Name implements Task.
|
||||
func (d dummyTask) Name() string {
|
||||
return "Dummy Task"
|
||||
}
|
||||
|
||||
// OnComplete implements Task.
|
||||
func (d dummyTask) OnComplete(about string, fn func()) {
|
||||
panic("call of dummyTask.OnComplete")
|
||||
}
|
||||
|
||||
// Parent implements Task.
|
||||
func (d dummyTask) Parent() Task {
|
||||
panic("call of dummyTask.Parent")
|
||||
}
|
||||
|
||||
// Subtask implements Task.
|
||||
func (d dummyTask) Subtask(usageFmt string, args ...any) Task {
|
||||
panic("call of dummyTask.Subtask")
|
||||
}
|
||||
|
||||
// Wait implements Task.
|
||||
func (d dummyTask) Wait() {}
|
||||
|
||||
// WaitSubTasks implements Task.
|
||||
func (d dummyTask) WaitSubTasks() {}
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
@@ -28,8 +27,6 @@ func createGlobalTask() (t *task) {
|
||||
type (
|
||||
// Task controls objects' lifetime.
|
||||
//
|
||||
// Task must be initialized, use DummyTask if the task is not yet started.
|
||||
//
|
||||
// Objects that uses a task should implement the TaskStarter and the TaskFinisher interface.
|
||||
//
|
||||
// When passing a Task object to another function,
|
||||
@@ -167,8 +164,8 @@ func GlobalContextWait(timeout time.Duration) {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *task) trace() *zerolog.Event {
|
||||
return logger.Trace().Str("name", t.name)
|
||||
func (t *task) trace(msg string) {
|
||||
logger.Trace().Str("name", t.name).Msg(msg)
|
||||
}
|
||||
|
||||
func (t *task) Name() string {
|
||||
@@ -244,7 +241,7 @@ func (t *task) OnCancel(about string, fn func()) {
|
||||
<-t.ctx.Done()
|
||||
fn()
|
||||
onCompTask.Finish("done")
|
||||
t.trace().Msg("onCancel done: " + about)
|
||||
t.trace("onCancel done: " + about)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -284,10 +281,10 @@ func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, n
|
||||
parent.subTasksWg.Add(1)
|
||||
parent.subtasks.Add(subtask)
|
||||
if common.IsTrace {
|
||||
subtask.trace().Msg("started")
|
||||
subtask.trace("started")
|
||||
go func() {
|
||||
subtask.Wait()
|
||||
subtask.trace().Msg("finished: " + subtask.FinishCause().Error())
|
||||
subtask.trace("finished: " + subtask.FinishCause().Error())
|
||||
}()
|
||||
}
|
||||
go func() {
|
||||
|
||||
@@ -102,7 +102,7 @@ func (p BidirectionalPipe) Start() E.Error {
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// This is a copy of io.Copy with context handling
|
||||
// Author: yusing <yusing@6uo.me>
|
||||
// Author: yusing <yusing@6uo.me>.
|
||||
func Copy(dst *ContextWriter, src *ContextReader) (err error) {
|
||||
size := 32 * 1024
|
||||
if l, ok := src.Reader.(*io.LimitedReader); ok && int64(size) > l.N {
|
||||
|
||||
@@ -18,9 +18,10 @@ func NearestField(input string, s any) string {
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
if t.Kind() == reflect.Struct {
|
||||
switch t.Kind() {
|
||||
case reflect.Struct:
|
||||
fields = make([]string, 0)
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
for i := range t.NumField() {
|
||||
jsonTag, ok := t.Field(i).Tag.Lookup("json")
|
||||
if ok {
|
||||
fields = append(fields, jsonTag)
|
||||
@@ -28,13 +29,13 @@ func NearestField(input string, s any) string {
|
||||
fields = append(fields, t.Field(i).Name)
|
||||
}
|
||||
}
|
||||
} else if t.Kind() == reflect.Map {
|
||||
case reflect.Map:
|
||||
keys := reflect.ValueOf(s).MapKeys()
|
||||
fields = make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
fields[i] = key.String()
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
panic("unsupported type: " + t.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ var (
|
||||
ErrInvalidType = E.New("invalid type")
|
||||
ErrNilValue = E.New("nil")
|
||||
ErrUnsettable = E.New("unsettable")
|
||||
ErrUnsupportedConvertion = E.New("unsupported convertion")
|
||||
ErrUnsupportedConversion = E.New("unsupported conversion")
|
||||
ErrMapMissingColon = E.New("map missing colon")
|
||||
ErrMapTooManyColons = E.New("map too many colons")
|
||||
ErrUnknownField = E.New("unknown field")
|
||||
@@ -176,10 +176,10 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
||||
case reflect.Struct:
|
||||
mapping := make(map[string]reflect.Value)
|
||||
for _, field := range reflect.VisibleFields(dstT) {
|
||||
mapping[ToLowerNoSnake(field.Name)] = dstV.FieldByName(field.Name)
|
||||
mapping[strutils.ToLowerNoSnake(field.Name)] = dstV.FieldByName(field.Name)
|
||||
}
|
||||
for k, v := range src {
|
||||
if field, ok := mapping[ToLowerNoSnake(k)]; ok {
|
||||
if field, ok := mapping[strutils.ToLowerNoSnake(k)]; ok {
|
||||
err := Convert(reflect.ValueOf(v), field)
|
||||
if err != nil {
|
||||
errs.Add(err.Subject(k))
|
||||
@@ -199,11 +199,11 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
||||
if err != nil {
|
||||
errs.Add(err.Subject(k))
|
||||
}
|
||||
dstV.SetMapIndex(reflect.ValueOf(ToLowerNoSnake(k)), tmp)
|
||||
dstV.SetMapIndex(reflect.ValueOf(strutils.ToLowerNoSnake(k)), tmp)
|
||||
}
|
||||
return errs.Error()
|
||||
default:
|
||||
return ErrUnsupportedConvertion.Subject("deserialize to " + dstT.String())
|
||||
return ErrUnsupportedConversion.Subject("deserialize to " + dstT.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,12 +250,12 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error {
|
||||
case srcT.Kind() == reflect.Map:
|
||||
obj, ok := src.Interface().(SerializedObject)
|
||||
if !ok {
|
||||
return ErrUnsupportedConvertion.Subject(dstT.String() + " to " + srcT.String())
|
||||
return ErrUnsupportedConversion.Subject(dstT.String() + " to " + srcT.String())
|
||||
}
|
||||
return Deserialize(obj, dst.Addr().Interface())
|
||||
case srcT.Kind() == reflect.Slice:
|
||||
if dstT.Kind() != reflect.Slice {
|
||||
return ErrUnsupportedConvertion.Subject(dstT.String() + " to slice")
|
||||
return ErrUnsupportedConversion.Subject(dstT.String() + " to slice")
|
||||
}
|
||||
newSlice := reflect.MakeSlice(dstT, 0, src.Len())
|
||||
i := 0
|
||||
@@ -280,7 +280,7 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error {
|
||||
var ok bool
|
||||
// check if (*T).Convertor is implemented
|
||||
if converter, ok = dst.Addr().Interface().(Converter); !ok {
|
||||
return ErrUnsupportedConvertion.Subjectf("%s to %s", srcT, dstT)
|
||||
return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT)
|
||||
}
|
||||
|
||||
return converter.ConvertFrom(src.Interface())
|
||||
@@ -310,6 +310,7 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E
|
||||
}
|
||||
dst.Set(reflect.ValueOf(d))
|
||||
return
|
||||
default:
|
||||
}
|
||||
// primitive types / simple types
|
||||
switch dst.Kind() {
|
||||
@@ -392,7 +393,3 @@ func DeserializeJSON(j map[string]string, target any) error {
|
||||
}
|
||||
return json.Unmarshal(data, target)
|
||||
}
|
||||
|
||||
func ToLowerNoSnake(s string) string {
|
||||
return strings.ToLower(strings.ReplaceAll(s, "_", ""))
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package strutils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -57,6 +58,10 @@ func ParseBool(s string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func PortString(port uint16) string {
|
||||
return strconv.FormatUint(uint64(port), 10)
|
||||
}
|
||||
|
||||
func DoYouMean(s string) string {
|
||||
return "Did you mean " + ansi.HighlightGreen + s + ansi.Reset + "?"
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package strutils
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/text/cases"
|
||||
@@ -29,8 +28,8 @@ func ExtractPort(fullURL string) (int, error) {
|
||||
return Atoi(url.Port())
|
||||
}
|
||||
|
||||
func PortString(port uint16) string {
|
||||
return strconv.FormatUint(uint64(port), 10)
|
||||
func ToLowerNoSnake(s string) string {
|
||||
return strings.ToLower(strings.ReplaceAll(s, "_", ""))
|
||||
}
|
||||
|
||||
func LevenshteinDistance(a, b string) int {
|
||||
@@ -60,7 +59,7 @@ func LevenshteinDistance(a, b string) int {
|
||||
cost = 1
|
||||
}
|
||||
|
||||
v1[j+1] = min(v1[j]+1, v0[j+1]+1, v0[j]+cost)
|
||||
v1[j+1] = min3(v1[j]+1, v0[j+1]+1, v0[j]+cost)
|
||||
}
|
||||
|
||||
for j := 0; j <= len(b); j++ {
|
||||
@@ -71,7 +70,7 @@ func LevenshteinDistance(a, b string) int {
|
||||
return v1[len(b)]
|
||||
}
|
||||
|
||||
func min(a, b, c int) int {
|
||||
func min3(a, b, c int) int {
|
||||
if a < b && a < c {
|
||||
return a
|
||||
}
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
)
|
||||
|
||||
var logger = logging.With().Str("module", "health_mon").Logger()
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/yusing/go-proxy/internal/notif"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -31,34 +30,24 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
var monMap = F.NewMapOf[string, HealthMonitor]()
|
||||
|
||||
var (
|
||||
ErrNegativeInterval = errors.New("negative interval")
|
||||
)
|
||||
var ErrNegativeInterval = errors.New("negative interval")
|
||||
|
||||
func newMonitor(url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor {
|
||||
mon := &monitor{
|
||||
config: config,
|
||||
checkHealth: healthCheckFunc,
|
||||
startTime: time.Now(),
|
||||
task: task.DummyTask(),
|
||||
}
|
||||
mon.url.Store(url)
|
||||
mon.status.Store(StatusHealthy)
|
||||
return mon
|
||||
}
|
||||
|
||||
func Inspect(service string) (HealthMonitor, bool) {
|
||||
return monMap.Load(service)
|
||||
}
|
||||
|
||||
func (mon *monitor) ContextWithTimeout(cause string) (ctx context.Context, cancel context.CancelFunc) {
|
||||
if mon.task != nil {
|
||||
return context.WithTimeoutCause(mon.task.Context(), mon.config.Timeout, errors.New(cause))
|
||||
} else {
|
||||
return context.WithTimeoutCause(context.Background(), mon.config.Timeout, errors.New(cause))
|
||||
}
|
||||
return context.WithTimeoutCause(context.Background(), mon.config.Timeout, errors.New(cause))
|
||||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
@@ -85,9 +74,6 @@ func (mon *monitor) Start(routeSubtask task.Task) E.Error {
|
||||
return
|
||||
}
|
||||
|
||||
monMap.Store(mon.service, mon)
|
||||
defer monMap.Delete(mon.service)
|
||||
|
||||
ticker := time.NewTicker(mon.config.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user