refactor and organize code

This commit is contained in:
yusing
2025-02-15 05:44:47 +08:00
parent 1af6dd9cf8
commit 18d258aaa2
169 changed files with 1020 additions and 755 deletions

View File

@@ -0,0 +1,173 @@
package accesslog
import (
"bytes"
"io"
"net/http"
"sync"
"time"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
)
type (
AccessLogger struct {
task *task.Task
cfg *Config
io AccessLogIO
buf bytes.Buffer // buffer for non-flushed log
bufMu sync.RWMutex
bufPool sync.Pool // buffer pool for formatting a single log line
flushThreshold int
Formatter
}
AccessLogIO interface {
io.ReadWriteCloser
io.ReadWriteSeeker
io.ReaderAt
sync.Locker
Name() string // file name or path
Truncate(size int64) error
}
Formatter interface {
// Format writes a log line to line without a trailing newline
Format(line *bytes.Buffer, req *http.Request, res *http.Response)
SetGetTimeNow(getTimeNow func() time.Time)
}
)
func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger {
l := &AccessLogger{
task: parent.Subtask("accesslog"),
cfg: cfg,
io: io,
}
if cfg.BufferSize < 1024 {
cfg.BufferSize = DefaultBufferSize
}
fmt := CommonFormatter{cfg: &l.cfg.Fields, GetTimeNow: time.Now}
switch l.cfg.Format {
case FormatCommon:
l.Formatter = &fmt
case FormatCombined:
l.Formatter = &CombinedFormatter{fmt}
case FormatJSON:
l.Formatter = &JSONFormatter{fmt}
default: // should not happen, validation has done by validate tags
panic("invalid access log format")
}
l.flushThreshold = int(cfg.BufferSize * 4 / 5) // 80%
l.buf.Grow(int(cfg.BufferSize))
l.bufPool.New = func() any {
return new(bytes.Buffer)
}
go l.start()
return l
}
func (l *AccessLogger) checkKeep(req *http.Request, res *http.Response) bool {
if !l.cfg.Filters.StatusCodes.CheckKeep(req, res) ||
!l.cfg.Filters.Method.CheckKeep(req, res) ||
!l.cfg.Filters.Headers.CheckKeep(req, res) ||
!l.cfg.Filters.CIDR.CheckKeep(req, res) {
return false
}
return true
}
func (l *AccessLogger) Log(req *http.Request, res *http.Response) {
if !l.checkKeep(req, res) {
return
}
line := l.bufPool.Get().(*bytes.Buffer)
l.Format(line, req, res)
line.WriteRune('\n')
l.bufMu.Lock()
l.buf.Write(line.Bytes())
line.Reset()
l.bufPool.Put(line)
l.bufMu.Unlock()
}
func (l *AccessLogger) LogError(req *http.Request, err error) {
l.Log(req, &http.Response{StatusCode: http.StatusInternalServerError, Status: err.Error()})
}
func (l *AccessLogger) Config() *Config {
return l.cfg
}
func (l *AccessLogger) Rotate() error {
if l.cfg.Retention == nil {
return nil
}
l.io.Lock()
defer l.io.Unlock()
return l.cfg.Retention.rotateLogFile(l.io)
}
func (l *AccessLogger) Flush(force bool) {
if l.buf.Len() == 0 {
return
}
if force || l.buf.Len() >= l.flushThreshold {
l.bufMu.RLock()
l.write(l.buf.Bytes())
l.buf.Reset()
l.bufMu.RUnlock()
}
}
func (l *AccessLogger) handleErr(err error) {
gperr.LogError("failed to write access log", err)
}
func (l *AccessLogger) start() {
defer func() {
if l.buf.Len() > 0 { // flush last
l.write(l.buf.Bytes())
}
l.io.Close()
l.task.Finish(nil)
}()
// periodic flush + threshold flush
periodic := time.NewTicker(5 * time.Second)
threshold := time.NewTicker(time.Second)
defer periodic.Stop()
defer threshold.Stop()
for {
select {
case <-l.task.Context().Done():
return
case <-periodic.C:
l.Flush(true)
case <-threshold.C:
l.Flush(false)
}
}
}
func (l *AccessLogger) write(data []byte) {
l.io.Lock() // prevent concurrent write, i.e. log rotation, other access loggers
_, err := l.io.Write(data)
l.io.Unlock()
if err != nil {
l.handleErr(err)
} else {
logging.Debug().Msg("access log flushed to " + l.io.Name())
}
}

View File

@@ -0,0 +1,127 @@
package accesslog_test
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/url"
"testing"
"time"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/task"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
const (
remote = "192.168.1.1"
host = "example.com"
uri = "/?bar=baz&foo=bar"
uriRedacted = "/?bar=" + RedactedValue + "&foo=" + RedactedValue
referer = "https://www.google.com/"
proto = "HTTP/1.1"
ua = "Go-http-client/1.1"
status = http.StatusOK
contentLength = 100
method = http.MethodGet
)
var (
testTask = task.RootTask("test", false)
testURL = Must(url.Parse("http://" + host + uri))
req = &http.Request{
RemoteAddr: remote,
Method: method,
Proto: proto,
Host: testURL.Host,
URL: testURL,
Header: http.Header{
"User-Agent": []string{ua},
"Referer": []string{referer},
"Cookie": []string{
"foo=bar",
"bar=baz",
},
},
}
resp = &http.Response{
StatusCode: status,
ContentLength: contentLength,
Header: http.Header{"Content-Type": []string{"text/plain"}},
}
)
func fmtLog(cfg *Config) (ts string, line string) {
var buf bytes.Buffer
t := time.Now()
logger := NewAccessLogger(testTask, nil, cfg)
logger.Formatter.SetGetTimeNow(func() time.Time {
return t
})
logger.Format(&buf, req, resp)
return t.Format(LogTimeFormat), buf.String()
}
func TestAccessLoggerCommon(t *testing.T) {
config := DefaultConfig()
config.Format = FormatCommon
ts, log := fmtLog(config)
ExpectEqual(t, log,
fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d",
host, remote, ts, method, uri, proto, status, contentLength,
),
)
}
func TestAccessLoggerCombined(t *testing.T) {
config := DefaultConfig()
config.Format = FormatCombined
ts, log := fmtLog(config)
ExpectEqual(t, log,
fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d \"%s\" \"%s\"",
host, remote, ts, method, uri, proto, status, contentLength, referer, ua,
),
)
}
func TestAccessLoggerRedactQuery(t *testing.T) {
config := DefaultConfig()
config.Format = FormatCommon
config.Fields.Query.Default = FieldModeRedact
ts, log := fmtLog(config)
ExpectEqual(t, log,
fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d",
host, remote, ts, method, uriRedacted, proto, status, contentLength,
),
)
}
func getJSONEntry(t *testing.T, config *Config) JSONLogEntry {
t.Helper()
config.Format = FormatJSON
var entry JSONLogEntry
_, log := fmtLog(config)
err := json.Unmarshal([]byte(log), &entry)
ExpectNoError(t, err)
return entry
}
func TestAccessLoggerJSON(t *testing.T) {
config := DefaultConfig()
entry := getJSONEntry(t, config)
ExpectEqual(t, entry.IP, remote)
ExpectEqual(t, entry.Method, method)
ExpectEqual(t, entry.Scheme, "http")
ExpectEqual(t, entry.Host, testURL.Host)
ExpectEqual(t, entry.URI, testURL.RequestURI())
ExpectEqual(t, entry.Protocol, proto)
ExpectEqual(t, entry.Status, status)
ExpectEqual(t, entry.ContentType, "text/plain")
ExpectEqual(t, entry.Size, contentLength)
ExpectEqual(t, entry.Referer, referer)
ExpectEqual(t, entry.UserAgent, ua)
ExpectEqual(t, len(entry.Headers), 0)
ExpectEqual(t, len(entry.Cookies), 0)
}

View File

@@ -0,0 +1,57 @@
package accesslog
import "github.com/yusing/go-proxy/internal/utils"
type (
Format string
Filters struct {
StatusCodes LogFilter[*StatusCodeRange] `json:"status_codes"`
Method LogFilter[HTTPMethod] `json:"method"`
Host LogFilter[Host] `json:"host"`
Headers LogFilter[*HTTPHeader] `json:"headers"` // header exists or header == value
CIDR LogFilter[*CIDR] `json:"cidr"`
}
Fields struct {
Headers FieldConfig `json:"headers"`
Query FieldConfig `json:"query"`
Cookies FieldConfig `json:"cookies"`
}
Config struct {
BufferSize uint `json:"buffer_size" validate:"gte=1"`
Format Format `json:"format" validate:"oneof=common combined json"`
Path string `json:"path" validate:"required"`
Filters Filters `json:"filters"`
Fields Fields `json:"fields"`
Retention *Retention `json:"retention"`
}
)
var (
FormatCommon Format = "common"
FormatCombined Format = "combined"
FormatJSON Format = "json"
)
const DefaultBufferSize = 64 * 1024 // 64KB
func DefaultConfig() *Config {
return &Config{
BufferSize: DefaultBufferSize,
Format: FormatCombined,
Fields: Fields{
Headers: FieldConfig{
Default: FieldModeDrop,
},
Query: FieldConfig{
Default: FieldModeKeep,
},
Cookies: FieldConfig{
Default: FieldModeDrop,
},
},
}
}
func init() {
utils.RegisterDefaultValueFactory(DefaultConfig)
}

View File

@@ -0,0 +1,53 @@
package accesslog_test
import (
"testing"
"github.com/yusing/go-proxy/internal/docker"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestNewConfig(t *testing.T) {
labels := map[string]string{
"proxy.buffer_size": "10",
"proxy.format": "combined",
"proxy.path": "/tmp/access.log",
"proxy.filters.status_codes.values": "200-299",
"proxy.filters.method.values": "GET, POST",
"proxy.filters.headers.values": "foo=bar, baz",
"proxy.filters.headers.negative": "true",
"proxy.filters.cidr.values": "192.168.10.0/24",
"proxy.fields.headers.default": "keep",
"proxy.fields.headers.config.foo": "redact",
"proxy.fields.query.default": "drop",
"proxy.fields.query.config.foo": "keep",
"proxy.fields.cookies.default": "redact",
"proxy.fields.cookies.config.foo": "keep",
}
parsed, err := docker.ParseLabels(labels)
ExpectNoError(t, err)
var config Config
err = utils.Deserialize(parsed, &config)
ExpectNoError(t, err)
ExpectEqual(t, config.BufferSize, 10)
ExpectEqual(t, config.Format, FormatCombined)
ExpectEqual(t, config.Path, "/tmp/access.log")
ExpectDeepEqual(t, config.Filters.StatusCodes.Values, []*StatusCodeRange{{Start: 200, End: 299}})
ExpectEqual(t, len(config.Filters.Method.Values), 2)
ExpectDeepEqual(t, config.Filters.Method.Values, []HTTPMethod{"GET", "POST"})
ExpectEqual(t, len(config.Filters.Headers.Values), 2)
ExpectDeepEqual(t, config.Filters.Headers.Values, []*HTTPHeader{{Key: "foo", Value: "bar"}, {Key: "baz", Value: ""}})
ExpectTrue(t, config.Filters.Headers.Negative)
ExpectEqual(t, len(config.Filters.CIDR.Values), 1)
ExpectEqual(t, config.Filters.CIDR.Values[0].String(), "192.168.10.0/24")
ExpectEqual(t, config.Fields.Headers.Default, FieldModeKeep)
ExpectEqual(t, config.Fields.Headers.Config["foo"], FieldModeRedact)
ExpectEqual(t, config.Fields.Query.Default, FieldModeDrop)
ExpectEqual(t, config.Fields.Query.Config["foo"], FieldModeKeep)
ExpectEqual(t, config.Fields.Cookies.Default, FieldModeRedact)
ExpectEqual(t, config.Fields.Cookies.Config["foo"], FieldModeKeep)
}

View File

@@ -0,0 +1,103 @@
package accesslog
import (
"net/http"
"net/url"
)
type (
FieldConfig struct {
Default FieldMode `json:"default" validate:"oneof=keep drop redact"`
Config map[string]FieldMode `json:"config" validate:"dive,oneof=keep drop redact"`
}
FieldMode string
)
const (
FieldModeKeep FieldMode = "keep"
FieldModeDrop FieldMode = "drop"
FieldModeRedact FieldMode = "redact"
RedactedValue = "REDACTED"
)
func processMap[V any](cfg *FieldConfig, m map[string]V, redactedV V) map[string]V {
if len(cfg.Config) == 0 {
switch cfg.Default {
case FieldModeKeep:
return m
case FieldModeDrop:
return nil
case FieldModeRedact:
redacted := make(map[string]V)
for k := range m {
redacted[k] = redactedV
}
return redacted
}
}
if len(m) == 0 {
return m
}
newMap := make(map[string]V, len(m))
for k := range m {
var mode FieldMode
var ok bool
if mode, ok = cfg.Config[k]; !ok {
mode = cfg.Default
}
switch mode {
case FieldModeKeep:
newMap[k] = m[k]
case FieldModeRedact:
newMap[k] = redactedV
}
}
return newMap
}
func processSlice[V any, VReturn any](cfg *FieldConfig, s []V, getKey func(V) string, convert func(V) VReturn, redact func(V) VReturn) map[string]VReturn {
if len(s) == 0 ||
len(cfg.Config) == 0 && cfg.Default == FieldModeDrop {
return nil
}
newMap := make(map[string]VReturn, len(s))
for _, v := range s {
var mode FieldMode
var ok bool
k := getKey(v)
if mode, ok = cfg.Config[k]; !ok {
mode = cfg.Default
}
switch mode {
case FieldModeKeep:
newMap[k] = convert(v)
case FieldModeRedact:
newMap[k] = redact(v)
}
}
return newMap
}
func (cfg *FieldConfig) ProcessHeaders(headers http.Header) http.Header {
return processMap(cfg, headers, []string{RedactedValue})
}
func (cfg *FieldConfig) ProcessQuery(q url.Values) url.Values {
return processMap(cfg, q, []string{RedactedValue})
}
func (cfg *FieldConfig) ProcessCookies(cookies []*http.Cookie) map[string]string {
return processSlice(cfg, cookies,
func(c *http.Cookie) string {
return c.Name
},
func(c *http.Cookie) string {
return c.Value
},
func(c *http.Cookie) string {
return RedactedValue
})
}

View File

@@ -0,0 +1,96 @@
package accesslog_test
import (
"testing"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
// Cookie header should be removed,
// stored in JSONLogEntry.Cookies instead.
func TestAccessLoggerJSONKeepHeaders(t *testing.T) {
config := DefaultConfig()
config.Fields.Headers.Default = FieldModeKeep
entry := getJSONEntry(t, config)
for k, v := range req.Header {
if k != "Cookie" {
ExpectDeepEqual(t, entry.Headers[k], v)
}
}
config.Fields.Headers.Config = map[string]FieldMode{
"Referer": FieldModeRedact,
"User-Agent": FieldModeDrop,
}
entry = getJSONEntry(t, config)
ExpectDeepEqual(t, entry.Headers["Referer"], []string{RedactedValue})
ExpectDeepEqual(t, entry.Headers["User-Agent"], nil)
}
func TestAccessLoggerJSONDropHeaders(t *testing.T) {
config := DefaultConfig()
config.Fields.Headers.Default = FieldModeDrop
entry := getJSONEntry(t, config)
for k := range req.Header {
ExpectDeepEqual(t, entry.Headers[k], nil)
}
config.Fields.Headers.Config = map[string]FieldMode{
"Referer": FieldModeKeep,
"User-Agent": FieldModeRedact,
}
entry = getJSONEntry(t, config)
ExpectDeepEqual(t, entry.Headers["Referer"], []string{req.Header.Get("Referer")})
ExpectDeepEqual(t, entry.Headers["User-Agent"], []string{RedactedValue})
}
func TestAccessLoggerJSONRedactHeaders(t *testing.T) {
config := DefaultConfig()
config.Fields.Headers.Default = FieldModeRedact
entry := getJSONEntry(t, config)
ExpectEqual(t, len(entry.Headers["Cookie"]), 0)
for k := range req.Header {
if k != "Cookie" {
ExpectDeepEqual(t, entry.Headers[k], []string{RedactedValue})
}
}
}
func TestAccessLoggerJSONKeepCookies(t *testing.T) {
config := DefaultConfig()
config.Fields.Headers.Default = FieldModeKeep
config.Fields.Cookies.Default = FieldModeKeep
entry := getJSONEntry(t, config)
ExpectEqual(t, len(entry.Headers["Cookie"]), 0)
for _, cookie := range req.Cookies() {
ExpectEqual(t, entry.Cookies[cookie.Name], cookie.Value)
}
}
func TestAccessLoggerJSONRedactCookies(t *testing.T) {
config := DefaultConfig()
config.Fields.Headers.Default = FieldModeKeep
config.Fields.Cookies.Default = FieldModeRedact
entry := getJSONEntry(t, config)
ExpectEqual(t, len(entry.Headers["Cookie"]), 0)
for _, cookie := range req.Cookies() {
ExpectEqual(t, entry.Cookies[cookie.Name], RedactedValue)
}
}
func TestAccessLoggerJSONDropQuery(t *testing.T) {
config := DefaultConfig()
config.Fields.Query.Default = FieldModeDrop
entry := getJSONEntry(t, config)
ExpectDeepEqual(t, entry.Query["foo"], nil)
ExpectDeepEqual(t, entry.Query["bar"], nil)
}
func TestAccessLoggerJSONRedactQuery(t *testing.T) {
config := DefaultConfig()
config.Fields.Query.Default = FieldModeRedact
entry := getJSONEntry(t, config)
ExpectDeepEqual(t, entry.Query["foo"], []string{RedactedValue})
ExpectDeepEqual(t, entry.Query["bar"], []string{RedactedValue})
}

View File

@@ -0,0 +1,69 @@
package accesslog
import (
"fmt"
"os"
"path"
"sync"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
)
type File struct {
*os.File
sync.Mutex
// os.File.Name() may not equal to key of `openedFiles`.
// Store it for later delete from `openedFiles`.
path string
refCount *utils.RefCount
}
var (
openedFiles = make(map[string]*File)
openedFilesMu sync.Mutex
)
func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) {
openedFilesMu.Lock()
var file *File
path := path.Clean(cfg.Path)
if opened, ok := openedFiles[path]; ok {
opened.refCount.Add()
file = opened
} else {
f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o644)
if err != nil {
openedFilesMu.Unlock()
return nil, fmt.Errorf("access log open error: %w", err)
}
file = &File{File: f, path: path, refCount: utils.NewRefCounter()}
openedFiles[path] = file
go file.closeOnZero()
}
openedFilesMu.Unlock()
return NewAccessLogger(parent, file, cfg), nil
}
func (f *File) Close() error {
f.refCount.Sub()
return nil
}
func (f *File) closeOnZero() {
defer logging.Debug().
Str("path", f.path).
Msg("access log closed")
<-f.refCount.Zero()
openedFilesMu.Lock()
delete(openedFiles, f.path)
openedFilesMu.Unlock()
f.File.Close()
}

View File

@@ -0,0 +1,95 @@
package accesslog
import (
"net/http"
"os"
"sync"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
"github.com/yusing/go-proxy/internal/task"
)
func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) {
var wg sync.WaitGroup
cfg := DefaultConfig()
cfg.Path = "test.log"
parent := task.RootTask("test", false)
loggerCount := 10
accessLogIOs := make([]AccessLogIO, loggerCount)
// make test log file
file, err := os.Create(cfg.Path)
ExpectNoError(t, err)
file.Close()
t.Cleanup(func() {
ExpectNoError(t, os.Remove(cfg.Path))
})
for i := range loggerCount {
wg.Add(1)
go func(index int) {
defer wg.Done()
logger, err := NewFileAccessLogger(parent, cfg)
ExpectNoError(t, err)
accessLogIOs[index] = logger.io
}(i)
}
wg.Wait()
firstIO := accessLogIOs[0]
for _, io := range accessLogIOs {
ExpectEqual(t, io, firstIO)
}
}
func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) {
var file MockFile
cfg := DefaultConfig()
cfg.BufferSize = 1024
parent := task.RootTask("test", false)
loggerCount := 5
logCountPerLogger := 10
loggers := make([]*AccessLogger, loggerCount)
for i := range loggerCount {
loggers[i] = NewAccessLogger(parent, &file, cfg)
}
var wg sync.WaitGroup
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
resp := &http.Response{StatusCode: http.StatusOK}
for _, logger := range loggers {
wg.Add(1)
go func(l *AccessLogger) {
defer wg.Done()
parallelLog(l, req, resp, logCountPerLogger)
l.Flush(true)
}(logger)
}
wg.Wait()
expected := loggerCount * logCountPerLogger
actual := file.Count()
ExpectEqual(t, actual, expected)
}
func parallelLog(logger *AccessLogger, req *http.Request, resp *http.Response, n int) {
var wg sync.WaitGroup
wg.Add(n)
for range n {
go func() {
defer wg.Done()
logger.Log(req, resp)
}()
}
wg.Wait()
}

View File

@@ -0,0 +1,99 @@
package accesslog
import (
"net"
"net/http"
"strings"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type (
LogFilter[T Filterable] struct {
Negative bool
Values []T
}
Filterable interface {
comparable
Fulfill(req *http.Request, res *http.Response) bool
}
HTTPMethod string
HTTPHeader struct {
Key, Value string
}
Host string
CIDR struct{ types.CIDR }
)
var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter")
func (f *LogFilter[T]) CheckKeep(req *http.Request, res *http.Response) bool {
if len(f.Values) == 0 {
return !f.Negative
}
for _, check := range f.Values {
if check.Fulfill(req, res) {
return !f.Negative
}
}
return f.Negative
}
func (r *StatusCodeRange) Fulfill(req *http.Request, res *http.Response) bool {
return r.Includes(res.StatusCode)
}
func (method HTTPMethod) Fulfill(req *http.Request, res *http.Response) bool {
return req.Method == string(method)
}
// Parse implements strutils.Parser.
func (k *HTTPHeader) Parse(v string) error {
split := strutils.SplitRune(v, '=')
switch len(split) {
case 1:
split = append(split, "")
case 2:
default:
return ErrInvalidHTTPHeaderFilter.Subject(v)
}
k.Key = split[0]
k.Value = split[1]
return nil
}
func (k *HTTPHeader) Fulfill(req *http.Request, res *http.Response) bool {
wanted := k.Value
// non canonical key matching
got, ok := req.Header[k.Key]
if wanted == "" {
return ok
}
if !ok {
return false
}
for _, v := range got {
if strings.EqualFold(v, wanted) {
return true
}
}
return false
}
func (h Host) Fulfill(req *http.Request, res *http.Response) bool {
return req.Host == string(h)
}
func (cidr CIDR) Fulfill(req *http.Request, res *http.Response) bool {
ip, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
ip = req.RemoteAddr
}
netIP := net.ParseIP(ip)
if netIP == nil {
return false
}
return cidr.Contains(netIP)
}

View File

@@ -0,0 +1,188 @@
package accesslog_test
import (
"net/http"
"testing"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/utils/strutils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestStatusCodeFilter(t *testing.T) {
values := []*StatusCodeRange{
strutils.MustParse[*StatusCodeRange]("200-308"),
}
t.Run("positive", func(t *testing.T) {
filter := &LogFilter[*StatusCodeRange]{}
ExpectTrue(t, filter.CheckKeep(nil, nil))
// keep any 2xx 3xx (inclusive)
filter.Values = values
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusForbidden,
}))
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusOK,
}))
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusMultipleChoices,
}))
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusPermanentRedirect,
}))
})
t.Run("negative", func(t *testing.T) {
filter := &LogFilter[*StatusCodeRange]{
Negative: true,
}
ExpectFalse(t, filter.CheckKeep(nil, nil))
// drop any 2xx 3xx (inclusive)
filter.Values = values
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusForbidden,
}))
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusOK,
}))
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusMultipleChoices,
}))
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusPermanentRedirect,
}))
})
}
func TestMethodFilter(t *testing.T) {
t.Run("positive", func(t *testing.T) {
filter := &LogFilter[HTTPMethod]{}
ExpectTrue(t, filter.CheckKeep(&http.Request{
Method: http.MethodGet,
}, nil))
ExpectTrue(t, filter.CheckKeep(&http.Request{
Method: http.MethodPost,
}, nil))
// keep get only
filter.Values = []HTTPMethod{http.MethodGet}
ExpectTrue(t, filter.CheckKeep(&http.Request{
Method: http.MethodGet,
}, nil))
ExpectFalse(t, filter.CheckKeep(&http.Request{
Method: http.MethodPost,
}, nil))
})
t.Run("negative", func(t *testing.T) {
filter := &LogFilter[HTTPMethod]{
Negative: true,
}
ExpectFalse(t, filter.CheckKeep(&http.Request{
Method: http.MethodGet,
}, nil))
ExpectFalse(t, filter.CheckKeep(&http.Request{
Method: http.MethodPost,
}, nil))
// drop post only
filter.Values = []HTTPMethod{http.MethodPost}
ExpectFalse(t, filter.CheckKeep(&http.Request{
Method: http.MethodPost,
}, nil))
ExpectTrue(t, filter.CheckKeep(&http.Request{
Method: http.MethodGet,
}, nil))
})
}
func TestHeaderFilter(t *testing.T) {
fooBar := &http.Request{
Header: http.Header{
"Foo": []string{"bar"},
},
}
fooBaz := &http.Request{
Header: http.Header{
"Foo": []string{"baz"},
},
}
headerFoo := []*HTTPHeader{
strutils.MustParse[*HTTPHeader]("Foo"),
}
ExpectEqual(t, headerFoo[0].Key, "Foo")
ExpectEqual(t, headerFoo[0].Value, "")
headerFooBar := []*HTTPHeader{
strutils.MustParse[*HTTPHeader]("Foo=bar"),
}
ExpectEqual(t, headerFooBar[0].Key, "Foo")
ExpectEqual(t, headerFooBar[0].Value, "bar")
t.Run("positive", func(t *testing.T) {
filter := &LogFilter[*HTTPHeader]{}
ExpectTrue(t, filter.CheckKeep(fooBar, nil))
ExpectTrue(t, filter.CheckKeep(fooBaz, nil))
// keep any foo
filter.Values = headerFoo
ExpectTrue(t, filter.CheckKeep(fooBar, nil))
ExpectTrue(t, filter.CheckKeep(fooBaz, nil))
// keep foo == bar
filter.Values = headerFooBar
ExpectTrue(t, filter.CheckKeep(fooBar, nil))
ExpectFalse(t, filter.CheckKeep(fooBaz, nil))
})
t.Run("negative", func(t *testing.T) {
filter := &LogFilter[*HTTPHeader]{
Negative: true,
}
ExpectFalse(t, filter.CheckKeep(fooBar, nil))
ExpectFalse(t, filter.CheckKeep(fooBaz, nil))
// drop any foo
filter.Values = headerFoo
ExpectFalse(t, filter.CheckKeep(fooBar, nil))
ExpectFalse(t, filter.CheckKeep(fooBaz, nil))
// drop foo == bar
filter.Values = headerFooBar
ExpectFalse(t, filter.CheckKeep(fooBar, nil))
ExpectTrue(t, filter.CheckKeep(fooBaz, nil))
})
}
func TestCIDRFilter(t *testing.T) {
cidr := []*CIDR{
strutils.MustParse[*CIDR]("192.168.10.0/24"),
}
ExpectEqual(t, cidr[0].String(), "192.168.10.0/24")
inCIDR := &http.Request{
RemoteAddr: "192.168.10.1",
}
notInCIDR := &http.Request{
RemoteAddr: "192.168.11.1",
}
t.Run("positive", func(t *testing.T) {
filter := &LogFilter[*CIDR]{}
ExpectTrue(t, filter.CheckKeep(inCIDR, nil))
ExpectTrue(t, filter.CheckKeep(notInCIDR, nil))
filter.Values = cidr
ExpectTrue(t, filter.CheckKeep(inCIDR, nil))
ExpectFalse(t, filter.CheckKeep(notInCIDR, nil))
})
t.Run("negative", func(t *testing.T) {
filter := &LogFilter[*CIDR]{Negative: true}
ExpectFalse(t, filter.CheckKeep(inCIDR, nil))
ExpectFalse(t, filter.CheckKeep(notInCIDR, nil))
filter.Values = cidr
ExpectFalse(t, filter.CheckKeep(inCIDR, nil))
ExpectTrue(t, filter.CheckKeep(notInCIDR, nil))
})
}

View File

@@ -0,0 +1,144 @@
package accesslog
import (
"bytes"
"encoding/json"
"net"
"net/http"
"net/url"
"strconv"
"time"
"github.com/yusing/go-proxy/internal/logging"
)
type (
CommonFormatter struct {
cfg *Fields
GetTimeNow func() time.Time // for testing purposes only
}
CombinedFormatter struct{ CommonFormatter }
JSONFormatter struct{ CommonFormatter }
JSONLogEntry struct {
Time string `json:"time"`
IP string `json:"ip"`
Method string `json:"method"`
Scheme string `json:"scheme"`
Host string `json:"host"`
URI string `json:"uri"`
Protocol string `json:"protocol"`
Status int `json:"status"`
Error string `json:"error,omitempty"`
ContentType string `json:"type"`
Size int64 `json:"size"`
Referer string `json:"referer"`
UserAgent string `json:"useragent"`
Query map[string][]string `json:"query,omitempty"`
Headers map[string][]string `json:"headers,omitempty"`
Cookies map[string]string `json:"cookies,omitempty"`
}
)
const LogTimeFormat = "02/Jan/2006:15:04:05 -0700"
func scheme(req *http.Request) string {
if req.TLS != nil {
return "https"
}
return "http"
}
func requestURI(u *url.URL, query url.Values) string {
uri := u.EscapedPath()
if len(query) > 0 {
uri += "?" + query.Encode()
}
return uri
}
func clientIP(req *http.Request) string {
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
return clientIP
}
return req.RemoteAddr
}
// debug only.
func (f *CommonFormatter) SetGetTimeNow(getTimeNow func() time.Time) {
f.GetTimeNow = getTimeNow
}
func (f *CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
query := f.cfg.Query.ProcessQuery(req.URL.Query())
line.WriteString(req.Host)
line.WriteRune(' ')
line.WriteString(clientIP(req))
line.WriteString(" - - [")
line.WriteString(f.GetTimeNow().Format(LogTimeFormat))
line.WriteString("] \"")
line.WriteString(req.Method)
line.WriteRune(' ')
line.WriteString(requestURI(req.URL, query))
line.WriteRune(' ')
line.WriteString(req.Proto)
line.WriteString("\" ")
line.WriteString(strconv.Itoa(res.StatusCode))
line.WriteRune(' ')
line.WriteString(strconv.FormatInt(res.ContentLength, 10))
}
func (f *CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
f.CommonFormatter.Format(line, req, res)
line.WriteString(" \"")
line.WriteString(req.Referer())
line.WriteString("\" \"")
line.WriteString(req.UserAgent())
line.WriteRune('"')
}
func (f *JSONFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
query := f.cfg.Query.ProcessQuery(req.URL.Query())
headers := f.cfg.Headers.ProcessHeaders(req.Header)
headers.Del("Cookie")
cookies := f.cfg.Cookies.ProcessCookies(req.Cookies())
entry := JSONLogEntry{
Time: f.GetTimeNow().Format(LogTimeFormat),
IP: clientIP(req),
Method: req.Method,
Scheme: scheme(req),
Host: req.Host,
URI: requestURI(req.URL, query),
Protocol: req.Proto,
Status: res.StatusCode,
ContentType: res.Header.Get("Content-Type"),
Size: res.ContentLength,
Referer: req.Referer(),
UserAgent: req.UserAgent(),
Query: query,
Headers: headers,
Cookies: cookies,
}
if res.StatusCode >= 400 {
entry.Error = res.Status
}
if entry.ContentType == "" {
// try to get content type from request
entry.ContentType = req.Header.Get("Content-Type")
}
marshaller := json.NewEncoder(line)
err := marshaller.Encode(entry)
if err != nil {
logging.Err(err).Msg("failed to marshal json log")
}
}

View File

@@ -0,0 +1,74 @@
package accesslog
import (
"bytes"
"io"
"sync"
)
type MockFile struct {
data []byte
position int64
sync.Mutex
}
func (m *MockFile) Seek(offset int64, whence int) (int64, error) {
switch whence {
case io.SeekStart:
m.position = offset
case io.SeekCurrent:
m.position += offset
case io.SeekEnd:
m.position = int64(len(m.data)) + offset
}
return m.position, nil
}
func (m *MockFile) Write(p []byte) (n int, err error) {
m.data = append(m.data, p...)
n = len(p)
m.position += int64(n)
return
}
func (m *MockFile) Name() string {
return "mock"
}
func (m *MockFile) Read(p []byte) (n int, err error) {
if m.position >= int64(len(m.data)) {
return 0, io.EOF
}
n = copy(p, m.data[m.position:])
m.position += int64(n)
return n, nil
}
func (m *MockFile) ReadAt(p []byte, off int64) (n int, err error) {
if off >= int64(len(m.data)) {
return 0, io.EOF
}
n = copy(p, m.data[off:])
m.position += int64(n)
return n, nil
}
func (m *MockFile) Close() error {
return nil
}
func (m *MockFile) Truncate(size int64) error {
m.data = m.data[:size]
m.position = size
return nil
}
func (m *MockFile) Count() int {
m.Lock()
defer m.Unlock()
return bytes.Count(m.data[:m.position], []byte("\n"))
}
func (m *MockFile) Len() int64 {
return m.position
}

View File

@@ -0,0 +1,198 @@
package accesslog
import (
"bufio"
"bytes"
"io"
"strconv"
"time"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Retention struct {
Days uint64 `json:"days"`
Last uint64 `json:"last"`
}
const chunkSizeMax int64 = 128 * 1024 // 128KB
var (
ErrInvalidSyntax = gperr.New("invalid syntax")
ErrZeroValue = gperr.New("zero value")
)
// Syntax:
//
// <N> days|weeks|months
//
// last <N>
//
// Parse implements strutils.Parser.
func (r *Retention) Parse(v string) (err error) {
split := strutils.SplitSpace(v)
if len(split) != 2 {
return ErrInvalidSyntax.Subject(v)
}
switch split[0] {
case "last":
r.Last, err = strconv.ParseUint(split[1], 10, 64)
default: // <N> days|weeks|months
r.Days, err = strconv.ParseUint(split[0], 10, 64)
if err != nil {
return
}
switch split[1] {
case "days":
case "weeks":
r.Days *= 7
case "months":
r.Days *= 30
default:
return ErrInvalidSyntax.Subject("unit " + split[1])
}
}
if r.Days == 0 && r.Last == 0 {
return ErrZeroValue
}
return
}
func (r *Retention) rotateLogFile(file AccessLogIO) (err error) {
lastN := int(r.Last)
days := int(r.Days)
// Seek to end to get file size
size, err := file.Seek(0, io.SeekEnd)
if err != nil {
return err
}
// Initialize ring buffer for last N lines
lines := make([][]byte, 0, lastN|(days*1000))
pos := size
unprocessed := 0
var chunk [chunkSizeMax]byte
var lastLine []byte
var shouldStop func() bool
if days > 0 {
cutoff := time.Now().AddDate(0, 0, -days)
shouldStop = func() bool {
return len(lastLine) > 0 && !parseLogTime(lastLine).After(cutoff)
}
} else {
shouldStop = func() bool {
return len(lines) == lastN
}
}
// Read backwards until we have enough lines or reach start of file
for pos > 0 {
if pos > chunkSizeMax {
pos -= chunkSizeMax
} else {
pos = 0
}
// Seek to the current chunk
if _, err = file.Seek(pos, io.SeekStart); err != nil {
return err
}
var nRead int
// Read the chunk
if nRead, err = file.Read(chunk[unprocessed:]); err != nil {
return err
}
// last unprocessed bytes + read bytes
curChunk := chunk[:unprocessed+nRead]
unprocessed = len(curChunk)
// Split into lines
scanner := bufio.NewScanner(bytes.NewReader(curChunk))
for !shouldStop() && scanner.Scan() {
lastLine = scanner.Bytes()
lines = append(lines, lastLine)
unprocessed -= len(lastLine)
}
if shouldStop() {
break
}
// move unprocessed bytes to the beginning for next iteration
copy(chunk[:], curChunk[unprocessed:])
}
if days > 0 {
// truncate to the end of the log within last N days
return file.Truncate(pos)
}
// write lines to buffer in reverse order
// since we read them backwards
var buf bytes.Buffer
for i := len(lines) - 1; i >= 0; i-- {
buf.Write(lines[i])
buf.WriteRune('\n')
}
return writeTruncate(file, &buf)
}
func writeTruncate(file AccessLogIO, buf *bytes.Buffer) (err error) {
// Seek to beginning and truncate
if _, err := file.Seek(0, 0); err != nil {
return err
}
buffered := bufio.NewWriter(file)
// Write buffer back to file
nWritten, err := buffered.Write(buf.Bytes())
if err != nil {
return err
}
if err = buffered.Flush(); err != nil {
return err
}
// Truncate file
if err = file.Truncate(int64(nWritten)); err != nil {
return err
}
// check bytes written == buffer size
if nWritten != buf.Len() {
return io.ErrShortWrite
}
return
}
func parseLogTime(line []byte) (t time.Time) {
if len(line) == 0 {
return
}
var start, end int
const jsonStart = len(`{"time":"`)
const jsonEnd = jsonStart + len(LogTimeFormat)
if len(line) == '{' { // possibly json log
start = jsonStart
end = jsonEnd
} else { // possibly common or combined format
// Format: <virtual host> <host ip> - - [02/Jan/2006:15:04:05 -0700] ...
start = bytes.IndexRune(line, '[')
end = bytes.IndexRune(line[start+1:], ']')
if start == -1 || end == -1 || start >= end {
return
}
}
timeStr := line[start+1 : end]
t, _ = time.Parse(LogTimeFormat, string(timeStr)) // ignore error
return
}

View File

@@ -0,0 +1,81 @@
package accesslog_test
import (
"testing"
"time"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/strutils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestParseRetention(t *testing.T) {
tests := []struct {
input string
expected *Retention
shouldErr bool
}{
{"30 days", &Retention{Days: 30}, false},
{"2 weeks", &Retention{Days: 14}, false},
{"last 5", &Retention{Last: 5}, false},
{"invalid input", &Retention{}, true},
}
for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
r := &Retention{}
err := r.Parse(test.input)
if !test.shouldErr {
ExpectNoError(t, err)
} else {
ExpectDeepEqual(t, r, test.expected)
}
})
}
}
func TestRetentionCommonFormat(t *testing.T) {
var file MockFile
logger := NewAccessLogger(task.RootTask("test", false), &file, &Config{
Format: FormatCommon,
BufferSize: 1024,
})
for range 10 {
logger.Log(req, resp)
}
logger.Flush(true)
// test.Finish(nil)
ExpectEqual(t, logger.Config().Retention, nil)
ExpectTrue(t, file.Len() > 0)
ExpectEqual(t, file.Count(), 10)
t.Run("keep last", func(t *testing.T) {
logger.Config().Retention = strutils.MustParse[*Retention]("last 5")
ExpectEqual(t, logger.Config().Retention.Days, 0)
ExpectEqual(t, logger.Config().Retention.Last, 5)
ExpectNoError(t, logger.Rotate())
ExpectEqual(t, file.Count(), 5)
})
_ = file.Truncate(0)
timeNow := time.Now()
for i := range 10 {
logger.Formatter.(*CommonFormatter).GetTimeNow = func() time.Time {
return timeNow.AddDate(0, 0, -i)
}
logger.Log(req, resp)
}
logger.Flush(true)
// FIXME: keep days does not work
t.Run("keep days", func(t *testing.T) {
logger.Config().Retention = strutils.MustParse[*Retention]("3 days")
ExpectEqual(t, logger.Config().Retention.Days, 3)
ExpectEqual(t, logger.Config().Retention.Last, 0)
ExpectNoError(t, logger.Rotate())
ExpectEqual(t, file.Count(), 3)
})
}

View File

@@ -0,0 +1,52 @@
package accesslog
import (
"strconv"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type StatusCodeRange struct {
Start int
End int
}
var ErrInvalidStatusCodeRange = gperr.New("invalid status code range")
func (r *StatusCodeRange) Includes(code int) bool {
return r.Start <= code && code <= r.End
}
// Parse implements strutils.Parser.
func (r *StatusCodeRange) Parse(v string) error {
split := strutils.SplitRune(v, '-')
switch len(split) {
case 1:
start, err := strconv.Atoi(split[0])
if err != nil {
return gperr.Wrap(err)
}
r.Start = start
r.End = start
return nil
case 2:
start, errStart := strconv.Atoi(split[0])
end, errEnd := strconv.Atoi(split[1])
if err := gperr.Join(errStart, errEnd); err != nil {
return err
}
r.Start = start
r.End = end
return nil
default:
return ErrInvalidStatusCodeRange.Subject(v)
}
}
func (r *StatusCodeRange) String() string {
if r.Start == r.End {
return strconv.Itoa(r.Start)
}
return strconv.Itoa(r.Start) + "-" + strconv.Itoa(r.End)
}

View File

@@ -0,0 +1,46 @@
package gphttp
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"github.com/yusing/go-proxy/internal/logging"
)
func WriteBody(w http.ResponseWriter, body []byte) {
if _, err := w.Write(body); err != nil {
switch {
case errors.Is(err, http.ErrHandlerTimeout),
errors.Is(err, context.DeadlineExceeded):
logging.Err(err).Msg("timeout writing body")
default:
logging.Err(err).Msg("failed to write body")
}
}
}
func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) (canProceed bool) {
if len(code) > 0 {
w.WriteHeader(code[0])
}
w.Header().Set("Content-Type", "application/json")
var err error
switch data := data.(type) {
case string:
_, err = w.Write([]byte(fmt.Sprintf("%q", data)))
case []byte:
panic("use WriteBody instead")
default:
err = json.NewEncoder(w).Encode(data)
}
if err != nil {
LogError(r).Err(err).Msg("failed to encode json")
return false
}
return true
}

View File

@@ -0,0 +1,78 @@
package gphttp
import (
"mime"
"net/http"
)
type (
ContentType string
AcceptContentType []ContentType
)
func GetContentType(h http.Header) ContentType {
ct := h.Get("Content-Type")
if ct == "" {
return ""
}
ct, _, err := mime.ParseMediaType(ct)
if err != nil {
return ""
}
return ContentType(ct)
}
func GetAccept(h http.Header) AcceptContentType {
var accepts []ContentType
for _, v := range h["Accept"] {
ct, _, err := mime.ParseMediaType(v)
if err != nil {
continue
}
accepts = append(accepts, ContentType(ct))
}
return accepts
}
func (ct ContentType) IsHTML() bool {
return ct == "text/html" || ct == "application/xhtml+xml"
}
func (ct ContentType) IsJSON() bool {
return ct == "application/json"
}
func (ct ContentType) IsPlainText() bool {
return ct == "text/plain"
}
func (act AcceptContentType) IsEmpty() bool {
return len(act) == 0
}
func (act AcceptContentType) AcceptHTML() bool {
for _, v := range act {
if v.IsHTML() || v == "text/*" || v == "*/*" {
return true
}
}
return false
}
func (act AcceptContentType) AcceptJSON() bool {
for _, v := range act {
if v.IsJSON() || v == "*/*" {
return true
}
}
return false
}
func (act AcceptContentType) AcceptPlainText() bool {
for _, v := range act {
if v.IsPlainText() || v == "text/*" || v == "*/*" {
return true
}
}
return false
}

View File

@@ -0,0 +1,41 @@
package gphttp
import (
"net/http"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestContentTypes(t *testing.T) {
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsHTML())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/html; charset=utf-8"}}).IsHTML())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/xhtml+xml"}}).IsHTML())
ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/plain"}}).IsHTML())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/json"}}).IsJSON())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/json; charset=utf-8"}}).IsJSON())
ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsJSON())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/plain"}}).IsPlainText())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/plain; charset=utf-8"}}).IsPlainText())
ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsPlainText())
}
func TestAcceptContentTypes(t *testing.T) {
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain"}}).AcceptPlainText())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain; charset=utf-8"}}).AcceptPlainText())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain"}}).AcceptHTML())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"application/json"}}).AcceptJSON())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptPlainText())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptHTML())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptJSON())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptPlainText())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptHTML())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/plain"}}).AcceptHTML())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/plain; charset=utf-8"}}).AcceptHTML())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/html"}}).AcceptPlainText())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/html"}}).AcceptJSON())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptJSON())
}

View File

@@ -0,0 +1,27 @@
package gphttp
import (
"crypto/tls"
"net"
"net/http"
"time"
)
var (
httpClient = &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
DisableKeepAlives: true,
ForceAttemptHTTP2: false,
DialContext: (&net.Dialer{
Timeout: 3 * time.Second,
KeepAlive: 60 * time.Second, // this is different from DisableKeepAlives
}).DialContext,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
Get = httpClient.Get
Post = httpClient.Post
Head = httpClient.Head
)

View File

@@ -0,0 +1,95 @@
package gphttp
import (
"context"
"encoding/json"
"errors"
"net/http"
"syscall"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
)
// ServerError is for handling server errors.
//
// It logs the error and returns http.StatusInternalServerError to the client.
// Status code can be specified as an argument.
func ServerError(w http.ResponseWriter, r *http.Request, err error, code ...int) {
switch {
case err == nil,
errors.Is(err, context.Canceled),
errors.Is(err, syscall.EPIPE),
errors.Is(err, syscall.ECONNRESET):
return
}
LogError(r).Msg(err.Error())
if httpheaders.IsWebsocket(r.Header) {
return
}
if len(code) == 0 {
code = []int{http.StatusInternalServerError}
}
http.Error(w, http.StatusText(code[0]), code[0])
}
// ClientError is for responding to client errors.
//
// It returns http.StatusBadRequest with reason to the client.
// Status code can be specified as an argument.
//
// For JSON marshallable errors (e.g. gperr.Error), it returns the error details as JSON.
// Otherwise, it returns the error details as plain text.
func ClientError(w http.ResponseWriter, err error, code ...int) {
if len(code) == 0 {
code = []int{http.StatusBadRequest}
}
if gperr.IsJSONMarshallable(err) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(err)
} else {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
}
http.Error(w, err.Error(), code[0])
}
// JSONError returns a JSON response of gperr.Error with the given status code.
func JSONError(w http.ResponseWriter, err gperr.Error, code int) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(err)
http.Error(w, err.Error(), code)
}
// BadRequest returns a Bad Request response with the given error message.
func BadRequest(w http.ResponseWriter, err string, code ...int) {
if len(code) == 0 {
code = []int{http.StatusBadRequest}
}
http.Error(w, err, code[0])
}
// Unauthorized returns an Unauthorized response with the given error message.
func Unauthorized(w http.ResponseWriter, err string) {
BadRequest(w, err, http.StatusUnauthorized)
}
// NotFound returns a Not Found response with the given error message.
func NotFound(w http.ResponseWriter, err string) {
BadRequest(w, err, http.StatusNotFound)
}
func ErrMissingKey(k string) error {
return gperr.New(k + " is required")
}
func ErrInvalidKey(k string) error {
return gperr.New(k + " is invalid")
}
func ErrAlreadyExists(k, v string) error {
return gperr.Errorf("%s %q already exists", k, v)
}
func ErrNotFound(k, v string) error {
return gperr.Errorf("%s %q not found", k, v)
}

View File

@@ -0,0 +1,86 @@
package gpwebsocket
import (
"net/http"
"sync"
"time"
"github.com/coder/websocket"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
)
func warnNoMatchDomains() {
logging.Warn().Msg("no match domains configured, accepting websocket API request from all origins")
}
var warnNoMatchDomainOnce sync.Once
func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
var originPats []string
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
allowedDomains := httpheaders.WebsocketAllowedDomains(r.Header)
if len(allowedDomains) == 0 || common.IsDebug {
warnNoMatchDomainOnce.Do(warnNoMatchDomains)
originPats = []string{"*"}
} else {
originPats = make([]string, len(allowedDomains))
for i, domain := range allowedDomains {
if domain[0] != '.' {
originPats[i] = "*." + domain
} else {
originPats[i] = "*" + domain
}
}
originPats = append(originPats, localAddresses...)
}
return websocket.Accept(w, r, &websocket.AcceptOptions{
OriginPatterns: originPats,
})
}
func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) {
conn, err := Initiate(w, r)
if err != nil {
gphttp.ServerError(w, r, err)
return
}
//nolint:errcheck
defer conn.CloseNow()
if err := do(conn); err != nil {
gphttp.ServerError(w, r, err)
return
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-r.Context().Done():
return
case <-ticker.C:
if err := do(conn); err != nil {
gphttp.ServerError(w, r, err)
return
}
}
}
}
// WriteText writes a text message to the websocket connection.
// It returns true if the message was written successfully, false otherwise.
// It logs an error if the message is not written successfully.
func WriteText(r *http.Request, conn *websocket.Conn, msg string) bool {
if err := conn.Write(r.Context(), websocket.MessageText, []byte(msg)); err != nil {
gperr.LogError("failed to write text message", err)
return false
}
return true
}

View File

@@ -0,0 +1,7 @@
package httpheaders
import "net/http"
func IsSSE(h http.Header) bool {
return h.Get("Content-Type") == "text/event-stream"
}

View File

@@ -0,0 +1,119 @@
package httpheaders
import (
"net/http"
"net/textproto"
"github.com/yusing/go-proxy/internal/utils/strutils"
"golang.org/x/net/http/httpguts"
)
const (
HeaderXForwardedMethod = "X-Forwarded-Method"
HeaderXForwardedFor = "X-Forwarded-For"
HeaderXForwardedProto = "X-Forwarded-Proto"
HeaderXForwardedHost = "X-Forwarded-Host"
HeaderXForwardedPort = "X-Forwarded-Port"
HeaderXForwardedURI = "X-Forwarded-Uri"
HeaderXRealIP = "X-Real-IP"
HeaderContentType = "Content-Type"
HeaderContentLength = "Content-Length"
HeaderUpstreamName = "X-Godoxy-Upstream-Name"
HeaderUpstreamScheme = "X-Godoxy-Upstream-Scheme"
HeaderUpstreamHost = "X-Godoxy-Upstream-Host"
HeaderUpstreamPort = "X-Godoxy-Upstream-Port"
HeaderGoDoxyCheckRedirect = "X-Godoxy-Check-Redirect"
)
// Hop-by-hop headers. These are removed when sent to the backend.
// As of RFC 7230, hop-by-hop headers are required to appear in the
// Connection header field. These are the headers defined by the
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
// compatibility.
var hopHeaders = []string{
"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
"Upgrade",
}
func UpgradeType(h http.Header) string {
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
return ""
}
return h.Get("Upgrade")
}
// RemoveHopByHopHeaders removes hop-by-hop headers.
func RemoveHopByHopHeaders(h http.Header) {
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
for _, f := range h["Connection"] {
for _, sf := range strutils.SplitComma(f) {
if sf = textproto.TrimString(sf); sf != "" {
h.Del(sf)
}
}
}
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
// This behavior is superseded by the RFC 7230 Connection header, but
// preserve it for backwards compatibility.
for _, f := range hopHeaders {
h.Del(f)
}
}
func RemoveHop(h http.Header) {
reqUpType := UpgradeType(h)
RemoveHopByHopHeaders(h)
if reqUpType != "" {
h.Set("Connection", "Upgrade")
h.Set("Upgrade", reqUpType)
} else {
h.Del("Connection")
}
}
func CopyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func FilterHeaders(h http.Header, allowed []string) http.Header {
if len(allowed) == 0 {
return h
}
filtered := make(http.Header)
for i, header := range allowed {
values := h.Values(header)
if len(values) == 0 {
continue
}
filtered[http.CanonicalHeaderKey(allowed[i])] = append([]string(nil), values...)
}
return filtered
}
func HeaderToMap(h http.Header) map[string]string {
result := make(map[string]string)
for k, v := range h {
if len(v) > 0 {
result[k] = v[0] // Take the first value
}
}
return result
}

View File

@@ -0,0 +1,21 @@
package httpheaders
import (
"net/http"
)
const (
HeaderXGoDoxyWebsocketAllowedDomains = "X-GoDoxy-Websocket-Allowed-Domains"
)
func WebsocketAllowedDomains(h http.Header) []string {
return h[HeaderXGoDoxyWebsocketAllowedDomains]
}
func SetWebsocketAllowedDomains(h http.Header, domains []string) {
h[HeaderXGoDoxyWebsocketAllowedDomains] = domains
}
func IsWebsocket(h http.Header) bool {
return UpgradeType(h) == "websocket"
}

View File

@@ -0,0 +1,91 @@
package loadbalancer
import (
"hash/fnv"
"net"
"net/http"
"sync"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/gphttp/middleware"
)
type ipHash struct {
*LoadBalancer
realIP *middleware.Middleware
pool Servers
mu sync.Mutex
}
func (lb *LoadBalancer) newIPHash() impl {
impl := &ipHash{LoadBalancer: lb}
if len(lb.Options) == 0 {
return impl
}
var err gperr.Error
impl.realIP, err = middleware.RealIP.New(lb.Options)
if err != nil {
gperr.LogError("invalid real_ip options, ignoring", err, &impl.l)
}
return impl
}
func (impl *ipHash) OnAddServer(srv Server) {
impl.mu.Lock()
defer impl.mu.Unlock()
for i, s := range impl.pool {
if s == srv {
return
}
if s == nil {
impl.pool[i] = srv
return
}
}
impl.pool = append(impl.pool, srv)
}
func (impl *ipHash) OnRemoveServer(srv Server) {
impl.mu.Lock()
defer impl.mu.Unlock()
for i, s := range impl.pool {
if s == srv {
impl.pool[i] = nil
return
}
}
}
func (impl *ipHash) ServeHTTP(_ Servers, rw http.ResponseWriter, r *http.Request) {
if impl.realIP != nil {
impl.realIP.ModifyRequest(impl.serveHTTP, rw, r)
} else {
impl.serveHTTP(rw, r)
}
}
func (impl *ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(rw, "Internal error", http.StatusInternalServerError)
impl.l.Err(err).Msg("invalid remote address " + r.RemoteAddr)
return
}
idx := hashIP(ip) % uint32(len(impl.pool))
srv := impl.pool[idx]
if srv == nil || srv.Status().Bad() {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
}
srv.ServeHTTP(rw, r)
}
func hashIP(ip string) uint32 {
h := fnv.New32a()
h.Write([]byte(ip))
return h.Sum32()
}

View File

@@ -0,0 +1,53 @@
package loadbalancer
import (
"net/http"
"sync/atomic"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type leastConn struct {
*LoadBalancer
nConn F.Map[Server, *atomic.Int64]
}
func (lb *LoadBalancer) newLeastConn() impl {
return &leastConn{
LoadBalancer: lb,
nConn: F.NewMapOf[Server, *atomic.Int64](),
}
}
func (impl *leastConn) OnAddServer(srv Server) {
impl.nConn.Store(srv, new(atomic.Int64))
}
func (impl *leastConn) OnRemoveServer(srv Server) {
impl.nConn.Delete(srv)
}
func (impl *leastConn) ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request) {
srv := srvs[0]
minConn, ok := impl.nConn.Load(srv)
if !ok {
impl.l.Error().Msgf("[BUG] server %s not found", srv.Name())
http.Error(rw, "Internal error", http.StatusInternalServerError)
}
for i := 1; i < len(srvs); i++ {
nConn, ok := impl.nConn.Load(srvs[i])
if !ok {
impl.l.Error().Msgf("[BUG] server %s not found", srv.Name())
http.Error(rw, "Internal error", http.StatusInternalServerError)
}
if nConn.Load() < minConn.Load() {
minConn = nConn
srv = srvs[i]
}
}
minConn.Add(1)
srv.ServeHTTP(rw, r)
minConn.Add(-1)
}

View File

@@ -0,0 +1,314 @@
package loadbalancer
import (
"net/http"
"sync"
"time"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
"github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/health"
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
)
// TODO: stats of each server.
// TODO: support weighted mode.
type (
impl interface {
ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request)
OnAddServer(srv Server)
OnRemoveServer(srv Server)
}
LoadBalancer struct {
impl
*Config
task *task.Task
pool Pool
poolMu sync.Mutex
sumWeight Weight
startTime time.Time
l zerolog.Logger
}
)
const maxWeight Weight = 100
func New(cfg *Config) *LoadBalancer {
lb := &LoadBalancer{
Config: new(Config),
pool: types.NewServerPool(),
l: logging.With().Str("name", cfg.Link).Logger(),
}
lb.UpdateConfigIfNeeded(cfg)
return lb
}
// Start implements task.TaskStarter.
func (lb *LoadBalancer) Start(parent task.Parent) gperr.Error {
lb.startTime = time.Now()
lb.task = parent.Subtask("loadbalancer."+lb.Link, false)
parent.OnCancel("lb_remove_route", func() {
routes.DeleteHTTPRoute(lb.Link)
})
lb.task.OnFinished("cleanup", func() {
if lb.impl != nil {
lb.pool.RangeAll(func(k string, v Server) {
lb.impl.OnRemoveServer(v)
})
}
})
return nil
}
// Task implements task.TaskStarter.
func (lb *LoadBalancer) Task() *task.Task {
return lb.task
}
// Finish implements task.TaskFinisher.
func (lb *LoadBalancer) Finish(reason any) {
lb.task.Finish(reason)
}
func (lb *LoadBalancer) updateImpl() {
switch lb.Mode {
case types.ModeUnset, types.ModeRoundRobin:
lb.impl = lb.newRoundRobin()
case types.ModeLeastConn:
lb.impl = lb.newLeastConn()
case types.ModeIPHash:
lb.impl = lb.newIPHash()
default: // should happen in test only
lb.impl = lb.newRoundRobin()
}
lb.pool.RangeAll(func(_ string, srv Server) {
lb.impl.OnAddServer(srv)
})
}
func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
if cfg != nil {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
lb.Link = cfg.Link
if lb.Mode == types.ModeUnset && cfg.Mode != types.ModeUnset {
lb.Mode = cfg.Mode
if !lb.Mode.ValidateUpdate() {
lb.l.Error().Msgf("invalid mode %q, fallback to %q", cfg.Mode, lb.Mode)
}
lb.updateImpl()
}
if len(lb.Options) == 0 && len(cfg.Options) > 0 {
lb.Options = cfg.Options
}
}
if lb.impl == nil {
lb.updateImpl()
}
}
func (lb *LoadBalancer) AddServer(srv Server) {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
if lb.pool.Has(srv.Name()) {
old, _ := lb.pool.Load(srv.Name())
lb.sumWeight -= old.Weight()
lb.impl.OnRemoveServer(old)
}
lb.pool.Store(srv.Name(), srv)
lb.sumWeight += srv.Weight()
lb.rebalance()
lb.impl.OnAddServer(srv)
lb.l.Debug().
Str("action", "add").
Str("server", srv.Name()).
Msgf("%d servers available", lb.pool.Size())
}
func (lb *LoadBalancer) RemoveServer(srv Server) {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
if !lb.pool.Has(srv.Name()) {
return
}
lb.pool.Delete(srv.Name())
lb.sumWeight -= srv.Weight()
lb.rebalance()
lb.impl.OnRemoveServer(srv)
lb.l.Debug().
Str("action", "remove").
Str("server", srv.Name()).
Msgf("%d servers left", lb.pool.Size())
if lb.pool.Size() == 0 {
lb.task.Finish("no server left")
return
}
}
func (lb *LoadBalancer) rebalance() {
if lb.sumWeight == maxWeight {
return
}
poolSize := lb.pool.Size()
if poolSize == 0 {
return
}
if lb.sumWeight == 0 { // distribute evenly
weightEach := maxWeight / Weight(poolSize)
remainder := maxWeight % Weight(poolSize)
lb.pool.RangeAll(func(_ string, s Server) {
w := weightEach
lb.sumWeight += weightEach
if remainder > 0 {
w++
remainder--
}
s.SetWeight(w)
})
return
}
// scale evenly
scaleFactor := float64(maxWeight) / float64(lb.sumWeight)
lb.sumWeight = 0
lb.pool.RangeAll(func(_ string, s Server) {
s.SetWeight(Weight(float64(s.Weight()) * scaleFactor))
lb.sumWeight += s.Weight()
})
delta := maxWeight - lb.sumWeight
if delta == 0 {
return
}
lb.pool.Range(func(_ string, s Server) bool {
if delta == 0 {
return false
}
if delta > 0 {
s.SetWeight(s.Weight() + 1)
lb.sumWeight++
delta--
} else {
s.SetWeight(s.Weight() - 1)
lb.sumWeight--
delta++
}
return true
})
}
func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
srvs := lb.availServers()
if len(srvs) == 0 {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return
}
if r.Header.Get(httpheaders.HeaderGoDoxyCheckRedirect) != "" {
// wake all servers
for _, srv := range srvs {
if err := srv.TryWake(); err != nil {
lb.l.Warn().Err(err).
Str("server", srv.Name()).
Msg("failed to wake server")
}
}
}
lb.impl.ServeHTTP(srvs, rw, r)
}
// MarshalJSON implements health.HealthMonitor.
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
extra := make(map[string]any)
lb.pool.RangeAll(func(k string, v Server) {
extra[v.Name()] = v
})
return (&monitor.JSONRepresentation{
Name: lb.Name(),
Status: lb.Status(),
Started: lb.startTime,
Uptime: lb.Uptime(),
Extra: map[string]any{
"config": lb.Config,
"pool": extra,
},
}).MarshalJSON()
}
// Name implements health.HealthMonitor.
func (lb *LoadBalancer) Name() string {
return lb.Link
}
// Status implements health.HealthMonitor.
func (lb *LoadBalancer) Status() health.Status {
if lb.pool.Size() == 0 {
return health.StatusUnknown
}
isHealthy := true
lb.pool.Range(func(_ string, srv Server) bool {
if srv.Status().Bad() {
isHealthy = false
return false
}
return true
})
if !isHealthy {
return health.StatusUnhealthy
}
return health.StatusHealthy
}
// Uptime implements health.HealthMonitor.
func (lb *LoadBalancer) Uptime() time.Duration {
return time.Since(lb.startTime)
}
// Latency implements health.HealthMonitor.
func (lb *LoadBalancer) Latency() time.Duration {
var sum time.Duration
lb.pool.RangeAll(func(_ string, srv Server) {
sum += srv.Latency()
})
return sum
}
// String implements health.HealthMonitor.
func (lb *LoadBalancer) String() string {
return lb.Name()
}
func (lb *LoadBalancer) availServers() []Server {
avail := make([]Server, 0, lb.pool.Size())
lb.pool.RangeAll(func(_ string, srv Server) {
if srv.Status().Good() {
avail = append(avail, srv)
}
})
return avail
}

View File

@@ -0,0 +1,44 @@
package loadbalancer
import (
"testing"
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRebalance(t *testing.T) {
t.Parallel()
t.Run("zero", func(t *testing.T) {
lb := New(new(types.Config))
for range 10 {
lb.AddServer(types.TestNewServer(0))
}
lb.rebalance()
ExpectEqual(t, lb.sumWeight, maxWeight)
})
t.Run("less", func(t *testing.T) {
lb := New(new(types.Config))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .3))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
lb.rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight)
})
t.Run("more", func(t *testing.T) {
lb := New(new(types.Config))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .3))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .4))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .3))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
lb.rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight)
})
}

View File

@@ -0,0 +1,22 @@
package loadbalancer
import (
"net/http"
"sync/atomic"
)
type roundRobin struct {
index atomic.Uint32
}
func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} }
func (lb *roundRobin) OnAddServer(srv Server) {}
func (lb *roundRobin) OnRemoveServer(srv Server) {}
func (lb *roundRobin) ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request) {
index := lb.index.Add(1) % uint32(len(srvs))
srvs[index].ServeHTTP(rw, r)
if lb.index.Load() >= 2*uint32(len(srvs)) {
lb.index.Store(0)
}
}

View File

@@ -0,0 +1,14 @@
package loadbalancer
import (
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
)
type (
Server = types.Server
Servers = []types.Server
Pool = types.Pool
Weight = types.Weight
Config = types.Config
Mode = types.Mode
)

View File

@@ -0,0 +1,8 @@
package types
type Config struct {
Link string `json:"link"`
Mode Mode `json:"mode"`
Weight Weight `json:"weight"`
Options map[string]any `json:"options,omitempty"`
}

View File

@@ -0,0 +1,32 @@
package types
import (
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Mode string
const (
ModeUnset Mode = ""
ModeRoundRobin Mode = "roundrobin"
ModeLeastConn Mode = "leastconn"
ModeIPHash Mode = "iphash"
)
func (mode *Mode) ValidateUpdate() bool {
switch strutils.ToLowerNoSnake(string(*mode)) {
case "":
return true
case string(ModeRoundRobin):
*mode = ModeRoundRobin
return true
case string(ModeLeastConn):
*mode = ModeLeastConn
return true
case string(ModeIPHash):
*mode = ModeIPHash
return true
}
*mode = ModeRoundRobin
return false
}

View File

@@ -0,0 +1,86 @@
package types
import (
"net/http"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
net "github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type (
server struct {
_ U.NoCopy
name string
url *net.URL
weight Weight
http.Handler `json:"-"`
health.HealthMonitor
}
Server interface {
http.Handler
health.HealthMonitor
Name() string
URL() *net.URL
Weight() Weight
SetWeight(weight Weight)
TryWake() error
}
Pool = F.Map[string, Server]
)
var NewServerPool = F.NewMap[Pool]
func NewServer(name string, url *net.URL, weight Weight, handler http.Handler, healthMon health.HealthMonitor) Server {
srv := &server{
name: name,
url: url,
weight: weight,
Handler: handler,
HealthMonitor: healthMon,
}
return srv
}
func TestNewServer[T ~int | ~float32 | ~float64](weight T) Server {
srv := &server{
weight: Weight(weight),
}
return srv
}
func (srv *server) Name() string {
return srv.name
}
func (srv *server) URL() *net.URL {
return srv.url
}
func (srv *server) Weight() Weight {
return srv.weight
}
func (srv *server) SetWeight(weight Weight) {
srv.weight = weight
}
func (srv *server) String() string {
return srv.name
}
func (srv *server) TryWake() error {
waker, ok := srv.Handler.(idlewatcher.Waker)
if ok {
if err := waker.Wake(); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,3 @@
package types
type Weight uint16

View File

@@ -0,0 +1,20 @@
package gphttp
import (
"net/http"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/logging"
)
func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
return logging.WithLevel(level).
Str("remote", r.RemoteAddr).
Str("host", r.Host).
Str("uri", r.Method+" "+r.RequestURI)
}
func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) }
func LogWarn(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.WarnLevel) }
func LogInfo(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.InfoLevel) }
func LogDebug(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.DebugLevel) }

View File

@@ -0,0 +1,20 @@
package gphttp
import "net/http"
func IsMethodValid(method string) bool {
switch method {
case http.MethodGet,
http.MethodHead,
http.MethodPost,
http.MethodPut,
http.MethodPatch,
http.MethodDelete,
http.MethodConnect,
http.MethodOptions,
http.MethodTrace:
return true
default:
return false
}
}

View File

@@ -0,0 +1,82 @@
package middleware
import (
"net"
"net/http"
"github.com/go-playground/validator/v10"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
cidrWhitelist struct {
CIDRWhitelistOpts
Tracer
cachedAddr F.Map[string, bool] // cache for trusted IPs
}
CIDRWhitelistOpts struct {
Allow []*types.CIDR `validate:"min=1"`
StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,status_code"`
Message string
}
)
var (
CIDRWhiteList = NewMiddleware[cidrWhitelist]()
cidrWhitelistDefaults = CIDRWhitelistOpts{
Allow: []*types.CIDR{},
StatusCode: http.StatusForbidden,
Message: "IP not allowed",
}
)
func init() {
utils.MustRegisterValidation("status_code", func(fl validator.FieldLevel) bool {
statusCode := fl.Field().Int()
return gphttp.IsStatusCodeValid(int(statusCode))
})
}
// setup implements MiddlewareWithSetup.
func (wl *cidrWhitelist) setup() {
wl.CIDRWhitelistOpts = cidrWhitelistDefaults
wl.cachedAddr = F.NewMapOf[string, bool]()
}
// before implements RequestModifier.
func (wl *cidrWhitelist) before(w http.ResponseWriter, r *http.Request) bool {
return wl.checkIP(w, r)
}
// checkIP checks if the IP address is allowed.
func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
var allow, ok bool
if allow, ok = wl.cachedAddr.Load(r.RemoteAddr); !ok {
ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
ipStr = r.RemoteAddr
}
ip := net.ParseIP(ipStr)
for _, cidr := range wl.CIDRWhitelistOpts.Allow {
if cidr.Contains(ip) {
wl.cachedAddr.Store(r.RemoteAddr, true)
allow = true
wl.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
break
}
}
if !allow {
wl.cachedAddr.Store(r.RemoteAddr, false)
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.CIDRWhitelistOpts.Allow)
}
}
if !allow {
http.Error(w, wl.Message, wl.StatusCode)
return false
}
return true
}

View File

@@ -0,0 +1,91 @@
package middleware
import (
_ "embed"
"net"
"net/http"
"strings"
"testing"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
//go:embed test_data/cidr_whitelist_test.yml
var testCIDRWhitelistCompose []byte
var deny, accept *Middleware
func TestCIDRWhitelistValidation(t *testing.T) {
const testMessage = "test-message"
t.Run("valid", func(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
"allow": []string{"192.168.2.100/32"},
"message": testMessage,
})
ExpectNoError(t, err)
_, err = CIDRWhiteList.New(OptionsRaw{
"allow": []string{"192.168.2.100/32"},
"message": testMessage,
"status": 403,
})
ExpectNoError(t, err)
_, err = CIDRWhiteList.New(OptionsRaw{
"allow": []string{"192.168.2.100/32"},
"message": testMessage,
"status_code": 403,
})
ExpectNoError(t, err)
})
t.Run("missing allow", func(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
"message": testMessage,
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid cidr", func(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
"allow": []string{"192.168.2.100/123"},
"message": testMessage,
})
ExpectErrorT[*net.ParseError](t, err)
})
t.Run("invalid status code", func(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
"allow": []string{"192.168.2.100/32"},
"status_code": 600,
"message": testMessage,
})
ExpectError(t, utils.ErrValidationError, err)
})
}
func TestCIDRWhitelist(t *testing.T) {
errs := gperr.NewBuilder("")
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
ExpectNoError(t, errs.Error())
deny = mids["deny@file"]
accept = mids["accept@file"]
if deny == nil || accept == nil {
panic("bug occurred")
}
t.Run("deny", func(t *testing.T) {
t.Parallel()
for range 10 {
result, err := newMiddlewareTest(deny, nil)
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults.StatusCode)
ExpectEqual(t, strings.TrimSpace(string(result.Data)), cidrWhitelistDefaults.Message)
}
})
t.Run("accept", func(t *testing.T) {
t.Parallel()
for range 10 {
result, err := newMiddlewareTest(accept, nil)
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
}
})
}

View File

@@ -0,0 +1,127 @@
package middleware
import (
"errors"
"fmt"
"io"
"net"
"net/http"
"sync"
"time"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type cloudflareRealIP struct {
realIP realIP
Recursive bool
}
const (
cfIPv4CIDRsEndpoint = "https://www.cloudflare.com/ips-v4"
cfIPv6CIDRsEndpoint = "https://www.cloudflare.com/ips-v6"
cfCIDRsUpdateInterval = time.Hour
cfCIDRsUpdateRetryInterval = 3 * time.Second
)
var (
cfCIDRsLastUpdate time.Time
cfCIDRsMu sync.Mutex
)
var CloudflareRealIP = NewMiddleware[cloudflareRealIP]()
// setup implements MiddlewareWithSetup.
func (cri *cloudflareRealIP) setup() {
cri.realIP.RealIPOpts = RealIPOpts{
Header: "Cf-Connecting-Ip",
Recursive: cri.Recursive,
}
}
// before implements RequestModifier.
func (cri *cloudflareRealIP) before(w http.ResponseWriter, r *http.Request) bool {
cidrs := tryFetchCFCIDR()
if cidrs != nil {
cri.realIP.From = cidrs
}
return cri.realIP.before(w, r)
}
func (cri *cloudflareRealIP) enableTrace() {
cri.realIP.enableTrace()
}
func (cri *cloudflareRealIP) getTracer() *Tracer {
return cri.realIP.getTracer()
}
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
return
}
cfCIDRsMu.Lock()
defer cfCIDRsMu.Unlock()
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
return
}
if common.IsTest {
cfCIDRs = []*types.CIDR{
{IP: net.IPv4(127, 0, 0, 1), Mask: net.IPv4Mask(255, 0, 0, 0)},
{IP: net.IPv4(10, 0, 0, 0), Mask: net.IPv4Mask(255, 0, 0, 0)},
{IP: net.IPv4(172, 16, 0, 0), Mask: net.IPv4Mask(255, 255, 0, 0)},
{IP: net.IPv4(192, 168, 0, 0), Mask: net.IPv4Mask(255, 255, 255, 0)},
}
} else {
cfCIDRs = make([]*types.CIDR, 0, 30)
err := errors.Join(
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, &cfCIDRs),
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs),
)
if err != nil {
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)
logging.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
return nil
}
if len(cfCIDRs) == 0 {
logging.Warn().Msg("cloudflare CIDR range is empty")
}
}
cfCIDRsLastUpdate = time.Now()
logging.Info().Msg("cloudflare CIDR range updated")
return
}
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
resp, err := http.Get(endpoint)
if err != nil {
return err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
for _, line := range strutils.SplitLine(string(body)) {
if line == "" {
continue
}
_, cidr, err := net.ParseCIDR(line)
if err != nil {
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
}
*cfCIDRs = append(*cfCIDRs, (*types.CIDR)(cidr))
}
return nil
}

View File

@@ -0,0 +1,80 @@
package middleware
import (
"bytes"
"io"
"net/http"
"path/filepath"
"strconv"
"strings"
"github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/middleware/errorpage"
)
type customErrorPage struct{}
var CustomErrorPage = NewMiddleware[customErrorPage]()
const StaticFilePathPrefix = "/$gperrorpage/"
// before implements RequestModifier.
func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
return !ServeStaticErrorPageFile(w, r)
}
// modifyResponse implements ResponseModifier.
func (customErrorPage) modifyResponse(resp *http.Response) error {
// only handles non-success status code and html/plain content type
contentType := gphttp.GetContentType(resp.Header)
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
if ok {
logging.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
_, _ = io.Copy(io.Discard, resp.Body) // drain the original body
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
resp.ContentLength = int64(len(errorPage))
resp.Header.Set(httpheaders.HeaderContentLength, strconv.Itoa(len(errorPage)))
resp.Header.Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
} else {
logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
}
return nil
}
return nil
}
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bool) {
path := r.URL.Path
if path != "" && path[0] != '/' {
path = "/" + path
}
if strings.HasPrefix(path, StaticFilePathPrefix) {
filename := path[len(StaticFilePathPrefix):]
file, ok := errorpage.GetStaticFile(filename)
if !ok {
logging.Error().Msg("unable to load resource " + filename)
return false
}
ext := filepath.Ext(filename)
switch ext {
case ".html":
w.Header().Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
case ".js":
w.Header().Set(httpheaders.HeaderContentType, "application/javascript; charset=utf-8")
case ".css":
w.Header().Set(httpheaders.HeaderContentType, "text/css; charset=utf-8")
default:
logging.Error().Msgf("unexpected file type %q for %s", ext, filename)
}
if _, err := w.Write(file); err != nil {
logging.Err(err).Msg("unable to write resource " + filename)
http.Error(w, "Error page failure", http.StatusInternalServerError)
}
return true
}
return false
}

View File

@@ -0,0 +1,96 @@
package errorpage
import (
"fmt"
"os"
"path"
"sync"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
W "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events"
)
const errPagesBasePath = common.ErrorPagesBasePath
var (
setupOnce sync.Once
dirWatcher W.Watcher
fileContentMap = F.NewMapOf[string, []byte]()
)
func setup() {
t := task.RootTask("error_page", false)
dirWatcher = W.NewDirectoryWatcher(t, errPagesBasePath)
loadContent()
go watchDir()
}
func GetStaticFile(filename string) ([]byte, bool) {
setupOnce.Do(setup)
return fileContentMap.Load(filename)
}
// try <statusCode>.html -> 404.html -> not ok.
func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
content, ok = GetStaticFile(fmt.Sprintf("%d.html", statusCode))
if !ok && statusCode != 404 {
return fileContentMap.Load("404.html")
}
return
}
func loadContent() {
files, err := U.ListFiles(errPagesBasePath, 0)
if err != nil {
logging.Err(err).Msg("failed to list error page resources")
return
}
for _, file := range files {
if fileContentMap.Has(file) {
continue
}
content, err := os.ReadFile(file)
if err != nil {
logging.Warn().Err(err).Msgf("failed to read error page resource %s", file)
continue
}
file = path.Base(file)
logging.Info().Msgf("error page resource %s loaded", file)
fileContentMap.Store(file, content)
}
}
func watchDir() {
eventCh, errCh := dirWatcher.Events(task.RootContext())
for {
select {
case <-task.RootContextCanceled():
return
case event, ok := <-eventCh:
if !ok {
return
}
filename := event.ActorName
switch event.Action {
case events.ActionFileWritten:
fileContentMap.Delete(filename)
loadContent()
case events.ActionFileDeleted:
fileContentMap.Delete(filename)
logging.Warn().Msgf("error page resource %s deleted", filename)
case events.ActionFileRenamed:
logging.Warn().Msgf("error page resource %s deleted", filename)
fileContentMap.Delete(filename)
loadContent()
}
case err := <-errCh:
gperr.LogError("error watching error page directory", err)
}
}
}

View File

@@ -0,0 +1,44 @@
package metricslogger
import (
"net"
"net/http"
"github.com/yusing/go-proxy/internal/metrics"
)
type MetricsLogger struct {
ServiceName string `json:"service_name"`
}
func NewMetricsLogger(serviceName string) *MetricsLogger {
return &MetricsLogger{serviceName}
}
func (m *MetricsLogger) GetHandler(next http.Handler) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
m.ServeHTTP(rw, req, next.ServeHTTP)
}
}
func (m *MetricsLogger) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
visitorIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
visitorIP = req.RemoteAddr
}
// req.RemoteAddr had been modified by middleware (if any)
lbls := &metrics.HTTPRouteMetricLabels{
Service: m.ServiceName,
Method: req.Method,
Host: req.Host,
Visitor: visitorIP,
Path: req.URL.Path,
}
next.ServeHTTP(newHTTPMetricLogger(rw, lbls), req)
}
func (m *MetricsLogger) ResetMetrics() {
metrics.GetRouteMetrics().UnregisterService(m.ServiceName)
}

View File

@@ -0,0 +1,47 @@
package metricslogger
import (
"net/http"
"time"
"github.com/yusing/go-proxy/internal/metrics"
)
type httpMetricLogger struct {
http.ResponseWriter
timestamp time.Time
labels *metrics.HTTPRouteMetricLabels
}
// WriteHeader implements http.ResponseWriter.
func (l *httpMetricLogger) WriteHeader(status int) {
l.ResponseWriter.WriteHeader(status)
duration := time.Since(l.timestamp)
go func() {
m := metrics.GetRouteMetrics()
m.HTTPReqTotal.Inc()
m.HTTPReqElapsed.With(l.labels).Set(float64(duration.Milliseconds()))
// ignore 1xx
switch {
case status >= 500:
m.HTTP5xx.With(l.labels).Inc()
case status >= 400:
m.HTTP4xx.With(l.labels).Inc()
case status >= 200:
m.HTTP2xx3xx.With(l.labels).Inc()
}
}()
}
func (l *httpMetricLogger) Unwrap() http.ResponseWriter {
return l.ResponseWriter
}
func newHTTPMetricLogger(w http.ResponseWriter, labels *metrics.HTTPRouteMetricLabels) *httpMetricLogger {
return &httpMetricLogger{
ResponseWriter: w,
timestamp: time.Now(),
labels: labels,
}
}

View File

@@ -0,0 +1,237 @@
package middleware
import (
"encoding/json"
"net/http"
"reflect"
"sort"
"strings"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/utils"
)
type (
Error = gperr.Error
ReverseProxy = reverseproxy.ReverseProxy
ProxyRequest = reverseproxy.ProxyRequest
ImplNewFunc = func() any
OptionsRaw = map[string]any
Middleware struct {
name string
construct ImplNewFunc
impl any
// priority is only applied for ReverseProxy.
//
// Middleware compose follows the order of the slice
//
// Default is 10, 0 is the highest
priority int
}
ByPriority []*Middleware
RequestModifier interface {
before(w http.ResponseWriter, r *http.Request) (proceed bool)
}
ResponseModifier interface{ modifyResponse(r *http.Response) error }
MiddlewareWithSetup interface{ setup() }
MiddlewareFinalizer interface{ finalize() }
MiddlewareFinalizerWithError interface {
finalize() error
}
MiddlewareWithTracer interface {
enableTrace()
getTracer() *Tracer
}
)
const DefaultPriority = 10
func (m ByPriority) Len() int { return len(m) }
func (m ByPriority) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
func (m ByPriority) Less(i, j int) bool { return m[i].priority < m[j].priority }
func NewMiddleware[ImplType any]() *Middleware {
// type check
t := any(new(ImplType))
switch t.(type) {
case RequestModifier:
case ResponseModifier:
default:
panic("must implement RequestModifier or ResponseModifier")
}
_, hasFinializer := t.(MiddlewareFinalizer)
_, hasFinializerWithError := t.(MiddlewareFinalizerWithError)
if hasFinializer && hasFinializerWithError {
panic("MiddlewareFinalizer and MiddlewareFinalizerWithError are mutually exclusive")
}
return &Middleware{
name: strings.ToLower(reflect.TypeFor[ImplType]().Name()),
construct: func() any { return new(ImplType) },
}
}
func (m *Middleware) enableTrace() {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
tracer.enableTrace()
logging.Debug().Msgf("middleware %s enabled trace", m.name)
}
}
func (m *Middleware) getTracer() *Tracer {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
return tracer.getTracer()
}
return nil
}
func (m *Middleware) setParent(parent *Middleware) {
if tracer := m.getTracer(); tracer != nil {
tracer.SetParent(parent.getTracer())
}
}
func (m *Middleware) setup() {
if setup, ok := m.impl.(MiddlewareWithSetup); ok {
setup.setup()
}
}
func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error {
if len(optsRaw) == 0 {
return nil
}
priority, ok := optsRaw["priority"].(int)
if ok {
m.priority = priority
// remove priority for deserialization, restore later
delete(optsRaw, "priority")
defer func() {
optsRaw["priority"] = priority
}()
} else {
m.priority = DefaultPriority
}
return utils.Deserialize(optsRaw, m.impl)
}
func (m *Middleware) finalize() error {
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
finalizer.finalize()
return nil
}
if finalizer, ok := m.impl.(MiddlewareFinalizerWithError); ok {
return finalizer.finalize()
}
return nil
}
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, gperr.Error) {
if m.construct == nil { // likely a middleware from compose
if len(optsRaw) != 0 {
return nil, gperr.New("additional options not allowed for middleware ").Subject(m.name)
}
return m, nil
}
mid := &Middleware{name: m.name, impl: m.construct()}
mid.setup()
if err := mid.apply(optsRaw); err != nil {
return nil, err
}
if err := mid.finalize(); err != nil {
return nil, gperr.Wrap(err)
}
return mid, nil
}
func (m *Middleware) Name() string {
return m.name
}
func (m *Middleware) String() string {
return m.name
}
func (m *Middleware) MarshalJSON() ([]byte, error) {
return json.MarshalIndent(map[string]any{
"name": m.name,
"options": m.impl,
"priority": m.priority,
}, "", " ")
}
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
if exec, ok := m.impl.(RequestModifier); ok {
if proceed := exec.before(w, r); !proceed {
return
}
}
next(w, r)
}
func (m *Middleware) ModifyResponse(resp *http.Response) error {
if exec, ok := m.impl.(ResponseModifier); ok {
return exec.modifyResponse(resp)
}
return nil
}
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
if exec, ok := m.impl.(ResponseModifier); ok {
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
return exec.modifyResponse(resp)
})
}
if exec, ok := m.impl.(RequestModifier); ok {
if proceed := exec.before(w, r); !proceed {
return
}
}
next(w, r)
}
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err gperr.Error) {
var middlewares []*Middleware
middlewares, err = compileMiddlewares(middlewaresMap)
if err != nil {
return
}
patchReverseProxy(rp, middlewares)
return
}
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
sort.Sort(ByPriority(middlewares))
middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
mid := NewMiddlewareChain(rp.TargetName, middlewares)
if before, ok := mid.impl.(RequestModifier); ok {
next := rp.HandlerFunc
rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
if proceed := before.before(w, r); proceed {
next(w, r)
}
}
}
if mr, ok := mid.impl.(ResponseModifier); ok {
if rp.ModifyResponse != nil {
ori := rp.ModifyResponse
rp.ModifyResponse = func(res *http.Response) error {
if err := mr.modifyResponse(res); err != nil {
return err
}
return ori(res)
}
} else {
rp.ModifyResponse = mr.modifyResponse
}
}
}

View File

@@ -0,0 +1,107 @@
package middleware
import (
"fmt"
"os"
"path"
"sort"
"github.com/yusing/go-proxy/internal/gperr"
"gopkg.in/yaml.v3"
)
var ErrMissingMiddlewareUse = gperr.New("missing middleware 'use' field")
func BuildMiddlewaresFromComposeFile(filePath string, eb *gperr.Builder) map[string]*Middleware {
fileContent, err := os.ReadFile(filePath)
if err != nil {
eb.Add(err)
return nil
}
return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
}
func BuildMiddlewaresFromYAML(source string, data []byte, eb *gperr.Builder) map[string]*Middleware {
var rawMap map[string][]map[string]any
err := yaml.Unmarshal(data, &rawMap)
if err != nil {
eb.Add(err)
return nil
}
middlewares := make(map[string]*Middleware)
for name, defs := range rawMap {
chain, err := BuildMiddlewareFromChainRaw(name, defs)
if err != nil {
eb.Add(err.Subject(source))
} else {
middlewares[name+"@file"] = chain
}
}
return middlewares
}
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, gperr.Error) {
middlewares := make([]*Middleware, 0, len(middlewaresMap))
errs := gperr.NewBuilder("middlewares compile error")
invalidOpts := gperr.NewBuilder("options compile error")
for name, opts := range middlewaresMap {
m, err := Get(name)
if err != nil {
errs.Add(err)
continue
}
m, err = m.New(opts)
if err != nil {
invalidOpts.Add(err.Subject(name))
continue
}
middlewares = append(middlewares, m)
}
if invalidOpts.HasError() {
errs.Add(invalidOpts.Error())
}
sort.Sort(ByPriority(middlewares))
return middlewares, errs.Error()
}
func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (*Middleware, gperr.Error) {
compiled, err := compileMiddlewares(middlewaresMap)
if err != nil {
return nil, err
}
return NewMiddlewareChain(name, compiled), nil
}
// TODO: check conflict or duplicates.
func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middleware, gperr.Error) {
chainErr := gperr.NewBuilder("")
chain := make([]*Middleware, 0, len(defs))
for i, def := range defs {
if def["use"] == nil || def["use"] == "" {
chainErr.Add(ErrMissingMiddlewareUse.Subjectf("%s[%d]", name, i))
continue
}
baseName := def["use"].(string)
base, err := Get(baseName)
if err != nil {
chainErr.Add(err.Subjectf("%s[%d]", name, i))
continue
}
delete(def, "use")
m, err := base.New(def)
if err != nil {
chainErr.Add(err.Subjectf("%s[%d]", name, i))
continue
}
m.name = fmt.Sprintf("%s[%d]", name, i)
chain = append(chain, m)
}
if chainErr.HasError() {
return nil, chainErr.Error()
}
return NewMiddlewareChain(name, chain), nil
}

View File

@@ -0,0 +1,22 @@
package middleware
import (
_ "embed"
"encoding/json"
"testing"
"github.com/yusing/go-proxy/internal/gperr"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
//go:embed test_data/middleware_compose.yml
var testMiddlewareCompose []byte
func TestBuild(t *testing.T) {
errs := gperr.NewBuilder("")
middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs)
ExpectNoError(t, errs.Error())
Must(json.MarshalIndent(middlewares, "", " "))
// t.Log(string(data))
// TODO: test
}

View File

@@ -0,0 +1,61 @@
package middleware
import (
"net/http"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
)
type middlewareChain struct {
befores []RequestModifier
modResps []ResponseModifier
}
// TODO: check conflict or duplicates.
func NewMiddlewareChain(name string, chain []*Middleware) *Middleware {
chainMid := &middlewareChain{befores: []RequestModifier{}, modResps: []ResponseModifier{}}
m := &Middleware{name: name, impl: chainMid}
for _, comp := range chain {
if before, ok := comp.impl.(RequestModifier); ok {
chainMid.befores = append(chainMid.befores, before)
}
if mr, ok := comp.impl.(ResponseModifier); ok {
chainMid.modResps = append(chainMid.modResps, mr)
}
comp.setParent(m)
}
if common.IsDebug {
for _, child := range chain {
child.enableTrace()
}
m.enableTrace()
}
return m
}
// before implements RequestModifier.
func (m *middlewareChain) before(w http.ResponseWriter, r *http.Request) (proceedNext bool) {
for _, b := range m.befores {
if proceedNext = b.before(w, r); !proceedNext {
return false
}
}
return true
}
// modifyResponse implements ResponseModifier.
func (m *middlewareChain) modifyResponse(resp *http.Response) error {
if len(m.modResps) == 0 {
return nil
}
errs := gperr.NewBuilder("modify response errors")
for i, mr := range m.modResps {
if err := mr.modifyResponse(resp); err != nil {
errs.Add(gperr.Wrap(err).Subjectf("%d", i))
}
}
return errs.Error()
}

View File

@@ -0,0 +1,37 @@
package middleware
import (
"net/http"
"strconv"
"strings"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
type testPriority struct {
Value int `json:"value"`
}
var test = NewMiddleware[testPriority]()
func (t testPriority) before(w http.ResponseWriter, r *http.Request) bool {
w.Header().Add("Test-Value", strconv.Itoa(t.Value))
return true
}
func TestMiddlewarePriority(t *testing.T) {
priorities := []int{4, 7, 9, 0}
chain := make([]*Middleware, len(priorities))
for i, p := range priorities {
mid, err := test.New(OptionsRaw{
"priority": p,
"value": i,
})
ExpectNoError(t, err)
chain[i] = mid
}
res, err := newMiddlewaresTest(chain, nil)
ExpectNoError(t, err)
ExpectEqual(t, strings.Join(res.ResponseHeaders["Test-Value"], ","), "3,0,1,2")
}

View File

@@ -0,0 +1,104 @@
package middleware
import (
"path"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
// snakes and cases will be stripped on `Get`
// so keys are lowercase without snake.
var allMiddlewares = map[string]*Middleware{
"redirecthttp": RedirectHTTP,
"oidc": OIDC,
"request": ModifyRequest,
"modifyrequest": ModifyRequest,
"response": ModifyResponse,
"modifyresponse": ModifyResponse,
"setxforwarded": SetXForwarded,
"hidexforwarded": HideXForwarded,
"errorpage": CustomErrorPage,
"customerrorpage": CustomErrorPage,
"realip": RealIP,
"cloudflarerealip": CloudflareRealIP,
"cidrwhitelist": CIDRWhiteList,
"ratelimit": RateLimiter,
}
var (
ErrUnknownMiddleware = gperr.New("unknown middleware")
ErrDuplicatedMiddleware = gperr.New("duplicated middleware")
)
func Get(name string) (*Middleware, Error) {
middleware, ok := allMiddlewares[strutils.ToLowerNoSnake(name)]
if !ok {
return nil, ErrUnknownMiddleware.
Subject(name).
Withf(strutils.DoYouMean(utils.NearestField(name, allMiddlewares)))
}
return middleware, nil
}
func All() map[string]*Middleware {
return allMiddlewares
}
func LoadComposeFiles() {
errs := gperr.NewBuilder("middleware compile errors")
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
if err != nil {
logging.Err(err).Msg("failed to list middleware definitions")
return
}
for _, defFile := range middlewareDefs {
voidErrs := gperr.NewBuilder("") // ignore these errors, will be added in next step
mws := BuildMiddlewaresFromComposeFile(defFile, voidErrs)
if len(mws) == 0 {
continue
}
for name, m := range mws {
name = strutils.ToLowerNoSnake(name)
if _, ok := allMiddlewares[name]; ok {
errs.Add(ErrDuplicatedMiddleware.Subject(name))
continue
}
allMiddlewares[name] = m
logging.Info().
Str("src", path.Base(defFile)).
Str("name", name).
Msg("middleware loaded")
}
}
// build again to resolve cross references
for _, defFile := range middlewareDefs {
mws := BuildMiddlewaresFromComposeFile(defFile, errs)
if len(mws) == 0 {
continue
}
for name, m := range mws {
name = strutils.ToLowerNoSnake(name)
if _, ok := allMiddlewares[name]; ok {
// already loaded above
continue
}
allMiddlewares[name] = m
logging.Info().
Str("src", path.Base(defFile)).
Str("name", name).
Msg("middleware loaded")
}
}
if errs.HasError() {
gperr.LogError(errs.About(), errs.Error())
}
}

View File

@@ -0,0 +1,91 @@
package middleware
import (
"net/http"
"path/filepath"
"strings"
)
type (
modifyRequest struct {
ModifyRequestOpts
Tracer
}
// order: add_prefix -> set_headers -> add_headers -> hide_headers
ModifyRequestOpts struct {
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
AddPrefix string
needVarSubstitution bool
}
)
var ModifyRequest = NewMiddleware[modifyRequest]()
// finalize implements MiddlewareFinalizer.
func (mr *ModifyRequestOpts) finalize() {
mr.checkVarSubstitution()
}
// before implements RequestModifier.
func (mr *modifyRequest) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
mr.AddTraceRequest("before modify request", r)
mr.addPrefix(r, nil, r.URL.Path)
mr.modifyHeaders(r, nil, r.Header)
mr.AddTraceRequest("after modify request", r)
return true
}
func (mr *ModifyRequestOpts) checkVarSubstitution() {
for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} {
for _, v := range m {
if strings.ContainsRune(v, '$') {
mr.needVarSubstitution = true
return
}
}
}
}
func (mr *ModifyRequestOpts) modifyHeaders(req *http.Request, resp *http.Response, headers http.Header) {
if !mr.needVarSubstitution {
for k, v := range mr.SetHeaders {
if req != nil && strings.EqualFold(k, "host") {
defer func() {
req.Host = v
}()
}
headers[k] = []string{v}
}
for k, v := range mr.AddHeaders {
headers[k] = append(headers[k], v)
}
} else {
for k, v := range mr.SetHeaders {
if req != nil && strings.EqualFold(k, "host") {
defer func() {
req.Host = varReplace(req, resp, v)
}()
}
headers[k] = []string{varReplace(req, resp, v)}
}
for k, v := range mr.AddHeaders {
headers[k] = append(headers[k], varReplace(req, resp, v))
}
}
for _, k := range mr.HideHeaders {
delete(headers, k)
}
}
func (mr *modifyRequest) addPrefix(r *http.Request, _ *http.Response, path string) {
if len(mr.AddPrefix) == 0 {
return
}
r.URL.Path = filepath.Join(mr.AddPrefix, path)
}

View File

@@ -0,0 +1,145 @@
package middleware
import (
"bytes"
"net"
"net/http"
"slices"
"testing"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestModifyRequest(t *testing.T) {
opts := OptionsRaw{
"set_headers": map[string]string{
"User-Agent": "go-proxy/v0.5.0",
"Host": VarUpstreamAddr,
"X-Test-Req-Method": VarRequestMethod,
"X-Test-Req-Scheme": VarRequestScheme,
"X-Test-Req-Host": VarRequestHost,
"X-Test-Req-Port": VarRequestPort,
"X-Test-Req-Addr": VarRequestAddr,
"X-Test-Req-Path": VarRequestPath,
"X-Test-Req-Query": VarRequestQuery,
"X-Test-Req-Url": VarRequestURL,
"X-Test-Req-Uri": VarRequestURI,
"X-Test-Req-Content-Type": VarRequestContentType,
"X-Test-Req-Content-Length": VarRequestContentLen,
"X-Test-Remote-Host": VarRemoteHost,
"X-Test-Remote-Port": VarRemotePort,
"X-Test-Remote-Addr": VarRemoteAddr,
"X-Test-Upstream-Scheme": VarUpstreamScheme,
"X-Test-Upstream-Host": VarUpstreamHost,
"X-Test-Upstream-Port": VarUpstreamPort,
"X-Test-Upstream-Addr": VarUpstreamAddr,
"X-Test-Upstream-Url": VarUpstreamURL,
"X-Test-Header-Content-Type": "$header(Content-Type)",
"X-Test-Arg-Arg_1": "$arg(arg_1)",
},
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
"hide_headers": []string{"Accept"},
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.New(opts)
ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string))
})
t.Run("request_headers", func(t *testing.T) {
reqURL := types.MustParseURL("https://my.app/?arg_1=b")
upstreamURL := types.MustParseURL("http://test.example.com")
result, err := newMiddlewareTest(ModifyRequest, &testArgs{
middlewareOpt: opts,
reqURL: reqURL,
upstreamURL: upstreamURL,
body: bytes.Repeat([]byte("a"), 100),
headers: http.Header{
"Content-Type": []string{"application/json"},
},
})
ExpectNoError(t, err)
ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
ExpectEqual(t, result.RequestHeaders.Get("Host"), "test.example.com")
ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Method"), "GET")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Port"), reqURL.Port())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Path"), reqURL.Path)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Url"), reqURL.String())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Type"), "application/json")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Length"), "100")
remoteHost, remotePort, _ := net.SplitHostPort(result.RemoteAddr)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Host"), remoteHost)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Port"), remotePort)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Header-Content-Type"), "application/json")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Arg-Arg_1"), "b")
})
t.Run("add_prefix", func(t *testing.T) {
tests := []struct {
name string
path string
expectedPath string
upstreamURL string
addPrefix string
}{
{
name: "no prefix",
path: "/foo",
expectedPath: "/foo",
upstreamURL: "http://test.example.com",
},
{
name: "slash only",
path: "/",
expectedPath: "/",
upstreamURL: "http://test.example.com",
addPrefix: "/", // should not change anything
},
{
name: "some prefix",
path: "/test",
expectedPath: "/foo/test",
upstreamURL: "http://test.example.com",
addPrefix: "/foo",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reqURL := types.MustParseURL("https://my.app" + tt.path)
upstreamURL := types.MustParseURL(tt.upstreamURL)
opts["add_prefix"] = tt.addPrefix
result, err := newMiddlewareTest(ModifyRequest, &testArgs{
middlewareOpt: opts,
reqURL: reqURL,
upstreamURL: upstreamURL,
})
ExpectNoError(t, err)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Path"), tt.expectedPath)
})
}
})
}

View File

@@ -0,0 +1,20 @@
package middleware
import (
"net/http"
)
type modifyResponse struct {
ModifyRequestOpts
Tracer
}
var ModifyResponse = NewMiddleware[modifyResponse]()
// modifyResponse implements ResponseModifier.
func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
mr.AddTraceResponse("before modify response", resp)
mr.modifyHeaders(resp.Request, resp, resp.Header)
mr.AddTraceResponse("after modify response", resp)
return nil
}

View File

@@ -0,0 +1,108 @@
package middleware
import (
"bytes"
"net"
"net/http"
"slices"
"testing"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestModifyResponse(t *testing.T) {
opts := OptionsRaw{
"set_headers": map[string]string{
"X-Test-Resp-Status": VarRespStatusCode,
"X-Test-Resp-Content-Type": VarRespContentType,
"X-Test-Resp-Content-Length": VarRespContentLen,
"X-Test-Resp-Header-Content-Type": "$resp_header(Content-Type)",
"X-Test-Req-Method": VarRequestMethod,
"X-Test-Req-Scheme": VarRequestScheme,
"X-Test-Req-Host": VarRequestHost,
"X-Test-Req-Port": VarRequestPort,
"X-Test-Req-Addr": VarRequestAddr,
"X-Test-Req-Path": VarRequestPath,
"X-Test-Req-Query": VarRequestQuery,
"X-Test-Req-Url": VarRequestURL,
"X-Test-Req-Uri": VarRequestURI,
"X-Test-Req-Content-Type": VarRequestContentType,
"X-Test-Req-Content-Length": VarRequestContentLen,
"X-Test-Remote-Host": VarRemoteHost,
"X-Test-Remote-Port": VarRemotePort,
"X-Test-Remote-Addr": VarRemoteAddr,
"X-Test-Upstream-Scheme": VarUpstreamScheme,
"X-Test-Upstream-Host": VarUpstreamHost,
"X-Test-Upstream-Port": VarUpstreamPort,
"X-Test-Upstream-Addr": VarUpstreamAddr,
"X-Test-Upstream-Url": VarUpstreamURL,
"X-Test-Header-Content-Type": "$header(Content-Type)",
"X-Test-Arg-Arg_1": "$arg(arg_1)",
},
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
"hide_headers": []string{"Accept"},
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.New(opts)
ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
})
t.Run("response_headers", func(t *testing.T) {
reqURL := types.MustParseURL("https://my.app/?arg_1=b")
upstreamURL := types.MustParseURL("http://test.example.com")
result, err := newMiddlewareTest(ModifyResponse, &testArgs{
middlewareOpt: opts,
reqURL: reqURL,
upstreamURL: upstreamURL,
body: bytes.Repeat([]byte("a"), 100),
headers: http.Header{
"Content-Type": []string{"application/json"},
},
respHeaders: http.Header{
"Content-Type": []string{"application/json"},
},
respBody: bytes.Repeat([]byte("a"), 50),
respStatus: http.StatusOK,
})
ExpectNoError(t, err)
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Status"), "200")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Type"), "application/json")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Length"), "50")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Header-Content-Type"), "application/json")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Method"), http.MethodGet)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Port"), reqURL.Port())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Path"), reqURL.Path)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Url"), reqURL.String())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Content-Type"), "application/json")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Content-Length"), "100")
remoteHost, remotePort, _ := net.SplitHostPort(result.RemoteAddr)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Host"), remoteHost)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Port"), remotePort)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Header-Content-Type"), "application/json")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Arg-Arg_1"), "b")
})
}

View File

@@ -0,0 +1,91 @@
package middleware
import (
"net/http"
"sync"
"sync/atomic"
"github.com/yusing/go-proxy/internal/api/v1/auth"
"github.com/yusing/go-proxy/internal/gperr"
)
type oidcMiddleware struct {
AllowedUsers []string `json:"allowed_users"`
AllowedGroups []string `json:"allowed_groups"`
auth auth.Provider
authMux *http.ServeMux
isInitialized int32
initMu sync.Mutex
}
var OIDC = NewMiddleware[oidcMiddleware]()
func (amw *oidcMiddleware) finalize() error {
if !auth.IsOIDCEnabled() {
return gperr.New("OIDC not enabled but OIDC middleware is used")
}
return nil
}
func (amw *oidcMiddleware) init() error {
if atomic.LoadInt32(&amw.isInitialized) == 1 {
return nil
}
return amw.initSlow()
}
func (amw *oidcMiddleware) initSlow() error {
amw.initMu.Lock()
if amw.isInitialized == 1 {
amw.initMu.Unlock()
return nil
}
defer func() {
amw.isInitialized = 1
amw.initMu.Unlock()
}()
authProvider, err := auth.NewOIDCProviderFromEnv()
if err != nil {
return err
}
authProvider.SetIsMiddleware(true)
if len(amw.AllowedUsers) > 0 {
authProvider.SetAllowedUsers(amw.AllowedUsers)
}
if len(amw.AllowedGroups) > 0 {
authProvider.SetAllowedGroups(amw.AllowedGroups)
}
amw.authMux = http.NewServeMux()
amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
amw.authMux.HandleFunc(auth.OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
})
amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage)
amw.auth = authProvider
return nil
}
func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
if err := amw.init(); err != nil {
// no need to log here, main OIDC may already failed and logged
http.Error(w, err.Error(), http.StatusInternalServerError)
return false
}
if err := amw.auth.CheckToken(r); err != nil {
amw.authMux.ServeHTTP(w, r)
return false
}
if r.URL.Path == auth.OIDCLogoutPath {
amw.auth.LogoutCallbackHandler(w, r)
return false
}
return true
}

View File

@@ -0,0 +1,73 @@
package middleware
import (
"net"
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
)
type (
requestMap = map[string]*rate.Limiter
rateLimiter struct {
RateLimiterOpts
Tracer
requestMap requestMap
mu sync.Mutex
}
RateLimiterOpts struct {
Average int `validate:"min=1,required"`
Burst int `validate:"min=1,required"`
Period time.Duration `validate:"min=1s"`
}
)
var (
RateLimiter = NewMiddleware[rateLimiter]()
rateLimiterOptsDefault = RateLimiterOpts{
Period: time.Second,
}
)
// setup implements MiddlewareWithSetup.
func (rl *rateLimiter) setup() {
rl.RateLimiterOpts = rateLimiterOptsDefault
rl.requestMap = make(requestMap, 0)
}
// before implements RequestModifier.
func (rl *rateLimiter) before(w http.ResponseWriter, r *http.Request) bool {
return rl.limit(w, r)
}
func (rl *rateLimiter) newLimiter() *rate.Limiter {
return rate.NewLimiter(rate.Limit(rl.Average)*rate.Every(rl.Period), rl.Burst)
}
func (rl *rateLimiter) limit(w http.ResponseWriter, r *http.Request) bool {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
rl.AddTracef("unable to parse remote address %s", r.RemoteAddr)
http.Error(w, "Internal error", http.StatusInternalServerError)
return false
}
rl.mu.Lock()
limiter, ok := rl.requestMap[host]
if !ok {
limiter = rl.newLimiter()
rl.requestMap[host] = limiter
}
rl.mu.Unlock()
if limiter.Allow() {
return true
}
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
return false
}

View File

@@ -0,0 +1,27 @@
package middleware
import (
"net/http"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRateLimit(t *testing.T) {
opts := OptionsRaw{
"average": "10",
"burst": "10",
"period": "1s",
}
rl, err := RateLimiter.New(opts)
ExpectNoError(t, err)
for range 10 {
result, err := newMiddlewareTest(rl, nil)
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
}
result, err := newMiddlewareTest(rl, nil)
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusTooManyRequests)
}

View File

@@ -0,0 +1,116 @@
package middleware
import (
"net"
"net/http"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/types"
)
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
type (
realIP struct {
RealIPOpts
Tracer
}
RealIPOpts struct {
// Header is the name of the header to use for the real client IP
Header string `validate:"required"`
// From is a list of Address / CIDRs to trust
From []*types.CIDR `validate:"required,min=1"`
/*
If recursive search is disabled,
the original client address that matches one of the trusted addresses is replaced by
the last address sent in the request header field defined by the Header field.
If recursive search is enabled,
the original client address that matches one of the trusted addresses is replaced by
the last non-trusted address sent in the request header field.
*/
Recursive bool
}
)
var (
RealIP = NewMiddleware[realIP]()
realIPOptsDefault = RealIPOpts{
Header: "X-Real-IP",
From: []*types.CIDR{},
}
)
// setup implements MiddlewareWithSetup.
func (ri *realIP) setup() {
ri.RealIPOpts = realIPOptsDefault
}
// before implements RequestModifier.
func (ri *realIP) before(w http.ResponseWriter, r *http.Request) bool {
ri.setRealIP(r)
return true
}
func (ri *realIP) isInCIDRList(ip net.IP) bool {
for _, CIDR := range ri.From {
if CIDR.Contains(ip) {
return true
}
}
// not in any CIDR
return false
}
func (ri *realIP) setRealIP(req *http.Request) {
clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
clientIPStr = req.RemoteAddr
}
clientIP := net.ParseIP(clientIPStr)
isTrusted := false
for _, CIDR := range ri.From {
if CIDR.Contains(clientIP) {
isTrusted = true
break
}
}
if !isTrusted {
ri.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
return
}
realIPs := req.Header.Values(ri.Header)
lastNonTrustedIP := ""
if len(realIPs) == 0 {
// try non-canonical key
realIPs = req.Header[ri.Header]
}
if len(realIPs) == 0 {
ri.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
return
}
if !ri.Recursive {
lastNonTrustedIP = realIPs[len(realIPs)-1]
} else {
for _, r := range realIPs {
if !ri.isInCIDRList(net.ParseIP(r)) {
lastNonTrustedIP = r
}
}
}
if lastNonTrustedIP == "" {
ri.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
return
}
req.RemoteAddr = lastNonTrustedIP
req.Header.Set(ri.Header, lastNonTrustedIP)
req.Header.Set(httpheaders.HeaderXRealIP, lastNonTrustedIP)
ri.AddTracef("set real ip %s", lastNonTrustedIP)
}

View File

@@ -0,0 +1,77 @@
package middleware
import (
"net"
"net/http"
"strings"
"testing"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetRealIPOpts(t *testing.T) {
opts := OptionsRaw{
"header": httpheaders.HeaderXRealIP,
"from": []string{
"127.0.0.0/8",
"192.168.0.0/16",
"172.16.0.0/12",
},
"recursive": true,
}
optExpected := &RealIPOpts{
Header: httpheaders.HeaderXRealIP,
From: []*types.CIDR{
{
IP: net.ParseIP("127.0.0.0"),
Mask: net.IPv4Mask(255, 0, 0, 0),
},
{
IP: net.ParseIP("192.168.0.0"),
Mask: net.IPv4Mask(255, 255, 0, 0),
},
{
IP: net.ParseIP("172.16.0.0"),
Mask: net.IPv4Mask(255, 240, 0, 0),
},
},
Recursive: true,
}
ri, err := RealIP.New(opts)
ExpectNoError(t, err)
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
for i, CIDR := range ri.impl.(*realIP).From {
ExpectEqual(t, CIDR.String(), optExpected.From[i].String())
}
}
func TestSetRealIP(t *testing.T) {
const (
testHeader = httpheaders.HeaderXRealIP
testRealIP = "192.168.1.1"
)
opts := OptionsRaw{
"header": testHeader,
"from": []string{"0.0.0.0/0"},
}
optsMr := OptionsRaw{
"set_headers": map[string]string{testHeader: testRealIP},
}
realip, err := RealIP.New(opts)
ExpectNoError(t, err)
mr, err := ModifyRequest.New(optsMr)
ExpectNoError(t, err)
mid := NewMiddlewareChain("test", []*Middleware{mr, realip})
result, err := newMiddlewareTest(mid, nil)
ExpectNoError(t, err)
t.Log(traces)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
}

View File

@@ -0,0 +1,29 @@
package middleware
import (
"net/http"
"strings"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging"
)
type redirectHTTP struct{}
var RedirectHTTP = NewMiddleware[redirectHTTP]()
// before implements RequestModifier.
func (redirectHTTP) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
if r.TLS != nil {
return true
}
r.URL.Scheme = "https"
host := r.Host
if i := strings.Index(host, ":"); i != -1 {
host = host[:i] // strip port number if present
}
r.URL.Host = host + ":" + common.ProxyHTTPSPort
logging.Debug().Str("url", r.URL.String()).Msg("redirect to https")
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
return true
}

View File

@@ -0,0 +1,27 @@
package middleware
import (
"net/http"
"testing"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRedirectToHTTPs(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
reqURL: types.MustParseURL("http://example.com"),
})
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect)
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://example.com:"+common.ProxyHTTPSPort)
}
func TestNoRedirect(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
reqURL: types.MustParseURL("https://example.com"),
})
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
}

View File

@@ -0,0 +1,37 @@
package middleware
import (
"net/http"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
)
// internal use only.
type setUpstreamHeaders struct {
Name, Scheme, Host, Port string
}
var suh = NewMiddleware[setUpstreamHeaders]()
func newSetUpstreamHeaders(rp *reverseproxy.ReverseProxy) *Middleware {
m, err := suh.New(OptionsRaw{
"name": rp.TargetName,
"scheme": rp.TargetURL.Scheme,
"host": rp.TargetURL.Hostname(),
"port": rp.TargetURL.Port(),
})
if err != nil {
panic(err)
}
return m
}
// before implements RequestModifier.
func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Set(httpheaders.HeaderUpstreamName, s.Name)
r.Header.Set(httpheaders.HeaderUpstreamScheme, s.Scheme)
r.Header.Set(httpheaders.HeaderUpstreamHost, s.Host)
r.Header.Set(httpheaders.HeaderUpstreamPort, s.Port)
return true
}

View File

@@ -0,0 +1,23 @@
deny:
- use: ModifyRequest
setHeaders:
X-Real-IP: 192.168.1.1:1234
- use: RealIP
header: X-Real-IP
from:
- 0.0.0.0/0
- use: CIDRWhitelist
allow:
- 192.168.0.0/24
accept:
- use: ModifyRequest
setHeaders:
X-Real-IP: 192.168.0.1:1234
- use: RealIP
header: X-Real-IP
from:
- 0.0.0.0/0
- use: CIDRWhitelist
allow:
- 192.168.0.0/24
- 127.0.0.1

View File

@@ -0,0 +1,41 @@
theGreatPretender:
- use: HideXForwarded
- use: ModifyRequest
setHeaders:
X-Real-IP: 6.6.6.6
- use: ModifyResponse
hideHeaders:
- X-Test3
- X-Test4
notAuthenticAuthentik:
- use: RedirectHTTP
- use: ForwardAuth
address: https://authentik.company
trustForwardHeader: true
addAuthCookiesToResponse:
- session_id
- user_id
authResponseHeaders:
- X-Auth-SessionID
- X-Auth-UserID
- use: CustomErrorPage
realIPAuthentik:
- use: RedirectHTTP
- use: RealIP
header: X-Real-IP
from:
- "127.0.0.0/8"
- "192.168.0.0/16"
- "172.16.0.0/12"
recursive: true
- use: ForwardAuth
address: https://authentik.company
trustForwardHeader: true
testFakeRealIP:
- use: ModifyRequest
setHeaders:
CF-Connecting-IP: 127.0.0.1
- use: CloudflareRealIP

View File

@@ -0,0 +1,17 @@
{
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"Accept-Encoding": "gzip, deflate, br, zstd",
"Accept-Language": "en,zh-HK;q=0.9,zh-TW;q=0.8,zh-CN;q=0.7,zh;q=0.6",
"Dnt": "1",
"Host": "localhost",
"Priority": "u=0, i",
"Sec-Ch-Ua": "\"Chromium\";v=\"129\", \"Not=A?Brand\";v=\"8\"",
"Sec-Ch-Ua-Mobile": "?0",
"Sec-Ch-Ua-Platform": "\"Windows\"",
"Sec-Fetch-Dest": "document",
"Sec-Fetch-Mode": "navigate",
"Sec-Fetch-Site": "none",
"Sec-Fetch-User": "?1",
"Upgrade-Insecure-Requests": "1",
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36"
}

View File

@@ -0,0 +1,176 @@
package middleware
import (
"bytes"
_ "embed"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
//go:embed test_data/sample_headers.json
var testHeadersRaw []byte
var testHeaders http.Header
func init() {
if !common.IsTest {
return
}
tmp := map[string]string{}
err := json.Unmarshal(testHeadersRaw, &tmp)
if err != nil {
panic(err)
}
testHeaders = http.Header{}
for k, v := range tmp {
testHeaders.Set(k, v)
}
}
type requestRecorder struct {
args *testArgs
parent http.RoundTripper
headers http.Header
remoteAddr string
}
func newRequestRecorder(args *testArgs) *requestRecorder {
return &requestRecorder{args: args}
}
func (rt *requestRecorder) RoundTrip(req *http.Request) (resp *http.Response, err error) {
rt.headers = req.Header
rt.remoteAddr = req.RemoteAddr
if rt.parent != nil {
resp, err = rt.parent.RoundTrip(req)
} else {
resp = &http.Response{
Status: http.StatusText(rt.args.respStatus),
StatusCode: rt.args.respStatus,
Header: testHeaders,
Body: io.NopCloser(bytes.NewReader(rt.args.respBody)),
ContentLength: int64(len(rt.args.respBody)),
Request: req,
TLS: req.TLS,
}
}
if err == nil {
for k, v := range rt.args.respHeaders {
resp.Header[k] = v
}
}
return resp, nil
}
type TestResult struct {
RequestHeaders http.Header
ResponseHeaders http.Header
ResponseStatus int
RemoteAddr string
Data []byte
}
type testArgs struct {
middlewareOpt OptionsRaw
upstreamURL *types.URL
realRoundTrip bool
reqURL *types.URL
reqMethod string
headers http.Header
body []byte
respHeaders http.Header
respBody []byte
respStatus int
}
func (args *testArgs) setDefaults() {
if args.reqURL == nil {
args.reqURL = Must(types.ParseURL("https://example.com"))
}
if args.reqMethod == "" {
args.reqMethod = http.MethodGet
}
if args.upstreamURL == nil {
args.upstreamURL = Must(types.ParseURL("https://10.0.0.1:8443")) // dummy url, no actual effect
}
if args.respHeaders == nil {
args.respHeaders = http.Header{}
}
if args.respBody == nil {
args.respBody = []byte("OK")
}
if args.respStatus == 0 {
args.respStatus = http.StatusOK
}
}
func (args *testArgs) bodyReader() io.Reader {
if args.body != nil {
return bytes.NewReader(args.body)
}
return nil
}
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, gperr.Error) {
if args == nil {
args = new(testArgs)
}
args.setDefaults()
mid, setOptErr := middleware.New(args.middlewareOpt)
if setOptErr != nil {
return nil, setOptErr
}
return newMiddlewaresTest([]*Middleware{mid}, args)
}
func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, gperr.Error) {
if args == nil {
args = new(testArgs)
}
args.setDefaults()
req := httptest.NewRequest(args.reqMethod, args.reqURL.String(), args.bodyReader())
for k, v := range args.headers {
req.Header[k] = v
}
w := httptest.NewRecorder()
rr := newRequestRecorder(args)
if args.realRoundTrip {
rr.parent = http.DefaultTransport
}
rp := reverseproxy.NewReverseProxy("test", args.upstreamURL, rr)
patchReverseProxy(rp, middlewares)
rp.ServeHTTP(w, req)
resp := w.Result()
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, gperr.Wrap(err)
}
return &TestResult{
RequestHeaders: rr.headers,
ResponseHeaders: resp.Header,
ResponseStatus: resp.StatusCode,
RemoteAddr: rr.remoteAddr,
Data: data,
}, nil
}

View File

@@ -0,0 +1,87 @@
package middleware
import (
"net/http"
"sync"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
)
type (
Trace struct {
Time string `json:"time,omitempty"`
Caller string `json:"caller,omitempty"`
URL string `json:"url,omitempty"`
Message string `json:"msg"`
ReqHeaders map[string]string `json:"req_headers,omitempty"`
RespHeaders map[string]string `json:"resp_headers,omitempty"`
RespStatus int `json:"resp_status,omitempty"`
Additional map[string]any `json:"additional,omitempty"`
}
Traces []*Trace
)
var (
traces = make(Traces, 0)
tracesMu sync.Mutex
)
const MaxTraceNum = 100
func GetAllTrace() []*Trace {
return traces
}
func (tr *Trace) WithRequest(req *http.Request) *Trace {
if tr == nil {
return nil
}
tr.URL = req.RequestURI
tr.ReqHeaders = httpheaders.HeaderToMap(req.Header)
return tr
}
func (tr *Trace) WithResponse(resp *http.Response) *Trace {
if tr == nil {
return nil
}
tr.URL = resp.Request.RequestURI
tr.ReqHeaders = httpheaders.HeaderToMap(resp.Request.Header)
tr.RespHeaders = httpheaders.HeaderToMap(resp.Header)
tr.RespStatus = resp.StatusCode
return tr
}
func (tr *Trace) With(what string, additional any) *Trace {
if tr == nil {
return nil
}
if tr.Additional == nil {
tr.Additional = map[string]any{}
}
tr.Additional[what] = additional
return tr
}
func (tr *Trace) WithError(err error) *Trace {
if tr == nil {
return nil
}
if tr.Additional == nil {
tr.Additional = map[string]any{}
}
tr.Additional["error"] = err.Error()
return tr
}
func addTrace(t *Trace) *Trace {
tracesMu.Lock()
defer tracesMu.Unlock()
if len(traces) > MaxTraceNum {
traces = traces[1:]
}
traces = append(traces, t)
return t
}

View File

@@ -0,0 +1,62 @@
package middleware
import (
"fmt"
"net/http"
"time"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Tracer struct {
name string
enabled bool
}
func _() {
var _ MiddlewareWithTracer = &Tracer{}
}
func (t *Tracer) enableTrace() {
t.enabled = true
}
func (t *Tracer) getTracer() *Tracer {
return t
}
func (t *Tracer) SetParent(parent *Tracer) {
if parent == nil {
return
}
t.name = parent.name + "." + t.name
}
func (t *Tracer) addTrace(msg string) *Trace {
return addTrace(&Trace{
Time: strutils.FormatTime(time.Now()),
Caller: t.name,
Message: msg,
})
}
func (t *Tracer) AddTracef(msg string, args ...any) *Trace {
if !t.enabled {
return nil
}
return t.addTrace(fmt.Sprintf(msg, args...))
}
func (t *Tracer) AddTraceRequest(msg string, req *http.Request) *Trace {
if !t.enabled {
return nil
}
return t.addTrace(msg).WithRequest(req)
}
func (t *Tracer) AddTraceResponse(msg string, resp *http.Response) *Trace {
if !t.enabled {
return nil
}
return t.addTrace(msg).WithResponse(resp)
}

View File

@@ -0,0 +1,175 @@
package middleware
import (
"net"
"net/http"
"regexp"
"strconv"
"strings"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
)
type (
reqVarGetter func(*http.Request) string
respVarGetter func(*http.Response) string
)
var (
reArg = regexp.MustCompile(`\$arg\([\w-_]+\)`)
reReqHeader = regexp.MustCompile(`\$header\([\w-]+\)`)
reRespHeader = regexp.MustCompile(`\$resp_header\([\w-]+\)`)
reStatic = regexp.MustCompile(`\$[\w_]+`)
)
const (
VarRequestMethod = "$req_method"
VarRequestScheme = "$req_scheme"
VarRequestHost = "$req_host"
VarRequestPort = "$req_port"
VarRequestPath = "$req_path"
VarRequestAddr = "$req_addr"
VarRequestQuery = "$req_query"
VarRequestURL = "$req_url"
VarRequestURI = "$req_uri"
VarRequestContentType = "$req_content_type"
VarRequestContentLen = "$req_content_length"
VarRemoteHost = "$remote_host"
VarRemotePort = "$remote_port"
VarRemoteAddr = "$remote_addr"
VarUpstreamName = "$upstream_name"
VarUpstreamScheme = "$upstream_scheme"
VarUpstreamHost = "$upstream_host"
VarUpstreamPort = "$upstream_port"
VarUpstreamAddr = "$upstream_addr"
VarUpstreamURL = "$upstream_url"
VarRespContentType = "$resp_content_type"
VarRespContentLen = "$resp_content_length"
VarRespStatusCode = "$status_code"
)
var staticReqVarSubsMap = map[string]reqVarGetter{
VarRequestMethod: func(req *http.Request) string { return req.Method },
VarRequestScheme: func(req *http.Request) string {
if req.TLS != nil {
return "https"
}
return "http"
},
VarRequestHost: func(req *http.Request) string {
reqHost, _, err := net.SplitHostPort(req.Host)
if err != nil {
return req.Host
}
return reqHost
},
VarRequestPort: func(req *http.Request) string {
_, reqPort, _ := net.SplitHostPort(req.Host)
return reqPort
},
VarRequestAddr: func(req *http.Request) string { return req.Host },
VarRequestPath: func(req *http.Request) string { return req.URL.Path },
VarRequestQuery: func(req *http.Request) string { return req.URL.RawQuery },
VarRequestURL: func(req *http.Request) string { return req.URL.String() },
VarRequestURI: func(req *http.Request) string { return req.URL.RequestURI() },
VarRequestContentType: func(req *http.Request) string { return req.Header.Get("Content-Type") },
VarRequestContentLen: func(req *http.Request) string { return strconv.FormatInt(req.ContentLength, 10) },
VarRemoteHost: func(req *http.Request) string {
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
return clientIP
}
return ""
},
VarRemotePort: func(req *http.Request) string {
_, clientPort, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
return clientPort
}
return ""
},
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
VarUpstreamName: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamName) },
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamScheme) },
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamHost) },
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamPort) },
VarUpstreamAddr: func(req *http.Request) string {
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
if upPort != "" {
return upHost + ":" + upPort
}
return upHost
},
VarUpstreamURL: func(req *http.Request) string {
upScheme := req.Header.Get(httpheaders.HeaderUpstreamScheme)
if upScheme == "" {
return ""
}
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
upAddr := upHost
if upPort != "" {
upAddr += ":" + upPort
}
return upScheme + "://" + upAddr
},
}
var staticRespVarSubsMap = map[string]respVarGetter{
VarRespContentType: func(resp *http.Response) string { return resp.Header.Get("Content-Type") },
VarRespContentLen: func(resp *http.Response) string { return strconv.FormatInt(resp.ContentLength, 10) },
VarRespStatusCode: func(resp *http.Response) string { return strconv.Itoa(resp.StatusCode) },
}
func varReplace(req *http.Request, resp *http.Response, s string) string {
if req != nil {
// Replace query parameters
s = reArg.ReplaceAllStringFunc(s, func(match string) string {
name := match[5 : len(match)-1]
for k, v := range req.URL.Query() {
if strings.EqualFold(k, name) {
return v[0]
}
}
return ""
})
// Replace request headers
s = reReqHeader.ReplaceAllStringFunc(s, func(match string) string {
header := http.CanonicalHeaderKey(match[8 : len(match)-1])
return req.Header.Get(header)
})
}
if resp != nil {
// Replace response headers
s = reRespHeader.ReplaceAllStringFunc(s, func(match string) string {
header := http.CanonicalHeaderKey(match[13 : len(match)-1])
return resp.Header.Get(header)
})
}
// Replace static variables
if req != nil {
s = reStatic.ReplaceAllStringFunc(s, func(match string) string {
if fn, ok := staticReqVarSubsMap[match]; ok {
return fn(req)
}
return match
})
}
if resp != nil {
s = reStatic.ReplaceAllStringFunc(s, func(match string) string {
if fn, ok := staticRespVarSubsMap[match]; ok {
return fn(resp)
}
return match
})
}
return s
}

View File

@@ -0,0 +1,45 @@
package middleware
import (
"net"
"net/http"
"strings"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
)
type (
setXForwarded struct{}
hideXForwarded struct{}
)
var (
SetXForwarded = NewMiddleware[setXForwarded]()
HideXForwarded = NewMiddleware[hideXForwarded]()
)
// before implements RequestModifier.
func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Del(httpheaders.HeaderXForwardedFor)
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil {
r.Header.Set(httpheaders.HeaderXForwardedFor, clientIP)
}
return true
}
// before implements RequestModifier.
func (hideXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
toDelete := make([]string, 0, len(r.Header))
for k := range r.Header {
if strings.HasPrefix(k, "X-Forwarded-") {
toDelete = append(toDelete, k)
}
}
for _, k := range toDelete {
r.Header.Del(k)
}
return true
}

View File

@@ -0,0 +1,116 @@
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/response_modifier.go)
// Copyright (c) 2020-2024 Traefik Labs
package gphttp
import (
"bufio"
"fmt"
"net"
"net/http"
)
type (
ModifyResponseFunc func(*http.Response) error
ModifyResponseWriter struct {
w http.ResponseWriter
r *http.Request
headerSent bool
code int
size int
modifier ModifyResponseFunc
modified bool
modifierErr error
}
)
func NewModifyResponseWriter(w http.ResponseWriter, r *http.Request, f ModifyResponseFunc) *ModifyResponseWriter {
return &ModifyResponseWriter{
w: w,
r: r,
modifier: f,
code: http.StatusOK,
}
}
func (w *ModifyResponseWriter) Unwrap() http.ResponseWriter {
return w.w
}
func (w *ModifyResponseWriter) StatusCode() int {
return w.code
}
func (w *ModifyResponseWriter) Size() int {
return w.size
}
func (w *ModifyResponseWriter) WriteHeader(code int) {
if w.headerSent {
return
}
if code >= http.StatusContinue && code < http.StatusOK {
w.w.WriteHeader(code)
}
defer func() {
w.headerSent = true
w.code = code
}()
if w.modifier == nil || w.modified {
w.w.WriteHeader(code)
return
}
resp := http.Response{
StatusCode: code,
Header: w.w.Header(),
Request: w.r,
ContentLength: int64(w.size),
}
if err := w.modifier(&resp); err != nil {
w.modifierErr = fmt.Errorf("response modifier error: %w", err)
resp.Status = w.modifierErr.Error()
w.w.WriteHeader(http.StatusInternalServerError)
return
}
w.modified = true
w.w.WriteHeader(code)
}
func (w *ModifyResponseWriter) Header() http.Header {
return w.w.Header()
}
func (w *ModifyResponseWriter) Write(b []byte) (int, error) {
w.WriteHeader(w.code)
if w.modifierErr != nil {
return 0, w.modifierErr
}
n, err := w.w.Write(b)
w.size += n
return n, err
}
// Hijack hijacks the connection.
func (w *ModifyResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if h, ok := w.w.(http.Hijacker); ok {
return h.Hijack()
}
return nil, nil, fmt.Errorf("not a hijacker: %T", w.w)
}
// Flush sends any buffered data to the client.
func (w *ModifyResponseWriter) Flush() {
if flusher, ok := w.w.(http.Flusher); ok {
flusher.Flush()
}
}

View File

@@ -0,0 +1,530 @@
// Copyright 2011 The Go Authors.
// Modified from the Go project under the a BSD-style License (https://cs.opensource.google/go/go/+/refs/tags/go1.23.1:src/net/http/httputil/reverseproxy.go)
// https://cs.opensource.google/go/go/+/master:LICENSE
package reverseproxy
// This is a small mod on net/http/httputil/reverseproxy.go
// that boosts performance in some cases
// and compatible to other modules of this project
// Copyright (c) 2024 yusing
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptrace"
"net/textproto"
"net/url"
"strings"
"sync"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils"
"golang.org/x/net/http/httpguts"
)
// A ProxyRequest contains a request to be rewritten by a [ReverseProxy].
type ProxyRequest struct {
// In is the request received by the proxy.
// The Rewrite function must not modify In.
In *http.Request
// Out is the request which will be sent by the proxy.
// The Rewrite function may modify or replace this request.
// Hop-by-hop headers are removed from this request
// before Rewrite is called.
Out *http.Request
}
// SetXForwarded sets the X-Forwarded-For, X-Forwarded-Host, and
// X-Forwarded-Proto headers of the outbound request.
//
// - The X-Forwarded-For header is set to the client IP address.
// - The X-Forwarded-Host header is set to the host name requested
// by the client.
// - The X-Forwarded-Proto header is set to "http" or "https", depending
// on whether the inbound request was made on a TLS-enabled connection.
//
// If the outbound request contains an existing X-Forwarded-For header,
// SetXForwarded appends the client IP address to it. To append to the
// inbound request's X-Forwarded-For header (the default behavior of
// [ReverseProxy] when using a Director function), copy the header
// from the inbound request before calling SetXForwarded:
//
// rewriteFunc := func(r *httputil.ProxyRequest) {
// r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
// r.SetXForwarded()
// }
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
//
// 1xx responses are forwarded to the client if the underlying
// transport supports ClientTrace.Got1xxResponse.
type ReverseProxy struct {
zerolog.Logger
// The transport used to perform proxy requests.
Transport http.RoundTripper
// ModifyResponse is an optional function that modifies the
// Response from the backend. It is called if the backend
// returns a response at all, with any HTTP status code.
// If the backend is unreachable, the optional ErrorHandler is
// called before ModifyResponse.
//
// If ModifyResponse returns an error, ErrorHandler is called
// with its error value. If ErrorHandler is nil, its default
// implementation is used.
ModifyResponse func(*http.Response) error
AccessLogger *accesslog.AccessLogger
HandlerFunc http.HandlerFunc
TargetName string
TargetURL *types.URL
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
func joinURLPath(a, b *url.URL) (path, rawpath string) {
if a.RawPath == "" && b.RawPath == "" {
return singleJoiningSlash(a.Path, b.Path), ""
}
// Same as singleJoiningSlash, but uses EscapedPath to determine
// whether a slash should be added
apath := a.EscapedPath()
bpath := b.EscapedPath()
aslash := strings.HasSuffix(apath, "/")
bslash := strings.HasPrefix(bpath, "/")
switch {
case aslash && bslash:
return a.Path + b.Path[1:], apath + bpath[1:]
case !aslash && !bslash:
return a.Path + "/" + b.Path, apath + "/" + bpath
}
return a.Path + b.Path, apath + bpath
}
// NewReverseProxy returns a new [ReverseProxy] that routes
// URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir.
func NewReverseProxy(name string, target *types.URL, transport http.RoundTripper) *ReverseProxy {
if transport == nil {
panic("nil transport")
}
rp := &ReverseProxy{
Logger: logging.With().Str("name", name).Logger(),
Transport: transport,
TargetName: name,
TargetURL: target,
}
rp.HandlerFunc = rp.handler
return rp
}
func (p *ReverseProxy) rewriteRequestURL(req *http.Request) {
targetQuery := p.TargetURL.RawQuery
req.URL.Scheme = p.TargetURL.Scheme
req.URL.Host = p.TargetURL.Host
req.URL.Path, req.URL.RawPath = joinURLPath(&p.TargetURL.URL, req.URL)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err error, writeHeader bool) {
reqURL := r.Host + r.URL.Path
switch {
case errors.Is(err, context.Canceled),
errors.Is(err, io.EOF),
errors.Is(err, context.DeadlineExceeded):
logging.Debug().Err(err).Str("url", reqURL).Msg("http proxy error")
default:
var recordErr tls.RecordHeaderError
if errors.As(err, &recordErr) {
logging.Error().
Str("url", reqURL).
Msgf(`scheme was likely misconfigured as https,
try setting "proxy.%s.scheme" back to "http"`, p.TargetName)
logging.Err(err).Msg("underlying error")
} else {
logging.Err(err).Str("url", reqURL).Msg("http proxy error")
}
}
if writeHeader {
rw.WriteHeader(http.StatusInternalServerError)
}
if p.AccessLogger != nil {
p.AccessLogger.LogError(r, err)
}
}
// modifyResponse conditionally runs the optional ModifyResponse hook
// and reports whether the request should proceed.
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, origReq, req *http.Request) bool {
if p.ModifyResponse == nil {
return true
}
res.Request = origReq
err := p.ModifyResponse(res)
res.Request = req
if err != nil {
res.Body.Close()
p.errorHandler(rw, req, err, true)
return false
}
return true
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
p.HandlerFunc(rw, req)
}
func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport
ctx := req.Context()
/* trunk-ignore(golangci-lint/revive) */
if ctx.Done() != nil {
// CloseNotifier predates context.Context, and has been
// entirely superseded by it. If the request contains
// a Context that carries a cancellation signal, don't
// bother spinning up a goroutine to watch the CloseNotify
// channel (if any).
//
// If the request Context has a nil Done channel (which
// means it is either context.Background, or a custom
// Context implementation with no cancellation signal),
// then consult the CloseNotifier if available.
} else if cn, ok := rw.(http.CloseNotifier); ok {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
defer cancel()
notifyChan := cn.CloseNotify()
go func() {
select {
case <-notifyChan:
cancel()
case <-ctx.Done():
}
}()
}
outreq := req.Clone(ctx)
if req.ContentLength == 0 {
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
}
if outreq.Body != nil {
// Reading from the request body after returning from a handler is not
// allowed, and the RoundTrip goroutine that reads the Body can outlive
// this handler. This can lead to a crash if the handler panics (see
// Issue 46866). Although calling Close doesn't guarantee there isn't
// any Read in flight after the handle returns, in practice it's safe to
// read after closing it.
defer outreq.Body.Close()
}
if outreq.Header == nil {
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
}
p.rewriteRequestURL(outreq)
outreq.Close = false
reqUpType := httpheaders.UpgradeType(outreq.Header)
if !IsPrint(reqUpType) {
p.errorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), true)
return
}
req.Header.Del("Forwarded")
httpheaders.RemoveHopByHopHeaders(outreq.Header)
// Issue 21096: tell backend applications that care about trailer support
// that we support trailers. (We do, but we don't go out of our way to
// advertise that unless the incoming client request thought it was worth
// mentioning.) Note that we look at req.Header, not outreq.Header, since
// the latter has passed through removeHopByHopHeaders.
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
outreq.Header.Set("Te", "trailers")
}
// After stripping all the hop-by-hop connection headers above, add back any
// necessary for protocol upgrades, such as for websockets.
if reqUpType != "" {
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
if strings.EqualFold(reqUpType, "websocket") {
cleanWebsocketHeaders(outreq)
}
}
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
prior, ok := outreq.Header[httpheaders.HeaderXForwardedFor]
omit := ok && prior == nil // Issue 38079: nil now means don't populate the header
xff, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
xff = req.RemoteAddr
}
if len(prior) > 0 {
xff = strings.Join(prior, ", ") + ", " + xff
}
if !omit {
outreq.Header.Set(httpheaders.HeaderXForwardedFor, xff)
}
var reqScheme string
if req.TLS != nil {
reqScheme = "https"
} else {
reqScheme = "http"
}
outreq.Header.Set(httpheaders.HeaderXForwardedMethod, req.Method)
outreq.Header.Set(httpheaders.HeaderXForwardedProto, reqScheme)
outreq.Header.Set(httpheaders.HeaderXForwardedHost, req.Host)
outreq.Header.Set(httpheaders.HeaderXForwardedURI, req.RequestURI)
if _, ok := outreq.Header["User-Agent"]; !ok {
// If the outbound request doesn't have a User-Agent header set,
// don't send the default Go HTTP client User-Agent.
outreq.Header.Set("User-Agent", "")
}
var (
roundTripMutex sync.Mutex
roundTripDone bool
)
trace := &httptrace.ClientTrace{
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
roundTripMutex.Lock()
defer roundTripMutex.Unlock()
if roundTripDone {
// If RoundTrip has returned, don't try to further modify
// the ResponseWriter's header map.
return nil
}
h := rw.Header()
copyHeader(h, http.Header(header))
rw.WriteHeader(code)
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
clear(h)
return nil
},
}
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
res, err := transport.RoundTrip(outreq)
roundTripMutex.Lock()
roundTripDone = true
roundTripMutex.Unlock()
if err != nil {
p.errorHandler(rw, outreq, err, false)
res = &http.Response{
Status: http.StatusText(http.StatusBadGateway),
StatusCode: http.StatusBadGateway,
Proto: req.Proto,
ProtoMajor: req.ProtoMajor,
ProtoMinor: req.ProtoMinor,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))),
Request: req,
TLS: req.TLS,
}
}
if p.AccessLogger != nil {
defer func() {
p.AccessLogger.Log(req, res)
}()
}
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode == http.StatusSwitchingProtocols {
if !p.modifyResponse(rw, res, req, outreq) {
return
}
p.handleUpgradeResponse(rw, outreq, res)
return
}
httpheaders.RemoveHopByHopHeaders(res.Header)
if !p.modifyResponse(rw, res, req, outreq) {
return
}
copyHeader(rw.Header(), res.Header)
// The "Trailer" header isn't included in the Transport's response,
// at least for *http.Transport. Build it up from Trailer.
announcedTrailers := len(res.Trailer)
if announcedTrailers > 0 {
trailerKeys := make([]string, 0, len(res.Trailer))
for k := range res.Trailer {
trailerKeys = append(trailerKeys, k)
}
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
}
rw.WriteHeader(res.StatusCode)
err = U.CopyClose(U.NewContextWriter(ctx, rw), U.NewContextReader(ctx, res.Body)) // close now, instead of defer, to populate res.Trailer
if err != nil {
if !errors.Is(err, context.Canceled) {
p.errorHandler(rw, req, err, false)
}
return
}
if len(res.Trailer) > 0 {
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
// bodies and adding a Content-Length.
http.NewResponseController(rw).Flush()
}
if len(res.Trailer) == announcedTrailers {
copyHeader(rw.Header(), res.Trailer)
return
}
for k, vv := range res.Trailer {
k = http.TrailerPrefix + k
for _, v := range vv {
rw.Header().Add(k, v)
}
}
}
// reference: https://github.com/traefik/traefik/blob/master/pkg/proxy/httputil/proxy.go
// https://tools.ietf.org/html/rfc6455#page-20
func cleanWebsocketHeaders(req *http.Request) {
req.Header["Sec-WebSocket-Key"] = req.Header["Sec-Websocket-Key"]
delete(req.Header, "Sec-Websocket-Key")
req.Header["Sec-WebSocket-Extensions"] = req.Header["Sec-Websocket-Extensions"]
delete(req.Header, "Sec-Websocket-Extensions")
req.Header["Sec-WebSocket-Accept"] = req.Header["Sec-Websocket-Accept"]
delete(req.Header, "Sec-Websocket-Accept")
req.Header["Sec-WebSocket-Protocol"] = req.Header["Sec-Websocket-Protocol"]
delete(req.Header, "Sec-Websocket-Protocol")
req.Header["Sec-WebSocket-Version"] = req.Header["Sec-Websocket-Version"]
delete(req.Header, "Sec-Websocket-Version")
}
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
reqUpType := httpheaders.UpgradeType(req.Header)
resUpType := httpheaders.UpgradeType(res.Header)
if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), true)
return
}
if !strings.EqualFold(reqUpType, resUpType) {
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType), true)
return
}
backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok {
p.errorHandler(rw, req, errors.New("internal error: 101 switching protocols response with non-writable body"), true)
return
}
rc := http.NewResponseController(rw)
conn, brw, hijackErr := rc.Hijack()
if errors.Is(hijackErr, http.ErrNotSupported) {
p.errorHandler(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw), true)
return
}
backConnCloseCh := make(chan bool)
go func() {
// Ensure that the cancellation of a request closes the backend.
// See issue https://golang.org/issue/35559.
select {
case <-req.Context().Done():
case <-backConnCloseCh:
}
backConn.Close()
}()
defer close(backConnCloseCh)
if hijackErr != nil {
p.errorHandler(rw, req, fmt.Errorf("hijack failed on protocol switch: %w", hijackErr), true)
return
}
defer conn.Close()
copyHeader(rw.Header(), res.Header)
res.Header = rw.Header()
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
if err := res.Write(brw); err != nil {
/* trunk-ignore(golangci-lint/errorlint) */
p.errorHandler(rw, req, fmt.Errorf("response write: %s", err), true)
return
}
if err := brw.Flush(); err != nil {
/* trunk-ignore(golangci-lint/errorlint) */
p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true)
return
}
bdp := U.NewBidirectionalPipe(req.Context(), conn, backConn)
/* trunk-ignore(golangci-lint/errcheck) */
bdp.Start()
}
func IsPrint(s string) bool {
for _, r := range s {
if r < ' ' || r > '~' {
return false
}
}
return true
}

View File

@@ -0,0 +1,31 @@
package gphttp
import "net/http"
type ServeMux struct {
*http.ServeMux
}
func NewServeMux() ServeMux {
return ServeMux{http.NewServeMux()}
}
func (mux ServeMux) Handle(pattern string, handler http.Handler) (err error) {
defer func() {
if r := recover(); r != nil {
err = r.(error)
}
}()
mux.ServeMux.Handle(pattern, handler)
return
}
func (mux ServeMux) HandleFunc(pattern string, handler http.HandlerFunc) (err error) {
defer func() {
if r := recover(); r != nil {
err = r.(error)
}
}()
mux.ServeMux.HandleFunc(pattern, handler)
return
}

View File

@@ -0,0 +1,18 @@
package server
import (
"context"
"errors"
"net/http"
"github.com/rs/zerolog"
)
func HandleError(logger *zerolog.Logger, err error, msg string) {
switch {
case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled):
return
default:
logger.Fatal().Err(err).Msg(msg)
}
}

View File

@@ -0,0 +1,163 @@
package server
import (
"context"
"crypto/tls"
"io"
"log"
"net"
"net/http"
"time"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
)
type Server struct {
Name string
CertProvider *autocert.Provider
http *http.Server
https *http.Server
httpStarted bool
httpsStarted bool
startTime time.Time
l zerolog.Logger
}
type Options struct {
Name string
HTTPAddr string
HTTPSAddr string
CertProvider *autocert.Provider
Handler http.Handler
}
func StartServer(parent task.Parent, opt Options) (s *Server) {
s = NewServer(opt)
s.Start(parent)
return s
}
func NewServer(opt Options) (s *Server) {
var httpSer, httpsSer *http.Server
logger := logging.With().Str("server", opt.Name).Logger()
certAvailable := false
if opt.CertProvider != nil {
_, err := opt.CertProvider.GetCert(nil)
certAvailable = err == nil
}
out := io.Discard
if common.IsDebug {
out = logger
}
if opt.HTTPAddr != "" {
httpSer = &http.Server{
Addr: opt.HTTPAddr,
Handler: opt.Handler,
ErrorLog: log.New(out, "", 0), // most are tls related
}
}
if certAvailable && opt.HTTPSAddr != "" {
httpsSer = &http.Server{
Addr: opt.HTTPSAddr,
Handler: opt.Handler,
ErrorLog: log.New(out, "", 0), // most are tls related
TLSConfig: &tls.Config{
GetCertificate: opt.CertProvider.GetCert,
},
}
}
return &Server{
Name: opt.Name,
CertProvider: opt.CertProvider,
http: httpSer,
https: httpsSer,
l: logger,
}
}
// Start will start the http and https servers.
//
// If both are not set, this does nothing.
//
// Start() is non-blocking.
func (s *Server) Start(parent task.Parent) {
if s.http == nil && s.https == nil {
return
}
task := parent.Subtask("server."+s.Name, false)
s.startTime = time.Now()
if s.http != nil {
go func() {
err := s.http.ListenAndServe()
if err != nil {
s.handleErr(err, "failed to serve http server")
}
}()
s.httpStarted = true
s.l.Info().Str("addr", s.http.Addr).Msg("server started")
}
if s.https != nil {
go func() {
l, err := net.Listen("tcp", s.https.Addr)
if err != nil {
s.handleErr(err, "failed to listen on port")
return
}
defer l.Close()
s.handleErr(s.https.Serve(tls.NewListener(l, s.https.TLSConfig)), "failed to serve https server")
}()
s.httpsStarted = true
s.l.Info().Str("addr", s.https.Addr).Msgf("server started")
}
task.OnCancel("stop", s.stop)
}
func (s *Server) stop() {
if s.http == nil && s.https == nil {
return
}
ctx, cancel := context.WithTimeout(task.RootContext(), 5*time.Second)
defer cancel()
if s.http != nil && s.httpStarted {
err := s.http.Shutdown(ctx)
if err != nil {
s.handleErr(err, "failed to shutdown http server")
} else {
s.httpStarted = false
s.l.Info().Str("addr", s.http.Addr).Msgf("server stopped")
}
}
if s.https != nil && s.httpsStarted {
err := s.https.Shutdown(ctx)
if err != nil {
s.handleErr(err, "failed to shutdown https server")
} else {
s.httpsStarted = false
s.l.Info().Str("addr", s.https.Addr).Msgf("server stopped")
}
}
}
func (s *Server) Uptime() time.Duration {
return time.Since(s.startTime)
}
func (s *Server) handleErr(err error, msg string) {
HandleError(&s.l, err, msg)
}

View File

@@ -0,0 +1,11 @@
package gphttp
import "net/http"
func IsSuccess(status int) bool {
return status >= http.StatusOK && status < http.StatusMultipleChoices
}
func IsStatusCodeValid(status int) bool {
return http.StatusText(status) != ""
}

View File

@@ -0,0 +1,34 @@
package gphttp
import (
"crypto/tls"
"net"
"net/http"
"time"
)
var DefaultDialer = net.Dialer{
Timeout: 5 * time.Second,
}
func NewTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: DefaultDialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
// DisableCompression: true, // Prevent double compression
ResponseHeaderTimeout: 60 * time.Second,
WriteBufferSize: 16 * 1024, // 16KB
ReadBufferSize: 16 * 1024, // 16KB
}
}
func NewTransportWithTLSConfig(tlsConfig *tls.Config) *http.Transport {
tr := NewTransport()
tr.TLSClientConfig = tlsConfig
return tr
}