Compare commits

...

21 Commits

Author SHA1 Message Date
yusing
69361aea1b fixed host set to localhost even on remote docker, fixed one error in provider causing all routes not to load 2024-09-21 18:23:20 +08:00
yusing
26e2154c64 fixed startup crash for file provider 2024-09-21 17:22:17 +08:00
Yuzerion
a29bf880bc Update docker.md
Too sleepy...
2024-09-21 16:08:11 +08:00
Yuzerion
1f6d03bdbb Update compose.example.yml 2024-09-21 16:07:12 +08:00
Yuzerion
4a7d898b8e Update docker.md 2024-09-21 16:06:32 +08:00
Yuzerion
521b694aec Update docker.md 2024-09-21 15:56:39 +08:00
yusing
a351de7441 github CI fix attempt 2024-09-21 14:32:52 +08:00
yusing
ab2dc26b76 fixing udp stream listening on wrong port 2024-09-21 14:18:29 +08:00
yusing
9a81b13b67 fixing tcp/udp error on closing 2024-09-21 13:40:20 +08:00
yusing
626bd9666b check release 2024-09-21 12:45:56 +08:00
yusing
d7eab2ebcd fixing idlewatcher 2024-09-21 09:42:40 +08:00
yusing
e48b9bbb0a 新增繁中README (未完成) 2024-09-19 21:16:38 +08:00
yusing
339411530b v0.5.0-rc5: merge 2024-09-19 20:42:12 +08:00
yusing
4a2d42bfa9 v0.5.0-rc5: check release 2024-09-19 20:40:03 +08:00
Yuzerion
81da9ad83a small fix 2024-09-18 09:10:41 +08:00
yusing
be7a766cb2 v0.5.0-rc5: added proxy.exclude label, refactored some code 2024-09-17 17:56:41 +08:00
yusing
83d1d027c6 added TZ env to docker compose example 2024-09-17 12:36:13 +08:00
yusing
21fcceb391 v0.5.0-rc4: initial support for ovh, provider generator implementation update, replaced all interface{} to any 2024-09-17 12:06:58 +08:00
yusing
82f06374f7 v0.5.0-rc4: fixing autocert issue, cache ACME registration, added ls-config option 2024-09-17 08:41:36 +08:00
yusing
04fd6543fd README update for sonarcloud badges, simplify some test code, fixed some sonarlint issues 2024-09-17 04:51:26 +08:00
yusing
409a18df38 update default branch for setup script 2024-09-17 03:54:55 +08:00
92 changed files with 3275 additions and 1578 deletions

View File

@@ -15,7 +15,7 @@ jobs:
tags: ${{ github.ref_name }}
- name: Tag as latest
if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
if: startsWith(github.ref, 'refs/tags/') && !contains(github.ref_name, '-')
run: |
docker tag ghcr.io/${{ github.repository }}:${{ github.ref_name }} ghcr.io/${{ github.repository }}:latest
docker push ghcr.io/${{ github.repository }}:latest

View File

@@ -2,13 +2,12 @@ FROM golang:1.23.1-alpine AS builder
COPY src /src
ENV GOCACHE=/root/.cache/go-build
WORKDIR /src
RUN --mount=type=cache,target="/go/pkg/mod" \
go mod download
RUN --mount=type=cache,target="/go/pkg/mod" \
--mount=type=cache,target="/root/.cache/go-build" \
go mod download && \
CGO_ENABLED=0 GOOS=linux go build -pgo=auto -o go-proxy github.com/yusing/go-proxy
FROM alpine:latest
FROM alpine:3.20
LABEL maintainer="yusing@6uo.me"

View File

@@ -12,7 +12,7 @@ build:
CGO_ENABLED=0 GOOS=linux go build -pgo=auto -o bin/go-proxy github.com/yusing/go-proxy
test:
cd src && go test ./... && cd ..
go test ./src/...
up:
docker compose up -d
@@ -27,7 +27,7 @@ get:
cd src && go get -u && go mod tidy && cd ..
debug:
make build && GOPROXY_DEBUG=1 bin/go-proxy
make build && sudo GOPROXY_DEBUG=1 bin/go-proxy
archive:
git archive HEAD -o ../go-proxy-$$(date +"%Y%m%d%H%M").zip
@@ -42,3 +42,6 @@ rapid-crash:
sudo docker run --restart=always --name test_crash debian:bookworm-slim /bin/cat &&\
sleep 3 &&\
sudo docker rm -f test_crash
debug-list-containers:
bash -c 'echo -e "GET /containers/json HTTP/1.0\r\n" | sudo netcat -U /var/run/docker.sock | tail -n +9 | jq'

View File

@@ -1,14 +1,24 @@
# go-proxy
A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse proxy and load balancer with a web UI.
[![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=yusing_go-proxy&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=yusing_go-proxy)
[![Lines of Code](https://sonarcloud.io/api/project_badges/measure?project=yusing_go-proxy&metric=ncloc)](https://sonarcloud.io/summary/new_code?id=yusing_go-proxy)
[![Security Rating](https://sonarcloud.io/api/project_badges/measure?project=yusing_go-proxy&metric=security_rating)](https://sonarcloud.io/summary/new_code?id=yusing_go-proxy)
[![Maintainability Rating](https://sonarcloud.io/api/project_badges/measure?project=yusing_go-proxy&metric=sqale_rating)](https://sonarcloud.io/summary/new_code?id=yusing_go-proxy)
[![Vulnerabilities](https://sonarcloud.io/api/project_badges/measure?project=yusing_go-proxy&metric=vulnerabilities)](https://sonarcloud.io/summary/new_code?id=yusing_go-proxy)
**Table of content**
[繁體中文文檔請看此](README_CHT.md)
A lightweight, easy-to-use, and [performant](docs/benchmark_result.md) reverse proxy with a web UI.
## Table of content
<!-- TOC -->
- [go-proxy](#go-proxy)
- [Table of content](#table-of-content)
- [Key Points](#key-points)
- [Getting Started](#getting-started)
- [Setup](#setup)
- [Commands line arguments](#commands-line-arguments)
- [Environment variables](#environment-variables)
- [Use JSON Schema in VSCode](#use-json-schema-in-vscode)
@@ -20,10 +30,14 @@ A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse pr
## Key Points
- Easy to use
- Effortless configuration
- Error messages is clear and detailed, easy troubleshooting
- Auto certificate obtaining and renewal (See [Supported DNS Challenge Providers](docs/dns_providers.md))
- Auto configuration for docker contaienrs
- Auto configuration for docker containers
- Auto hot-reload on container state / config file changes
- Support HTTP(s), TCP and UDP
- Stop containers on idle, wake it up on traffic _(optional)_
- HTTP(s) reserve proxy
- TCP and UDP port forwarding
- Web UI for configuration and monitoring (See [screenshots](https://github.com/yusing/go-proxy-frontend?tab=readme-ov-file#screenshots))
- Written in **[Go](https://go.dev)**
@@ -31,7 +45,9 @@ A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse pr
## Getting Started
1. Setup DNS Records
### Setup
1. Setup DNS Records, e.g.
- A Record: `*.y.z` -> `10.0.10.1`
- AAAA Record: `*.y.z` -> `::ffff:a00:a01`
@@ -39,27 +55,32 @@ A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse pr
2. Setup `go-proxy` [See here](docs/docker.md)
3. Configure `go-proxy`
- with text editor (i.e. Visual Studio Code)
- with text editor (e.g. Visual Studio Code)
- or with web config editor via `http://gp.y.z`
[🔼Back to top](#table-of-content)
### Commands line arguments
| Argument | Description |
| ---------- | -------------------------------- |
| empty | start proxy server |
| `validate` | validate config and exit |
| `reload` | trigger a force reload of config |
| Argument | Description | Example |
| ----------- | -------------------------------- | -------------------------- |
| empty | start proxy server | |
| `validate` | validate config and exit | |
| `reload` | trigger a force reload of config | |
| `ls-config` | list config and exit | `go-proxy ls-config \| jq` |
| `ls-route` | list proxy entries and exit | `go-proxy ls-route \| jq` |
**run with `docker exec <container_name> /app/go-proxy <command>`**
### Environment variables
| Environment Variable | Description | Default | Values |
| ------------------------------ | ------------------------- | ------- | ------- |
| `GOPROXY_NO_SCHEMA_VALIDATION` | disable schema validation | `false` | boolean |
| `GOPROXY_DEBUG` | enable debug behaviors | `false` | boolean |
| Environment Variable | Description | Default | Values |
| ------------------------------ | ----------------------------- | ------- | ------- |
| `GOPROXY_NO_SCHEMA_VALIDATION` | disable schema validation | `false` | boolean |
| `GOPROXY_DEBUG` | enable debug behaviors | `false` | boolean |
| `GOPROXY_HTTP_PORT` | http server port | `80` | integer |
| `GOPROXY_HTTPS_PORT` | http server port (if enabled) | `443` | integer |
| `GOPROXY_API_PORT` | api server port | `8888` | integer |
### Use JSON Schema in VSCode
@@ -95,7 +116,7 @@ providers:
### Provider File
Fields are same as [docker labels](docs/docker.md#labels) starting from `scheme`
See [Fields](docs/docker.md#fields)
See [providers.example.yml](providers.example.yml) for examples
@@ -105,6 +126,8 @@ See [providers.example.yml](providers.example.yml) for examples
- Cert "renewal" is actually obtaining a new cert instead of renewing the existing one
- `autocert` config is not hot-reloadable
[🔼Back to top](#table-of-content)
## Build it yourself
@@ -119,6 +142,4 @@ See [providers.example.yml](providers.example.yml) for examples
5. build binary with `make build`
6. start your container with `make up` (docker) or `bin/go-proxy` (binary)
[🔼Back to top](#table-of-content)

140
README_CHT.md Normal file
View File

@@ -0,0 +1,140 @@
# go-proxy
[![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=yusing_go-proxy&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=yusing_go-proxy)
[![Lines of Code](https://sonarcloud.io/api/project_badges/measure?project=yusing_go-proxy&metric=ncloc)](https://sonarcloud.io/summary/new_code?id=yusing_go-proxy)
[![Security Rating](https://sonarcloud.io/api/project_badges/measure?project=yusing_go-proxy&metric=security_rating)](https://sonarcloud.io/summary/new_code?id=yusing_go-proxy)
[![Maintainability Rating](https://sonarcloud.io/api/project_badges/measure?project=yusing_go-proxy&metric=sqale_rating)](https://sonarcloud.io/summary/new_code?id=yusing_go-proxy)
[![Vulnerabilities](https://sonarcloud.io/api/project_badges/measure?project=yusing_go-proxy&metric=vulnerabilities)](https://sonarcloud.io/summary/new_code?id=yusing_go-proxy)
一個輕量化、易用且[高效](docs/benchmark_result.md)的反向代理工具
## 目錄
<!-- TOC -->
- [go-proxy](#go-proxy)
- [目錄](#目錄)
- [重點](#重點)
- [入門指南](#入門指南)
- [安裝](#安裝)
- [命令行參數](#命令行參數)
- [環境變量](#環境變量)
- [VSCode 中使用 JSON Schema](#vscode-中使用-json-schema)
- [配置文件](#配置文件)
- [透過文件配置](#透過文件配置)
- [已知問題](#已知問題)
- [源碼編譯](#源碼編譯)
## 重點
- 易用
- 不需花費太多時間就能輕鬆配置
- 除錯簡單
- 自動處理 HTTPS 證書(參見[可用的 DNS 供應商](docs/dns_providers.md)
- 透過 Docker 容器自動配置
- 容器狀態變更時自動熱重載
- 容器閒置時自動暫停/停止,入站時自動喚醒
- HTTP(s)反向代理
- TCP/UDP 端口轉發
- 用於配置和監控的前端 Web 面板([截圖](https://github.com/yusing/go-proxy-frontend?tab=readme-ov-file#screenshots)
- 使用 **[Go](https://go.dev)** 編寫
[🔼 返回頂部](#目錄)
## 入門指南
### 安裝
1. 設置 DNS 記錄,例如:
- A 記錄: `*.y.z` -> `10.0.10.1`
- AAAA 記錄: `*.y.z` -> `::ffff:a00:a01`
2. 安裝 `go-proxy` [參見這裡](docs/docker.md)
3. 配置 `go-proxy`
- 使用文本編輯器 (推薦 Visual Studio Code [參見 VSCode 使用 schema](#vscode-中使用-json-schema))
- 或通過 `http://gp.y.z` 使用網頁配置編輯器
[🔼 返回頂部](#目錄)
### 命令行參數
| 參數 | 描述 | 示例 |
| ----------- | -------------- | -------------------------- |
| 空 | 啟動代理服務器 | |
| `validate` | 驗證配置並退出 | |
| `reload` | 強制刷新配置 | |
| `ls-config` | 列出配置並退出 | `go-proxy ls-config \| jq` |
| `ls-route` | 列出路由並退出 | `go-proxy ls-route \| jq` |
**使用 `docker exec <容器名稱> /app/go-proxy <參數>` 運行**
### 環境變量
| 環境變量 | 描述 | 默認 | 值 |
| ------------------------------ | ---------------- | ------- | ------- |
| `GOPROXY_NO_SCHEMA_VALIDATION` | 禁用 schema 驗證 | `false` | boolean |
| `GOPROXY_DEBUG` | 啟用調試輸出 | `false` | boolean |
### VSCode 中使用 JSON Schema
複製 [`.vscode/settings.example.json`](.vscode/settings.example.json) 到 `.vscode/settings.json` 並根據需求修改
[🔼 返回頂部](#目錄)
### 配置文件
參見 [config.example.yml](config.example.yml) 了解更多
```yaml
# autocert 配置
autocert:
email: # ACME 電子郵件
domains: # 域名列表
provider: # DNS 供應商
options: # 供應商個別配置
- ...
# 配置文件 / docker
providers:
include:
- providers.yml
- other_file_1.yml
- ...
docker:
local: $DOCKER_HOST
remote-1: tcp://10.0.2.1:2375
remote-2: ssh://root:1234@10.0.2.2
```
[🔼 返回頂部](#目錄)
### 透過文件配置
參見 [Fields](docs/docker.md#fields)
參見範例 [providers.example.yml](providers.example.yml)
[🔼 返回頂部](#目錄)
## 已知問題
- 證書“更新”實際上是獲取新證書而不是更新現有證書
- `autocert` 配置不能熱重載
[🔼 返回頂部](#目錄)
## 源碼編譯
1. 獲取源碼 `git clone https://github.com/yusing/go-proxy --depth=1`
2. 安裝/升級 [go 版本 (>=1.22)](https://go.dev/doc/install) 和 `make`(如果尚未安裝)
3. 如果之前編譯過go 版本 < 1.22),請使用 `go clean -cache` 清除緩存
4. 使用 `make get` 獲取依賴項
5. 使用 `make build` 編譯
[🔼 返回頂部](#目錄)

View File

@@ -6,7 +6,7 @@ services:
network_mode: host
labels:
- proxy.aliases=gp
- proxy.gp.port=8888
- proxy.gp.port=3000
depends_on:
- app
app:
@@ -14,8 +14,11 @@ services:
container_name: go-proxy
restart: always
network_mode: host
environment:
# (Optional) change this to your timezone to get correct log timestamp
TZ: ETC/UTC
volumes:
- /var/run/docker.sock:/var/run/docker.sock:ro
- /var/run/docker.sock:/var/run/docker.sock
- ./config:/app/config
# (Optional) choose one of below to enable https

View File

@@ -7,7 +7,7 @@
```go
var providersGenMap = map[string]ProviderGenerator{
"cloudflare": providerGenerator(cloudflare.NewDefaultConfig, cloudflare.NewDNSProviderConfig),
// add here, i.e.
// add here, e.g.
"clouddns": providerGenerator(clouddns.NewDefaultConfig, clouddns.NewDNSProviderConfig),
}
```

View File

@@ -1,11 +1,13 @@
# Supported DNS Providers
<!-- TOC -->
- [Cloudflare](#cloudflare)
- [CloudDNS](#clouddns)
- [DuckDNS](#duckdns)
- [Implement other DNS providers](#implement-other-dns-providers)
<!-- /TOC -->
- [Supported DNS Providers](#supported-dns-providers)
- [Cloudflare](#cloudflare)
- [CloudDNS](#clouddns)
- [DuckDNS](#duckdns)
- [OVHCloud](#ovhcloud)
- [Implement other DNS providers](#implement-other-dns-providers)
## Cloudflare
@@ -23,10 +25,29 @@ Follow [this guide](https://cloudkul.com/blog/automcatic-renew-and-generate-ssl-
## DuckDNS
`token`: DuckDNS Token
- `token`: DuckDNS Token
Tested by [earvingad](https://github.com/earvingad)
## OVHCloud
_Note, `application_key` and `oauth2_config` **CANNOT** be used together_
- `api_endpoint`: Endpoint URL, or one of
- `ovh-eu`,
- `ovh-ca`,
- `ovh-us`,
- `kimsufi-eu`,
- `kimsufi-ca`,
- `soyoustart-eu`,
- `soyoustart-ca`
- `application_secret`
- `application_key`
- `consumer_key`
- `oauth2_config`: Client ID and Client Secret
- `client_id`
- `client_secret`
## Implement other DNS providers
See [add_dns_provider.md](docs/add_dns_provider.md)

View File

@@ -85,11 +85,18 @@
### Syntax
| Label | Description | Default |
| ----------------------- | -------------------------------------------------------- | ---------------- |
| `proxy.aliases` | comma separated aliases for subdomain and label matching | `container_name` |
| `proxy.<alias>.<field>` | set field for specific alias | N/A |
| `proxy.*.<field>` | set field for all aliases | N/A |
| Label | Description | Default | Accepted values |
| ------------------------ | --------------------------------------------------------------------- | -------------------- | ------------------------------------------------------------------------- |
| `proxy.aliases` | comma separated aliases for subdomain and label matching | `container_name` | any |
| `proxy.exclude` | to be excluded from `go-proxy` | false | boolean |
| `proxy.idle_timeout` | time for idle (no traffic) before put it into sleep **(http/s only)** | empty **(disabled)** | `number[unit]...`, e.g. `1m30s` |
| `proxy.wake_timeout` | time to wait for container to start before responding a loading page | empty | `number[unit]...` |
| `proxy.stop_method` | method to stop after `idle_timeout` | `stop` | `stop`, `pause`, `kill` |
| `proxy.stop_timeout` | time to wait for stop command | `10s` | `number[unit]...` |
| `proxy.stop_signal` | signal sent to container for `stop` and `kill` methods | docker's default | `SIGINT`, `SIGTERM`, `SIGHUP`, `SIGQUIT` and those without **SIG** prefix |
| `proxy.<alias>.<field>` | set field for specific alias | N/A | N/A |
| `proxy.$<index>.<field>` | set field for specific alias at index (starting from **1**) | N/A | N/A |
| `proxy.*.<field>` | set field for all aliases | N/A | N/A |
### Fields
@@ -172,7 +179,7 @@ service_a:
- Container not showing up in proxies list
Please check that either `ports` or label `proxy.<alias>.port` is declared, i.e.
Please check that either `ports` or label `proxy.<alias>.port` is declared, e.g.
```yaml
services:
@@ -183,6 +190,7 @@ service_a:
nginx-2: # Option 2
...
container_name: nginx-2
network_mode: host
labels:
proxy.nginx-2.port: 80
```
@@ -220,19 +228,25 @@ services:
restart: unless-stopped
labels:
- proxy.aliases=adg,adg-dns,adg-setup
- proxy.adg.port=80
- proxy.adg-setup.port=3000
- proxy.adg-dns.scheme=udp
- proxy.adg-dns.port=20000:dns
- proxy.$1.port=80
- proxy.$2.scheme=udp
- proxy.$2.port=20000:dns
- proxy.$3.port=3000
volumes:
- adg-work:/opt/adguardhome/work
- adg-conf:/opt/adguardhome/conf
ports:
- 80
- 3000
- 53/udp
mc:
image: itzg/minecraft-server
tty: true
stdin_open: true
container_name: mc
restart: unless-stopped
ports:
- 25565
labels:
- proxy.mc.scheme=tcp
- proxy.mc.port=20001:25565
@@ -245,11 +259,14 @@ services:
restart: unless-stopped
container_name: pal
stop_grace_period: 30s
ports:
- 8211/udp
- 27015/udp
labels:
- proxy.aliases=pal1,pal2
- proxy.*.scheme=udp
- proxy.pal1.port=20002:8211
- proxy.pal2.port=20003:27015
- proxy.$1.port=20002:8211
- proxy.$2.port=20003:27015
environment: ...
volumes:
- palworld:/palworld
@@ -260,6 +277,8 @@ services:
- nginx:/usr/share/nginx/html
ports:
- 80
labels:
proxy.idle_timeout: 1m
go-proxy:
image: ghcr.io/yusing/go-proxy:latest
container_name: go-proxy
@@ -267,7 +286,7 @@ services:
network_mode: host
volumes:
- ./config:/app/config
- /var/run/docker.sock:/var/run/docker.sock:ro
- /var/run/docker.sock:/var/run/docker.sock
go-proxy-frontend:
image: ghcr.io/yusing/go-proxy-frontend:latest
container_name: go-proxy-frontend
@@ -275,7 +294,7 @@ services:
network_mode: host
labels:
- proxy.aliases=gp
- proxy.gp.port=8888
- proxy.gp.port=3000
depends_on:
- go-proxy
```
@@ -288,8 +307,8 @@ services:
- `adg-setup.yourdomain.com`: adguard setup (first time setup)
- `adg.yourdomain.com`: adguard dashboard
- `nginx.yourdomain.com`: nginx
- `yourdomain.com:53`: adguard dns
- `yourdomain.com:25565`: minecraft server
- `yourdomain.com:8211`: palworld server
- `yourdomain.com:2000`: adguard dns (udp)
- `yourdomain.com:20001`: minecraft server
- `yourdomain.com:20002`: palworld server
[🔼Back to top](#table-of-content)

View File

@@ -1,4 +1,4 @@
go 1.22
go 1.22.0
toolchain go1.23.1

View File

@@ -37,7 +37,7 @@
"title": "DNS Challenge Provider",
"default": "local",
"type": "string",
"enum": ["local", "cloudflare", "clouddns", "duckdns"]
"enum": ["local", "cloudflare", "clouddns", "duckdns", "ovh"]
},
"options": {
"title": "Provider specific options",
@@ -135,6 +135,82 @@
}
}
}
},
{
"if": {
"properties": {
"provider": {
"const": "ovh"
}
}
},
"then": {
"properties": {
"options": {
"required": ["application_secret", "consumer_key"],
"additionalProperties": false,
"oneOf": [
{
"required": ["application_key"]
},
{
"required": ["oauth2_config"]
}
],
"properties": {
"api_endpoint": {
"description": "OVH API endpoint",
"default": "ovh-eu",
"anyOf": [
{
"enum": [
"ovh-eu",
"ovh-ca",
"ovh-us",
"kimsufi-eu",
"kimsufi-ca",
"soyoustart-eu",
"soyoustart-ca"
]
},
{
"type": "string",
"format": "uri"
}
]
},
"application_secret": {
"description": "OVH Application Secret",
"type": "string"
},
"consumer_key": {
"description": "OVH Consumer Key",
"type": "string"
},
"application_key": {
"description": "OVH Application Key",
"type": "string"
},
"oauth2_config": {
"description": "OVH OAuth2 config",
"type": "object",
"additionalProperties": false,
"properties": {
"client_id": {
"description": "OVH Client ID",
"type": "string"
},
"client_secret": {
"description": "OVH Client Secret",
"type": "string"
}
},
"required": ["client_id", "client_secret"]
}
}
}
}
}
}
]
},

View File

@@ -2,7 +2,7 @@
set -e
if [ -z "$BRANCH" ]; then
BRANCH="main"
BRANCH="v0.5"
fi
BASE_URL="https://github.com/yusing/go-proxy/raw/${BRANCH}"
mkdir -p go-proxy

View File

@@ -3,6 +3,7 @@ package v1
import (
"fmt"
"net/http"
"strings"
U "github.com/yusing/go-proxy/api/v1/utils"
"github.com/yusing/go-proxy/config"
@@ -17,17 +18,19 @@ func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
}
var ok bool
route := cfg.FindRoute(target)
switch route := cfg.FindRoute(target).(type) {
case nil:
switch {
case route == nil:
U.HandleErr(w, r, U.ErrNotFound("target", target), http.StatusNotFound)
return
case *R.HTTPRoute:
ok = U.IsSiteHealthy(route.TargetURL.String())
case *R.StreamRoute:
case route.Type() == R.RouteTypeReverseProxy:
ok = U.IsSiteHealthy(route.URL().String())
case route.Type() == R.RouteTypeStream:
entry := route.Entry()
ok = U.IsStreamHealthy(
string(route.Scheme.ProxyScheme),
fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort),
strings.Split(entry.Scheme, ":")[1], // target scheme
fmt.Sprintf("%s:%v", entry.Host, strings.Split(entry.Port, ":")[1]),
)
}

View File

@@ -9,7 +9,6 @@ import (
U "github.com/yusing/go-proxy/api/v1/utils"
"github.com/yusing/go-proxy/common"
"github.com/yusing/go-proxy/config"
E "github.com/yusing/go-proxy/error"
"github.com/yusing/go-proxy/proxy/provider"
)
@@ -32,25 +31,25 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) {
U.HandleErr(w, r, U.ErrMissingKey("filename"), http.StatusBadRequest)
return
}
content, err := E.Check(io.ReadAll(r.Body))
if err.IsNotNil() {
content, err := io.ReadAll(r.Body)
if err != nil {
U.HandleErr(w, r, err)
return
}
if filename == common.ConfigFileName {
err = config.Validate(content)
err = config.Validate(content).Error()
} else {
err = provider.Validate(content)
err = provider.Validate(content).Error()
}
if err.IsNotNil() {
if err != nil {
U.HandleErr(w, r, err, http.StatusBadRequest)
return
}
err = E.From(os.WriteFile(path.Join(common.ConfigBasePath, filename), content, 0644))
if err.IsNotNil() {
err = os.WriteFile(path.Join(common.ConfigBasePath, filename), content, 0644)
if err != nil {
U.HandleErr(w, r, err)
return
}

View File

@@ -29,10 +29,10 @@ func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
func listRoutes(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
routes := cfg.RoutesByAlias()
type_filter := r.FormValue("type")
if type_filter != "" {
typeFilter := r.FormValue("type")
if typeFilter != "" {
for k, v := range routes {
if v["type"] != type_filter {
if v["type"] != typeFilter {
delete(routes, k)
}
}

View File

@@ -8,7 +8,7 @@ import (
)
func Reload(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
if err := cfg.Reload(); err.IsNotNil() {
if err := cfg.Reload().Error(); err != nil {
U.HandleErr(w, r, err)
return
}

View File

@@ -10,7 +10,7 @@ import (
)
func Stats(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
stats := map[string]interface{}{
stats := map[string]any{
"proxies": cfg.Statistics(),
"uptime": utils.FormatDuration(server.GetProxyServer().Uptime()),
}

View File

@@ -9,14 +9,14 @@ import (
E "github.com/yusing/go-proxy/error"
)
func HandleErr(w http.ResponseWriter, r *http.Request, err error, code ...int) {
err = E.From(err).Subjectf("%s %s", r.Method, r.URL)
func HandleErr(w http.ResponseWriter, r *http.Request, origErr error, code ...int) {
err := E.From(origErr).Subjectf("%s %s", r.Method, r.URL)
logrus.WithField("module", "api").Error(err)
if len(code) > 0 {
http.Error(w, err.Error(), code[0])
http.Error(w, err.String(), code[0])
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
http.Error(w, err.String(), http.StatusInternalServerError)
}
func ErrMissingKey(k string) error {

View File

@@ -44,7 +44,7 @@ func ReloadServer() E.NestedError {
if resp.StatusCode != http.StatusOK {
return E.Failure("server reload").Subjectf("status code: %v", resp.StatusCode)
}
return E.Nil()
return nil
}
var HttpClient = &http.Client{

View File

@@ -20,59 +20,56 @@ func NewConfig(cfg *M.AutoCertConfig) *Config {
if cfg.KeyPath == "" {
cfg.KeyPath = KeyFileDefault
}
if cfg.Provider == "" {
cfg.Provider = ProviderLocal
}
return (*Config)(cfg)
}
func (cfg *Config) GetProvider() (*Provider, E.NestedError) {
errors := E.NewBuilder("cannot create autocert provider")
func (cfg *Config) GetProvider() (provider *Provider, res E.NestedError) {
b := E.NewBuilder("unable to initialize autocert")
defer b.To(&res)
if cfg.Provider != ProviderLocal {
if len(cfg.Domains) == 0 {
errors.Addf("no domains specified")
b.Addf("no domains specified")
}
if cfg.Provider == "" {
errors.Addf("no provider specified")
b.Addf("no provider specified")
}
if cfg.Email == "" {
errors.Addf("no email specified")
b.Addf("no email specified")
}
// check if provider is implemented
_, ok := providersGenMap[cfg.Provider]
if !ok {
b.Addf("unknown provider: %q", cfg.Provider)
}
}
gen, ok := providersGenMap[cfg.Provider]
if !ok {
errors.Addf("unknown provider: %q", cfg.Provider)
}
if err := errors.Build(); err.IsNotNil() {
return nil, err
if b.HasError() {
return
}
privKey, err := E.Check(ecdsa.GenerateKey(elliptic.P256(), rand.Reader))
if err.IsNotNil() {
return nil, E.Failure("generate private key").With(err)
if err.HasError() {
b.Add(E.FailWith("generate private key", err))
return
}
user := &User{
Email: cfg.Email,
key: privKey,
}
legoCfg := lego.NewConfig(user)
legoCfg.Certificate.KeyType = certcrypto.RSA2048
legoClient, err := E.Check(lego.NewClient(legoCfg))
if err.IsNotNil() {
return nil, E.Failure("create lego client").With(err)
}
base := &Provider{
provider = &Provider{
cfg: cfg,
user: user,
legoCfg: legoCfg,
client: legoClient,
}
legoProvider, err := E.Check(gen(cfg.Options))
if err.IsNotNil() {
return nil, E.Failure("create lego provider").With(err)
}
err = E.From(legoClient.Challenge.SetDNS01Provider(legoProvider))
if err.IsNotNil() {
return nil, E.Failure("set challenge provider").With(err)
}
return base, E.Nil()
return
}

View File

@@ -1,16 +1,20 @@
package autocert
import (
"errors"
"github.com/go-acme/lego/v4/providers/dns/clouddns"
"github.com/go-acme/lego/v4/providers/dns/cloudflare"
"github.com/go-acme/lego/v4/providers/dns/duckdns"
"github.com/go-acme/lego/v4/providers/dns/ovh"
"github.com/sirupsen/logrus"
)
const (
certBasePath = "certs/"
CertFileDefault = certBasePath + "cert.crt"
KeyFileDefault = certBasePath + "priv.key"
certBasePath = "certs/"
CertFileDefault = certBasePath + "cert.crt"
KeyFileDefault = certBasePath + "priv.key"
RegistrationFile = certBasePath + "registration.json"
)
const (
@@ -18,14 +22,19 @@ const (
ProviderCloudflare = "cloudflare"
ProviderClouddns = "clouddns"
ProviderDuckdns = "duckdns"
ProviderOVH = "ovh"
)
var providersGenMap = map[string]ProviderGenerator{
"": providerGenerator(NewDummyDefaultConfig, NewDummyDNSProviderConfig),
ProviderLocal: providerGenerator(NewDummyDefaultConfig, NewDummyDNSProviderConfig),
ProviderCloudflare: providerGenerator(cloudflare.NewDefaultConfig, cloudflare.NewDNSProviderConfig),
ProviderClouddns: providerGenerator(clouddns.NewDefaultConfig, clouddns.NewDNSProviderConfig),
ProviderDuckdns: providerGenerator(duckdns.NewDefaultConfig, duckdns.NewDNSProviderConfig),
ProviderOVH: providerGenerator(ovh.NewDefaultConfig, ovh.NewDNSProviderConfig),
}
var Logger = logrus.WithField("module", "autocert")
var (
ErrGetCertFailure = errors.New("get certificate failed")
)
var logger = logrus.WithField("module", "autocert")

View File

@@ -5,18 +5,17 @@ import (
"crypto/tls"
"crypto/x509"
"os"
"slices"
"sync"
"reflect"
"sort"
"time"
"github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/registration"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models"
"github.com/yusing/go-proxy/utils"
U "github.com/yusing/go-proxy/utils"
)
type Provider struct {
@@ -27,15 +26,14 @@ type Provider struct {
tlsCert *tls.Certificate
certExpiries CertExpiries
mutex sync.Mutex
}
type ProviderGenerator func(M.AutocertProviderOpt) (challenge.Provider, error)
type ProviderGenerator func(M.AutocertProviderOpt) (challenge.Provider, E.NestedError)
type CertExpiries map[string]time.Time
func (p *Provider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
if p.tlsCert == nil {
return nil, E.Failure("get certificate")
return nil, ErrGetCertFailure
}
return p.tlsCert, nil
}
@@ -56,60 +54,81 @@ func (p *Provider) GetExpiries() CertExpiries {
return p.certExpiries
}
func (p *Provider) ObtainCert() E.NestedError {
ne := E.Failure("obtain certificate")
func (p *Provider) ObtainCert() (res E.NestedError) {
b := E.NewBuilder("failed to obtain certificate")
defer b.To(&res)
if p.cfg.Provider == ProviderLocal {
b.Addf("provider is set to %q", ProviderLocal)
return
}
if p.client == nil {
if err := p.initClient(); err.HasError() {
b.Add(E.FailWith("init autocert client", err))
return
}
}
if p.user.Registration == nil {
if err := p.loadRegistration(); err.HasError() {
if err := p.registerACME(); err.HasError() {
b.Add(E.FailWith("register ACME", err))
return
}
}
}
client := p.client
if p.user.Registration == nil {
reg, err := E.Check(client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}))
if err.IsNotNil() {
return ne.With(E.Failure("register account").With(err))
}
p.user.Registration = reg
}
req := certificate.ObtainRequest{
Domains: p.cfg.Domains,
Bundle: true,
}
cert, err := E.Check(client.Certificate.Obtain(req))
if err.IsNotNil() {
return ne.With(err)
if err.HasError() {
b.Add(err)
return
}
err = p.saveCert(cert)
if err.IsNotNil() {
return ne.With(E.Failure("save certificate").With(err))
if err.HasError() {
b.Add(E.FailWith("save certificate", err))
return
}
tlsCert, err := E.Check(tls.X509KeyPair(cert.Certificate, cert.PrivateKey))
if err.IsNotNil() {
return ne.With(E.Failure("parse obtained certificate").With(err))
if err.HasError() {
b.Add(E.FailWith("parse obtained certificate", err))
return
}
expiries, err := getCertExpiries(&tlsCert)
if err.IsNotNil() {
return ne.With(E.Failure("get certificate expiry").With(err))
if err.HasError() {
b.Add(E.FailWith("get certificate expiry", err))
return
}
p.tlsCert = &tlsCert
p.certExpiries = expiries
return E.Nil()
return nil
}
func (p *Provider) LoadCert() E.NestedError {
cert, err := E.Check(tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath))
if err.IsNotNil() {
if err.HasError() {
return err
}
expiries, err := getCertExpiries(&cert)
if err.IsNotNil() {
if err.HasError() {
return err
}
p.tlsCert = &cert
p.certExpiries = expiries
p.renewIfNeeded()
return E.Nil()
logger.Infof("next renewal in %v", U.FormatDuration(time.Until(p.ShouldRenewOn())))
return p.renewIfNeeded()
}
func (p *Provider) ShouldRenewOn() time.Time {
for _, expiry := range p.certExpiries {
return expiry.AddDate(0, -1, 0)
return expiry.AddDate(0, -1, 0) // 1 month before
}
// this line should never be reached
panic("no certificate available")
@@ -120,139 +139,161 @@ func (p *Provider) ScheduleRenewal(ctx context.Context) {
return
}
logger.Debug("starting renewal scheduler")
logger.Debug("started renewal scheduler")
defer logger.Debug("renewal scheduler stopped")
stop := make(chan struct{})
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
default:
t := time.Until(p.ShouldRenewOn())
Logger.Infof("next renewal in %v", t.Round(time.Second))
go func() {
<-time.After(t)
close(stop)
}()
select {
case <-ctx.Done():
return
case <-stop:
if err := p.renewIfNeeded(); err.IsNotNil() {
Logger.Fatal(err)
}
case <-ticker.C: // check every 5 seconds
if err := p.renewIfNeeded(); err.HasError() {
logger.Warn(err)
}
}
}
}
func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError {
err := os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0600) // -rw-------
if err != nil {
return E.Failure("write key file").With(err)
func (p *Provider) initClient() E.NestedError {
legoClient, err := E.Check(lego.NewClient(p.legoCfg))
if err.HasError() {
return E.FailWith("create lego client", err)
}
err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0644) // -rw-r--r--
if err != nil {
return E.Failure("write cert file").With(err)
legoProvider, err := providersGenMap[p.cfg.Provider](p.cfg.Options)
if err.HasError() {
return E.FailWith("create lego provider", err)
}
return E.Nil()
err = E.From(legoClient.Challenge.SetDNS01Provider(legoProvider))
if err.HasError() {
return E.FailWith("set challenge provider", err)
}
p.client = legoClient
return nil
}
func (p *Provider) needRenewal() bool {
expired := time.Now().After(p.ShouldRenewOn())
if expired {
return true
func (p *Provider) registerACME() E.NestedError {
if p.user.Registration != nil {
return nil
}
if len(p.cfg.Domains) != len(p.certExpiries) {
return true
reg, err := E.Check(p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}))
if err.HasError() {
return err
}
wantedDomains := make([]string, len(p.cfg.Domains))
p.user.Registration = reg
if err := p.saveRegistration(); err.HasError() {
logger.Warn(err)
}
return nil
}
func (p *Provider) loadRegistration() E.NestedError {
if p.user.Registration != nil {
return nil
}
reg := &registration.Resource{}
err := U.LoadJson(RegistrationFile, reg)
if err.HasError() {
return E.FailWith("parse registration file", err)
}
p.user.Registration = reg
return nil
}
func (p *Provider) saveRegistration() E.NestedError {
return U.SaveJson(RegistrationFile, p.user.Registration, 0o600)
}
func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError {
err := os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw-------
if err != nil {
return E.FailWith("write key file", err)
}
err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0o644) // -rw-r--r--
if err != nil {
return E.FailWith("write cert file", err)
}
return nil
}
func (p *Provider) certState() CertState {
if time.Now().After(p.ShouldRenewOn()) {
return CertStateExpired
}
certDomains := make([]string, len(p.certExpiries))
copy(wantedDomains, p.cfg.Domains)
wantedDomains := make([]string, len(p.cfg.Domains))
i := 0
for domain := range p.certExpiries {
certDomains[i] = domain
i++
}
slices.Sort(wantedDomains)
slices.Sort(certDomains)
for i, domain := range certDomains {
if domain != wantedDomains[i] {
return true
}
copy(wantedDomains, p.cfg.Domains)
sort.Strings(wantedDomains)
sort.Strings(certDomains)
if !reflect.DeepEqual(certDomains, wantedDomains) {
logger.Debugf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains)
return CertStateMismatch
}
return false
return CertStateValid
}
func (p *Provider) renewIfNeeded() E.NestedError {
if !p.needRenewal() {
return E.Nil()
switch p.certState() {
case CertStateExpired:
logger.Info("certs expired, renewing")
case CertStateMismatch:
logger.Info("cert domains mismatch with config, renewing")
default:
return nil
}
p.mutex.Lock()
defer p.mutex.Unlock()
if !p.needRenewal() {
return E.Nil()
}
trials := 0
for {
err := p.ObtainCert()
if err.IsNotNil() {
return E.Nil()
}
trials++
if trials > 3 {
return E.Failure("renew certificate").With(err)
}
time.Sleep(5 * time.Second)
if err := p.ObtainCert(); err.HasError() {
return E.FailWith("renew certificate", err)
}
return nil
}
func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) {
r := make(CertExpiries, len(cert.Certificate))
for _, cert := range cert.Certificate {
x509Cert, err := E.Check(x509.ParseCertificate(cert))
if err.IsNotNil() {
return nil, E.Failure("parse certificate").With(err)
if err.HasError() {
return nil, E.FailWith("parse certificate", err)
}
if x509Cert.IsCA {
continue
}
r[x509Cert.Subject.CommonName] = x509Cert.NotAfter
}
return r, E.Nil()
}
func setOptions[T interface{}](cfg *T, opt M.AutocertProviderOpt) E.NestedError {
for k, v := range opt {
err := utils.SetFieldFromSnake(cfg, k, v)
if err.IsNotNil() {
return E.Failure("set autocert option").Subject(k).With(err)
for i := range x509Cert.DNSNames {
r[x509Cert.DNSNames[i]] = x509Cert.NotAfter
}
}
return E.Nil()
return r, nil
}
func providerGenerator[CT any, PT challenge.Provider](
defaultCfg func() *CT,
newProvider func(*CT) (PT, error),
) ProviderGenerator {
return func(opt M.AutocertProviderOpt) (challenge.Provider, error) {
return func(opt M.AutocertProviderOpt) (challenge.Provider, E.NestedError) {
cfg := defaultCfg()
err := setOptions(cfg, opt)
if err.IsNotNil() {
err := U.Deserialize(opt, cfg)
if err.HasError() {
return nil, err
}
p, err := E.Check(newProvider(cfg))
if err.IsNotNil() {
if err.HasError() {
return nil, err
}
return p, nil
}
}
var logger = logrus.WithField("module", "autocert")

View File

@@ -0,0 +1,50 @@
package provider_test
import (
"testing"
"github.com/go-acme/lego/v4/providers/dns/ovh"
U "github.com/yusing/go-proxy/utils"
. "github.com/yusing/go-proxy/utils/testing"
"gopkg.in/yaml.v3"
)
// type Config struct {
// APIEndpoint string
// ApplicationKey string
// ApplicationSecret string
// ConsumerKey string
// OAuth2Config *OAuth2Config
// PropagationTimeout time.Duration
// PollingInterval time.Duration
// TTL int
// HTTPClient *http.Client
// }
func TestOVH(t *testing.T) {
cfg := &ovh.Config{}
testYaml := `
api_endpoint: https://eu.api.ovh.com
application_key: <application_key>
application_secret: <application_secret>
consumer_key: <consumer_key>
oauth2_config:
client_id: <client_id>
client_secret: <client_secret>
`
cfgExpected := &ovh.Config{
APIEndpoint: "https://eu.api.ovh.com",
ApplicationKey: "<application_key>",
ApplicationSecret: "<application_secret>",
ConsumerKey: "<consumer_key>",
OAuth2Config: &ovh.OAuth2Config{ClientID: "<client_id>", ClientSecret: "<client_secret>"},
}
testYaml = testYaml[1:] // remove first \n
opt := make(map[string]any)
ExpectNoError(t, yaml.Unmarshal([]byte(testYaml), opt))
ExpectTrue(t, U.Deserialize(opt, cfg).NoError())
ExpectDeepEqual(t, cfg, cfgExpected)
}

9
src/autocert/state.go Normal file
View File

@@ -0,0 +1,9 @@
package autocert
type CertState int
const (
CertStateValid CertState = 0
CertStateExpired CertState = iota
CertStateMismatch CertState = iota
)

View File

@@ -12,27 +12,37 @@ type Args struct {
}
const (
CommandStart = ""
CommandValidate = "validate"
CommandReload = "reload"
CommandStart = ""
CommandValidate = "validate"
CommandListConfigs = "ls-config"
CommandListRoutes = "ls-routes"
CommandReload = "reload"
CommandDebugListEntries = "debug-ls-entries"
)
var ValidCommands = []string{CommandStart, CommandValidate, CommandReload}
var ValidCommands = []string{
CommandStart,
CommandValidate,
CommandListConfigs,
CommandListRoutes,
CommandReload,
CommandDebugListEntries,
}
func GetArgs() Args {
var args Args
flag.Parse()
args.Command = flag.Arg(0)
if err := validateArgs(args.Command, ValidCommands); err.IsNotNil() {
if err := validateArg(args.Command); err.HasError() {
logrus.Fatal(err)
}
return args
}
func validateArgs[T comparable](arg T, validArgs []T) E.NestedError {
for _, v := range validArgs {
func validateArg(arg string) E.NestedError {
for _, v := range ValidCommands {
if arg == v {
return E.Nil()
return nil
}
}
return E.Invalid("argument", arg)

View File

@@ -37,13 +37,6 @@ const (
const DockerHostFromEnv = "$DOCKER_HOST"
const (
ProxyHTTPPort = ":80"
ProxyHTTPSPort = ":443"
APIHTTPPort = ":8888"
PanelHTTPPort = ":8080"
)
var WellKnownHTTPPorts = map[uint16]bool{
80: true,
8000: true,
@@ -53,17 +46,17 @@ var WellKnownHTTPPorts = map[uint16]bool{
}
var (
ImageNamePortMapTCP = map[string]int{
"postgres": 5432,
"mysql": 3306,
"mariadb": 3306,
"redis": 6379,
"mssql": 1433,
"memcached": 11211,
"rabbitmq": 5672,
"mongo": 27017,
}
ExtraNamePortMapTCP = map[string]int{
ServiceNamePortMapTCP = map[string]int{
"postgres": 5432,
"mysql": 3306,
"mariadb": 3306,
"redis": 6379,
"mssql": 1433,
"memcached": 11211,
"rabbitmq": 5672,
"mongo": 27017,
"minecraft-server": 25565,
"dns": 53,
"ssh": 22,
"ftp": 21,
@@ -71,20 +64,9 @@ var (
"pop3": 110,
"imap": 143,
}
NamePortMapTCP = func() map[string]int {
m := make(map[string]int)
for k, v := range ImageNamePortMapTCP {
m[k] = v
}
for k, v := range ExtraNamePortMapTCP {
m[k] = v
}
return m
}()
)
// docker library uses uint16, so followed here
var ImageNamePortMapHTTP = map[string]uint16{
var ImageNamePortMapHTTP = map[string]int{
"nginx": 80,
"httpd": 80,
"adguardhome": 3000,
@@ -101,3 +83,10 @@ var ImageNamePortMapHTTP = map[string]uint16{
"dockge": 5001,
"nginx-proxy-manager": 81,
}
const (
IdleTimeoutDefault = "0"
WakeTimeoutDefault = "10s"
StopTimeoutDefault = "10s"
StopMethodDefault = "stop"
)

View File

@@ -2,22 +2,26 @@ package common
import (
"os"
"strings"
"github.com/sirupsen/logrus"
U "github.com/yusing/go-proxy/utils"
)
var NoSchemaValidation = getEnvBool("GOPROXY_NO_SCHEMA_VALIDATION")
var IsDebug = getEnvBool("GOPROXY_DEBUG")
var LogLevel = func() logrus.Level {
if IsDebug {
logrus.SetLevel(logrus.DebugLevel)
}
return logrus.GetLevel()
}()
var (
NoSchemaValidation = getEnvBool("GOPROXY_NO_SCHEMA_VALIDATION")
IsDebug = getEnvBool("GOPROXY_DEBUG")
ProxyHTTPPort = ":" + getEnv("GOPROXY_HTTP_PORT", "80")
ProxyHTTPSPort = ":" + getEnv("GOPROXY_HTTPS_PORT", "443")
APIHTTPPort = ":" + getEnv("GOPROXY_API_PORT", "8888")
)
func getEnvBool(key string) bool {
v := os.Getenv(key)
return v == "1" || strings.ToLower(v) == "true" || strings.ToLower(v) == "yes" || strings.ToLower(v) == "on"
return U.ParseBool(os.Getenv(key))
}
func getEnv(key string, defaultValue string) string {
value, ok := os.LookupEnv(key)
if !ok {
value = defaultValue
}
return value
}

View File

@@ -2,6 +2,7 @@ package config
import (
"context"
"os"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/autocert"
@@ -17,32 +18,26 @@ import (
)
type Config struct {
value *M.Config
l logrus.FieldLogger
reader U.Reader
proxyProviders *F.Map[string, *PR.Provider]
value *M.Config
proxyProviders F.Map[string, *PR.Provider]
autocertProvider *autocert.Provider
l logrus.FieldLogger
watcher W.Watcher
watcherCtx context.Context
watcherCancel context.CancelFunc
reloadReq chan struct{}
}
func New() (*Config, E.NestedError) {
func Load() (*Config, E.NestedError) {
cfg := &Config{
l: logrus.WithField("module", "config"),
reader: U.NewFileReader(common.ConfigPath),
watcher: W.NewFileWatcher(common.ConfigFileName),
reloadReq: make(chan struct{}, 1),
proxyProviders: F.NewMapOf[string, *PR.Provider](),
l: logrus.WithField("module", "config"),
watcher: W.NewFileWatcher(common.ConfigFileName),
reloadReq: make(chan struct{}, 1),
}
if err := cfg.load(); err.IsNotNil() {
return nil, err
}
cfg.startProviders()
cfg.watchChanges()
return cfg, E.Nil()
return cfg, cfg.load()
}
func Validate(data []byte) E.NestedError {
@@ -58,97 +53,27 @@ func (cfg *Config) GetAutoCertProvider() *autocert.Provider {
}
func (cfg *Config) Dispose() {
cfg.watcherCancel()
cfg.l.Debug("stopped watcher")
if cfg.watcherCancel != nil {
cfg.watcherCancel()
cfg.l.Debug("stopped watcher")
}
cfg.stopProviders()
cfg.l.Debug("stopped providers")
}
func (cfg *Config) Reload() E.NestedError {
cfg.stopProviders()
if err := cfg.load(); err.IsNotNil() {
if err := cfg.load(); err.HasError() {
return err
}
cfg.startProviders()
return E.Nil()
cfg.StartProxyProviders()
return nil
}
func (cfg *Config) FindRoute(alias string) R.Route {
r := cfg.proxyProviders.Find(
func(p *PR.Provider) (any, bool) {
rs := p.GetCurrentRoutes()
if rs.Contains(alias) {
return rs.Get(alias), true
}
return nil, false
},
)
if r == nil {
return nil
}
return r.(R.Route)
func (cfg *Config) StartProxyProviders() {
cfg.controlProviders("start", (*PR.Provider).StartAllRoutes)
}
func (cfg *Config) RoutesByAlias() map[string]U.SerializedObject {
routes := make(map[string]U.SerializedObject)
cfg.proxyProviders.Each(func(p *PR.Provider) {
prName := p.GetName()
p.GetCurrentRoutes().EachKV(func(a string, r R.Route) {
obj, err := U.Serialize(r)
if err != nil {
cfg.l.Error(err)
return
}
obj["provider"] = prName
switch r.(type) {
case *R.StreamRoute:
obj["type"] = "stream"
case *R.HTTPRoute:
obj["type"] = "reverse_proxy"
default:
panic("bug: should not reach here")
}
routes[a] = obj
})
})
return routes
}
func (cfg *Config) Statistics() map[string]interface{} {
nTotalStreams := 0
nTotalRPs := 0
providerStats := make(map[string]interface{})
cfg.proxyProviders.Each(func(p *PR.Provider) {
stats := make(map[string]interface{})
nStreams := 0
nRPs := 0
p.GetCurrentRoutes().EachKV(func(a string, r R.Route) {
switch r.(type) {
case *R.StreamRoute:
nStreams++
nTotalStreams++
case *R.HTTPRoute:
nRPs++
nTotalRPs++
default:
panic("bug: should not reach here")
}
})
stats["type"] = p.GetType()
stats["num_streams"] = nStreams
stats["num_reverse_proxies"] = nRPs
providerStats[p.GetName()] = stats
})
return map[string]interface{}{
"num_total_streams": nTotalStreams,
"num_total_reverse_proxies": nTotalRPs,
"providers": providerStats,
}
}
func (cfg *Config) watchChanges() {
func (cfg *Config) WatchChanges() {
cfg.watcherCtx, cfg.watcherCancel = context.WithCancel(context.Background())
go func() {
for {
@@ -156,7 +81,7 @@ func (cfg *Config) watchChanges() {
case <-cfg.watcherCtx.Done():
return
case <-cfg.reloadReq:
if err := cfg.Reload(); err.IsNotNil() {
if err := cfg.Reload(); err.HasError() {
cfg.l.Error(err)
}
}
@@ -182,76 +107,171 @@ func (cfg *Config) watchChanges() {
}()
}
func (cfg *Config) load() E.NestedError {
cfg.l.Debug("loading config")
func (cfg *Config) FindRoute(alias string) R.Route {
return F.MapFind(cfg.proxyProviders,
func(p *PR.Provider) (R.Route, bool) {
if route, ok := p.GetRoute(alias); ok {
return route, true
}
return nil, false
},
)
}
data, err := cfg.reader.Read()
if err.IsNotNil() {
return E.Failure("read config").With(err)
func (cfg *Config) RoutesByAlias() map[string]U.SerializedObject {
routes := make(map[string]U.SerializedObject)
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
obj, err := U.Serialize(r)
if err.HasError() {
cfg.l.Error(err)
return
}
obj["provider"] = p.GetName()
obj["type"] = string(r.Type())
routes[alias] = obj
})
return routes
}
func (cfg *Config) Statistics() map[string]any {
nTotalStreams := 0
nTotalRPs := 0
providerStats := make(map[string]any)
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
s, ok := providerStats[p.GetName()]
if !ok {
s = make(map[string]int)
}
stats := s.(map[string]int)
switch r.Type() {
case R.RouteTypeStream:
stats["num_streams"]++
nTotalStreams++
case R.RouteTypeReverseProxy:
stats["num_reverse_proxies"]++
nTotalRPs++
default:
panic("bug: should not reach here")
}
})
return map[string]any{
"num_total_streams": nTotalStreams,
"num_total_reverse_proxies": nTotalRPs,
"providers": providerStats,
}
}
model := M.DefaultConfig()
if err := E.From(yaml.Unmarshal(data, model)); err.IsNotNil() {
return E.Failure("parse config").With(err)
func (cfg *Config) DumpEntries() map[string]*M.ProxyEntry {
entries := make(map[string]*M.ProxyEntry)
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
entries[alias] = r.Entry()
})
return entries
}
func (cfg *Config) forEachRoute(do func(alias string, r R.Route, p *PR.Provider)) {
cfg.proxyProviders.RangeAll(func(_ string, p *PR.Provider) {
p.RangeRoutes(func(a string, r R.Route) {
do(a, r, p)
})
})
}
func (cfg *Config) load() (res E.NestedError) {
b := E.NewBuilder("errors loading config")
defer b.To(&res)
cfg.l.Debug("loading config")
defer cfg.l.Debug("loaded config")
data, err := E.Check(os.ReadFile(common.ConfigPath))
if err.HasError() {
b.Add(E.FailWith("read config", err))
return
}
if !common.NoSchemaValidation {
if err = Validate(data); err.IsNotNil() {
return err
if err = Validate(data); err.HasError() {
b.Add(E.FailWith("schema validation", err))
return
}
}
warnings := E.NewBuilder("errors loading config")
model := M.DefaultConfig()
if err := E.From(yaml.Unmarshal(data, model)); err.HasError() {
b.Add(E.FailWith("parse config", err))
return
}
cfg.l.Debug("starting autocert")
ap, err := autocert.NewConfig(&model.AutoCert).GetProvider()
if err.IsNotNil() {
warnings.Add(E.Failure("autocert provider").With(err))
} else {
cfg.l.Debug("started autocert")
}
cfg.autocertProvider = ap
cfg.l.Debug("loading providers")
cfg.proxyProviders = F.NewMap[string, *PR.Provider]()
for _, filename := range model.Providers.Files {
p := PR.NewFileProvider(filename)
cfg.proxyProviders.Set(p.GetName(), p)
}
for name, dockerHost := range model.Providers.Docker {
p := PR.NewDockerProvider(name, dockerHost)
cfg.proxyProviders.Set(p.GetName(), p)
}
cfg.l.Debug("loaded providers")
// errors are non fatal below
b.WithSeverity(E.SeverityWarning)
b.Add(cfg.initAutoCert(&model.AutoCert))
b.Add(cfg.loadProviders(&model.Providers))
cfg.value = model
return
}
if err := warnings.Build(); err.IsNotNil() {
cfg.l.Warn(err)
func (cfg *Config) initAutoCert(autocertCfg *M.AutoCertConfig) (err E.NestedError) {
if cfg.autocertProvider != nil {
return
}
cfg.l.Debug("loaded config")
return E.Nil()
cfg.l.Debug("initializing autocert")
defer cfg.l.Debug("initialized autocert")
cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider()
if err.HasError() {
err = E.FailWith("autocert provider", err)
}
return
}
func (cfg *Config) loadProviders(providers *M.ProxyProviders) (res E.NestedError) {
cfg.l.Debug("loading providers")
defer cfg.l.Debug("loaded providers")
b := E.NewBuilder("errors loading providers")
defer b.To(&res)
for _, filename := range providers.Files {
p, err := PR.NewFileProvider(filename)
if err != nil {
b.Add(err.Subject(filename))
continue
}
cfg.proxyProviders.Store(p.GetName(), p)
b.Add(p.LoadRoutes().Subject(filename))
}
for name, dockerHost := range providers.Docker {
p, err := PR.NewDockerProvider(name, dockerHost)
if err != nil {
b.Add(err.Subject(dockerHost))
continue
}
cfg.proxyProviders.Store(p.GetName(), p)
b.Add(p.LoadRoutes().Subject(dockerHost))
}
return
}
func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) {
errors := E.NewBuilder("cannot %s these providers", action)
cfg.proxyProviders.EachKVParallel(func(name string, p *PR.Provider) {
if err := do(p); err.IsNotNil() {
errors.Add(E.From(err).Subject(p))
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) {
if err := do(p); err.HasError() {
errors.Add(err.Subject(p))
}
})
if err := errors.Build(); err.IsNotNil() {
if err := errors.Build(); err.HasError() {
cfg.l.Error(err)
}
}
func (cfg *Config) startProviders() {
cfg.controlProviders("start", (*PR.Provider).StartAllRoutes)
}
func (cfg *Config) stopProviders() {
cfg.controlProviders("stop routes", (*PR.Provider).StopAllRoutes)
}

View File

@@ -3,6 +3,7 @@ package docker
import (
"net/http"
"sync"
"sync/atomic"
"github.com/docker/cli/cli/connhelper"
"github.com/docker/docker/client"
@@ -11,14 +12,64 @@ import (
E "github.com/yusing/go-proxy/error"
)
type Client = *client.Client
type Client struct {
key string
refCount *atomic.Int32
*client.Client
l logrus.FieldLogger
}
func ParseDockerHostname(host string) (string, E.NestedError) {
if host == common.DockerHostFromEnv {
return host, nil
} else if host == "" {
return "localhost", nil
}
url, err := E.Check(client.ParseHostURL(host))
if err != nil {
return "", E.Invalid("host", host).With(err)
}
return url.Hostname(), nil
}
func (c Client) DaemonHostname() string {
// DaemonHost should always return a valid host
hostname, _ := ParseDockerHostname(c.DaemonHost())
return hostname
}
func (c Client) Connected() bool {
return c.Client != nil
}
// if the client is still referenced, this is no-op
func (c *Client) Close() error {
if c.refCount.Add(-1) > 0 {
return nil
}
clientMapMu.Lock()
defer clientMapMu.Unlock()
delete(clientMap, c.key)
client := c.Client
c.Client = nil
c.l.Debugf("client closed")
if client != nil {
return client.Close()
}
return nil
}
// ConnectClient creates a new Docker client connection to the specified host.
//
// Returns existing client if available.
//
// Parameters:
// - host: the host to connect to (either a URL or "FROM_ENV").
// - host: the host to connect to (either a URL or common.DockerHostFromEnv).
//
// Returns:
// - Client: the Docker client connection.
@@ -29,7 +80,8 @@ func ConnectClient(host string) (Client, E.NestedError) {
// check if client exists
if client, ok := clientMap[host]; ok {
return client, E.Nil()
client.refCount.Add(1)
return client, nil
}
// create client
@@ -40,8 +92,8 @@ func ConnectClient(host string) (Client, E.NestedError) {
opt = clientOptEnvHost
default:
helper, err := E.Check(connhelper.GetConnectionHelper(host))
if err.IsNotNil() {
logger.Fatalf("unexpected error: %s", err)
if err.HasError() {
return Client{}, E.UnexpectedError(err.Error())
}
if helper != nil {
httpClient := &http.Client{
@@ -65,12 +117,21 @@ func ConnectClient(host string) (Client, E.NestedError) {
client, err := E.Check(client.NewClientWithOpts(opt...))
if err.IsNotNil() {
return nil, err
if err.HasError() {
return Client{}, err
}
clientMap[host] = client
return client, E.Nil()
c := Client{
Client: client,
key: host,
refCount: &atomic.Int32{},
l: logger.WithField("docker_client", client.DaemonHost()),
}
c.refCount.Add(1)
c.l.Debugf("client connected")
clientMap[host] = c
return clientMap[host], nil
}
func CloseAllClients() {
@@ -83,12 +144,13 @@ func CloseAllClients() {
logger.Debug("closed all clients")
}
var clientMap map[string]Client = make(map[string]Client)
var clientMapMu sync.Mutex
var (
clientMap map[string]Client = make(map[string]Client)
clientMapMu sync.Mutex
clientOptEnvHost = []client.Opt{
client.WithHostFromEnv(),
client.WithAPIVersionNegotiation(),
}
var clientOptEnvHost = []client.Opt{
client.WithHostFromEnv(),
client.WithAPIVersionNegotiation(),
}
var logger = logrus.WithField("module", "docker")
logger = logrus.WithField("module", "docker")
)

View File

@@ -12,35 +12,41 @@ import (
)
type ClientInfo struct {
Host string
Client Client
Containers []types.Container
}
func GetClientInfo(clientHost string) (*ClientInfo, E.NestedError) {
var listOptions = container.ListOptions{
// Filters: filters.NewArgs(
// filters.Arg("health", "healthy"),
// filters.Arg("health", "none"),
// filters.Arg("health", "starting"),
// ),
All: true,
}
func GetClientInfo(clientHost string, getContainer bool) (*ClientInfo, E.NestedError) {
dockerClient, err := ConnectClient(clientHost)
if err.IsNotNil() {
return nil, E.Failure("create docker client").With(err)
if err.HasError() {
return nil, E.FailWith("connect to docker", err)
}
defer dockerClient.Close()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
containers, err := E.Check(dockerClient.ContainerList(ctx, container.ListOptions{}))
if err.IsNotNil() {
return nil, E.Failure("list containers").With(err)
var containers []types.Container
if getContainer {
containers, err = E.Check(dockerClient.ContainerList(ctx, listOptions))
if err.HasError() {
return nil, E.FailWith("list containers", err)
}
}
// extract host from docker client url
// since the services being proxied to
// should have the same IP as the docker client
url, err := E.Check(client.ParseHostURL(dockerClient.DaemonHost()))
if err.IsNotNil() {
return nil, E.Invalid("host url", dockerClient.DaemonHost()).With(err)
}
if url.Scheme == "unix" {
return &ClientInfo{Host: "localhost", Containers: containers}, E.Nil()
}
return &ClientInfo{Host: url.Hostname(), Containers: containers}, E.Nil()
return &ClientInfo{
Client: dockerClient,
Containers: containers,
}, nil
}
func IsErrConnectionFailed(err error) bool {

111
src/docker/container.go Normal file
View File

@@ -0,0 +1,111 @@
package docker
import (
"fmt"
"strconv"
"strings"
"github.com/docker/docker/api/types"
U "github.com/yusing/go-proxy/utils"
)
type ProxyProperties struct {
DockerHost string `yaml:"-" json:"docker_host"`
ContainerName string `yaml:"-" json:"container_name"`
ImageName string `yaml:"-" json:"image_name"`
Aliases []string `yaml:"-" json:"aliases"`
IsExcluded bool `yaml:"-" json:"is_excluded"`
FirstPort string `yaml:"-" json:"first_port"`
IdleTimeout string `yaml:"-" json:"idle_timeout"`
WakeTimeout string `yaml:"-" json:"wake_timeout"`
StopMethod string `yaml:"-" json:"stop_method"`
StopTimeout string `yaml:"-" json:"stop_timeout"` // stop_method = "stop" only
StopSignal string `yaml:"-" json:"stop_signal"` // stop_method = "stop" | "kill" only
Running bool `yaml:"-" json:"running"`
}
type Container struct {
*types.Container
*ProxyProperties
}
func FromDocker(c *types.Container, dockerHost string) (res Container) {
res.Container = c
res.ProxyProperties = &ProxyProperties{
DockerHost: dockerHost,
ContainerName: res.getName(),
ImageName: res.getImageName(),
Aliases: res.getAliases(),
IsExcluded: U.ParseBool(res.getDeleteLabel(LableExclude)),
FirstPort: res.firstPortOrEmpty(),
IdleTimeout: res.getDeleteLabel(LabelIdleTimeout),
WakeTimeout: res.getDeleteLabel(LabelWakeTimeout),
StopMethod: res.getDeleteLabel(LabelStopMethod),
StopTimeout: res.getDeleteLabel(LabelStopTimeout),
StopSignal: res.getDeleteLabel(LabelStopSignal),
Running: c.Status == "running" || c.State == "running",
}
return
}
func FromJson(json types.ContainerJSON, dockerHost string) Container {
ports := make([]types.Port, 0)
for k, bindings := range json.NetworkSettings.Ports {
for _, v := range bindings {
pubPort, _ := strconv.Atoi(v.HostPort)
privPort, _ := strconv.Atoi(k.Port())
ports = append(ports, types.Port{
IP: v.HostIP,
PublicPort: uint16(pubPort),
PrivatePort: uint16(privPort),
})
}
}
return FromDocker(&types.Container{
ID: json.ID,
Names: []string{json.Name},
Image: json.Image,
Ports: ports,
Labels: json.Config.Labels,
State: json.State.Status,
Status: json.State.Status,
}, dockerHost)
}
func (c Container) getDeleteLabel(label string) string {
if l, ok := c.Labels[label]; ok {
delete(c.Labels, label)
return l
}
return ""
}
func (c Container) getAliases() []string {
if l := c.getDeleteLabel(LableAliases); l != "" {
return U.CommaSeperatedList(l)
} else {
return []string{c.getName()}
}
}
func (c Container) getName() string {
return strings.TrimPrefix(c.Names[0], "/")
}
func (c Container) getImageName() string {
colonSep := strings.Split(c.Image, ":")
slashSep := strings.Split(colonSep[0], "/")
return slashSep[len(slashSep)-1]
}
func (c Container) firstPortOrEmpty() string {
if len(c.Ports) == 0 {
return ""
}
for _, p := range c.Ports {
if p.PublicPort != 0 {
return fmt.Sprint(p.PublicPort)
}
}
return ""
}

View File

@@ -9,7 +9,7 @@ type (
Icon string
Category string
Description string
WidgetConfig map[string]interface{}
WidgetConfig map[string]any
}
)

View File

@@ -0,0 +1,14 @@
package idlewatcher
import "net/http"
type (
roundTripper struct {
patched roundTripFunc
}
roundTripFunc func(*http.Request) (*http.Response, error)
)
func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return rt.patched(req)
}

View File

@@ -0,0 +1,357 @@
package idlewatcher
import (
"bytes"
"context"
"io"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/docker/docker/api/types/container"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/docker"
E "github.com/yusing/go-proxy/error"
P "github.com/yusing/go-proxy/proxy"
PT "github.com/yusing/go-proxy/proxy/fields"
W "github.com/yusing/go-proxy/watcher"
event "github.com/yusing/go-proxy/watcher/events"
)
type watcher struct {
*P.ReverseProxyEntry
client D.Client
refCount atomic.Int32
stopByMethod StopCallback
wakeCh chan struct{}
wakeDone chan E.NestedError
running atomic.Bool
ctx context.Context
cancel context.CancelFunc
l logrus.FieldLogger
}
type (
WakeDone <-chan error
WakeFunc func() WakeDone
StopCallback func() E.NestedError
)
func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) {
failure := E.Failure("idle_watcher register")
if entry.IdleTimeout == 0 {
return nil, failure.With(E.Invalid("idle_timeout", 0))
}
watcherMapMu.Lock()
defer watcherMapMu.Unlock()
if w, ok := watcherMap[entry.ContainerName]; ok {
w.refCount.Add(1)
w.ReverseProxyEntry = entry
return w, nil
}
client, err := D.ConnectClient(entry.DockerHost)
if err.HasError() {
return nil, failure.With(err)
}
w := &watcher{
ReverseProxyEntry: entry,
client: client,
wakeCh: make(chan struct{}, 1),
wakeDone: make(chan E.NestedError, 1),
l: logger.WithField("container", entry.ContainerName),
}
w.refCount.Add(1)
w.running.Store(entry.ContainerRunning)
w.stopByMethod = w.getStopCallback()
watcherMap[w.ContainerName] = w
go func() {
newWatcherCh <- w
}()
return w, nil
}
// If the container is not registered, this is no-op
func Unregister(containerName string) {
watcherMapMu.Lock()
defer watcherMapMu.Unlock()
if w, ok := watcherMap[containerName]; ok {
if w.refCount.Add(-1) > 0 {
return
}
if w.cancel != nil {
w.cancel()
}
w.client.Close()
delete(watcherMap, containerName)
}
}
func Start() {
logger.Debug("started")
defer logger.Debug("stopped")
mainLoopCtx, mainLoopCancel = context.WithCancel(context.Background())
defer mainLoopWg.Wait()
for {
select {
case <-mainLoopCtx.Done():
return
case w := <-newWatcherCh:
w.l.Debug("registered")
mainLoopWg.Add(1)
go func() {
w.watch()
Unregister(w.ContainerName)
w.l.Debug("unregistered")
mainLoopWg.Done()
}()
}
}
}
func Stop() {
mainLoopCancel()
mainLoopWg.Wait()
}
func (w *watcher) PatchRoundTripper(rtp http.RoundTripper) roundTripper {
return roundTripper{patched: func(r *http.Request) (*http.Response, error) {
return w.roundTrip(rtp.RoundTrip, r)
}}
}
func (w *watcher) roundTrip(origRoundTrip roundTripFunc, req *http.Request) (*http.Response, error) {
w.wakeCh <- struct{}{}
if w.running.Load() {
return origRoundTrip(req)
}
timeout := time.After(w.WakeTimeout)
for {
if w.running.Load() {
return origRoundTrip(req)
}
select {
case <-req.Context().Done():
return nil, req.Context().Err()
case err := <-w.wakeDone:
if err != nil {
return nil, err.Error()
}
case <-timeout:
return getLoadingResponse(), nil
}
}
}
func (w *watcher) containerStop() error {
return w.client.ContainerStop(w.ctx, w.ContainerName, container.StopOptions{
Signal: string(w.StopSignal),
Timeout: &w.StopTimeout})
}
func (w *watcher) containerPause() error {
return w.client.ContainerPause(w.ctx, w.ContainerName)
}
func (w *watcher) containerKill() error {
return w.client.ContainerKill(w.ctx, w.ContainerName, string(w.StopSignal))
}
func (w *watcher) containerUnpause() error {
return w.client.ContainerUnpause(w.ctx, w.ContainerName)
}
func (w *watcher) containerStart() error {
return w.client.ContainerStart(w.ctx, w.ContainerName, container.StartOptions{})
}
func (w *watcher) containerStatus() (string, E.NestedError) {
json, err := w.client.ContainerInspect(w.ctx, w.ContainerName)
if err != nil {
return "", E.FailWith("inspect container", err)
}
return json.State.Status, nil
}
func (w *watcher) wakeIfStopped() E.NestedError {
status, err := w.containerStatus()
if err.HasError() {
return err
}
// "created", "running", "paused", "restarting", "removing", "exited", or "dead"
switch status {
case "exited", "dead":
return E.From(w.containerStart())
case "paused":
return E.From(w.containerUnpause())
case "running":
w.running.Store(true)
return nil
default:
return E.Unexpected("container state", status)
}
}
func (w *watcher) getStopCallback() StopCallback {
var cb func() error
switch w.StopMethod {
case PT.StopMethodPause:
cb = w.containerPause
case PT.StopMethodStop:
cb = w.containerStop
case PT.StopMethodKill:
cb = w.containerKill
default:
panic("should not reach here")
}
return func() E.NestedError {
status, err := w.containerStatus()
if err.HasError() {
return err
}
if status != "running" {
return nil
}
return E.From(cb())
}
}
func (w *watcher) watch() {
watcherCtx, watcherCancel := context.WithCancel(context.Background())
w.ctx = watcherCtx
w.cancel = watcherCancel
dockerWatcher := W.NewDockerWatcherWithClient(w.client)
defer close(w.wakeCh)
dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.ctx, W.DockerListOptions{
Filters: W.NewDockerFilter(
W.DockerFilterContainer,
W.DockerrFilterContainerName(w.ContainerName),
W.DockerFilterStart,
W.DockerFilterStop,
W.DockerFilterDie,
W.DockerFilterKill,
W.DockerFilterPause,
W.DockerFilterUnpause,
),
})
ticker := time.NewTicker(w.IdleTimeout)
defer ticker.Stop()
for {
select {
case <-mainLoopCtx.Done():
w.cancel()
case <-watcherCtx.Done():
w.l.Debug("stopped")
return
case err := <-dockerEventErrCh:
if err != nil && err.IsNot(context.Canceled) {
w.l.Error(E.FailWith("docker watcher", err))
}
case e := <-dockerEventCh:
switch e.Action {
case event.ActionDockerStartUnpause:
w.running.Store(true)
w.l.Infof("%s %s", e.ActorName, e.Action)
case event.ActionDockerStopPause:
w.running.Store(false)
w.l.Infof("%s %s", e.ActorName, e.Action)
}
case <-ticker.C:
w.l.Debug("timeout")
ticker.Stop()
if err := w.stopByMethod(); err != nil && err.IsNot(context.Canceled) {
w.l.Error(E.FailWith("stop", err).Extraf("stop method: %s", w.StopMethod))
}
case <-w.wakeCh:
w.l.Debug("wake signal received")
ticker.Reset(w.IdleTimeout)
err := w.wakeIfStopped()
if err != nil && err.IsNot(context.Canceled) {
w.l.Error(E.FailWith("wake", err))
}
select {
case w.wakeDone <- err: // this is passed to roundtrip
default:
}
}
}
}
func getLoadingResponse() *http.Response {
return &http.Response{
StatusCode: http.StatusAccepted,
Header: http.Header{
"Content-Type": {"text/html"},
"Cache-Control": {
"no-cache",
"no-store",
"must-revalidate",
},
},
Body: io.NopCloser(bytes.NewReader((loadingPage))),
ContentLength: int64(len(loadingPage)),
}
}
var (
mainLoopCtx context.Context
mainLoopCancel context.CancelFunc
mainLoopWg sync.WaitGroup
watcherMap = make(map[string]*watcher)
watcherMapMu sync.Mutex
newWatcherCh = make(chan *watcher)
logger = logrus.WithField("module", "idle_watcher")
loadingPage = []byte(`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Loading...</title>
</head>
<body>
<script>
window.onload = function() {
setTimeout(function() {
window.location.reload()
}, 1000)
// fetch(window.location.href)
// .then(resp => resp.text())
// .then(data => { document.body.innerHTML = data; })
// .catch(err => { document.body.innerHTML = 'Error: ' + err; });
};
</script>
<h1>Container is starting... Please wait</h1>
</body>
</html>
`[1:])
)

19
src/docker/inspect.go Normal file
View File

@@ -0,0 +1,19 @@
package docker
import (
"context"
"time"
E "github.com/yusing/go-proxy/error"
)
func (c Client) Inspect(containerID string) (Container, E.NestedError) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
json, err := c.ContainerInspect(ctx, containerID)
if err != nil {
return Container{}, E.From(err)
}
return FromJson(json, c.key), nil
}

View File

@@ -23,7 +23,7 @@ type Label struct {
// Returns:
// - error: an error if the field does not exist.
func ApplyLabel[T any](obj *T, l *Label) E.NestedError {
return U.SetFieldFromSnake(obj, l.Attribute, l.Value)
return U.Deserialize(map[string]any{l.Attribute: l.Value}, obj)
}
type ValueParser func(string) (any, E.NestedError)
@@ -36,7 +36,7 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) {
return &Label{
Namespace: label,
Value: value,
}, E.Nil()
}, nil
}
l := &Label{
@@ -54,20 +54,20 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) {
// find if namespace has value parser
pm, ok := labelValueParserMap[l.Namespace]
if !ok {
return l, E.Nil()
return l, nil
}
// find if attribute has value parser
p, ok := pm[l.Attribute]
if !ok {
return l, E.Nil()
return l, nil
}
// try to parse value
v, err := p(value)
if err.IsNotNil() {
if err.HasError() {
return nil, err
}
l.Value = v
return l, E.Nil()
return l, nil
}
func RegisterNamespace(namespace string, pm ValueParserMap) {

View File

@@ -10,7 +10,7 @@ import (
func yamlListParser(value string) (any, E.NestedError) {
value = strings.TrimSpace(value)
if value == "" {
return []string{}, E.Nil()
return []string{}, nil
}
var data []string
err := E.From(yaml.Unmarshal([]byte(value), &data))
@@ -34,23 +34,15 @@ func yamlStringMappingParser(value string) (any, E.NestedError) {
h[key] = val
}
}
return h, E.Nil()
}
func commaSepParser(value string) (any, E.NestedError) {
v := strings.Split(value, ",")
for i := range v {
v[i] = strings.TrimSpace(v[i])
}
return v, E.Nil()
return h, nil
}
func boolParser(value string) (any, E.NestedError) {
switch strings.ToLower(value) {
case "true", "yes", "1":
return true, E.Nil()
return true, nil
case "false", "no", "0":
return false, E.Nil()
return false, nil
default:
return nil, E.Invalid("boolean value", value)
}
@@ -60,7 +52,6 @@ const NSProxy = "proxy"
var _ = func() int {
RegisterNamespace(NSProxy, ValueParserMap{
"aliases": commaSepParser,
"path_patterns": yamlListParser,
"set_headers": yamlStringMappingParser,
"hide_headers": yamlListParser,

View File

@@ -7,6 +7,7 @@ import (
"testing"
E "github.com/yusing/go-proxy/error"
. "github.com/yusing/go-proxy/utils/testing"
)
func makeLabel(namespace string, alias string, field string) string {
@@ -18,29 +19,23 @@ func TestHomePageLabel(t *testing.T) {
field := "ip"
v := "bar"
pl, err := ParseLabel(makeLabel(NSHomePage, alias, field), v)
if err.IsNotNil() {
t.Errorf("expected err=nil, got %s", err.Error())
}
ExpectNoError(t, err.Error())
if pl.Target != alias {
t.Errorf("expected alias=%s, got %s", alias, pl.Target)
t.Errorf("Expected alias=%s, got %s", alias, pl.Target)
}
if pl.Attribute != field {
t.Errorf("expected field=%s, got %s", field, pl.Target)
t.Errorf("Expected field=%s, got %s", field, pl.Target)
}
if pl.Value != v {
t.Errorf("expected value=%q, got %s", v, pl.Value)
t.Errorf("Expected value=%q, got %s", v, pl.Value)
}
}
func TestStringProxyLabel(t *testing.T) {
v := "bar"
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "ip"), v)
if err.IsNotNil() {
t.Errorf("expected err=nil, got %s", err.Error())
}
if pl.Value != v {
t.Errorf("expected value=%q, got %s", v, pl.Value)
}
ExpectNoError(t, err.Error())
ExpectEqual(t, pl.Value.(string), v)
}
func TestBoolProxyLabelValid(t *testing.T) {
@@ -57,12 +52,8 @@ func TestBoolProxyLabelValid(t *testing.T) {
for k, v := range tests {
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "no_tls_verify"), k)
if err.IsNotNil() {
t.Errorf("expected err=nil, got %s", err.Error())
}
if pl.Value != v {
t.Errorf("expected value=%v, got %v", v, pl.Value)
}
ExpectNoError(t, err.Error())
ExpectEqual(t, pl.Value.(bool), v)
}
}
@@ -71,7 +62,7 @@ func TestBoolProxyLabelInvalid(t *testing.T) {
field := "no_tls_verify"
_, err := ParseLabel(makeLabel(NSProxy, alias, field), "invalid")
if !err.Is(E.ErrInvalid) {
t.Errorf("expected err InvalidProxyLabel, got %s", err.Error())
t.Errorf("Expected err InvalidProxyLabel, got %s", err.Error())
}
}
@@ -87,17 +78,12 @@ X-Custom-Header2: boo`
}
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "set_headers"), v)
if err.IsNotNil() {
t.Errorf("expected err=nil, got %s", err.Error())
}
hGot, ok := pl.Value.(map[string]string)
if !ok {
t.Errorf("value is not a map[string]string, but %T", pl.Value)
return
}
if !reflect.DeepEqual(h, hGot) {
t.Errorf("expected %v, got %v", h, hGot)
ExpectNoError(t, err.Error())
hGot := ExpectType[map[string]string](t, pl.Value)
if hGot != nil && !reflect.DeepEqual(h, hGot) {
t.Errorf("Expected %v, got %v", h, hGot)
}
}
func TestSetHeaderProxyLabelInvalid(t *testing.T) {
@@ -110,7 +96,7 @@ func TestSetHeaderProxyLabelInvalid(t *testing.T) {
for _, v := range tests {
_, err := ParseLabel(makeLabel(NSProxy, "foo", "set_headers"), v)
if !err.Is(E.ErrInvalid) {
t.Errorf("expected invalid err for %q, got %s", v, err.Error())
t.Errorf("Expected invalid err for %q, got %s", v, err.Error())
}
}
}
@@ -123,47 +109,32 @@ func TestHideHeadersProxyLabel(t *testing.T) {
`
v = strings.TrimPrefix(v, "\n")
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "hide_headers"), v)
if err.IsNotNil() {
t.Errorf("expected err=nil, got %s", err.Error())
}
sGot, ok := pl.Value.([]string)
ExpectNoError(t, err.Error())
sGot := ExpectType[[]string](t, pl.Value)
sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"}
if !ok {
t.Errorf("value is not []string, but %T", pl.Value)
}
if !reflect.DeepEqual(sGot, sWant) {
t.Errorf("expected %q, got %q", sWant, sGot)
if sGot != nil {
ExpectDeepEqual(t, sGot, sWant)
}
}
func TestCommaSepProxyLabelSingle(t *testing.T) {
v := "a"
pl, err := ParseLabel("proxy.aliases", v)
if err.IsNotNil() {
t.Errorf("expected err=nil, got %s", err.Error())
}
sGot, ok := pl.Value.([]string)
sWant := []string{"a"}
if !ok {
t.Errorf("value is not []string, but %T", pl.Value)
}
if !reflect.DeepEqual(sGot, sWant) {
t.Errorf("expected %q, got %q", sWant, sGot)
}
}
// func TestCommaSepProxyLabelSingle(t *testing.T) {
// v := "a"
// pl, err := ParseLabel("proxy.aliases", v)
// ExpectNoError(t, err)
// sGot := ExpectType[[]string](t, pl.Value)
// sWant := []string{"a"}
// if sGot != nil {
// ExpectEqual(t, sGot, sWant)
// }
// }
func TestCommaSepProxyLabelMulti(t *testing.T) {
v := "X-Custom-Header1, X-Custom-Header2,X-Custom-Header3"
pl, err := ParseLabel("proxy.aliases", v)
if err.IsNotNil() {
t.Errorf("expected err=nil, got %s", err.Error())
}
sGot, ok := pl.Value.([]string)
sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"}
if !ok {
t.Errorf("value is not []string, but %T", pl.Value)
}
if !reflect.DeepEqual(sGot, sWant) {
t.Errorf("expected %q, got %q", sWant, sGot)
}
}
// func TestCommaSepProxyLabelMulti(t *testing.T) {
// v := "X-Custom-Header1, X-Custom-Header2,X-Custom-Header3"
// pl, err := ParseLabel("proxy.aliases", v)
// ExpectNoError(t, err)
// sGot := ExpectType[[]string](t, pl.Value)
// sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"}
// if sGot != nil {
// ExpectEqual(t, sGot, sWant)
// }
// }

13
src/docker/labels.go Normal file
View File

@@ -0,0 +1,13 @@
package docker
const (
WildcardAlias = "*"
LableAliases = NSProxy + ".aliases"
LableExclude = NSProxy + ".exclude"
LabelIdleTimeout = NSProxy + ".idle_timeout"
LabelWakeTimeout = NSProxy + ".wake_timeout"
LabelStopMethod = NSProxy + ".stop_method"
LabelStopTimeout = NSProxy + ".stop_timeout"
LabelStopSignal = NSProxy + ".stop_signal"
)

View File

@@ -6,16 +6,23 @@ import (
)
type Builder struct {
message string
errors []error
*builder
}
type builder struct {
message string
errors []NestedError
severity Severity
sync.Mutex
}
func NewBuilder(format string, args ...any) *Builder {
return &Builder{message: fmt.Sprintf(format, args...)}
func NewBuilder(format string, args ...any) Builder {
return Builder{&builder{message: fmt.Sprintf(format, args...)}}
}
func (b *Builder) Add(err error) *Builder {
// adding nil / nil is no-op,
// you may safely pass expressions returning error to it
func (b Builder) Add(err NestedError) Builder {
if err != nil {
b.Lock()
b.errors = append(b.errors, err)
@@ -24,8 +31,17 @@ func (b *Builder) Add(err error) *Builder {
return b
}
func (b *Builder) Addf(format string, args ...any) *Builder {
return b.Add(fmt.Errorf(format, args...))
func (b Builder) AddE(err error) Builder {
return b.Add(From(err))
}
func (b Builder) Addf(format string, args ...any) Builder {
return b.Add(errorf(format, args...))
}
func (b Builder) WithSeverity(s Severity) Builder {
b.severity = s
return b
}
// Build builds a NestedError based on the errors collected in the Builder.
@@ -35,9 +51,21 @@ func (b *Builder) Addf(format string, args ...any) *Builder {
//
// Returns:
// - NestedError: the built NestedError.
func (b *Builder) Build() NestedError {
func (b Builder) Build() NestedError {
if len(b.errors) == 0 {
return Nil()
return nil
}
return Join(b.message, b.errors...)
return Join(b.message, b.errors...).Severity(b.severity)
}
func (b Builder) To(ptr *NestedError) {
if *ptr == nil {
*ptr = b.Build()
} else {
**ptr = *b.Build()
}
}
func (b Builder) HasError() bool {
return len(b.errors) > 0
}

View File

@@ -1,27 +1,52 @@
package error
import "testing"
import (
"testing"
func TestBuilder(t *testing.T) {
. "github.com/yusing/go-proxy/utils/testing"
)
func TestBuilderEmpty(t *testing.T) {
eb := NewBuilder("qwer")
ExpectTrue(t, eb.Build() == nil)
ExpectTrue(t, eb.Build().NoError())
ExpectFalse(t, eb.HasError())
}
func TestBuilderAddNil(t *testing.T) {
eb := NewBuilder("asdf")
var err NestedError
for range 3 {
eb.Add(nil)
}
for range 3 {
eb.Add(err)
}
ExpectTrue(t, eb.Build() == nil)
ExpectTrue(t, eb.Build().NoError())
ExpectFalse(t, eb.HasError())
}
func TestBuilderNested(t *testing.T) {
eb := NewBuilder("error occurred")
eb.Add(Failure("Action 1").With(Invalid("Inner", "1")).With(Invalid("Inner", "2")))
eb.Add(Failure("Action 2").With(Invalid("Inner", "3")))
got := eb.Build().Error()
got := eb.Build().String()
expected1 :=
(`error occurred:
- Action 1 failed:
- invalid Inner - 1
- invalid Inner - 2
- invalid Inner: 1
- invalid Inner: 2
- Action 2 failed:
- invalid Inner - 3`)
- invalid Inner: 3`)
expected2 :=
(`error occurred:
- Action 1 failed:
- invalid Inner - 2
- invalid Inner - 1
- invalid Inner: 2
- invalid Inner: 1
- Action 2 failed:
- invalid Inner - 3`)
- invalid Inner: 3`)
if got != expected1 && got != expected2 {
t.Errorf("expected \n%s, got \n%s", expected1, got)
}

View File

@@ -7,36 +7,26 @@ import (
)
type (
// NestedError is an error with an inner error
// and a list of extra nested errors.
//
// It is designed to be non nil.
//
// You can use it to join multiple errors,
// or to set a inner reason for a nested error.
//
// When a method returns both valid values and errors,
// You should return (Slice/Map, NestedError).
// Caller then should handle the nested error,
// and continue with the valid values.
NestedError struct {
subject string
err error // can be nil
extras []NestedError
NestedError = *nestedError
nestedError struct {
subject string
err error
extras []nestedError
severity Severity
}
Severity uint8
)
func Nil() NestedError { return NestedError{} }
const (
SeverityFatal Severity = iota
SeverityWarning
)
func From(err error) NestedError {
switch err := err.(type) {
case nil:
return Nil()
case NestedError:
return err
default:
return NestedError{err: err}
if IsNil(err) {
return nil
}
return &nestedError{err: err}
}
// Check is a helper function that
@@ -45,42 +35,86 @@ func Check[T any](obj T, err error) (T, NestedError) {
return obj, From(err)
}
func Join(message string, err ...error) NestedError {
extras := make([]NestedError, 0, len(err))
func Join(message string, err ...NestedError) NestedError {
extras := make([]nestedError, len(err))
nErr := 0
for _, e := range err {
if err == nil {
for i, e := range err {
if e == nil {
continue
}
extras = append(extras, From(e))
extras[i] = *e
nErr += 1
}
if nErr == 0 {
return Nil()
return nil
}
return NestedError{
return &nestedError{
err: errors.New(message),
extras: extras,
}
}
func (ne NestedError) Error() string {
func JoinE(message string, err ...error) NestedError {
b := NewBuilder(message)
for _, e := range err {
b.AddE(e)
}
return b.Build()
}
func IsNil(err error) bool {
return err == nil
}
func IsNotNil(err error) bool {
return err != nil
}
func (ne NestedError) String() string {
var buf strings.Builder
ne.writeToSB(&buf, 0, "")
return buf.String()
}
func (ne NestedError) Is(err error) bool {
return errors.Is(ne.err, err)
if ne == nil {
return err == nil
}
// return errors.Is(ne.err, err)
if errors.Is(ne.err, err) {
return true
}
for _, e := range ne.extras {
if e.Is(err) {
return true
}
}
return false
}
func (ne NestedError) IsNot(err error) bool {
return !ne.Is(err)
}
func (ne NestedError) Error() error {
if ne == nil {
return nil
}
return ne.buildError(0, "")
}
func (ne NestedError) With(s any) NestedError {
if ne == nil {
return ne
}
var msg string
switch ss := s.(type) {
case nil:
return ne
case error:
case NestedError:
return ne.withError(ss)
case error:
return ne.withError(From(ss))
case string:
msg = ss
case fmt.Stringer:
@@ -88,14 +122,17 @@ func (ne NestedError) With(s any) NestedError {
default:
msg = fmt.Sprint(s)
}
return ne.withError(errors.New(msg))
return ne.withError(From(errors.New(msg)))
}
func (ne NestedError) Extraf(format string, args ...any) NestedError {
return ne.With(fmt.Errorf(format, args...))
return ne.With(errorf(format, args...))
}
func (ne NestedError) Subject(s any) NestedError {
if ne == nil {
return ne
}
switch ss := s.(type) {
case string:
ne.subject = ss
@@ -108,6 +145,9 @@ func (ne NestedError) Subject(s any) NestedError {
}
func (ne NestedError) Subjectf(format string, args ...any) NestedError {
if ne == nil {
return ne
}
if strings.Contains(format, "%q") {
panic("Subjectf format should not contain %q")
}
@@ -118,39 +158,63 @@ func (ne NestedError) Subjectf(format string, args ...any) NestedError {
return ne
}
func (ne NestedError) IsNil() bool {
return ne.err == nil
func (ne NestedError) Severity(s Severity) NestedError {
if ne == nil {
return ne
}
ne.severity = s
return ne
}
func (ne NestedError) IsNotNil() bool {
return ne.err != nil
func (ne NestedError) Warn() NestedError {
if ne == nil {
return ne
}
ne.severity = SeverityWarning
return ne
}
func (ne NestedError) NoError() bool {
return ne == nil
}
func (ne NestedError) HasError() bool {
return ne != nil
}
func (ne NestedError) IsFatal() bool {
return ne != nil && ne.severity == SeverityFatal
}
func (ne NestedError) IsWarning() bool {
return ne != nil && ne.severity == SeverityWarning
}
func errorf(format string, args ...any) NestedError {
return From(fmt.Errorf(format, args...))
}
func (ne NestedError) withError(err error) NestedError {
ne.extras = append(ne.extras, From(err))
func (ne NestedError) withError(err NestedError) NestedError {
if ne != nil && err != nil {
ne.extras = append(ne.extras, *err)
}
return ne
}
func (ne *NestedError) writeToSB(sb *strings.Builder, level int, prefix string) {
ne.writeIndents(sb, level)
func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) {
for i := 0; i < level; i++ {
sb.WriteString(" ")
}
sb.WriteString(prefix)
if ne.IsNil() {
if ne.NoError() {
sb.WriteString("nil")
return
}
sb.WriteString(ne.err.Error())
if ne.subject != "" {
if ne.err != nil {
sb.WriteString(fmt.Sprintf(" for %q", ne.subject))
} else {
sb.WriteString(fmt.Sprint(ne.subject))
}
sb.WriteString(fmt.Sprintf(" for %q", ne.subject))
}
if len(ne.extras) > 0 {
sb.WriteRune(':')
@@ -161,8 +225,32 @@ func (ne *NestedError) writeToSB(sb *strings.Builder, level int, prefix string)
}
}
func (ne *NestedError) writeIndents(sb *strings.Builder, level int) {
func (ne NestedError) buildError(level int, prefix string) error {
var res error
var sb strings.Builder
for i := 0; i < level; i++ {
sb.WriteString(" ")
}
sb.WriteString(prefix)
if ne.NoError() {
sb.WriteString("nil")
return errors.New(sb.String())
}
res = fmt.Errorf("%s%w", sb.String(), ne.err)
sb.Reset()
if ne.subject != "" {
sb.WriteString(fmt.Sprintf(" for %q", ne.subject))
}
if len(ne.extras) > 0 {
sb.WriteRune(':')
res = fmt.Errorf("%w%s", res, sb.String())
for _, extra := range ne.extras {
res = errors.Join(res, extra.buildError(level+1, "- "))
}
}
return res
}

View File

@@ -1,46 +1,74 @@
package error
package error_test
import (
"errors"
"testing"
. "github.com/yusing/go-proxy/error"
. "github.com/yusing/go-proxy/utils/testing"
)
func AssertEq[T comparable](t *testing.T, got, want T) {
t.Helper()
if got != want {
t.Errorf("expected:\n%v, got\n%v", want, got)
}
}
func TestErrorIs(t *testing.T) {
AssertEq(t, Failure("foo").Is(ErrFailure), true)
AssertEq(t, Failure("foo").With("bar").Is(ErrFailure), true)
AssertEq(t, Failure("foo").With("bar").Is(ErrInvalid), false)
AssertEq(t, Failure("foo").With("bar").With("baz").Is(ErrInvalid), false)
ExpectTrue(t, Failure("foo").Is(ErrFailure))
ExpectTrue(t, Failure("foo").With("bar").Is(ErrFailure))
ExpectFalse(t, Failure("foo").With("bar").Is(ErrInvalid))
ExpectFalse(t, Failure("foo").With("bar").With("baz").Is(ErrInvalid))
AssertEq(t, Invalid("foo", "bar").Is(ErrInvalid), true)
AssertEq(t, Invalid("foo", "bar").Is(ErrFailure), false)
ExpectTrue(t, Invalid("foo", "bar").Is(ErrInvalid))
ExpectFalse(t, Invalid("foo", "bar").Is(ErrFailure))
AssertEq(t, Nil().Is(nil), true)
AssertEq(t, Nil().Is(ErrInvalid), false)
AssertEq(t, Invalid("foo", "bar").Is(nil), false)
ExpectFalse(t, Invalid("foo", "bar").Is(nil))
ExpectTrue(t, errors.Is(Failure("foo").Error(), ErrFailure))
ExpectTrue(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrInvalid))
ExpectTrue(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrFailure))
ExpectFalse(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrNotExists))
}
func TestNil(t *testing.T) {
AssertEq(t, Nil().IsNil(), true)
AssertEq(t, Nil().IsNotNil(), false)
AssertEq(t, Nil().Error(), "nil")
func TestErrorNestedIs(t *testing.T) {
var err NestedError
ExpectTrue(t, err.Is(nil))
err = Failure("some reason")
ExpectTrue(t, err.Is(ErrFailure))
ExpectFalse(t, err.Is(ErrAlreadyExist))
err.With(AlreadyExist("something", ""))
ExpectTrue(t, err.Is(ErrFailure))
ExpectTrue(t, err.Is(ErrAlreadyExist))
ExpectFalse(t, err.Is(ErrInvalid))
}
func TestIsNil(t *testing.T) {
var err NestedError
ExpectTrue(t, err.Is(nil))
ExpectFalse(t, err.HasError())
ExpectTrue(t, err == nil)
ExpectTrue(t, err.NoError())
eb := NewBuilder("")
returnNil := func() error {
return eb.Build().Error()
}
ExpectTrue(t, IsNil(returnNil()))
ExpectTrue(t, returnNil() == nil)
ExpectTrue(t, (err.
Subject("any").
With("something").
Extraf("foo %s", "bar")) == nil)
}
func TestErrorSimple(t *testing.T) {
ne := Failure("foo bar")
AssertEq(t, ne.Error(), "foo bar failed")
ExpectEqual(t, ne.String(), "foo bar failed")
ne = ne.Subject("baz")
AssertEq(t, ne.Error(), "foo bar failed for \"baz\"")
ExpectEqual(t, ne.String(), "foo bar failed for \"baz\"")
}
func TestErrorWith(t *testing.T) {
ne := Failure("foo").With("bar").With("baz")
AssertEq(t, ne.Error(), "foo failed:\n - bar\n - baz")
ExpectEqual(t, ne.String(), "foo failed:\n - bar\n - baz")
}
func TestErrorNested(t *testing.T) {
@@ -76,5 +104,6 @@ func TestErrorNested(t *testing.T) {
- inner3 failed for "action 3":
- 3
- 3`
AssertEq(t, ne.Error(), want)
ExpectEqual(t, ne.String(), want)
ExpectEqual(t, ne.Error().Error(), want)
}

View File

@@ -5,33 +5,48 @@ import (
)
var (
ErrFailure = stderrors.New("failed")
ErrInvalid = stderrors.New("invalid")
ErrUnsupported = stderrors.New("unsupported")
ErrNotExists = stderrors.New("does not exist")
ErrDuplicated = stderrors.New("duplicated")
ErrFailure = stderrors.New("failed")
ErrInvalid = stderrors.New("invalid")
ErrUnsupported = stderrors.New("unsupported")
ErrUnexpected = stderrors.New("unexpected")
ErrNotExists = stderrors.New("does not exist")
ErrAlreadyExist = stderrors.New("already exist")
)
const fmtSubjectWhat = "%w %v: %v"
func Failure(what string) NestedError {
return errorf("%s %w", what, ErrFailure)
}
func FailureWhy(what string, why string) NestedError {
func FailedWhy(what string, why string) NestedError {
return errorf("%s %w because %s", what, ErrFailure, why)
}
func FailWith(what string, err any) NestedError {
return Failure(what).With(err)
}
func Invalid(subject, what any) NestedError {
return errorf("%w %v - %v", ErrInvalid, subject, what)
return errorf(fmtSubjectWhat, ErrInvalid, subject, what)
}
func Unsupported(subject, what any) NestedError {
return errorf("%w %v - %v", ErrUnsupported, subject, what)
return errorf(fmtSubjectWhat, ErrUnsupported, subject, what)
}
func NotExists(subject, what any) NestedError {
return errorf("%s %v - %v", subject, ErrNotExists, what)
func Unexpected(subject, what any) NestedError {
return errorf(fmtSubjectWhat, ErrUnexpected, subject, what)
}
func Duplicated(subject, what any) NestedError {
return errorf("%w %v: %v", ErrDuplicated, subject, what)
func UnexpectedError(err error) NestedError {
return errorf("%w error: %w", ErrUnexpected, err)
}
func NotExist(subject, what any) NestedError {
return errorf("%v %w: %v", subject, ErrNotExists, what)
}
func AlreadyExist(subject, what any) NestedError {
return errorf("%v %w: %v", subject, ErrAlreadyExist, what)
}

View File

@@ -1,14 +1,13 @@
module github.com/yusing/go-proxy
go 1.22
toolchain go1.23.1
go 1.22.0
require (
github.com/docker/cli v27.2.1+incompatible
github.com/docker/docker v27.2.1+incompatible
github.com/docker/cli v27.3.1+incompatible
github.com/docker/docker v27.3.1+incompatible
github.com/fsnotify/fsnotify v1.7.0
github.com/go-acme/lego/v4 v4.18.0
github.com/puzpuzpuz/xsync/v3 v3.4.0
github.com/santhosh-tekuri/jsonschema v1.2.4
github.com/sirupsen/logrus v1.9.3
golang.org/x/net v0.29.0
@@ -36,6 +35,7 @@ require (
github.com/morikuni/aec v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/ovh/go-ovh v1.6.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0 // indirect
go.opentelemetry.io/otel v1.30.0 // indirect
@@ -45,10 +45,12 @@ require (
go.opentelemetry.io/otel/trace v1.30.0 // indirect
golang.org/x/crypto v0.27.0 // indirect
golang.org/x/mod v0.21.0 // indirect
golang.org/x/oauth2 v0.23.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
golang.org/x/time v0.6.0 // indirect
golang.org/x/tools v0.25.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gotest.tools/v3 v3.5.1 // indirect
)

View File

@@ -13,10 +13,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/cli v27.2.1+incompatible h1:U5BPtiD0viUzjGAjV1p0MGB8eVA3L3cbIrnyWmSJI70=
github.com/docker/cli v27.2.1+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8=
github.com/docker/docker v27.2.1+incompatible h1:fQdiLfW7VLscyoeYEBz7/J8soYFDZV1u6VW6gJEjNMI=
github.com/docker/docker v27.2.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/cli v27.3.1+incompatible h1:qEGdFBF3Xu6SCvCYhc7CzaQTlBmqDuzxPDpigSyeKQQ=
github.com/docker/cli v27.3.1+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8=
github.com/docker/docker v27.3.1+incompatible h1:KttF0XoteNTicmUtBO0L2tP+J7FGRFTjaEF4k6WdhfI=
github.com/docker/docker v27.3.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
@@ -45,12 +45,16 @@ github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0 h1:Wqo399gCIufwto+VfwCSvsnfGpF/w5E9CNxSwbpD6No=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0/go.mod h1:qmOFXW2epJhM0qSnUUYpldc7gVz2KMQwJ/QYCDIa7XU=
github.com/jarcoal/httpmock v1.3.0 h1:2RJ8GP0IIaWwcC9Fp2BmVi8Kog3v2Hn7VXM3fTd+nuc=
github.com/jarcoal/httpmock v1.3.0/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g=
github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
@@ -63,10 +67,14 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
github.com/ovh/go-ovh v1.6.0 h1:ixLOwxQdzYDx296sXcgS35TOPEahJkpjMGtzPadCjQI=
github.com/ovh/go-ovh v1.6.0/go.mod h1:cTVDnl94z4tl8pP1uZ/8jlVxntjSIf09bNcQ5TJSC7c=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/puzpuzpuz/xsync/v3 v3.4.0 h1:DuVBAdXuGFHv8adVXjWWZ63pJq+NRXOWVXlKDBZ+mJ4=
github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg=
github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis=
@@ -110,6 +118,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs=
golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -149,6 +159,8 @@ google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHh
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -2,6 +2,8 @@ package main
import (
"context"
"encoding/json"
"log"
"net/http"
"os"
"os/signal"
@@ -16,6 +18,7 @@ import (
"github.com/yusing/go-proxy/common"
"github.com/yusing/go-proxy/config"
"github.com/yusing/go-proxy/docker"
"github.com/yusing/go-proxy/docker/idlewatcher"
E "github.com/yusing/go-proxy/error"
R "github.com/yusing/go-proxy/route"
"github.com/yusing/go-proxy/server"
@@ -33,14 +36,15 @@ func main() {
}
logrus.SetFormatter(&logrus.TextFormatter{
DisableSorting: true,
FullTimestamp: true,
ForceColors: true,
TimestampFormat: "01-02 15:04:05",
DisableSorting: true,
DisableLevelTruncation: true,
FullTimestamp: true,
ForceColors: true,
TimestampFormat: "01-02 15:04:05",
})
if args.Command == common.CommandReload {
if err := apiUtils.ReloadServer(); err.IsNotNil() {
if err := apiUtils.ReloadServer(); err.HasError() {
l.Fatal(err)
}
return
@@ -50,27 +54,47 @@ func main() {
// exit if only validate config
if args.Command == common.CommandValidate {
var err E.NestedError
data, err := E.Check(os.ReadFile(common.ConfigPath))
if err.IsNotNil() {
l.WithError(err).Fatalf("config error")
data, err := os.ReadFile(common.ConfigPath)
if err == nil {
err = config.Validate(data).Error()
}
if err = config.Validate(data); err.IsNotNil() {
l.WithError(err).Fatalf("config error")
if err != nil {
l.Fatal("config error: ", err)
}
l.Printf("config OK")
return
}
cfg, err := config.New()
if err.IsNotNil() {
l.Fatalf("config error: %s", err)
cfg, err := config.Load()
if err.IsFatal() {
l.Fatal(err)
}
onShutdown.Add(func() {
docker.CloseAllClients()
cfg.Dispose()
})
if args.Command == common.CommandListConfigs {
printJSON(cfg.Value())
return
}
cfg.StartProxyProviders()
if args.Command == common.CommandListRoutes {
printJSON(cfg.RoutesByAlias())
return
}
if args.Command == common.CommandDebugListEntries {
printJSON(cfg.DumpEntries())
return
}
if err.HasError() {
l.Warn(err)
}
cfg.WatchChanges()
onShutdown.Add(docker.CloseAllClients)
onShutdown.Add(cfg.Dispose)
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT)
@@ -80,23 +104,28 @@ func main() {
autocert := cfg.GetAutoCertProvider()
if autocert != nil {
err = autocert.LoadCert()
if err.IsNotNil() {
l.Error(err)
l.Info("Now attempting to obtain a new certificate...")
if err = autocert.ObtainCert(); err.IsNotNil() {
ctx, certRenewalCancel := context.WithCancel(context.Background())
go autocert.ScheduleRenewal(ctx)
onShutdown.Add(certRenewalCancel)
} else {
if err = autocert.LoadCert(); err.HasError() {
if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist
l.Error(err)
}
l.Debug("obtaining cert due to error loading cert")
if err = autocert.ObtainCert(); err.HasError() {
l.Warn(err)
}
} else {
for name, expiry := range autocert.GetExpiries() {
l.Infof("certificate %q: expire on %s", name, expiry)
}
}
if err.NoError() {
ctx, certRenewalCancel := context.WithCancel(context.Background())
go autocert.ScheduleRenewal(ctx)
onShutdown.Add(certRenewalCancel)
}
for _, expiry := range autocert.GetExpiries() {
l.Infof("certificate expire on %s", expiry)
break
}
} else {
l.Info("autocert not configured")
}
proxyServer := server.InitProxyServer(server.Options{
@@ -120,6 +149,9 @@ func main() {
onShutdown.Add(proxyServer.Stop)
onShutdown.Add(apiServer.Stop)
go idlewatcher.Start()
onShutdown.Add(idlewatcher.Stop)
// wait for signal
<-sig
@@ -147,3 +179,12 @@ func main() {
logrus.Info("timeout waiting for shutdown")
}
}
func printJSON(obj any) {
j, err := E.Check(json.Marshal(obj))
if err.HasError() {
logrus.Fatal(err)
}
rawLogger := log.New(os.Stdout, "", 0)
rawLogger.Printf("%s", j) // raw output for convenience using "jq"
}

View File

@@ -9,5 +9,5 @@ type (
Provider string `json:"provider"`
Options AutocertProviderOpt `yaml:",flow" json:"options"`
}
AutocertProviderOpt map[string]string
AutocertProviderOpt map[string]any
)

View File

@@ -1,13 +1,16 @@
package model
import (
"strconv"
"strings"
. "github.com/yusing/go-proxy/common"
D "github.com/yusing/go-proxy/docker"
F "github.com/yusing/go-proxy/utils/functional"
)
type (
ProxyEntry struct {
ProxyEntry struct { // raw entry object before validation
Alias string `yaml:"-" json:"-"`
Scheme string `yaml:"scheme" json:"scheme"`
Host string `yaml:"host" json:"host"`
@@ -16,35 +19,70 @@ type (
PathPatterns []string `yaml:"path_patterns" json:"path_patterns"` // http(s) proxy only
SetHeaders map[string]string `yaml:"set_headers" json:"set_headers"` // http(s) proxy only
HideHeaders []string `yaml:"hide_headers" json:"hide_headers"` // http(s) proxy only
/* Docker only */
*D.ProxyProperties `yaml:"-" json:"proxy_properties"`
}
ProxyEntries = *F.Map[string, *ProxyEntry]
ProxyEntries = F.Map[string, *ProxyEntry]
)
var NewProxyEntries = F.NewMap[string, *ProxyEntry]
var NewProxyEntries = F.NewMapOf[string, *ProxyEntry]
func (e *ProxyEntry) SetDefaults() {
if e.ProxyProperties == nil {
e.ProxyProperties = &D.ProxyProperties{}
}
if e.Scheme == "" {
if strings.ContainsRune(e.Port, ':') {
switch {
case strings.ContainsRune(e.Port, ':'):
e.Scheme = "tcp"
} else {
switch e.Port {
case "443", "8443":
e.Scheme = "https"
default:
e.Scheme = "http"
case e.ProxyProperties != nil:
if _, ok := ServiceNamePortMapTCP[e.ImageName]; ok {
e.Scheme = "tcp"
}
}
}
if e.Scheme == "" {
switch e.Port {
case "443", "8443":
e.Scheme = "https"
default:
e.Scheme = "http"
}
}
if e.Host == "" {
e.Host = "localhost"
}
if e.Port == "" {
switch e.Scheme {
case "http":
e.Port = "80"
case "https":
e.Port = "443"
e.Port = e.FirstPort
}
if e.Port == "" {
if port, ok := ServiceNamePortMapTCP[e.Port]; ok {
e.Port = strconv.Itoa(port)
} else if port, ok := ImageNamePortMapHTTP[e.Port]; ok {
e.Port = strconv.Itoa(port)
} else {
switch e.Scheme {
case "http":
e.Port = "80"
case "https":
e.Port = "443"
}
}
}
if e.IdleTimeout == "" {
e.IdleTimeout = IdleTimeoutDefault
}
if e.WakeTimeout == "" {
e.WakeTimeout = WakeTimeoutDefault
}
if e.StopTimeout == "" {
e.StopTimeout = StopTimeoutDefault
}
if e.StopMethod == "" {
e.StopMethod = StopMethodDefault
}
}

View File

@@ -1,10 +1,5 @@
package proxy
var (
PathMode_Forward = "forward"
PathMode_RemovedPath = ""
)
const (
StreamType_UDP string = "udp"
StreamType_TCP string = "tcp"
@@ -19,4 +14,3 @@ var (
HTTPSchemes = []string{"http", "https"}
ValidSchemes = append(StreamSchemes, HTTPSchemes...)
)

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"net/http"
"net/url"
"time"
E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models"
@@ -11,16 +12,24 @@ import (
)
type (
Entry struct { // real model after validation
ReverseProxyEntry struct { // real model after validation
Alias T.Alias
Scheme T.Scheme
Host T.Host
Port T.Port
URL *url.URL
NoTLSVerify bool
PathPatterns T.PathPatterns
SetHeaders http.Header
HideHeaders []string
/* Docker only */
IdleTimeout time.Duration
WakeTimeout time.Duration
StopMethod T.StopMethod
StopTimeout int
StopSignal T.Signal
DockerHost string
ContainerName string
ContainerRunning bool
}
StreamEntry struct {
Alias T.Alias `json:"alias"`
@@ -30,69 +39,106 @@ type (
}
)
func NewEntry(m *M.ProxyEntry) (any, E.NestedError) {
func (rp *ReverseProxyEntry) UseIdleWatcher() bool {
return rp.IdleTimeout > 0 && rp.DockerHost != ""
}
func ValidateEntry(m *M.ProxyEntry) (any, E.NestedError) {
m.SetDefaults()
scheme, err := T.NewScheme(m.Scheme)
if err.IsNotNil() {
if err.HasError() {
return nil, err
}
var entry any
e := E.NewBuilder("error validating proxy entry")
if scheme.IsStream() {
return validateStreamEntry(m)
entry = validateStreamEntry(m, e)
} else {
entry = validateRPEntry(m, scheme, e)
}
return validateEntry(m, scheme)
if err := e.Build(); err.HasError() {
return nil, err
}
return entry, nil
}
func validateEntry(m *M.ProxyEntry, s T.Scheme) (*Entry, E.NestedError) {
host, err := T.NewHost(m.Host)
if err.IsNotNil() {
return nil, err
}
port, err := T.NewPort(m.Port)
if err.IsNotNil() {
return nil, err
}
pathPatterns, err := T.NewPathPatterns(m.PathPatterns)
if err.IsNotNil() {
return nil, err
}
setHeaders, err := T.NewHTTPHeaders(m.SetHeaders)
if err.IsNotNil() {
return nil, err
}
func validateRPEntry(m *M.ProxyEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry {
var stopTimeOut time.Duration
host, err := T.ValidateHost(m.Host)
b.Add(err)
port, err := T.ValidatePort(m.Port)
b.Add(err)
pathPatterns, err := T.ValidatePathPatterns(m.PathPatterns)
b.Add(err)
setHeaders, err := T.ValidateHTTPHeaders(m.SetHeaders)
b.Add(err)
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
if err.IsNotNil() {
return nil, err
b.Add(err)
idleTimeout, err := T.ValidateDurationPostitive(m.IdleTimeout)
b.Add(err)
wakeTimeout, err := T.ValidateDurationPostitive(m.WakeTimeout)
b.Add(err)
stopMethod, err := T.ValidateStopMethod(m.StopMethod)
b.Add(err)
if stopMethod == T.StopMethodStop {
stopTimeOut, err = T.ValidateDurationPostitive(m.StopTimeout)
b.Add(err)
}
stopSignal, err := T.ValidateSignal(m.StopSignal)
b.Add(err)
if err.HasError() {
return nil
}
return &ReverseProxyEntry{
Alias: T.NewAlias(m.Alias),
Scheme: s,
URL: url,
NoTLSVerify: m.NoTLSVerify,
PathPatterns: pathPatterns,
SetHeaders: setHeaders,
HideHeaders: m.HideHeaders,
IdleTimeout: idleTimeout,
WakeTimeout: wakeTimeout,
StopMethod: stopMethod,
StopTimeout: int(stopTimeOut.Seconds()), // docker api takes integer seconds for timeout argument
StopSignal: stopSignal,
DockerHost: m.DockerHost,
ContainerName: m.ContainerName,
ContainerRunning: m.Running,
}
return &Entry{
Alias: T.NewAlias(m.Alias),
Scheme: s,
Host: host,
Port: port,
URL: url,
NoTLSVerify: m.NoTLSVerify,
PathPatterns: pathPatterns,
SetHeaders: setHeaders,
HideHeaders: m.HideHeaders,
}, E.Nil()
}
func validateStreamEntry(m *M.ProxyEntry) (*StreamEntry, E.NestedError) {
host, err := T.NewHost(m.Host)
if err.IsNotNil() {
return nil, err
}
port, err := T.NewStreamPort(m.Port)
if err.IsNotNil() {
return nil, err
}
scheme, err := T.NewStreamScheme(m.Scheme)
if err.IsNotNil() {
return nil, err
func validateStreamEntry(m *M.ProxyEntry, b E.Builder) *StreamEntry {
host, err := T.ValidateHost(m.Host)
b.Add(err)
port, err := T.ValidateStreamPort(m.Port)
b.Add(err)
scheme, err := T.ValidateStreamScheme(m.Scheme)
b.Add(err)
if b.HasError() {
return nil
}
return &StreamEntry{
Alias: T.NewAlias(m.Alias),
Scheme: *scheme,
Host: host,
Port: port,
}, E.Nil()
}
}

View File

@@ -1,23 +1,6 @@
package fields
import (
"strings"
F "github.com/yusing/go-proxy/utils/functional"
type (
Alias string
NewAlias = Alias
)
type Alias string
type Aliases struct{ *F.Slice[Alias] }
func NewAlias(s string) Alias {
return Alias(s)
}
func NewAliases(s string) Aliases {
split := strings.Split(s, ",")
a := Aliases{F.NewSliceN[Alias](len(split))}
for i, v := range split {
a.Set(i, NewAlias(v))
}
return a
}

View File

@@ -7,7 +7,7 @@ import (
E "github.com/yusing/go-proxy/error"
)
func NewHTTPHeaders(headers map[string]string) (http.Header, E.NestedError) {
func ValidateHTTPHeaders(headers map[string]string) (http.Header, E.NestedError) {
h := make(http.Header)
for k, v := range headers {
vSplit := strings.Split(v, ",")
@@ -15,5 +15,5 @@ func NewHTTPHeaders(headers map[string]string) (http.Header, E.NestedError) {
h.Add(k, strings.TrimSpace(header))
}
}
return h, E.Nil()
return h, nil
}

View File

@@ -7,6 +7,6 @@ import (
type Host string
type Subdomain = Alias
func NewHost(s string) (Host, E.NestedError) {
return Host(s), E.Nil()
func ValidateHost(s string) (Host, E.NestedError) {
return Host(s), nil
}

View File

@@ -9,7 +9,7 @@ type PathMode string
func NewPathMode(pm string) (PathMode, E.NestedError) {
switch pm {
case "", "forward":
return PathMode(pm), E.Nil()
return PathMode(pm), nil
default:
return "", E.Invalid("path mode", pm)
}

View File

@@ -16,22 +16,22 @@ func NewPathPattern(s string) (PathPattern, E.NestedError) {
if !pathPattern.MatchString(string(s)) {
return "", E.Invalid("path pattern", s)
}
return PathPattern(s), E.Nil()
return PathPattern(s), nil
}
func NewPathPatterns(s []string) (PathPatterns, E.NestedError) {
func ValidatePathPatterns(s []string) (PathPatterns, E.NestedError) {
if len(s) == 0 {
return []PathPattern{"/"}, E.Nil()
return []PathPattern{"/"}, nil
}
pp := make(PathPatterns, len(s))
for i, v := range s {
if pattern, err := NewPathPattern(v); err.IsNotNil() {
if pattern, err := NewPathPattern(v); err.HasError() {
return nil, err
} else {
pp[i] = pattern
}
}
return pp, E.Nil()
return pp, nil
}
var pathPattern = regexp.MustCompile("^((GET|POST|DELETE|PUT|PATCH|HEAD|OPTIONS|CONNECT)\\s)?(/\\w*)+/?$")

View File

@@ -8,7 +8,7 @@ import (
type Port int
func NewPort(v string) (Port, E.NestedError) {
func ValidatePort(v string) (Port, E.NestedError) {
p, err := strconv.Atoi(v)
if err != nil {
return ErrPort, E.Invalid("port number", v).With(err)
@@ -18,17 +18,17 @@ func NewPort(v string) (Port, E.NestedError) {
func NewPortInt[Int int | uint16](v Int) (Port, E.NestedError) {
pp := Port(v)
if err := pp.boundCheck(); err.IsNotNil() {
if err := pp.boundCheck(); err.HasError() {
return ErrPort, err
}
return pp, E.Nil()
return pp, nil
}
func (p Port) boundCheck() E.NestedError {
if p < MinPort || p > MaxPort {
return E.Invalid("port", p)
}
return E.Nil()
return nil
}
const (

View File

@@ -1,8 +1,6 @@
package fields
import (
"strings"
E "github.com/yusing/go-proxy/error"
)
@@ -11,24 +9,11 @@ type Scheme string
func NewScheme(s string) (Scheme, E.NestedError) {
switch s {
case "http", "https", "tcp", "udp":
return Scheme(s), E.Nil()
return Scheme(s), nil
}
return "", E.Invalid("scheme", s)
}
func NewSchemeFromPort(p string) (Scheme, E.NestedError) {
var s string
switch {
case strings.ContainsRune(p, ':'):
s = "tcp"
case strings.HasSuffix(p, "443"):
s = "https"
default:
s = "http"
}
return Scheme(s), E.Nil()
}
func (s Scheme) IsHTTP() bool { return s == "http" }
func (s Scheme) IsHTTPS() bool { return s == "https" }
func (s Scheme) IsTCP() bool { return s == "tcp" }

View File

@@ -0,0 +1,17 @@
package fields
import (
E "github.com/yusing/go-proxy/error"
)
type Signal string
func ValidateSignal(s string) (Signal, E.NestedError) {
switch s {
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
"INT", "TERM", "HUP", "QUIT":
return Signal(s), nil
}
return "", E.Invalid("signal", s)
}

View File

@@ -0,0 +1,23 @@
package fields
import (
E "github.com/yusing/go-proxy/error"
)
type StopMethod string
const (
StopMethodPause StopMethod = "pause"
StopMethodStop StopMethod = "stop"
StopMethodKill StopMethod = "kill"
)
func ValidateStopMethod(s string) (StopMethod, E.NestedError) {
sm := StopMethod(s)
switch sm {
case StopMethodPause, StopMethodStop, StopMethodKill:
return sm, nil
default:
return "", E.Invalid("stop_method", sm)
}
}

View File

@@ -1,6 +1,7 @@
package fields
import (
"fmt"
"strings"
"github.com/yusing/go-proxy/common"
@@ -12,38 +13,38 @@ type StreamPort struct {
ProxyPort Port `json:"proxy"`
}
func NewStreamPort(p string) (StreamPort, E.NestedError) {
func ValidateStreamPort(p string) (StreamPort, E.NestedError) {
split := strings.Split(p, ":")
if len(split) != 2 {
return StreamPort{}, E.Invalid("stream port", p).With("should be in 'x:y' format")
return StreamPort{}, E.Invalid("stream port", fmt.Sprintf("%q", p)).With("should be in 'x:y' format")
}
listeningPort, err := NewPort(split[0])
if err.IsNotNil() {
listeningPort, err := ValidatePort(split[0])
if err.HasError() {
return StreamPort{}, err
}
if err = listeningPort.boundCheck(); err.IsNotNil() {
if err = listeningPort.boundCheck(); err.HasError() {
return StreamPort{}, err
}
proxyPort, err := NewPort(split[1])
if err.IsNotNil() {
proxyPort, err := ValidatePort(split[1])
if err.HasError() {
proxyPort, err = parseNameToPort(split[1])
if err.IsNotNil() {
if err.HasError() {
return StreamPort{}, err
}
}
if err = proxyPort.boundCheck(); err.IsNotNil() {
if err = proxyPort.boundCheck(); err.HasError() {
return StreamPort{}, err
}
return StreamPort{ListeningPort: listeningPort, ProxyPort: proxyPort}, E.Nil()
return StreamPort{ListeningPort: listeningPort, ProxyPort: proxyPort}, nil
}
func parseNameToPort(name string) (Port, E.NestedError) {
port, ok := common.NamePortMapTCP[name]
port, ok := common.ServiceNamePortMapTCP[name]
if !ok {
return -1, E.Unsupported("service", name)
}
return Port(port), E.Nil()
return Port(port), nil
}

View File

@@ -12,7 +12,7 @@ type StreamScheme struct {
ProxyScheme Scheme `json:"proxy"`
}
func NewStreamScheme(s string) (ss *StreamScheme, err E.NestedError) {
func ValidateStreamScheme(s string) (ss *StreamScheme, err E.NestedError) {
ss = &StreamScheme{}
parts := strings.Split(s, ":")
if len(parts) == 1 {
@@ -21,14 +21,14 @@ func NewStreamScheme(s string) (ss *StreamScheme, err E.NestedError) {
return nil, E.Invalid("stream scheme", s)
}
ss.ListeningScheme, err = NewScheme(parts[0])
if err.IsNotNil() {
if err.HasError() {
return nil, err
}
ss.ProxyScheme, err = NewScheme(parts[1])
if err.IsNotNil() {
if err.HasError() {
return nil, err
}
return ss, E.Nil()
return ss, nil
}
func (s StreamScheme) String() string {

View File

@@ -0,0 +1,18 @@
package fields
import (
"time"
E "github.com/yusing/go-proxy/error"
)
func ValidateDurationPostitive(value string) (time.Duration, E.NestedError) {
d, err := time.ParseDuration(value)
if err != nil {
return 0, E.Invalid("duration", value)
}
if d < 0 {
return 0, E.Invalid("duration", "negative value")
}
return d, nil
}

View File

@@ -1,3 +0,0 @@
package provider
const wildcardAlias = "*"

View File

@@ -1,159 +1,205 @@
package provider
import (
"fmt"
"regexp"
"strconv"
"strings"
"github.com/docker/docker/api/types"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/docker"
E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models"
PT "github.com/yusing/go-proxy/proxy/fields"
R "github.com/yusing/go-proxy/route"
W "github.com/yusing/go-proxy/watcher"
)
type DockerProvider struct {
dockerHost string
dockerHost, hostname string
}
func DockerProviderImpl(dockerHost string) ProviderImpl {
return &DockerProvider{dockerHost: dockerHost}
}
var AliasRefRegex = regexp.MustCompile(`\$\d+`)
// GetProxyEntries returns proxy entries from a docker client.
//
// It retrieves the docker client information using the dockerhelper.GetClientInfo method.
// Then, it iterates over the containers in the docker client information and calls
// the getEntriesFromLabels method to get the proxy entries for each container.
// Any errors encountered during the process are added to the ne error object.
// Finally, it returns the collected proxy entries and the ne error object.
//
// Parameters:
// - p: A pointer to the DockerProvider struct.
//
// Returns:
// - P.EntryModelSlice: (non-nil) A slice of EntryModel structs representing the proxy entries.
// - error: An error object if there was an error retrieving the docker client information or parsing the labels.
func (p DockerProvider) GetProxyEntries() (M.ProxyEntries, E.NestedError) {
entries := M.NewProxyEntries()
info, err := D.GetClientInfo(p.dockerHost)
if err.IsNotNil() {
return entries, err
func DockerProviderImpl(dockerHost string) (ProviderImpl, E.NestedError) {
hostname, err := D.ParseDockerHostname(dockerHost)
if err.HasError() {
return nil, err
}
errors := E.NewBuilder("errors when parse docker labels")
for _, container := range info.Containers {
en, err := p.getEntriesFromLabels(&container, info.Host)
if err.IsNotNil() {
errors.Add(err)
}
// although err is not nil
// there may be some valid entries in `en`
dups := entries.MergeWith(en)
// add the duplicate proxy entries to the error
dups.EachKV(func(k string, v *M.ProxyEntry) {
errors.Addf("duplicate alias %s", k)
})
}
return entries, errors.Build()
return &DockerProvider{dockerHost: dockerHost, hostname: hostname}, nil
}
func (p *DockerProvider) NewWatcher() W.Watcher {
return W.NewDockerWatcher(p.dockerHost)
}
// Returns a list of proxy entries for a container.
// Always non-nil
func (p *DockerProvider) getEntriesFromLabels(container *types.Container, clientHost string) (M.ProxyEntries, E.NestedError) {
var mainAlias string
var aliases PT.Aliases
// set mainAlias to docker compose service name if available
if serviceName, ok := container.Labels["com.docker.compose.service"]; ok {
mainAlias = serviceName
}
// if mainAlias is not set,
// or container name is different from service name
// use container name
if containerName := strings.TrimPrefix(container.Names[0], "/"); containerName != mainAlias {
mainAlias = containerName
}
if l, ok := container.Labels["proxy.aliases"]; ok {
aliases = PT.NewAliases(l)
delete(container.Labels, "proxy.aliases")
} else {
aliases = PT.NewAliases(mainAlias)
}
func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
routes = R.NewRoutes()
entries := M.NewProxyEntries()
// find first port, return if no port exposed
defaultPort, err := findFirstPort(container)
if err.IsNotNil() {
logrus.Debug(mainAlias, " ", err.Error())
info, err := D.GetClientInfo(p.dockerHost, true)
if err.HasError() {
return routes, E.FailWith("connect to docker", err)
}
errors := E.NewBuilder("errors when parse docker labels")
for _, c := range info.Containers {
container := D.FromDocker(&c, p.dockerHost)
if container.IsExcluded {
continue
}
newEntries, err := p.entriesFromContainerLabels(container)
if err.HasError() {
errors.Add(err)
}
// although err is not nil
// there may be some valid entries in `en`
dups := entries.MergeFrom(newEntries)
// add the duplicate proxy entries to the error
dups.RangeAll(func(k string, v *M.ProxyEntry) {
errors.Addf("duplicate alias %s", k)
})
}
entries.RangeAll(func(_ string, e *M.ProxyEntry) {
e.DockerHost = p.dockerHost
})
routes, err = R.FromEntries(entries)
errors.Add(err)
return routes, errors.Build()
}
func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) {
b := E.NewBuilder("event %s error", event)
defer b.To(&res.err)
routes.RangeAll(func(k string, v R.Route) {
if v.Entry().ContainerName == event.ActorName {
b.Add(v.Stop())
routes.Delete(k)
res.nRemoved++
}
})
client, err := D.ConnectClient(p.dockerHost)
if err.HasError() {
b.Add(E.FailWith("connect to docker", err))
return
}
defer client.Close()
cont, err := client.Inspect(event.ActorID)
if err.HasError() {
b.Add(E.FailWith("inspect container", err))
return
}
entries, err := p.entriesFromContainerLabels(cont)
b.Add(err)
entries.RangeAll(func(alias string, entry *M.ProxyEntry) {
if routes.Has(alias) {
b.Add(E.AlreadyExist("alias", alias))
} else {
if route, err := R.NewRoute(entry); err.HasError() {
b.Add(err)
} else {
routes.Store(alias, route)
b.Add(route.Start())
res.nAdded++
}
}
})
return
}
// Returns a list of proxy entries for a container.
// Always non-nil
func (p *DockerProvider) entriesFromContainerLabels(container D.Container) (M.ProxyEntries, E.NestedError) {
entries := M.NewProxyEntries()
// init entries map for all aliases
aliases.ForEach(func(a PT.Alias) {
entries.Set(string(a), &M.ProxyEntry{
Alias: string(a),
Host: clientHost,
Port: defaultPort,
for _, a := range container.Aliases {
entries.Store(a, &M.ProxyEntry{
Alias: a,
Host: p.hostname,
ProxyProperties: container.ProxyProperties,
})
})
}
errors := E.NewBuilder("failed to apply label for %q", mainAlias)
errors := E.NewBuilder("failed to apply label")
for key, val := range container.Labels {
lbl, err := D.ParseLabel(key, val)
if err.IsNotNil() {
errors.Add(E.From(err).Subject(key))
continue
}
if lbl.Namespace != D.NSProxy {
continue
}
if lbl.Target == wildcardAlias {
// apply label for all aliases
entries.EachKV(func(a string, e *M.ProxyEntry) {
if err = D.ApplyLabel(e, lbl); err.IsNotNil() {
errors.Add(E.From(err).Subject(lbl.Target))
}
})
} else {
config, ok := entries.UnsafeGet(lbl.Target)
errors.Add(p.applyLabel(container, entries, key, val))
}
// selecting correct host port
if container.HostConfig.NetworkMode != "host" {
for _, a := range container.Aliases {
entry, ok := entries.Load(a)
if !ok {
errors.Add(E.NotExists("alias", lbl.Target))
continue
}
if err = D.ApplyLabel(config, lbl); err.IsNotNil() {
errors.Add(err.Subject(lbl.Target))
for _, p := range container.Ports {
containerPort := strconv.Itoa(int(p.PrivatePort))
publicPort := strconv.Itoa(int(p.PublicPort))
entryPortSplit := strings.Split(entry.Port, ":")
if len(entryPortSplit) == 2 && entryPortSplit[1] == containerPort {
entryPortSplit[1] = publicPort
} else if entryPortSplit[0] == containerPort {
entryPortSplit[0] = publicPort
}
entry.Port = strings.Join(entryPortSplit, ":")
}
}
}
entries.EachKV(func(a string, e *M.ProxyEntry) {
if e.Port == "" {
entries.UnsafeDelete(a)
}
})
return entries, errors.Build()
return entries, errors.Build().Subject(container.ContainerName)
}
func findFirstPort(c *types.Container) (string, E.NestedError) {
if len(c.Ports) == 0 {
return "", E.FailureWhy("findFirstPort", "no port exposed")
func (p *DockerProvider) applyLabel(container D.Container, entries M.ProxyEntries, key, val string) (res E.NestedError) {
b := E.NewBuilder("errors in label %s", key)
defer b.To(&res)
lbl, err := D.ParseLabel(key, val)
if err.HasError() {
b.Add(err.Subject(key))
}
for _, p := range c.Ports {
if p.PublicPort != 0 {
return fmt.Sprint(p.PublicPort), E.Nil()
if lbl.Namespace != D.NSProxy {
return
}
if lbl.Target == D.WildcardAlias {
// apply label for all aliases
entries.RangeAll(func(a string, e *M.ProxyEntry) {
if err = D.ApplyLabel(e, lbl); err.HasError() {
b.Add(err.Subject(lbl.Target))
}
})
} else {
refErr := E.NewBuilder("errors parsing alias references")
lbl.Target = AliasRefRegex.ReplaceAllStringFunc(lbl.Target, func(ref string) string {
index, err := strconv.Atoi(ref[1:])
if err != nil {
refErr.Add(E.Invalid("integer", ref))
return ref
}
if index < 1 || index > len(container.Aliases) {
refErr.Add(E.Invalid("index", ref).Extraf("index out of range"))
return ref
}
return container.Aliases[index-1]
})
if refErr.HasError() {
b.Add(refErr.Build())
return
}
config, ok := entries.Load(lbl.Target)
if !ok {
b.Add(E.NotExist("alias", lbl.Target))
return
}
if err = D.ApplyLabel(config, lbl); err.HasError() {
b.Add(err.Subject(lbl.Target))
}
}
return "", E.Failure("findFirstPort")
return
}

View File

@@ -0,0 +1,167 @@
package provider
import (
"strings"
"testing"
"github.com/docker/docker/api/types"
"github.com/yusing/go-proxy/common"
D "github.com/yusing/go-proxy/docker"
E "github.com/yusing/go-proxy/error"
F "github.com/yusing/go-proxy/utils/functional"
. "github.com/yusing/go-proxy/utils/testing"
)
func get[KT comparable, VT any](m F.Map[KT, VT], key KT) VT {
v, _ := m.Load(key)
return v
}
var dummyNames = []string{"/a"}
func TestApplyLabelFieldValidity(t *testing.T) {
pathPatterns := `
- /
- POST /upload/{$}
- GET /static
`[1:]
pathPatternsExpect := []string{
"/",
"POST /upload/{$}",
"GET /static",
}
setHeaders := `
X_Custom_Header1: value1
X_Custom_Header1: value2
X_Custom_Header2: value3
`[1:]
setHeadersExpect := map[string]string{
"X_Custom_Header1": "value1, value2",
"X_Custom_Header2": "value3",
}
hideHeaders := `
- X-Custom-Header1
- X-Custom-Header2
`[1:]
hideHeadersExpect := []string{
"X-Custom-Header1",
"X-Custom-Header2",
}
var p DockerProvider
var c = D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LableAliases: "a,b",
D.LabelIdleTimeout: common.IdleTimeoutDefault,
D.LabelStopMethod: common.StopMethodDefault,
D.LabelStopSignal: "SIGTERM",
D.LabelStopTimeout: common.StopTimeoutDefault,
D.LabelWakeTimeout: common.WakeTimeoutDefault,
"proxy.*.no_tls_verify": "true",
"proxy.*.scheme": "https",
"proxy.*.host": "app",
"proxy.*.port": "4567",
"proxy.a.no_tls_verify": "true",
"proxy.a.path_patterns": pathPatterns,
"proxy.a.set_headers": setHeaders,
"proxy.a.hide_headers": hideHeaders,
}}, "")
entries, err := p.entriesFromContainerLabels(c)
ExpectNoError(t, err.Error())
a := get(entries, "a")
b := get(entries, "b")
ExpectEqual(t, a.Scheme, "https")
ExpectEqual(t, b.Scheme, "https")
ExpectEqual(t, a.Host, "app")
ExpectEqual(t, b.Host, "app")
ExpectEqual(t, a.Port, "4567")
ExpectEqual(t, b.Port, "4567")
ExpectTrue(t, a.NoTLSVerify)
ExpectTrue(t, b.NoTLSVerify)
ExpectDeepEqual(t, a.PathPatterns, pathPatternsExpect)
ExpectEqual(t, len(b.PathPatterns), 0)
ExpectDeepEqual(t, a.SetHeaders, setHeadersExpect)
ExpectEqual(t, len(b.SetHeaders), 0)
ExpectDeepEqual(t, a.HideHeaders, hideHeadersExpect)
ExpectEqual(t, len(b.HideHeaders), 0)
ExpectEqual(t, a.IdleTimeout, common.IdleTimeoutDefault)
ExpectEqual(t, b.IdleTimeout, common.IdleTimeoutDefault)
ExpectEqual(t, a.StopTimeout, common.StopTimeoutDefault)
ExpectEqual(t, b.StopTimeout, common.StopTimeoutDefault)
ExpectEqual(t, a.StopMethod, common.StopMethodDefault)
ExpectEqual(t, b.StopMethod, common.StopMethodDefault)
ExpectEqual(t, a.WakeTimeout, common.WakeTimeoutDefault)
ExpectEqual(t, b.WakeTimeout, common.WakeTimeoutDefault)
ExpectEqual(t, a.StopSignal, "SIGTERM")
ExpectEqual(t, b.StopSignal, "SIGTERM")
}
func TestApplyLabel(t *testing.T) {
var p DockerProvider
var c = D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LableAliases: "a,b,c",
"proxy.a.no_tls_verify": "true",
"proxy.b.port": "1234",
"proxy.c.scheme": "https",
}}, "")
entries, err := p.entriesFromContainerLabels(c)
ExpectNoError(t, err.Error())
ExpectEqual(t, get(entries, "a").NoTLSVerify, true)
ExpectEqual(t, get(entries, "b").Port, "1234")
ExpectEqual(t, get(entries, "c").Scheme, "https")
}
func TestApplyLabelWithRef(t *testing.T) {
var p DockerProvider
var c = D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LableAliases: "a,b,c",
"proxy.$1.host": "localhost",
"proxy.$2.port": "1234",
"proxy.$3.scheme": "https",
}}, "")
entries, err := p.entriesFromContainerLabels(c)
ExpectNoError(t, err.Error())
ExpectEqual(t, get(entries, "a").Host, "localhost")
ExpectEqual(t, get(entries, "b").Port, "1234")
ExpectEqual(t, get(entries, "c").Scheme, "https")
}
func TestApplyLabelWithRefIndexError(t *testing.T) {
var p DockerProvider
var c = D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LableAliases: "a,b",
"proxy.$1.host": "localhost",
"proxy.$4.scheme": "https",
}}, "")
_, err := p.entriesFromContainerLabels(c)
ExpectError(t, E.ErrInvalid, err.Error())
ExpectTrue(t, strings.Contains(err.String(), "index out of range"))
c = D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LableAliases: "a,b",
"proxy.$0.host": "localhost",
}}, "")
_, err = p.entriesFromContainerLabels(c)
ExpectError(t, E.ErrInvalid, err.Error())
ExpectTrue(t, strings.Contains(err.String(), "index out of range"))
}

View File

@@ -1,12 +1,14 @@
package provider
import (
"errors"
"os"
"path"
"github.com/yusing/go-proxy/common"
E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models"
R "github.com/yusing/go-proxy/route"
U "github.com/yusing/go-proxy/utils"
W "github.com/yusing/go-proxy/watcher"
)
@@ -16,37 +18,75 @@ type FileProvider struct {
path string
}
func FileProviderImpl(filename string) ProviderImpl {
return &FileProvider{
func FileProviderImpl(filename string) (ProviderImpl, E.NestedError) {
impl := &FileProvider{
fileName: filename,
path: path.Join(common.ConfigBasePath, filename),
}
_, err := os.Stat(impl.path)
switch {
case err == nil:
return impl, nil
case errors.Is(err, os.ErrNotExist):
return nil, E.NotExist("file", impl.path)
default:
return nil, E.UnexpectedError(err)
}
}
func Validate(data []byte) E.NestedError {
return U.ValidateYaml(U.GetSchema(common.ProvidersSchemaPath), data)
}
func (p *FileProvider) String() string {
return p.fileName
func (p FileProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) {
b := E.NewBuilder("event %s error", event)
defer b.To(&res.err)
newRoutes, err := p.LoadRoutesImpl()
if err.HasError() {
b.Add(err)
return
}
routes.RangeAll(func(_ string, v R.Route) {
b.Add(v.Stop())
})
routes.Clear()
newRoutes.RangeAll(func(_ string, v R.Route) {
b.Add(v.Start())
})
routes.MergeFrom(newRoutes)
return
}
func (p *FileProvider) GetProxyEntries() (M.ProxyEntries, E.NestedError) {
func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) {
routes = R.NewRoutes()
b := E.NewBuilder("file %q validation failure", p.fileName)
defer b.To(&res)
entries := M.NewProxyEntries()
data, err := E.Check(os.ReadFile(p.path))
if err.IsNotNil() {
return entries, E.Failure("read file").Subject(p).With(err)
if err.HasError() {
b.Add(E.FailWith("read file", err))
return
}
ne := E.Failure("validation").Subject(p)
if !common.NoSchemaValidation {
if err = Validate(data); err.IsNotNil() {
return entries, ne.With(err)
if err = Validate(data); err.HasError() {
b.Add(err)
return
}
}
if err = entries.UnmarshalFromYAML(data); err.IsNotNil() {
return entries, ne.With(err)
if err = entries.UnmarshalFromYAML(data); err.HasError() {
b.Add(err)
return
}
return entries, E.Nil()
return R.FromEntries(entries)
}
func (p *FileProvider) NewWatcher() W.Watcher {

View File

@@ -4,38 +4,40 @@ import (
"context"
"fmt"
"path"
"time"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models"
R "github.com/yusing/go-proxy/route"
W "github.com/yusing/go-proxy/watcher"
)
type ProviderImpl interface {
GetProxyEntries() (M.ProxyEntries, E.NestedError)
NewWatcher() W.Watcher
}
type (
Provider struct {
ProviderImpl
type Provider struct {
ProviderImpl
name string
t ProviderType
routes R.Routes
name string
t ProviderType
routes *R.Routes
reloadReqCh chan struct{}
watcher W.Watcher
watcherCtx context.Context
watcherCancel context.CancelFunc
watcher W.Watcher
watcherCtx context.Context
watcherCancel context.CancelFunc
l *logrus.Entry
cooldownCh chan struct{}
}
type ProviderType string
l *logrus.Entry
}
ProviderImpl interface {
NewWatcher() W.Watcher
// even returns error, routes must be non-nil
LoadRoutesImpl() (R.Routes, E.NestedError)
OnEvent(event W.Event, routes R.Routes) EventResult
}
ProviderType string
EventResult struct {
nRemoved int
nAdded int
err E.NestedError
}
)
const (
ProviderTypeDocker ProviderType = "docker"
@@ -44,29 +46,33 @@ const (
func newProvider(name string, t ProviderType) *Provider {
p := &Provider{
name: name,
t: t,
routes: R.NewRoutes(),
reloadReqCh: make(chan struct{}, 1),
cooldownCh: make(chan struct{}, 1),
name: name,
t: t,
routes: R.NewRoutes(),
}
p.l = logrus.WithField("provider", p)
go p.processReloadRequests()
return p
}
func NewFileProvider(filename string) *Provider {
name := path.Base(filename)
p := newProvider(name, ProviderTypeFile)
p.ProviderImpl = FileProviderImpl(filename)
p.watcher = p.NewWatcher()
return p
}
func NewDockerProvider(name string, dockerHost string) *Provider {
p := newProvider(name, ProviderTypeDocker)
p.ProviderImpl = DockerProviderImpl(dockerHost)
func NewFileProvider(filename string) (p *Provider, err E.NestedError) {
name := path.Base(filename)
p = newProvider(name, ProviderTypeFile)
p.ProviderImpl, err = FileProviderImpl(filename)
if err != nil {
return nil, err
}
p.watcher = p.NewWatcher()
return p
return
}
func NewDockerProvider(name string, dockerHost string) (p *Provider, err E.NestedError) {
p = newProvider(name, ProviderTypeDocker)
p.ProviderImpl, err = DockerProviderImpl(dockerHost)
if err != nil {
return nil, err
}
p.watcher = p.NewWatcher()
return
}
func (p *Provider) GetName() string {
@@ -78,26 +84,22 @@ func (p *Provider) GetType() ProviderType {
}
func (p *Provider) String() string {
return fmt.Sprintf("%s: %s", p.t, p.name)
return fmt.Sprintf("%s-%s", p.t, p.name)
}
func (p *Provider) StartAllRoutes() E.NestedError {
err := p.loadRoutes()
func (p *Provider) StartAllRoutes() (res E.NestedError) {
errors := E.NewBuilder("errors in routes")
defer errors.To(&res)
// start watcher no matter load success or not
p.watcherCtx, p.watcherCancel = context.WithCancel(context.Background())
go p.watchEvents()
errors := E.NewBuilder("errors in routes")
nStarted := 0
nFailed := 0
if err.IsNotNil() {
errors.Add(err)
}
p.routes.EachKVParallel(func(alias string, r R.Route) {
if err := r.Start(); err.IsNotNil() {
p.routes.RangeAll(func(alias string, r R.Route) {
if err := r.Start(); err.HasError() {
errors.Add(err.Subject(r))
nFailed++
} else {
@@ -106,19 +108,22 @@ func (p *Provider) StartAllRoutes() E.NestedError {
})
p.l.Debugf("%d routes started, %d failed", nStarted, nFailed)
return errors.Build()
return
}
func (p *Provider) StopAllRoutes() E.NestedError {
func (p *Provider) StopAllRoutes() (res E.NestedError) {
if p.watcherCancel != nil {
p.watcherCancel()
p.watcherCancel = nil
}
errors := E.NewBuilder("errors stopping routes for provider %q", p.name)
defer errors.To(&res)
nStopped := 0
nFailed := 0
p.routes.EachKVParallel(func(alias string, r R.Route) {
if err := r.Stop(); err.IsNotNil() {
p.routes.RangeAll(func(alias string, r R.Route) {
if err := r.Stop(); err.HasError() {
errors.Add(err.Subject(r))
nFailed++
} else {
@@ -126,20 +131,22 @@ func (p *Provider) StopAllRoutes() E.NestedError {
}
})
p.l.Debugf("%d routes stopped, %d failed", nStopped, nFailed)
return errors.Build()
return
}
func (p *Provider) ReloadRoutes() {
select {
case p.reloadReqCh <- struct{}{}:
// Successfully sent reload request
default:
// Reload request already in progress, ignore this request
}
func (p *Provider) RangeRoutes(do func(string, R.Route)) {
p.routes.RangeAll(do)
}
func (p *Provider) GetCurrentRoutes() *R.Routes {
return p.routes
func (p *Provider) GetRoute(alias string) (R.Route, bool) {
return p.routes.Load(alias)
}
func (p *Provider) LoadRoutes() E.NestedError {
routes, err := p.LoadRoutesImpl()
p.routes = routes
p.l.Infof("loaded %d routes", routes.Size())
return err
}
func (p *Provider) watchEvents() {
@@ -151,11 +158,15 @@ func (p *Provider) watchEvents() {
case <-p.watcherCtx.Done():
return
case event, ok := <-events:
if !ok {
if !ok { // channel closed
return
}
l.Info(event)
p.ReloadRoutes()
res := p.OnEvent(event, p.routes)
l.Infof("%s event %q", event.Type, event)
l.Infof("%d route added, %d routes removed", res.nAdded, res.nRemoved)
if res.err.HasError() {
l.Error(res.err)
}
case err, ok := <-errs:
if !ok {
return
@@ -167,50 +178,3 @@ func (p *Provider) watchEvents() {
}
}
}
func (p *Provider) processReloadRequests() {
for range p.reloadReqCh {
// prevent busy loop caused by a container
// repeating crashing and restarting
select {
case p.cooldownCh <- struct{}{}:
p.l.Info("Starting to reload routes")
nRoutes := p.routes.Size()
p.StopAllRoutes()
p.loadRoutes()
p.StartAllRoutes()
p.l.Infof("Routes reloaded (%d -> %d)", nRoutes, p.routes.Size())
go func() {
time.Sleep(reloadCooldown)
<-p.cooldownCh
}()
default:
}
}
}
func (p *Provider) loadRoutes() E.NestedError {
entries, err := p.GetProxyEntries()
if err.IsNotNil() {
p.l.Warn(err.Subject(p))
}
p.routes = R.NewRoutes()
errors := E.NewBuilder("errors loading routes from %s", p)
entries.EachKV(func(a string, e *M.ProxyEntry) {
e.Alias = a
r, err := R.NewRoute(e)
if err.IsNotNil() {
errors.Add(err.Subject(a))
} else {
p.routes.Set(a, r)
}
})
return errors.Build()
}
const reloadCooldown = 50 * time.Millisecond

View File

@@ -207,7 +207,7 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
// }
//
// TODO: headers in ModifyResponse
func NewReverseProxy(target *url.URL, transport *http.Transport, entry *Entry) *ReverseProxy {
func NewReverseProxy(target *url.URL, transport http.RoundTripper, entry *ReverseProxyEntry) *ReverseProxy {
// check on init rather than on request
var setHeaders = func(r *http.Request) {}
var hideHeaders = func(r *http.Request) {}

View File

@@ -2,8 +2,8 @@ package route
import (
"crypto/tls"
"fmt"
"net"
"sync"
"time"
"net/http"
@@ -11,6 +11,7 @@ import (
"strings"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/docker/idlewatcher"
E "github.com/yusing/go-proxy/error"
P "github.com/yusing/go-proxy/proxy"
PT "github.com/yusing/go-proxy/proxy/fields"
@@ -23,57 +24,65 @@ type (
TargetURL *URL `json:"target_url"`
PathPatterns PT.PathPatterns `json:"path_patterns"`
entry *P.ReverseProxyEntry
mux *http.ServeMux
handler *P.ReverseProxy
regIdleWatcher func() E.NestedError
unregIdleWatcher func()
}
URL url.URL
PathKey = PT.PathPattern
SubdomainKey = PT.Alias
)
var httpRoutes = F.NewMap[SubdomainKey, *HTTPRoute]()
func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
var trans http.RoundTripper
var regIdleWatcher func() E.NestedError
var unregIdleWatcher func()
func NewHTTPRoute(entry *P.Entry) (*HTTPRoute, E.NestedError) {
var tr *http.Transport
if entry.NoTLSVerify {
tr = transportNoTLS
trans = transportNoTLS
} else {
tr = transport
trans = transport
}
rp := P.NewReverseProxy(entry.URL, tr, entry)
rp := P.NewReverseProxy(entry.URL, trans, entry)
httpRoutes.Lock()
defer httpRoutes.Unlock()
var r *HTTPRoute
r, ok := httpRoutes.UnsafeGet(entry.Alias)
if !ok {
r = &HTTPRoute{
Alias: entry.Alias,
TargetURL: (*URL)(entry.URL),
PathPatterns: entry.PathPatterns,
handler: rp,
if entry.UseIdleWatcher() {
regIdleWatcher = func() E.NestedError {
watcher, err := idlewatcher.Register(entry)
if err.HasError() {
return err
}
// patch round-tripper
rp.Transport = watcher.PatchRoundTripper(trans)
return nil
}
httpRoutes.UnsafeSet(entry.Alias, r)
}
rewrite := rp.Rewrite
if logrus.GetLevel() == logrus.DebugLevel {
l := logrus.WithField("alias", entry.Alias)
rp.Rewrite = func(pr *P.ProxyRequest) {
l.Debug("request URL: ", pr.In.Host, pr.In.URL.Path)
l.Debug("request headers: ", pr.In.Header)
rewrite(pr)
unregIdleWatcher = func() {
idlewatcher.Unregister(entry.ContainerName)
rp.Transport = trans
}
} else {
rp.Rewrite = rewrite
}
return r, E.Nil()
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
_, exists := httpRoutes.Load(entry.Alias)
if exists {
return nil, E.AlreadyExist("HTTPRoute alias", entry.Alias)
}
r := &HTTPRoute{
Alias: entry.Alias,
TargetURL: (*URL)(entry.URL),
PathPatterns: entry.PathPatterns,
entry: entry,
handler: rp,
regIdleWatcher: regIdleWatcher,
unregIdleWatcher: unregIdleWatcher,
}
return r, nil
}
func (r *HTTPRoute) String() string {
@@ -81,18 +90,35 @@ func (r *HTTPRoute) String() string {
}
func (r *HTTPRoute) Start() E.NestedError {
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
if r.regIdleWatcher != nil {
if err := r.regIdleWatcher(); err.HasError() {
return err
}
}
r.mux = http.NewServeMux()
for _, p := range r.PathPatterns {
r.mux.HandleFunc(string(p), r.handler.ServeHTTP)
}
httpRoutes.Set(r.Alias, r)
return E.Nil()
httpRoutes.Store(r.Alias, r)
return nil
}
func (r *HTTPRoute) Stop() E.NestedError {
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
if r.unregIdleWatcher != nil {
r.unregIdleWatcher()
}
r.mux = nil
httpRoutes.Delete(r.Alias)
return E.Nil()
return nil
}
func (u *URL) String() string {
@@ -104,27 +130,26 @@ func (u *URL) MarshalText() (text []byte, err error) {
}
func ProxyHandler(w http.ResponseWriter, r *http.Request) {
mux, err := findMux(r.Host, PathKey(r.URL.Path))
mux, err := findMux(r.Host)
if err != nil {
err = E.Failure("request").
Subjectf("%s %s%s", r.Method, r.Host, r.URL.Path).
With(err)
http.Error(w, err.Error(), http.StatusNotFound)
http.Error(w, err.String(), http.StatusNotFound)
logrus.Error(err)
return
}
mux.ServeHTTP(w, r)
}
func findMux(host string, path PathKey) (*http.ServeMux, error) {
func findMux(host string) (*http.ServeMux, E.NestedError) {
sd := strings.Split(host, ".")[0]
if r, ok := httpRoutes.UnsafeGet(PT.Alias(sd)); ok {
if r, ok := httpRoutes.Load(PT.Alias(sd)); ok {
return r.mux, nil
}
return nil, E.NotExists("route", fmt.Sprintf("subdomain: %s, path: %s", sd, path))
return nil, E.NotExist("route", sd)
}
// TODO: default + per proxy
var (
transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
@@ -135,10 +160,13 @@ var (
MaxIdleConns: 1000,
MaxIdleConnsPerHost: 1000,
}
transportNoTLS = func() *http.Transport {
var clone = transport.Clone()
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
return clone
}()
httpRoutes = F.NewMapOf[SubdomainKey, *HTTPRoute]()
httpRoutesMu sync.Mutex
globalMux = http.NewServeMux()
)

View File

@@ -1,6 +1,9 @@
package route
import (
"fmt"
"net/url"
E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models"
P "github.com/yusing/go-proxy/proxy"
@@ -9,27 +12,81 @@ import (
type (
Route interface {
RouteImpl
Entry() *M.ProxyEntry
Type() RouteType
URL() *url.URL
}
Routes = F.Map[string, Route]
RouteType string
RouteImpl interface {
Start() E.NestedError
Stop() E.NestedError
String() string
}
Routes = F.Map[string, Route]
route struct {
RouteImpl
type_ RouteType
entry *M.ProxyEntry
}
)
const (
RouteTypeStream RouteType = "stream"
RouteTypeReverseProxy RouteType = "reverse_proxy"
)
// function alias
var NewRoutes = F.NewMap[string, Route]
var NewRoutes = F.NewMapOf[string, Route]
func NewRoute(en *M.ProxyEntry) (Route, E.NestedError) {
entry, err := P.NewEntry(en)
if err.IsNotNil() {
rt, err := P.ValidateEntry(en)
if err.HasError() {
return nil, err
}
switch e := entry.(type) {
var t RouteType
switch e := rt.(type) {
case *P.StreamEntry:
return NewStreamRoute(e)
case *P.Entry:
return NewHTTPRoute(e)
rt, err = NewStreamRoute(e)
t = RouteTypeStream
case *P.ReverseProxyEntry:
rt, err = NewHTTPRoute(e)
t = RouteTypeReverseProxy
default:
panic("bug: should not reach here")
}
return &route{RouteImpl: rt.(RouteImpl), entry: en, type_: t}, err
}
func (rt *route) Entry() *M.ProxyEntry {
return rt.entry
}
func (rt *route) Type() RouteType {
return rt.type_
}
func (rt *route) URL() *url.URL {
url, _ := url.Parse(fmt.Sprintf("%s://%s", rt.entry.Scheme, rt.entry.Host))
return url
}
func FromEntries(entries M.ProxyEntries) (Routes, E.NestedError) {
b := E.NewBuilder("errors in routes")
routes := NewRoutes()
entries.RangeAll(func(alias string, entry *M.ProxyEntry) {
entry.Alias = alias
r, err := NewRoute(entry)
if err.HasError() {
b.Add(err.Subject(alias))
} else {
routes.Store(alias, r)
}
})
return routes, b.Build()
}

View File

@@ -1,6 +1,7 @@
package route
import (
"context"
"fmt"
"sync"
"sync/atomic"
@@ -12,11 +13,13 @@ import (
)
type StreamRoute struct {
*P.StreamEntry
P.StreamEntry
StreamImpl `json:"-"`
wg sync.WaitGroup
stopCh chan struct{}
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
connCh chan any
started atomic.Bool
l logrus.FieldLogger
@@ -35,10 +38,8 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) {
return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme))
}
base := &StreamRoute{
StreamEntry: entry,
wg: sync.WaitGroup{},
stopCh: make(chan struct{}, 1),
connCh: make(chan any),
StreamEntry: *entry,
connCh: make(chan any, 100),
}
if entry.Scheme.ListeningScheme.IsTCP() {
base.StreamImpl = NewTCPRoute(base)
@@ -46,34 +47,35 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) {
base.StreamImpl = NewUDPRoute(base)
}
base.l = logrus.WithField("route", base.StreamImpl)
return base, E.Nil()
return base, nil
}
func (r *StreamRoute) String() string {
return fmt.Sprintf("%s-stream: %s", r.Scheme, r.Alias)
return fmt.Sprintf("%s stream: %s", r.Scheme, r.Alias)
}
func (r *StreamRoute) Start() E.NestedError {
if r.started.Load() {
return E.Invalid("state", "already started")
return nil
}
r.ctx, r.cancel = context.WithCancel(context.Background())
r.wg.Wait()
if err := r.Setup(); err != nil {
return E.Failure("setup").With(err)
return E.FailWith("setup", err)
}
r.started.Store(true)
r.wg.Add(2)
go r.grAcceptConnections()
go r.grHandleConnections()
return E.Nil()
return nil
}
func (r *StreamRoute) Stop() E.NestedError {
if !r.started.Load() {
return E.Invalid("state", "not started")
return nil
}
l := r.l
close(r.stopCh)
r.cancel()
r.CloseListeners()
done := make(chan struct{}, 1)
@@ -82,13 +84,16 @@ func (r *StreamRoute) Stop() E.NestedError {
close(done)
}()
select {
case <-done:
l.Info("stopped listening")
case <-time.After(streamStopListenTimeout):
l.Error("timed out waiting for connections")
timeout := time.After(streamStopListenTimeout)
for {
select {
case <-done:
l.Debug("stopped listening")
return nil
case <-timeout:
return E.FailedWhy("stop", "timed out")
}
}
return E.Nil()
}
func (r *StreamRoute) grAcceptConnections() {
@@ -96,13 +101,13 @@ func (r *StreamRoute) grAcceptConnections() {
for {
select {
case <-r.stopCh:
case <-r.ctx.Done():
return
default:
conn, err := r.Accept()
if err != nil {
select {
case <-r.stopCh:
case <-r.ctx.Done():
return
default:
r.l.Error(err)
@@ -119,7 +124,7 @@ func (r *StreamRoute) grHandleConnections() {
for {
select {
case <-r.stopCh:
case <-r.ctx.Done():
return
case conn := <-r.connCh:
go func() {

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"sync"
"syscall"
"time"
U "github.com/yusing/go-proxy/utils"
@@ -24,7 +25,6 @@ type TCPRoute struct {
func NewTCPRoute(base *StreamRoute) StreamImpl {
return &TCPRoute{
StreamRoute: base,
listener: nil,
pipe: make(Pipes, 0),
}
}
@@ -38,16 +38,16 @@ func (route *TCPRoute) Setup() error {
return nil
}
func (route *TCPRoute) Accept() (interface{}, error) {
func (route *TCPRoute) Accept() (any, error) {
return route.listener.Accept()
}
func (route *TCPRoute) Handle(c interface{}) error {
func (route *TCPRoute) Handle(c any) error {
clientConn := c.(net.Conn)
defer clientConn.Close()
ctx, cancel := context.WithTimeout(context.Background(), tcpDialTimeout)
ctx, cancel := context.WithTimeout(route.ctx, tcpDialTimeout)
defer cancel()
serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
@@ -58,16 +58,11 @@ func (route *TCPRoute) Handle(c interface{}) error {
return err
}
pipeCtx, pipeCancel := context.WithCancel(context.Background())
go func() {
<-route.stopCh
pipeCancel()
}()
route.mu.Lock()
pipe := U.NewBidirectionalPipe(pipeCtx, clientConn, serverConn)
defer route.mu.Unlock()
pipe := U.NewBidirectionalPipe(route.ctx, clientConn, serverConn)
route.pipe = append(route.pipe, pipe)
route.mu.Unlock()
return pipe.Start()
}
@@ -78,8 +73,15 @@ func (route *TCPRoute) CloseListeners() {
route.listener.Close()
route.listener = nil
for _, pipe := range route.pipe {
if err := pipe.Stop(); err.IsNotNil() {
route.l.Error(err)
if err := pipe.Stop(); err != nil {
switch err {
// target closing connection
// TODO: handle this by fixing utils/io.go
case net.ErrClosed, syscall.EPIPE:
return
default:
route.l.Error(err)
}
}
}
}

View File

@@ -1,7 +1,6 @@
package route
import (
"context"
"fmt"
"io"
"net"
@@ -36,7 +35,7 @@ func NewUDPRoute(base *StreamRoute) StreamImpl {
}
func (route *UDPRoute) Setup() error {
laddr, err := net.ResolveUDPAddr(string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ProxyPort))
laddr, err := net.ResolveUDPAddr(string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ListeningPort))
if err != nil {
return err
}
@@ -55,7 +54,7 @@ func (route *UDPRoute) Setup() error {
return nil
}
func (route *UDPRoute) Accept() (interface{}, error) {
func (route *UDPRoute) Accept() (any, error) {
in := route.listeningConn
buffer := make([]byte, udpBufferSize)
@@ -84,15 +83,10 @@ func (route *UDPRoute) Accept() (interface{}, error) {
srcConn.Close()
return nil, err
}
pipeCtx, pipeCancel := context.WithCancel(context.Background())
go func() {
<-route.stopCh
pipeCancel()
}()
conn = &UDPConn{
srcConn,
dstConn,
utils.NewBidirectionalPipe(pipeCtx, sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}),
utils.NewBidirectionalPipe(route.ctx, sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}),
}
route.connMap[key] = conn
}
@@ -103,7 +97,7 @@ func (route *UDPRoute) Accept() (interface{}, error) {
return conn, err
}
func (route *UDPRoute) Handle(c interface{}) error {
func (route *UDPRoute) Handle(c any) error {
return c.(*UDPConn).Start()
}

View File

@@ -20,16 +20,16 @@ func FormatDuration(d time.Duration) string {
var parts []string
if days > 0 {
parts = append(parts, fmt.Sprintf("%d Day%s", days, pluralize(days)))
parts = append(parts, fmt.Sprintf("%d day%s", days, pluralize(days)))
}
if hours > 0 {
parts = append(parts, fmt.Sprintf("%d Hour%s", hours, pluralize(hours)))
parts = append(parts, fmt.Sprintf("%d hour%s", hours, pluralize(hours)))
}
if minutes > 0 {
parts = append(parts, fmt.Sprintf("%d Minute%s", minutes, pluralize(minutes)))
parts = append(parts, fmt.Sprintf("%d minute%s", minutes, pluralize(minutes)))
}
if seconds > 0 {
parts = append(parts, fmt.Sprintf("%d Second%s", seconds, pluralize(seconds)))
parts = append(parts, fmt.Sprintf("%d second%s", seconds, pluralize(seconds)))
}
// Join the parts with appropriate connectors
@@ -42,6 +42,15 @@ func FormatDuration(d time.Duration) string {
return strings.Join(parts[:len(parts)-1], ", ") + " and " + parts[len(parts)-1]
}
func ParseBool(s string) bool {
switch strings.ToLower(s) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
func pluralize(n int64) string {
if n > 1 {
return "s"

View File

@@ -1,15 +0,0 @@
package utils
import (
"os"
"path"
)
func FileOK(p string) bool {
_, err := os.Stat(p)
return err == nil
}
func FileName(p string) string {
return path.Base(p)
}

View File

@@ -2,25 +2,25 @@ package functional
import "sync"
func ForEachKey[K comparable, V interface{}](obj map[K]V, do func(K)) {
func ForEachKey[K comparable, V any](obj map[K]V, do func(K)) {
for k := range obj {
do(k)
}
}
func ForEachValue[K comparable, V interface{}](obj map[K]V, do func(V)) {
func ForEachValue[K comparable, V any](obj map[K]V, do func(V)) {
for _, v := range obj {
do(v)
}
}
func ForEachKV[K comparable, V interface{}](obj map[K]V, do func(K, V)) {
func ForEachKV[K comparable, V any](obj map[K]V, do func(K, V)) {
for k, v := range obj {
do(k, v)
}
}
func ParallelForEach[T interface{}](obj []T, do func(T)) {
func ParallelForEach[T any](obj []T, do func(T)) {
var wg sync.WaitGroup
wg.Add(len(obj))
for _, v := range obj {
@@ -32,7 +32,7 @@ func ParallelForEach[T interface{}](obj []T, do func(T)) {
wg.Wait()
}
func ParallelForEachKey[K comparable, V interface{}](obj map[K]V, do func(K)) {
func ParallelForEachKey[K comparable, V any](obj map[K]V, do func(K)) {
var wg sync.WaitGroup
wg.Add(len(obj))
for k := range obj {
@@ -44,7 +44,7 @@ func ParallelForEachKey[K comparable, V interface{}](obj map[K]V, do func(K)) {
wg.Wait()
}
func ParallelForEachValue[K comparable, V interface{}](obj map[K]V, do func(V)) {
func ParallelForEachValue[K comparable, V any](obj map[K]V, do func(V)) {
var wg sync.WaitGroup
wg.Add(len(obj))
for _, v := range obj {
@@ -56,7 +56,7 @@ func ParallelForEachValue[K comparable, V interface{}](obj map[K]V, do func(V))
wg.Wait()
}
func ParallelForEachKV[K comparable, V interface{}](obj map[K]V, do func(K, V)) {
func ParallelForEachKV[K comparable, V any](obj map[K]V, do func(K, V)) {
var wg sync.WaitGroup
wg.Add(len(obj))
for k, v := range obj {

View File

@@ -1,229 +1,116 @@
package functional
import (
"context"
"sync"
"github.com/puzpuzpuz/xsync/v3"
"gopkg.in/yaml.v3"
E "github.com/yusing/go-proxy/error"
)
type Map[KT comparable, VT interface{}] struct {
m map[KT]VT
defVals map[KT]VT
sync.RWMutex
type Map[KT comparable, VT any] struct {
*xsync.MapOf[KT, VT]
}
// NewMap creates a new Map with the given map as its initial values.
//
// Parameters:
// - dv: optional default values for the Map
//
// Return:
// - *Map[KT, VT]: a pointer to the newly created Map.
func NewMap[KT comparable, VT interface{}](dv ...map[KT]VT) *Map[KT, VT] {
return NewMapFrom(make(map[KT]VT), dv...)
func NewMapOf[KT comparable, VT any](options ...func(*xsync.MapConfig)) Map[KT, VT] {
return Map[KT, VT]{xsync.NewMapOf[KT, VT](options...)}
}
// NewMapOf creates a new Map with the given map as its initial values.
//
// Type parameters:
// - M: type for the new map.
//
// Parameters:
// - dv: optional default values for the Map
//
// Return:
// - *Map[KT, VT]: a pointer to the newly created Map.
func NewMapOf[M Map[KT, VT], KT comparable, VT interface{}](dv ...map[KT]VT) *Map[KT, VT] {
return NewMapFrom(make(map[KT]VT), dv...)
}
// NewMapFrom creates a new Map with the given map as its initial values.
//
// Parameters:
// - from: a map of type KT to VT, which will be the initial values of the Map.
// - dv: optional default values for the Map
//
// Return:
// - *Map[KT, VT]: a pointer to the newly created Map.
func NewMapFrom[KT comparable, VT interface{}](from map[KT]VT, dv ...map[KT]VT) *Map[KT, VT] {
if len(dv) > 0 {
return &Map[KT, VT]{m: from, defVals: dv[0]}
func NewMapFrom[KT comparable, VT any](m map[KT]VT) (res Map[KT, VT]) {
res = NewMapOf[KT, VT](xsync.WithPresize(len(m)))
for k, v := range m {
res.Store(k, v)
}
return &Map[KT, VT]{m: from}
return
}
func (m *Map[KT, VT]) Set(key KT, value VT) {
m.Lock()
m.m[key] = value
m.Unlock()
}
func MapFind[KT comparable, VT, CT any](m Map[KT, VT], criteria func(VT) (CT, bool)) (_ CT) {
result := make(chan CT, 1)
func (m *Map[KT, VT]) Get(key KT) VT {
m.RLock()
defer m.RUnlock()
value, ok := m.m[key]
if !ok && m.defVals != nil {
return m.defVals[key]
}
return value
}
// Find searches for the first element in the map that satisfies the given criteria.
//
// Parameters:
// - criteria: a function that takes a value of type VT and returns a tuple of any type and a boolean.
//
// Return:
// - any: the first value that satisfies the criteria, or nil if no match is found.
func (m *Map[KT, VT]) Find(criteria func(VT) (any, bool)) any {
m.RLock()
defer m.RUnlock()
result := make(chan any)
wg := sync.WaitGroup{}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for _, v := range m.m {
wg.Add(1)
go func(val VT) {
defer wg.Done()
if value, ok := criteria(val); ok {
select {
case result <- value:
cancel() // Cancel other goroutines if a result is found
case <-ctx.Done(): // If already cancelled
return
}
m.Range(func(key KT, value VT) bool {
select {
case <-result: // already have a result
return false // stop iteration
default:
if got, ok := criteria(value); ok {
result <- got
return false
}
}(v)
}
go func() {
wg.Wait()
close(result)
}()
// The first valid match, if any
select {
case res, ok := <-result:
if ok {
return res
return true
}
case <-ctx.Done():
})
select {
case v := <-result:
return v
default:
return
}
return nil // Return nil if no matches found
}
func (m *Map[KT, VT]) UnsafeGet(key KT) (VT, bool) {
value, ok := m.m[key]
return value, ok
}
func (m *Map[KT, VT]) UnsafeSet(key KT, value VT) {
m.m[key] = value
}
func (m *Map[KT, VT]) Delete(key KT) {
m.Lock()
delete(m.m, key)
m.Unlock()
}
func (m *Map[KT, VT]) UnsafeDelete(key KT) {
delete(m.m, key)
}
// MergeWith merges the contents of another Map[KT, VT]
// into the current Map[KT, VT] and
// returns a map that were duplicated.
// MergeFrom add contents from another `Map`, ignore duplicated keys
//
// Parameters:
// - other: a pointer to another Map[KT, VT] to be merged into the current Map[KT, VT].
// - other: `Map` of values to add from
//
// Return:
// - Map[KT, VT]: a map of key-value pairs that were duplicated during the merge.
func (m *Map[KT, VT]) MergeWith(other *Map[KT, VT]) Map[KT, VT] {
dups := make(map[KT]VT)
// - Map: a `Map` of duplicated keys-value pairs
func (m Map[KT, VT]) MergeFrom(other Map[KT, VT]) Map[KT, VT] {
dups := NewMapOf[KT, VT]()
m.Lock()
for k, v := range other.m {
if _, isDup := m.m[k]; !isDup {
m.m[k] = v
other.Range(func(k KT, v VT) bool {
if _, ok := m.Load(k); ok {
dups.Store(k, v)
} else {
dups[k] = v
m.Store(k, v)
}
}
m.Unlock()
return Map[KT, VT]{m: dups}
return true
})
return dups
}
func (m *Map[KT, VT]) Clear() {
m.Lock()
m.m = make(map[KT]VT)
m.Unlock()
func (m Map[KT, VT]) RangeAll(do func(k KT, v VT)) {
m.Range(func(k KT, v VT) bool {
do(k, v)
return true
})
}
func (m *Map[KT, VT]) Size() int {
m.RLock()
defer m.RUnlock()
return len(m.m)
func (m Map[KT, VT]) RemoveAll(criteria func(VT) bool) {
m.Range(func(k KT, v VT) bool {
if criteria(v) {
m.Delete(k)
}
return true
})
}
func (m *Map[KT, VT]) Contains(key KT) bool {
m.RLock()
_, ok := m.m[key]
m.RUnlock()
func (m Map[KT, VT]) Has(k KT) bool {
_, ok := m.Load(k)
return ok
}
func (m *Map[KT, VT]) Clone() *Map[KT, VT] {
m.RLock()
defer m.RUnlock()
clone := make(map[KT]VT, len(m.m))
for k, v := range m.m {
clone[k] = v
func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) E.NestedError {
if m.Size() != 0 {
return E.FailedWhy("unmarshal from yaml", "map is not empty")
}
return &Map[KT, VT]{m: clone, defVals: m.defVals}
}
func (m *Map[KT, VT]) EachKV(fn func(k KT, v VT)) {
m.Lock()
for k, v := range m.m {
fn(k, v)
tmp := make(map[KT]VT)
if err := E.From(yaml.Unmarshal(data, tmp)); err.HasError() {
return err
}
m.Unlock()
}
func (m *Map[KT, VT]) Each(fn func(v VT)) {
m.Lock()
for _, v := range m.m {
fn(v)
for k, v := range tmp {
m.Store(k, v)
}
m.Unlock()
return nil
}
func (m *Map[KT, VT]) EachParallel(fn func(v VT)) {
m.Lock()
ParallelForEachValue(m.m, fn)
m.Unlock()
}
func (m *Map[KT, VT]) EachKVParallel(fn func(k KT, v VT)) {
m.Lock()
ParallelForEachKV(m.m, fn)
m.Unlock()
}
func (m *Map[KT, VT]) UnmarshalFromYAML(data []byte) E.NestedError {
return E.From(yaml.Unmarshal(data, m.m))
}
func (m *Map[KT, VT]) Iterator() map[KT]VT {
return m.m
func (m Map[KT, VT]) String() string {
tmp := make(map[KT]VT, m.Size())
m.RangeAll(func(k KT, v VT) {
tmp[k] = v
})
data, err := yaml.Marshal(tmp)
if err != nil {
return err.Error()
}
return string(data)
}

View File

@@ -0,0 +1,75 @@
package functional_test
import (
"testing"
. "github.com/yusing/go-proxy/utils/functional"
. "github.com/yusing/go-proxy/utils/testing"
)
func TestNewMapFrom(t *testing.T) {
m := NewMapFrom(map[string]int{
"a": 1,
"b": 2,
"c": 3,
})
ExpectEqual(t, m.Size(), 3)
ExpectTrue(t, m.Has("a"))
ExpectTrue(t, m.Has("b"))
ExpectTrue(t, m.Has("c"))
}
func TestMapFind(t *testing.T) {
m := NewMapFrom(map[string]map[string]int{
"a": {
"a": 1,
},
"b": {
"a": 1,
"b": 2,
},
"c": {
"b": 2,
"c": 3,
},
})
res := MapFind(m, func(inner map[string]int) (int, bool) {
if _, ok := inner["c"]; ok && inner["c"] == 3 {
return inner["c"], true
}
return 0, false
})
ExpectEqual(t, res, 3)
}
func TestMergeFrom(t *testing.T) {
m1 := NewMapFrom(map[string]int{
"a": 1,
"b": 2,
"c": 3,
"d": 4,
})
m2 := NewMapFrom(map[string]int{
"a": 1,
"c": 123,
"e": 456,
"f": 6,
})
dup := m1.MergeFrom(m2)
ExpectEqual(t, m1.Size(), 6)
ExpectTrue(t, m1.Has("e"))
ExpectTrue(t, m1.Has("f"))
c, _ := m1.Load("c")
d, _ := m1.Load("d")
e, _ := m1.Load("e")
f, _ := m1.Load("f")
ExpectEqual(t, c, 3)
ExpectEqual(t, d, 4)
ExpectEqual(t, e, 456)
ExpectEqual(t, f, 6)
ExpectEqual(t, dup.Size(), 2)
ExpectTrue(t, dup.Has("a"))
ExpectTrue(t, dup.Has("c"))
}

View File

@@ -2,6 +2,7 @@ package utils
import (
"context"
"encoding/json"
"io"
"os"
"sync/atomic"
@@ -9,15 +10,8 @@ import (
E "github.com/yusing/go-proxy/error"
)
// TODO: move to "utils/io"
type (
Reader interface {
Read() ([]byte, E.NestedError)
}
StdReader struct {
r Reader
}
FileReader struct {
Path string
}
@@ -28,13 +22,6 @@ type (
closed atomic.Bool
}
StdReadCloser struct {
r *ReadCloser
}
ByteReader []byte
NewByteReader = ByteReader
Pipe struct {
r ReadCloser
w io.WriteCloser
@@ -43,49 +30,25 @@ type (
}
BidirectionalPipe struct {
pSrcDst Pipe
pDstSrc Pipe
pSrcDst *Pipe
pDstSrc *Pipe
}
)
func NewFileReader(path string) *FileReader {
return &FileReader{Path: path}
}
func (r StdReader) Read() ([]byte, error) {
return r.r.Read()
}
func (r *FileReader) Read() ([]byte, E.NestedError) {
return E.Check(os.ReadFile(r.Path))
}
func (r ByteReader) Read() ([]byte, E.NestedError) {
return r, E.Nil()
}
func (r *ReadCloser) Read(p []byte) (int, E.NestedError) {
func (r *ReadCloser) Read(p []byte) (int, error) {
select {
case <-r.ctx.Done():
return 0, E.From(r.ctx.Err())
return 0, r.ctx.Err()
default:
return E.Check(r.r.Read(p))
return r.r.Read(p)
}
}
func (r *ReadCloser) Close() E.NestedError {
func (r *ReadCloser) Close() error {
if r.closed.Load() {
return E.Nil()
return nil
}
r.closed.Store(true)
return E.From(r.r.Close())
}
func (r StdReadCloser) Read(p []byte) (int, error) {
return r.r.Read(p)
}
func (r StdReadCloser) Close() error {
return r.r.Close()
}
@@ -99,35 +62,35 @@ func NewPipe(ctx context.Context, r io.ReadCloser, w io.WriteCloser) *Pipe {
}
}
func (p *Pipe) Start() E.NestedError {
return Copy(p.ctx, p.w, &StdReadCloser{&p.r})
func (p *Pipe) Start() error {
return Copy(p.ctx, p.w, &p.r)
}
func (p *Pipe) Stop() E.NestedError {
func (p *Pipe) Stop() error {
p.cancel()
return E.Join("error stopping pipe", p.r.Close(), p.w.Close())
return E.JoinE("error stopping pipe", p.r.Close(), p.w.Close()).Error()
}
func (p *Pipe) Write(b []byte) (int, E.NestedError) {
return E.Check(p.w.Write(b))
func (p *Pipe) Write(b []byte) (int, error) {
return p.w.Write(b)
}
func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.ReadWriteCloser) *BidirectionalPipe {
return &BidirectionalPipe{
pSrcDst: *NewPipe(ctx, rw1, rw2),
pDstSrc: *NewPipe(ctx, rw2, rw1),
pSrcDst: NewPipe(ctx, rw1, rw2),
pDstSrc: NewPipe(ctx, rw2, rw1),
}
}
func NewBidirectionalPipeIntermediate(ctx context.Context, listener io.ReadCloser, client io.ReadWriteCloser, target io.ReadWriteCloser) *BidirectionalPipe {
return &BidirectionalPipe{
pSrcDst: *NewPipe(ctx, listener, client),
pDstSrc: *NewPipe(ctx, client, target),
pSrcDst: NewPipe(ctx, listener, client),
pDstSrc: NewPipe(ctx, client, target),
}
}
func (p *BidirectionalPipe) Start() E.NestedError {
errCh := make(chan E.NestedError, 2)
func (p *BidirectionalPipe) Start() error {
errCh := make(chan error, 2)
go func() {
errCh <- p.pSrcDst.Start()
}()
@@ -135,18 +98,34 @@ func (p *BidirectionalPipe) Start() E.NestedError {
errCh <- p.pDstSrc.Start()
}()
for err := range errCh {
if err.IsNotNil() {
if err != nil {
return err
}
}
return E.Nil()
return nil
}
func (p *BidirectionalPipe) Stop() E.NestedError {
return E.Join("error stopping pipe", p.pSrcDst.Stop(), p.pDstSrc.Stop())
func (p *BidirectionalPipe) Stop() error {
return E.JoinE("error stopping pipe", p.pSrcDst.Stop(), p.pDstSrc.Stop()).Error()
}
func Copy(ctx context.Context, dst io.WriteCloser, src io.ReadCloser) E.NestedError {
_, err := io.Copy(dst, StdReadCloser{&ReadCloser{ctx: ctx, r: src}})
return E.From(err)
}
func Copy(ctx context.Context, dst io.WriteCloser, src io.ReadCloser) error {
_, err := io.Copy(dst, &ReadCloser{ctx: ctx, r: src})
return err
}
func LoadJson[T any](path string, pointer *T) E.NestedError {
data, err := E.Check(os.ReadFile(path))
if err.HasError() {
return err
}
return E.From(json.Unmarshal(data, pointer))
}
func SaveJson[T any](path string, pointer *T, perm os.FileMode) E.NestedError {
data, err := E.Check(json.Marshal(pointer))
if err.HasError() {
return err
}
return E.From(os.WriteFile(path, data, perm))
}

View File

@@ -1,24 +0,0 @@
package utils
import (
"net/http"
"reflect"
"strings"
E "github.com/yusing/go-proxy/error"
)
func snakeToPascal(s string) string {
toHyphenCamel := http.CanonicalHeaderKey(strings.ReplaceAll(s, "_", "-"))
return strings.ReplaceAll(toHyphenCamel, "-", "")
}
func SetFieldFromSnake[T, VT any](obj *T, field string, value VT) E.NestedError {
field = snakeToPascal(field)
prop := reflect.ValueOf(obj).Elem().FieldByName(field)
if prop.Kind() == 0 {
return E.Invalid("field", field)
}
prop.Set(reflect.ValueOf(value))
return E.Nil()
}

View File

@@ -2,7 +2,6 @@ package utils
import (
"github.com/santhosh-tekuri/jsonschema"
"github.com/yusing/go-proxy/common"
)
var schemaCompiler = func() *jsonschema.Compiler {
@@ -11,16 +10,13 @@ var schemaCompiler = func() *jsonschema.Compiler {
return c
}()
var schemaStorage = make(map[string] *jsonschema.Schema)
var schemaStorage = make(map[string]*jsonschema.Schema)
func GetSchema(path string) *jsonschema.Schema {
if common.NoSchemaValidation {
panic("bug: GetSchema called when schema validation disabled")
}
if schema, ok := schemaStorage[path]; ok {
return schema
}
schema := schemaCompiler.MustCompile(path)
schemaStorage[path] = schema
return schema
}
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"reflect"
"strings"
"github.com/santhosh-tekuri/jsonschema"
E "github.com/yusing/go-proxy/error"
@@ -12,50 +13,31 @@ import (
)
func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError {
var i interface{}
var i any
err := yaml.Unmarshal(data, &i)
if err != nil {
return E.Failure("unmarshal yaml").With(err)
return E.FailWith("unmarshal yaml", err)
}
m, err := json.Marshal(i)
if err != nil {
return E.Failure("marshal json").With(err)
return E.FailWith("marshal json", err)
}
err = schema.Validate(bytes.NewReader(m))
if err == nil {
return E.Nil()
return nil
}
errors := E.NewBuilder("yaml validation error")
for _, e := range err.(*jsonschema.ValidationError).Causes {
errors.Add(e)
errors.AddE(e)
}
return errors.Build()
}
// TryJsonStringify converts the given object to a JSON string.
//
// It takes an object of any type and attempts to marshal it into a JSON string.
// If the marshaling is successful, the JSON string is returned.
// If the marshaling fails, the object is converted to a string using fmt.Sprint and returned.
//
// Parameters:
// - o: The object to be converted to a JSON string.
//
// Return type:
// - string: The JSON string representation of the object.
func TryJsonStringify(o any) string {
b, err := json.Marshal(o)
if err != nil {
return fmt.Sprint(o)
}
return string(b)
}
// Serialize converts the given data into a map[string]interface{} representation.
// Serialize converts the given data into a map[string]any representation.
//
// It uses reflection to inspect the data type and handle different kinds of data.
// For a struct, it extracts the fields using the json tag if present, or the field name if not.
@@ -66,9 +48,9 @@ func TryJsonStringify(o any) string {
// - data: The data to be converted into a map.
//
// Returns:
// - result: The resulting map[string]interface{} representation of the data.
// - result: The resulting map[string]any representation of the data.
// - error: An error if the data type is unsupported or if there is an error during conversion.
func Serialize(data interface{}) (SerializedObject, error) {
func Serialize(data any) (SerializedObject, E.NestedError) {
result := make(map[string]any)
// Use reflection to inspect the data type
@@ -76,7 +58,7 @@ func Serialize(data interface{}) (SerializedObject, error) {
// Check if the value is valid
if !value.IsValid() {
return nil, fmt.Errorf("invalid data")
return nil, E.Invalid("data", fmt.Sprintf("type: %T", data))
}
// Dereference pointers if necessary
@@ -107,7 +89,7 @@ func Serialize(data interface{}) (SerializedObject, error) {
} else if field.Anonymous {
// If the field is an embedded struct, add its fields to the result
fieldMap, err := Serialize(value.Field(i).Interface())
if err != nil {
if err.HasError() {
return nil, err
}
for k, v := range fieldMap {
@@ -118,10 +100,80 @@ func Serialize(data interface{}) (SerializedObject, error) {
}
}
default:
return nil, fmt.Errorf("unsupported type: %s", value.Kind())
// return nil, fmt.Errorf("unsupported type: %s", value.Kind())
return nil, E.Unsupported("type", value.Kind())
}
return result, nil
}
type SerializedObject map[string]any
func Deserialize(src map[string]any, target any) E.NestedError {
// convert data fields to lower no-snake
// convert target fields to lower
// then check if the field of data is in the target
mapping := make(map[string]string)
t := reflect.TypeOf(target).Elem()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
snakeCaseField := strings.ToLower(field.Name)
mapping[snakeCaseField] = field.Name
}
for k, v := range src {
kCleaned := toLowerNoSnake(k)
if fieldName, ok := mapping[kCleaned]; ok {
prop := reflect.ValueOf(target).Elem().FieldByName(fieldName)
propType := prop.Type()
isPtr := prop.Kind() == reflect.Ptr
if prop.CanSet() {
val := reflect.ValueOf(v)
vType := val.Type()
switch {
case isPtr && vType.ConvertibleTo(propType.Elem()):
ptr := reflect.New(propType.Elem())
ptr.Elem().Set(val.Convert(propType.Elem()))
prop.Set(ptr)
case vType.ConvertibleTo(propType):
prop.Set(val.Convert(propType))
case isPtr:
var vSerialized SerializedObject
vSerialized, ok = v.(SerializedObject)
if !ok {
if vType.ConvertibleTo(reflect.TypeFor[SerializedObject]()) {
vSerialized = val.Convert(reflect.TypeFor[SerializedObject]()).Interface().(SerializedObject)
} else {
return E.Failure(fmt.Sprintf("convert %s (%T) to %s", k, v, reflect.TypeFor[SerializedObject]()))
}
}
propNew := reflect.New(propType.Elem())
err := Deserialize(vSerialized, propNew.Interface())
if err.HasError() {
return E.Failure("set field").With(err).Subject(k)
}
prop.Set(propNew)
default:
return E.Unsupported("field", k).Extraf("type=%s", propType)
}
} else {
return E.Unsupported("field", k).Extraf("type=%s", propType)
}
} else {
return E.Failure("unknown field").With(k)
}
}
return nil
}
func DeserializeJson(j map[string]string, target any) E.NestedError {
data, err := E.Check(json.Marshal(j))
if err.HasError() {
return err
}
return E.From(json.Unmarshal(data, target))
}
func toLowerNoSnake(s string) string {
return strings.ToLower(strings.ReplaceAll(s, "_", ""))
}
type SerializedObject = map[string]any

11
src/utils/string.go Normal file
View File

@@ -0,0 +1,11 @@
package utils
import "strings"
func CommaSeperatedList(s string) []string {
res := strings.Split(s, ",")
for i, part := range res {
res[i] = strings.TrimSpace(part)
}
return res
}

View File

@@ -0,0 +1,59 @@
package utils
import (
"errors"
"reflect"
"testing"
)
func ExpectNoError(t *testing.T, err error) {
t.Helper()
if err != nil && !reflect.ValueOf(err).IsNil() {
t.Errorf("expected err=nil, got %s", err.Error())
}
}
func ExpectError(t *testing.T, expected error, err error) {
t.Helper()
if !errors.Is(err, expected) {
t.Errorf("expected err %s, got nil", expected.Error())
}
}
func ExpectEqual[T comparable](t *testing.T, got T, want T) {
t.Helper()
if got != want {
t.Errorf("expected:\n%v, got\n%v", want, got)
}
}
func ExpectDeepEqual[T any](t *testing.T, got T, want T) {
t.Helper()
if !reflect.DeepEqual(got, want) {
t.Errorf("expected:\n%v, got\n%v", want, got)
}
}
func ExpectTrue(t *testing.T, got bool) {
t.Helper()
if !got {
t.Errorf("expected true, got false")
}
}
func ExpectFalse(t *testing.T, got bool) {
t.Helper()
if got {
t.Errorf("expected false, got true")
}
}
func ExpectType[T any](t *testing.T, got any) T {
t.Helper()
tExpect := reflect.TypeFor[T]()
_, ok := got.(T)
if !ok {
t.Errorf("expected type %s, got %T", tExpect, got)
}
return got.(T)
}

View File

@@ -2,70 +2,103 @@ package watcher
import (
"context"
"fmt"
"time"
"github.com/docker/docker/api/types/events"
docker_events "github.com/docker/docker/api/types/events"
"github.com/docker/docker/api/types/filters"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/docker"
E "github.com/yusing/go-proxy/error"
"github.com/yusing/go-proxy/watcher/events"
)
type DockerWatcher struct {
host string
type (
DockerWatcher struct {
host string
client D.Client
logrus.FieldLogger
}
DockerListOptions = docker_events.ListOptions
)
// https://docs.docker.com/reference/api/engine/version/v1.47/#tag/System/operation/SystemPingHead
var (
DockerFilterContainer = filters.Arg("type", string(docker_events.ContainerEventType))
DockerFilterStart = filters.Arg("event", string(docker_events.ActionStart))
DockerFilterStop = filters.Arg("event", string(docker_events.ActionStop))
DockerFilterDie = filters.Arg("event", string(docker_events.ActionDie))
DockerFilterKill = filters.Arg("event", string(docker_events.ActionKill))
DockerFilterPause = filters.Arg("event", string(docker_events.ActionPause))
DockerFilterUnpause = filters.Arg("event", string(docker_events.ActionUnPause))
NewDockerFilter = filters.NewArgs
)
func DockerrFilterContainerName(name string) filters.KeyValuePair {
return filters.Arg("container", name)
}
func NewDockerWatcher(host string) *DockerWatcher {
return &DockerWatcher{host: host}
func NewDockerWatcher(host string) DockerWatcher {
return DockerWatcher{host: host, FieldLogger: logrus.WithField("module", "docker_watcher")}
}
func (w *DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.NestedError) {
func NewDockerWatcherWithClient(client D.Client) DockerWatcher {
return DockerWatcher{client: client, FieldLogger: logrus.WithField("module", "docker_watcher")}
}
func (w DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.NestedError) {
return w.EventsWithOptions(ctx, optionsWatchAll)
}
func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerListOptions) (<-chan Event, <-chan E.NestedError) {
eventCh := make(chan Event)
errCh := make(chan E.NestedError)
started := make(chan struct{})
go func() {
defer close(eventCh)
defer close(errCh)
var cl D.Client
var err E.NestedError
for range 3 {
cl, err = D.ConnectClient(w.host)
if err.IsNil() {
break
if !w.client.Connected() {
var err E.NestedError
for range 3 {
w.client, err = D.ConnectClient(w.host)
if err != nil {
defer w.client.Close()
break
}
time.Sleep(1 * time.Second)
}
if err.HasError() {
errCh <- E.FailWith("docker connection", err)
return
}
errCh <- E.From(err)
time.Sleep(1 * time.Second)
}
if err.IsNotNil() {
errCh <- E.Failure("connecting to docker")
return
}
cEventCh, cErrCh := cl.Events(ctx, dwOptions)
cEventCh, cErrCh := w.client.Events(ctx, options)
started <- struct{}{}
for {
select {
case <-ctx.Done():
if err := <-cErrCh; err != nil {
errCh <- E.From(err)
if err := E.From(ctx.Err()); err != nil && err.IsNot(context.Canceled) {
errCh <- err
}
return
case msg := <-cEventCh:
var Action Action
switch msg.Action {
case events.ActionStart:
Action = ActionCreated
case events.ActionDie:
Action = ActionDeleted
default: // NOTE: should not happen
Action = ActionModified
action, ok := events.DockerEventMap[msg.Action]
if !ok {
w.Debugf("ignored unknown docker event: %s for container %s", msg.Action, msg.Actor.Attributes["name"])
continue
}
eventCh <- Event{
ActorName: fmt.Sprintf("container %q", msg.Actor.Attributes["name"]),
Action: Action,
event := Event{
Type: events.EventTypeDocker,
ActorID: msg.Actor.ID,
ActorAttributes: msg.Actor.Attributes, // labels
ActorName: msg.Actor.Attributes["name"],
Action: action,
}
eventCh <- event
case err := <-cErrCh:
if err == nil {
continue
@@ -77,7 +110,7 @@ func (w *DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Nest
default:
if D.IsErrConnectionFailed(err) {
time.Sleep(100 * time.Millisecond)
cEventCh, cErrCh = cl.Events(ctx, dwOptions)
cEventCh, cErrCh = w.client.Events(ctx, options)
}
}
}
@@ -88,8 +121,9 @@ func (w *DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Nest
return eventCh, errCh
}
var dwOptions = events.ListOptions{Filters: filters.NewArgs(
filters.Arg("type", string(events.ContainerEventType)),
filters.Arg("event", string(events.ActionStart)),
filters.Arg("event", string(events.ActionDie)), // 'stop' already triggering 'die'
var optionsWatchAll = DockerListOptions{Filters: NewDockerFilter(
DockerFilterContainer,
DockerFilterStart,
DockerFilterStop,
DockerFilterDie,
)}

View File

@@ -1,26 +0,0 @@
package watcher
import "fmt"
type (
Event struct {
ActorName string
Action Action
}
Action string
)
const (
ActionModified Action = "MODIFIED"
ActionCreated Action = "CREATED"
ActionStarted Action = "STARTED"
ActionDeleted Action = "DELETED"
)
func (e Event) String() string {
return fmt.Sprintf("%s %s", e.ActorName, e.Action)
}
func (a Action) IsDelete() bool {
return a == ActionDeleted
}

View File

@@ -0,0 +1,49 @@
package events
import (
"fmt"
dockerEvents "github.com/docker/docker/api/types/events"
)
type (
Event struct {
Type EventType
ActorName string
ActorID string
ActorAttributes map[string]string
Action Action
}
Action string
EventType string
)
const (
ActionFileModified Action = "modified"
ActionFileCreated Action = "created"
ActionFileDeleted Action = "deleted"
ActionDockerStartUnpause Action = "start"
ActionDockerStopPause Action = "stop"
EventTypeDocker EventType = "docker"
EventTypeFile EventType = "file"
)
var DockerEventMap = map[dockerEvents.Action]Action{
dockerEvents.ActionCreate: ActionDockerStartUnpause,
dockerEvents.ActionStart: ActionDockerStartUnpause,
dockerEvents.ActionPause: ActionDockerStartUnpause,
dockerEvents.ActionDie: ActionDockerStopPause,
dockerEvents.ActionStop: ActionDockerStopPause,
dockerEvents.ActionUnPause: ActionDockerStopPause,
dockerEvents.ActionKill: ActionDockerStopPause,
}
func (e Event) String() string {
return fmt.Sprintf("%s %s", e.ActorName, e.Action)
}
func (a Action) IsDelete() bool {
return a == ActionFileDeleted
}

View File

@@ -20,7 +20,10 @@ func NewFileWatcher(filename string) Watcher {
}
func (f *fileWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.NestedError) {
if fwHelper == nil {
fwHelper = newFileWatcherHelper(common.ConfigBasePath)
}
return fwHelper.Add(ctx, f)
}
var fwHelper = newFileWatcherHelper(common.ConfigBasePath)
var fwHelper *fileWatcherHelper

View File

@@ -9,6 +9,7 @@ import (
"github.com/fsnotify/fsnotify"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/error"
"github.com/yusing/go-proxy/watcher/events"
)
type fileWatcherHelper struct {
@@ -80,27 +81,30 @@ func (h *fileWatcherHelper) start() {
for {
select {
case event, ok := <-h.w.Events:
case fsEvent, ok := <-h.w.Events:
if !ok {
// closed manually?
fsLogger.Error("channel closed")
return
}
// retrieve the watcher
w, ok := h.m[path.Base(event.Name)]
w, ok := h.m[path.Base(fsEvent.Name)]
if !ok {
// watcher for this file does not exist
continue
}
msg := Event{ActorName: w.filename}
msg := Event{
Type: events.EventTypeFile,
ActorName: w.filename,
}
switch {
case event.Has(fsnotify.Create):
msg.Action = ActionCreated
case event.Has(fsnotify.Write):
msg.Action = ActionModified
case event.Has(fsnotify.Remove), event.Has(fsnotify.Rename):
msg.Action = ActionDeleted
case fsEvent.Has(fsnotify.Create):
msg.Action = events.ActionFileCreated
case fsEvent.Has(fsnotify.Write):
msg.Action = events.ActionFileModified
case fsEvent.Has(fsnotify.Remove), fsEvent.Has(fsnotify.Rename):
msg.Action = events.ActionFileDeleted
default: // ignore other events
continue
}

View File

@@ -4,8 +4,11 @@ import (
"context"
E "github.com/yusing/go-proxy/error"
"github.com/yusing/go-proxy/watcher/events"
)
type Event = events.Event
type Watcher interface {
Events(ctx context.Context) (<-chan Event, <-chan E.NestedError)
}

View File

@@ -1 +0,0 @@
0.5.0-rc3