mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-24 01:08:31 +02:00
refactor and organize code
This commit is contained in:
173
internal/net/gphttp/accesslog/access_logger.go
Normal file
173
internal/net/gphttp/accesslog/access_logger.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
)
|
||||
|
||||
type (
|
||||
AccessLogger struct {
|
||||
task *task.Task
|
||||
cfg *Config
|
||||
io AccessLogIO
|
||||
|
||||
buf bytes.Buffer // buffer for non-flushed log
|
||||
bufMu sync.RWMutex
|
||||
bufPool sync.Pool // buffer pool for formatting a single log line
|
||||
|
||||
flushThreshold int
|
||||
|
||||
Formatter
|
||||
}
|
||||
|
||||
AccessLogIO interface {
|
||||
io.ReadWriteCloser
|
||||
io.ReadWriteSeeker
|
||||
io.ReaderAt
|
||||
sync.Locker
|
||||
Name() string // file name or path
|
||||
Truncate(size int64) error
|
||||
}
|
||||
|
||||
Formatter interface {
|
||||
// Format writes a log line to line without a trailing newline
|
||||
Format(line *bytes.Buffer, req *http.Request, res *http.Response)
|
||||
SetGetTimeNow(getTimeNow func() time.Time)
|
||||
}
|
||||
)
|
||||
|
||||
func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger {
|
||||
l := &AccessLogger{
|
||||
task: parent.Subtask("accesslog"),
|
||||
cfg: cfg,
|
||||
io: io,
|
||||
}
|
||||
if cfg.BufferSize < 1024 {
|
||||
cfg.BufferSize = DefaultBufferSize
|
||||
}
|
||||
|
||||
fmt := CommonFormatter{cfg: &l.cfg.Fields, GetTimeNow: time.Now}
|
||||
switch l.cfg.Format {
|
||||
case FormatCommon:
|
||||
l.Formatter = &fmt
|
||||
case FormatCombined:
|
||||
l.Formatter = &CombinedFormatter{fmt}
|
||||
case FormatJSON:
|
||||
l.Formatter = &JSONFormatter{fmt}
|
||||
default: // should not happen, validation has done by validate tags
|
||||
panic("invalid access log format")
|
||||
}
|
||||
|
||||
l.flushThreshold = int(cfg.BufferSize * 4 / 5) // 80%
|
||||
l.buf.Grow(int(cfg.BufferSize))
|
||||
l.bufPool.New = func() any {
|
||||
return new(bytes.Buffer)
|
||||
}
|
||||
go l.start()
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *AccessLogger) checkKeep(req *http.Request, res *http.Response) bool {
|
||||
if !l.cfg.Filters.StatusCodes.CheckKeep(req, res) ||
|
||||
!l.cfg.Filters.Method.CheckKeep(req, res) ||
|
||||
!l.cfg.Filters.Headers.CheckKeep(req, res) ||
|
||||
!l.cfg.Filters.CIDR.CheckKeep(req, res) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (l *AccessLogger) Log(req *http.Request, res *http.Response) {
|
||||
if !l.checkKeep(req, res) {
|
||||
return
|
||||
}
|
||||
|
||||
line := l.bufPool.Get().(*bytes.Buffer)
|
||||
l.Format(line, req, res)
|
||||
line.WriteRune('\n')
|
||||
|
||||
l.bufMu.Lock()
|
||||
l.buf.Write(line.Bytes())
|
||||
line.Reset()
|
||||
l.bufPool.Put(line)
|
||||
l.bufMu.Unlock()
|
||||
}
|
||||
|
||||
func (l *AccessLogger) LogError(req *http.Request, err error) {
|
||||
l.Log(req, &http.Response{StatusCode: http.StatusInternalServerError, Status: err.Error()})
|
||||
}
|
||||
|
||||
func (l *AccessLogger) Config() *Config {
|
||||
return l.cfg
|
||||
}
|
||||
|
||||
func (l *AccessLogger) Rotate() error {
|
||||
if l.cfg.Retention == nil {
|
||||
return nil
|
||||
}
|
||||
l.io.Lock()
|
||||
defer l.io.Unlock()
|
||||
|
||||
return l.cfg.Retention.rotateLogFile(l.io)
|
||||
}
|
||||
|
||||
func (l *AccessLogger) Flush(force bool) {
|
||||
if l.buf.Len() == 0 {
|
||||
return
|
||||
}
|
||||
if force || l.buf.Len() >= l.flushThreshold {
|
||||
l.bufMu.RLock()
|
||||
l.write(l.buf.Bytes())
|
||||
l.buf.Reset()
|
||||
l.bufMu.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *AccessLogger) handleErr(err error) {
|
||||
gperr.LogError("failed to write access log", err)
|
||||
}
|
||||
|
||||
func (l *AccessLogger) start() {
|
||||
defer func() {
|
||||
if l.buf.Len() > 0 { // flush last
|
||||
l.write(l.buf.Bytes())
|
||||
}
|
||||
l.io.Close()
|
||||
l.task.Finish(nil)
|
||||
}()
|
||||
|
||||
// periodic flush + threshold flush
|
||||
periodic := time.NewTicker(5 * time.Second)
|
||||
threshold := time.NewTicker(time.Second)
|
||||
defer periodic.Stop()
|
||||
defer threshold.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-l.task.Context().Done():
|
||||
return
|
||||
case <-periodic.C:
|
||||
l.Flush(true)
|
||||
case <-threshold.C:
|
||||
l.Flush(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *AccessLogger) write(data []byte) {
|
||||
l.io.Lock() // prevent concurrent write, i.e. log rotation, other access loggers
|
||||
_, err := l.io.Write(data)
|
||||
l.io.Unlock()
|
||||
if err != nil {
|
||||
l.handleErr(err)
|
||||
} else {
|
||||
logging.Debug().Msg("access log flushed to " + l.io.Name())
|
||||
}
|
||||
}
|
||||
127
internal/net/gphttp/accesslog/access_logger_test.go
Normal file
127
internal/net/gphttp/accesslog/access_logger_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package accesslog_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
const (
|
||||
remote = "192.168.1.1"
|
||||
host = "example.com"
|
||||
uri = "/?bar=baz&foo=bar"
|
||||
uriRedacted = "/?bar=" + RedactedValue + "&foo=" + RedactedValue
|
||||
referer = "https://www.google.com/"
|
||||
proto = "HTTP/1.1"
|
||||
ua = "Go-http-client/1.1"
|
||||
status = http.StatusOK
|
||||
contentLength = 100
|
||||
method = http.MethodGet
|
||||
)
|
||||
|
||||
var (
|
||||
testTask = task.RootTask("test", false)
|
||||
testURL = Must(url.Parse("http://" + host + uri))
|
||||
req = &http.Request{
|
||||
RemoteAddr: remote,
|
||||
Method: method,
|
||||
Proto: proto,
|
||||
Host: testURL.Host,
|
||||
URL: testURL,
|
||||
Header: http.Header{
|
||||
"User-Agent": []string{ua},
|
||||
"Referer": []string{referer},
|
||||
"Cookie": []string{
|
||||
"foo=bar",
|
||||
"bar=baz",
|
||||
},
|
||||
},
|
||||
}
|
||||
resp = &http.Response{
|
||||
StatusCode: status,
|
||||
ContentLength: contentLength,
|
||||
Header: http.Header{"Content-Type": []string{"text/plain"}},
|
||||
}
|
||||
)
|
||||
|
||||
func fmtLog(cfg *Config) (ts string, line string) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
t := time.Now()
|
||||
logger := NewAccessLogger(testTask, nil, cfg)
|
||||
logger.Formatter.SetGetTimeNow(func() time.Time {
|
||||
return t
|
||||
})
|
||||
logger.Format(&buf, req, resp)
|
||||
return t.Format(LogTimeFormat), buf.String()
|
||||
}
|
||||
|
||||
func TestAccessLoggerCommon(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.Format = FormatCommon
|
||||
ts, log := fmtLog(config)
|
||||
ExpectEqual(t, log,
|
||||
fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d",
|
||||
host, remote, ts, method, uri, proto, status, contentLength,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func TestAccessLoggerCombined(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.Format = FormatCombined
|
||||
ts, log := fmtLog(config)
|
||||
ExpectEqual(t, log,
|
||||
fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d \"%s\" \"%s\"",
|
||||
host, remote, ts, method, uri, proto, status, contentLength, referer, ua,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func TestAccessLoggerRedactQuery(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.Format = FormatCommon
|
||||
config.Fields.Query.Default = FieldModeRedact
|
||||
ts, log := fmtLog(config)
|
||||
ExpectEqual(t, log,
|
||||
fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d",
|
||||
host, remote, ts, method, uriRedacted, proto, status, contentLength,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func getJSONEntry(t *testing.T, config *Config) JSONLogEntry {
|
||||
t.Helper()
|
||||
config.Format = FormatJSON
|
||||
var entry JSONLogEntry
|
||||
_, log := fmtLog(config)
|
||||
err := json.Unmarshal([]byte(log), &entry)
|
||||
ExpectNoError(t, err)
|
||||
return entry
|
||||
}
|
||||
|
||||
func TestAccessLoggerJSON(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
entry := getJSONEntry(t, config)
|
||||
ExpectEqual(t, entry.IP, remote)
|
||||
ExpectEqual(t, entry.Method, method)
|
||||
ExpectEqual(t, entry.Scheme, "http")
|
||||
ExpectEqual(t, entry.Host, testURL.Host)
|
||||
ExpectEqual(t, entry.URI, testURL.RequestURI())
|
||||
ExpectEqual(t, entry.Protocol, proto)
|
||||
ExpectEqual(t, entry.Status, status)
|
||||
ExpectEqual(t, entry.ContentType, "text/plain")
|
||||
ExpectEqual(t, entry.Size, contentLength)
|
||||
ExpectEqual(t, entry.Referer, referer)
|
||||
ExpectEqual(t, entry.UserAgent, ua)
|
||||
ExpectEqual(t, len(entry.Headers), 0)
|
||||
ExpectEqual(t, len(entry.Cookies), 0)
|
||||
}
|
||||
57
internal/net/gphttp/accesslog/config.go
Normal file
57
internal/net/gphttp/accesslog/config.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package accesslog
|
||||
|
||||
import "github.com/yusing/go-proxy/internal/utils"
|
||||
|
||||
type (
|
||||
Format string
|
||||
Filters struct {
|
||||
StatusCodes LogFilter[*StatusCodeRange] `json:"status_codes"`
|
||||
Method LogFilter[HTTPMethod] `json:"method"`
|
||||
Host LogFilter[Host] `json:"host"`
|
||||
Headers LogFilter[*HTTPHeader] `json:"headers"` // header exists or header == value
|
||||
CIDR LogFilter[*CIDR] `json:"cidr"`
|
||||
}
|
||||
Fields struct {
|
||||
Headers FieldConfig `json:"headers"`
|
||||
Query FieldConfig `json:"query"`
|
||||
Cookies FieldConfig `json:"cookies"`
|
||||
}
|
||||
Config struct {
|
||||
BufferSize uint `json:"buffer_size" validate:"gte=1"`
|
||||
Format Format `json:"format" validate:"oneof=common combined json"`
|
||||
Path string `json:"path" validate:"required"`
|
||||
Filters Filters `json:"filters"`
|
||||
Fields Fields `json:"fields"`
|
||||
Retention *Retention `json:"retention"`
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
FormatCommon Format = "common"
|
||||
FormatCombined Format = "combined"
|
||||
FormatJSON Format = "json"
|
||||
)
|
||||
|
||||
const DefaultBufferSize = 64 * 1024 // 64KB
|
||||
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
BufferSize: DefaultBufferSize,
|
||||
Format: FormatCombined,
|
||||
Fields: Fields{
|
||||
Headers: FieldConfig{
|
||||
Default: FieldModeDrop,
|
||||
},
|
||||
Query: FieldConfig{
|
||||
Default: FieldModeKeep,
|
||||
},
|
||||
Cookies: FieldConfig{
|
||||
Default: FieldModeDrop,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
utils.RegisterDefaultValueFactory(DefaultConfig)
|
||||
}
|
||||
53
internal/net/gphttp/accesslog/config_test.go
Normal file
53
internal/net/gphttp/accesslog/config_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package accesslog_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/docker"
|
||||
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestNewConfig(t *testing.T) {
|
||||
labels := map[string]string{
|
||||
"proxy.buffer_size": "10",
|
||||
"proxy.format": "combined",
|
||||
"proxy.path": "/tmp/access.log",
|
||||
"proxy.filters.status_codes.values": "200-299",
|
||||
"proxy.filters.method.values": "GET, POST",
|
||||
"proxy.filters.headers.values": "foo=bar, baz",
|
||||
"proxy.filters.headers.negative": "true",
|
||||
"proxy.filters.cidr.values": "192.168.10.0/24",
|
||||
"proxy.fields.headers.default": "keep",
|
||||
"proxy.fields.headers.config.foo": "redact",
|
||||
"proxy.fields.query.default": "drop",
|
||||
"proxy.fields.query.config.foo": "keep",
|
||||
"proxy.fields.cookies.default": "redact",
|
||||
"proxy.fields.cookies.config.foo": "keep",
|
||||
}
|
||||
parsed, err := docker.ParseLabels(labels)
|
||||
ExpectNoError(t, err)
|
||||
|
||||
var config Config
|
||||
err = utils.Deserialize(parsed, &config)
|
||||
ExpectNoError(t, err)
|
||||
|
||||
ExpectEqual(t, config.BufferSize, 10)
|
||||
ExpectEqual(t, config.Format, FormatCombined)
|
||||
ExpectEqual(t, config.Path, "/tmp/access.log")
|
||||
ExpectDeepEqual(t, config.Filters.StatusCodes.Values, []*StatusCodeRange{{Start: 200, End: 299}})
|
||||
ExpectEqual(t, len(config.Filters.Method.Values), 2)
|
||||
ExpectDeepEqual(t, config.Filters.Method.Values, []HTTPMethod{"GET", "POST"})
|
||||
ExpectEqual(t, len(config.Filters.Headers.Values), 2)
|
||||
ExpectDeepEqual(t, config.Filters.Headers.Values, []*HTTPHeader{{Key: "foo", Value: "bar"}, {Key: "baz", Value: ""}})
|
||||
ExpectTrue(t, config.Filters.Headers.Negative)
|
||||
ExpectEqual(t, len(config.Filters.CIDR.Values), 1)
|
||||
ExpectEqual(t, config.Filters.CIDR.Values[0].String(), "192.168.10.0/24")
|
||||
ExpectEqual(t, config.Fields.Headers.Default, FieldModeKeep)
|
||||
ExpectEqual(t, config.Fields.Headers.Config["foo"], FieldModeRedact)
|
||||
ExpectEqual(t, config.Fields.Query.Default, FieldModeDrop)
|
||||
ExpectEqual(t, config.Fields.Query.Config["foo"], FieldModeKeep)
|
||||
ExpectEqual(t, config.Fields.Cookies.Default, FieldModeRedact)
|
||||
ExpectEqual(t, config.Fields.Cookies.Config["foo"], FieldModeKeep)
|
||||
}
|
||||
103
internal/net/gphttp/accesslog/fields.go
Normal file
103
internal/net/gphttp/accesslog/fields.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type (
|
||||
FieldConfig struct {
|
||||
Default FieldMode `json:"default" validate:"oneof=keep drop redact"`
|
||||
Config map[string]FieldMode `json:"config" validate:"dive,oneof=keep drop redact"`
|
||||
}
|
||||
FieldMode string
|
||||
)
|
||||
|
||||
const (
|
||||
FieldModeKeep FieldMode = "keep"
|
||||
FieldModeDrop FieldMode = "drop"
|
||||
FieldModeRedact FieldMode = "redact"
|
||||
|
||||
RedactedValue = "REDACTED"
|
||||
)
|
||||
|
||||
func processMap[V any](cfg *FieldConfig, m map[string]V, redactedV V) map[string]V {
|
||||
if len(cfg.Config) == 0 {
|
||||
switch cfg.Default {
|
||||
case FieldModeKeep:
|
||||
return m
|
||||
case FieldModeDrop:
|
||||
return nil
|
||||
case FieldModeRedact:
|
||||
redacted := make(map[string]V)
|
||||
for k := range m {
|
||||
redacted[k] = redactedV
|
||||
}
|
||||
return redacted
|
||||
}
|
||||
}
|
||||
|
||||
if len(m) == 0 {
|
||||
return m
|
||||
}
|
||||
|
||||
newMap := make(map[string]V, len(m))
|
||||
for k := range m {
|
||||
var mode FieldMode
|
||||
var ok bool
|
||||
if mode, ok = cfg.Config[k]; !ok {
|
||||
mode = cfg.Default
|
||||
}
|
||||
switch mode {
|
||||
case FieldModeKeep:
|
||||
newMap[k] = m[k]
|
||||
case FieldModeRedact:
|
||||
newMap[k] = redactedV
|
||||
}
|
||||
}
|
||||
return newMap
|
||||
}
|
||||
|
||||
func processSlice[V any, VReturn any](cfg *FieldConfig, s []V, getKey func(V) string, convert func(V) VReturn, redact func(V) VReturn) map[string]VReturn {
|
||||
if len(s) == 0 ||
|
||||
len(cfg.Config) == 0 && cfg.Default == FieldModeDrop {
|
||||
return nil
|
||||
}
|
||||
newMap := make(map[string]VReturn, len(s))
|
||||
for _, v := range s {
|
||||
var mode FieldMode
|
||||
var ok bool
|
||||
k := getKey(v)
|
||||
if mode, ok = cfg.Config[k]; !ok {
|
||||
mode = cfg.Default
|
||||
}
|
||||
switch mode {
|
||||
case FieldModeKeep:
|
||||
newMap[k] = convert(v)
|
||||
case FieldModeRedact:
|
||||
newMap[k] = redact(v)
|
||||
}
|
||||
}
|
||||
return newMap
|
||||
}
|
||||
|
||||
func (cfg *FieldConfig) ProcessHeaders(headers http.Header) http.Header {
|
||||
return processMap(cfg, headers, []string{RedactedValue})
|
||||
}
|
||||
|
||||
func (cfg *FieldConfig) ProcessQuery(q url.Values) url.Values {
|
||||
return processMap(cfg, q, []string{RedactedValue})
|
||||
}
|
||||
|
||||
func (cfg *FieldConfig) ProcessCookies(cookies []*http.Cookie) map[string]string {
|
||||
return processSlice(cfg, cookies,
|
||||
func(c *http.Cookie) string {
|
||||
return c.Name
|
||||
},
|
||||
func(c *http.Cookie) string {
|
||||
return c.Value
|
||||
},
|
||||
func(c *http.Cookie) string {
|
||||
return RedactedValue
|
||||
})
|
||||
}
|
||||
96
internal/net/gphttp/accesslog/fields_test.go
Normal file
96
internal/net/gphttp/accesslog/fields_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package accesslog_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
// Cookie header should be removed,
|
||||
// stored in JSONLogEntry.Cookies instead.
|
||||
func TestAccessLoggerJSONKeepHeaders(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.Fields.Headers.Default = FieldModeKeep
|
||||
entry := getJSONEntry(t, config)
|
||||
for k, v := range req.Header {
|
||||
if k != "Cookie" {
|
||||
ExpectDeepEqual(t, entry.Headers[k], v)
|
||||
}
|
||||
}
|
||||
|
||||
config.Fields.Headers.Config = map[string]FieldMode{
|
||||
"Referer": FieldModeRedact,
|
||||
"User-Agent": FieldModeDrop,
|
||||
}
|
||||
entry = getJSONEntry(t, config)
|
||||
ExpectDeepEqual(t, entry.Headers["Referer"], []string{RedactedValue})
|
||||
ExpectDeepEqual(t, entry.Headers["User-Agent"], nil)
|
||||
}
|
||||
|
||||
func TestAccessLoggerJSONDropHeaders(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.Fields.Headers.Default = FieldModeDrop
|
||||
entry := getJSONEntry(t, config)
|
||||
for k := range req.Header {
|
||||
ExpectDeepEqual(t, entry.Headers[k], nil)
|
||||
}
|
||||
|
||||
config.Fields.Headers.Config = map[string]FieldMode{
|
||||
"Referer": FieldModeKeep,
|
||||
"User-Agent": FieldModeRedact,
|
||||
}
|
||||
entry = getJSONEntry(t, config)
|
||||
ExpectDeepEqual(t, entry.Headers["Referer"], []string{req.Header.Get("Referer")})
|
||||
ExpectDeepEqual(t, entry.Headers["User-Agent"], []string{RedactedValue})
|
||||
}
|
||||
|
||||
func TestAccessLoggerJSONRedactHeaders(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.Fields.Headers.Default = FieldModeRedact
|
||||
entry := getJSONEntry(t, config)
|
||||
ExpectEqual(t, len(entry.Headers["Cookie"]), 0)
|
||||
for k := range req.Header {
|
||||
if k != "Cookie" {
|
||||
ExpectDeepEqual(t, entry.Headers[k], []string{RedactedValue})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLoggerJSONKeepCookies(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.Fields.Headers.Default = FieldModeKeep
|
||||
config.Fields.Cookies.Default = FieldModeKeep
|
||||
entry := getJSONEntry(t, config)
|
||||
ExpectEqual(t, len(entry.Headers["Cookie"]), 0)
|
||||
for _, cookie := range req.Cookies() {
|
||||
ExpectEqual(t, entry.Cookies[cookie.Name], cookie.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLoggerJSONRedactCookies(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.Fields.Headers.Default = FieldModeKeep
|
||||
config.Fields.Cookies.Default = FieldModeRedact
|
||||
entry := getJSONEntry(t, config)
|
||||
ExpectEqual(t, len(entry.Headers["Cookie"]), 0)
|
||||
for _, cookie := range req.Cookies() {
|
||||
ExpectEqual(t, entry.Cookies[cookie.Name], RedactedValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLoggerJSONDropQuery(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.Fields.Query.Default = FieldModeDrop
|
||||
entry := getJSONEntry(t, config)
|
||||
ExpectDeepEqual(t, entry.Query["foo"], nil)
|
||||
ExpectDeepEqual(t, entry.Query["bar"], nil)
|
||||
}
|
||||
|
||||
func TestAccessLoggerJSONRedactQuery(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.Fields.Query.Default = FieldModeRedact
|
||||
entry := getJSONEntry(t, config)
|
||||
ExpectDeepEqual(t, entry.Query["foo"], []string{RedactedValue})
|
||||
ExpectDeepEqual(t, entry.Query["bar"], []string{RedactedValue})
|
||||
}
|
||||
69
internal/net/gphttp/accesslog/file_logger.go
Normal file
69
internal/net/gphttp/accesslog/file_logger.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"sync"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type File struct {
|
||||
*os.File
|
||||
sync.Mutex
|
||||
|
||||
// os.File.Name() may not equal to key of `openedFiles`.
|
||||
// Store it for later delete from `openedFiles`.
|
||||
path string
|
||||
|
||||
refCount *utils.RefCount
|
||||
}
|
||||
|
||||
var (
|
||||
openedFiles = make(map[string]*File)
|
||||
openedFilesMu sync.Mutex
|
||||
)
|
||||
|
||||
func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) {
|
||||
openedFilesMu.Lock()
|
||||
|
||||
var file *File
|
||||
path := path.Clean(cfg.Path)
|
||||
if opened, ok := openedFiles[path]; ok {
|
||||
opened.refCount.Add()
|
||||
file = opened
|
||||
} else {
|
||||
f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o644)
|
||||
if err != nil {
|
||||
openedFilesMu.Unlock()
|
||||
return nil, fmt.Errorf("access log open error: %w", err)
|
||||
}
|
||||
file = &File{File: f, path: path, refCount: utils.NewRefCounter()}
|
||||
openedFiles[path] = file
|
||||
go file.closeOnZero()
|
||||
}
|
||||
|
||||
openedFilesMu.Unlock()
|
||||
return NewAccessLogger(parent, file, cfg), nil
|
||||
}
|
||||
|
||||
func (f *File) Close() error {
|
||||
f.refCount.Sub()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *File) closeOnZero() {
|
||||
defer logging.Debug().
|
||||
Str("path", f.path).
|
||||
Msg("access log closed")
|
||||
|
||||
<-f.refCount.Zero()
|
||||
|
||||
openedFilesMu.Lock()
|
||||
delete(openedFiles, f.path)
|
||||
openedFilesMu.Unlock()
|
||||
f.File.Close()
|
||||
}
|
||||
95
internal/net/gphttp/accesslog/file_logger_test.go
Normal file
95
internal/net/gphttp/accesslog/file_logger_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
)
|
||||
|
||||
func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.Path = "test.log"
|
||||
parent := task.RootTask("test", false)
|
||||
|
||||
loggerCount := 10
|
||||
accessLogIOs := make([]AccessLogIO, loggerCount)
|
||||
|
||||
// make test log file
|
||||
file, err := os.Create(cfg.Path)
|
||||
ExpectNoError(t, err)
|
||||
file.Close()
|
||||
t.Cleanup(func() {
|
||||
ExpectNoError(t, os.Remove(cfg.Path))
|
||||
})
|
||||
|
||||
for i := range loggerCount {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
logger, err := NewFileAccessLogger(parent, cfg)
|
||||
ExpectNoError(t, err)
|
||||
accessLogIOs[index] = logger.io
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
firstIO := accessLogIOs[0]
|
||||
for _, io := range accessLogIOs {
|
||||
ExpectEqual(t, io, firstIO)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) {
|
||||
var file MockFile
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.BufferSize = 1024
|
||||
parent := task.RootTask("test", false)
|
||||
|
||||
loggerCount := 5
|
||||
logCountPerLogger := 10
|
||||
loggers := make([]*AccessLogger, loggerCount)
|
||||
|
||||
for i := range loggerCount {
|
||||
loggers[i] = NewAccessLogger(parent, &file, cfg)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
resp := &http.Response{StatusCode: http.StatusOK}
|
||||
|
||||
for _, logger := range loggers {
|
||||
wg.Add(1)
|
||||
go func(l *AccessLogger) {
|
||||
defer wg.Done()
|
||||
parallelLog(l, req, resp, logCountPerLogger)
|
||||
l.Flush(true)
|
||||
}(logger)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
expected := loggerCount * logCountPerLogger
|
||||
actual := file.Count()
|
||||
ExpectEqual(t, actual, expected)
|
||||
}
|
||||
|
||||
func parallelLog(logger *AccessLogger, req *http.Request, resp *http.Response, n int) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for range n {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
logger.Log(req, resp)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
99
internal/net/gphttp/accesslog/filter.go
Normal file
99
internal/net/gphttp/accesslog/filter.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type (
|
||||
LogFilter[T Filterable] struct {
|
||||
Negative bool
|
||||
Values []T
|
||||
}
|
||||
Filterable interface {
|
||||
comparable
|
||||
Fulfill(req *http.Request, res *http.Response) bool
|
||||
}
|
||||
HTTPMethod string
|
||||
HTTPHeader struct {
|
||||
Key, Value string
|
||||
}
|
||||
Host string
|
||||
CIDR struct{ types.CIDR }
|
||||
)
|
||||
|
||||
var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter")
|
||||
|
||||
func (f *LogFilter[T]) CheckKeep(req *http.Request, res *http.Response) bool {
|
||||
if len(f.Values) == 0 {
|
||||
return !f.Negative
|
||||
}
|
||||
for _, check := range f.Values {
|
||||
if check.Fulfill(req, res) {
|
||||
return !f.Negative
|
||||
}
|
||||
}
|
||||
return f.Negative
|
||||
}
|
||||
|
||||
func (r *StatusCodeRange) Fulfill(req *http.Request, res *http.Response) bool {
|
||||
return r.Includes(res.StatusCode)
|
||||
}
|
||||
|
||||
func (method HTTPMethod) Fulfill(req *http.Request, res *http.Response) bool {
|
||||
return req.Method == string(method)
|
||||
}
|
||||
|
||||
// Parse implements strutils.Parser.
|
||||
func (k *HTTPHeader) Parse(v string) error {
|
||||
split := strutils.SplitRune(v, '=')
|
||||
switch len(split) {
|
||||
case 1:
|
||||
split = append(split, "")
|
||||
case 2:
|
||||
default:
|
||||
return ErrInvalidHTTPHeaderFilter.Subject(v)
|
||||
}
|
||||
k.Key = split[0]
|
||||
k.Value = split[1]
|
||||
return nil
|
||||
}
|
||||
|
||||
func (k *HTTPHeader) Fulfill(req *http.Request, res *http.Response) bool {
|
||||
wanted := k.Value
|
||||
// non canonical key matching
|
||||
got, ok := req.Header[k.Key]
|
||||
if wanted == "" {
|
||||
return ok
|
||||
}
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, v := range got {
|
||||
if strings.EqualFold(v, wanted) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h Host) Fulfill(req *http.Request, res *http.Response) bool {
|
||||
return req.Host == string(h)
|
||||
}
|
||||
|
||||
func (cidr CIDR) Fulfill(req *http.Request, res *http.Response) bool {
|
||||
ip, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err != nil {
|
||||
ip = req.RemoteAddr
|
||||
}
|
||||
netIP := net.ParseIP(ip)
|
||||
if netIP == nil {
|
||||
return false
|
||||
}
|
||||
return cidr.Contains(netIP)
|
||||
}
|
||||
188
internal/net/gphttp/accesslog/filter_test.go
Normal file
188
internal/net/gphttp/accesslog/filter_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package accesslog_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestStatusCodeFilter(t *testing.T) {
|
||||
values := []*StatusCodeRange{
|
||||
strutils.MustParse[*StatusCodeRange]("200-308"),
|
||||
}
|
||||
t.Run("positive", func(t *testing.T) {
|
||||
filter := &LogFilter[*StatusCodeRange]{}
|
||||
ExpectTrue(t, filter.CheckKeep(nil, nil))
|
||||
|
||||
// keep any 2xx 3xx (inclusive)
|
||||
filter.Values = values
|
||||
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
|
||||
StatusCode: http.StatusForbidden,
|
||||
}))
|
||||
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
}))
|
||||
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
|
||||
StatusCode: http.StatusMultipleChoices,
|
||||
}))
|
||||
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
|
||||
StatusCode: http.StatusPermanentRedirect,
|
||||
}))
|
||||
})
|
||||
|
||||
t.Run("negative", func(t *testing.T) {
|
||||
filter := &LogFilter[*StatusCodeRange]{
|
||||
Negative: true,
|
||||
}
|
||||
ExpectFalse(t, filter.CheckKeep(nil, nil))
|
||||
|
||||
// drop any 2xx 3xx (inclusive)
|
||||
filter.Values = values
|
||||
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
|
||||
StatusCode: http.StatusForbidden,
|
||||
}))
|
||||
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
}))
|
||||
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
|
||||
StatusCode: http.StatusMultipleChoices,
|
||||
}))
|
||||
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
|
||||
StatusCode: http.StatusPermanentRedirect,
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMethodFilter(t *testing.T) {
|
||||
t.Run("positive", func(t *testing.T) {
|
||||
filter := &LogFilter[HTTPMethod]{}
|
||||
ExpectTrue(t, filter.CheckKeep(&http.Request{
|
||||
Method: http.MethodGet,
|
||||
}, nil))
|
||||
ExpectTrue(t, filter.CheckKeep(&http.Request{
|
||||
Method: http.MethodPost,
|
||||
}, nil))
|
||||
|
||||
// keep get only
|
||||
filter.Values = []HTTPMethod{http.MethodGet}
|
||||
ExpectTrue(t, filter.CheckKeep(&http.Request{
|
||||
Method: http.MethodGet,
|
||||
}, nil))
|
||||
ExpectFalse(t, filter.CheckKeep(&http.Request{
|
||||
Method: http.MethodPost,
|
||||
}, nil))
|
||||
})
|
||||
|
||||
t.Run("negative", func(t *testing.T) {
|
||||
filter := &LogFilter[HTTPMethod]{
|
||||
Negative: true,
|
||||
}
|
||||
ExpectFalse(t, filter.CheckKeep(&http.Request{
|
||||
Method: http.MethodGet,
|
||||
}, nil))
|
||||
ExpectFalse(t, filter.CheckKeep(&http.Request{
|
||||
Method: http.MethodPost,
|
||||
}, nil))
|
||||
|
||||
// drop post only
|
||||
filter.Values = []HTTPMethod{http.MethodPost}
|
||||
ExpectFalse(t, filter.CheckKeep(&http.Request{
|
||||
Method: http.MethodPost,
|
||||
}, nil))
|
||||
ExpectTrue(t, filter.CheckKeep(&http.Request{
|
||||
Method: http.MethodGet,
|
||||
}, nil))
|
||||
})
|
||||
}
|
||||
|
||||
func TestHeaderFilter(t *testing.T) {
|
||||
fooBar := &http.Request{
|
||||
Header: http.Header{
|
||||
"Foo": []string{"bar"},
|
||||
},
|
||||
}
|
||||
fooBaz := &http.Request{
|
||||
Header: http.Header{
|
||||
"Foo": []string{"baz"},
|
||||
},
|
||||
}
|
||||
headerFoo := []*HTTPHeader{
|
||||
strutils.MustParse[*HTTPHeader]("Foo"),
|
||||
}
|
||||
ExpectEqual(t, headerFoo[0].Key, "Foo")
|
||||
ExpectEqual(t, headerFoo[0].Value, "")
|
||||
headerFooBar := []*HTTPHeader{
|
||||
strutils.MustParse[*HTTPHeader]("Foo=bar"),
|
||||
}
|
||||
ExpectEqual(t, headerFooBar[0].Key, "Foo")
|
||||
ExpectEqual(t, headerFooBar[0].Value, "bar")
|
||||
|
||||
t.Run("positive", func(t *testing.T) {
|
||||
filter := &LogFilter[*HTTPHeader]{}
|
||||
ExpectTrue(t, filter.CheckKeep(fooBar, nil))
|
||||
ExpectTrue(t, filter.CheckKeep(fooBaz, nil))
|
||||
|
||||
// keep any foo
|
||||
filter.Values = headerFoo
|
||||
ExpectTrue(t, filter.CheckKeep(fooBar, nil))
|
||||
ExpectTrue(t, filter.CheckKeep(fooBaz, nil))
|
||||
|
||||
// keep foo == bar
|
||||
filter.Values = headerFooBar
|
||||
ExpectTrue(t, filter.CheckKeep(fooBar, nil))
|
||||
ExpectFalse(t, filter.CheckKeep(fooBaz, nil))
|
||||
})
|
||||
t.Run("negative", func(t *testing.T) {
|
||||
filter := &LogFilter[*HTTPHeader]{
|
||||
Negative: true,
|
||||
}
|
||||
ExpectFalse(t, filter.CheckKeep(fooBar, nil))
|
||||
ExpectFalse(t, filter.CheckKeep(fooBaz, nil))
|
||||
|
||||
// drop any foo
|
||||
filter.Values = headerFoo
|
||||
ExpectFalse(t, filter.CheckKeep(fooBar, nil))
|
||||
ExpectFalse(t, filter.CheckKeep(fooBaz, nil))
|
||||
|
||||
// drop foo == bar
|
||||
filter.Values = headerFooBar
|
||||
ExpectFalse(t, filter.CheckKeep(fooBar, nil))
|
||||
ExpectTrue(t, filter.CheckKeep(fooBaz, nil))
|
||||
})
|
||||
}
|
||||
|
||||
func TestCIDRFilter(t *testing.T) {
|
||||
cidr := []*CIDR{
|
||||
strutils.MustParse[*CIDR]("192.168.10.0/24"),
|
||||
}
|
||||
ExpectEqual(t, cidr[0].String(), "192.168.10.0/24")
|
||||
inCIDR := &http.Request{
|
||||
RemoteAddr: "192.168.10.1",
|
||||
}
|
||||
notInCIDR := &http.Request{
|
||||
RemoteAddr: "192.168.11.1",
|
||||
}
|
||||
|
||||
t.Run("positive", func(t *testing.T) {
|
||||
filter := &LogFilter[*CIDR]{}
|
||||
ExpectTrue(t, filter.CheckKeep(inCIDR, nil))
|
||||
ExpectTrue(t, filter.CheckKeep(notInCIDR, nil))
|
||||
|
||||
filter.Values = cidr
|
||||
ExpectTrue(t, filter.CheckKeep(inCIDR, nil))
|
||||
ExpectFalse(t, filter.CheckKeep(notInCIDR, nil))
|
||||
})
|
||||
|
||||
t.Run("negative", func(t *testing.T) {
|
||||
filter := &LogFilter[*CIDR]{Negative: true}
|
||||
ExpectFalse(t, filter.CheckKeep(inCIDR, nil))
|
||||
ExpectFalse(t, filter.CheckKeep(notInCIDR, nil))
|
||||
|
||||
filter.Values = cidr
|
||||
ExpectFalse(t, filter.CheckKeep(inCIDR, nil))
|
||||
ExpectTrue(t, filter.CheckKeep(notInCIDR, nil))
|
||||
})
|
||||
}
|
||||
144
internal/net/gphttp/accesslog/formatter.go
Normal file
144
internal/net/gphttp/accesslog/formatter.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
)
|
||||
|
||||
type (
|
||||
CommonFormatter struct {
|
||||
cfg *Fields
|
||||
GetTimeNow func() time.Time // for testing purposes only
|
||||
}
|
||||
CombinedFormatter struct{ CommonFormatter }
|
||||
JSONFormatter struct{ CommonFormatter }
|
||||
|
||||
JSONLogEntry struct {
|
||||
Time string `json:"time"`
|
||||
IP string `json:"ip"`
|
||||
Method string `json:"method"`
|
||||
Scheme string `json:"scheme"`
|
||||
Host string `json:"host"`
|
||||
URI string `json:"uri"`
|
||||
Protocol string `json:"protocol"`
|
||||
Status int `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
ContentType string `json:"type"`
|
||||
Size int64 `json:"size"`
|
||||
Referer string `json:"referer"`
|
||||
UserAgent string `json:"useragent"`
|
||||
Query map[string][]string `json:"query,omitempty"`
|
||||
Headers map[string][]string `json:"headers,omitempty"`
|
||||
Cookies map[string]string `json:"cookies,omitempty"`
|
||||
}
|
||||
)
|
||||
|
||||
const LogTimeFormat = "02/Jan/2006:15:04:05 -0700"
|
||||
|
||||
func scheme(req *http.Request) string {
|
||||
if req.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
return "http"
|
||||
}
|
||||
|
||||
func requestURI(u *url.URL, query url.Values) string {
|
||||
uri := u.EscapedPath()
|
||||
if len(query) > 0 {
|
||||
uri += "?" + query.Encode()
|
||||
}
|
||||
return uri
|
||||
}
|
||||
|
||||
func clientIP(req *http.Request) string {
|
||||
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err == nil {
|
||||
return clientIP
|
||||
}
|
||||
return req.RemoteAddr
|
||||
}
|
||||
|
||||
// debug only.
|
||||
func (f *CommonFormatter) SetGetTimeNow(getTimeNow func() time.Time) {
|
||||
f.GetTimeNow = getTimeNow
|
||||
}
|
||||
|
||||
func (f *CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
|
||||
query := f.cfg.Query.ProcessQuery(req.URL.Query())
|
||||
|
||||
line.WriteString(req.Host)
|
||||
line.WriteRune(' ')
|
||||
|
||||
line.WriteString(clientIP(req))
|
||||
line.WriteString(" - - [")
|
||||
|
||||
line.WriteString(f.GetTimeNow().Format(LogTimeFormat))
|
||||
line.WriteString("] \"")
|
||||
|
||||
line.WriteString(req.Method)
|
||||
line.WriteRune(' ')
|
||||
line.WriteString(requestURI(req.URL, query))
|
||||
line.WriteRune(' ')
|
||||
line.WriteString(req.Proto)
|
||||
line.WriteString("\" ")
|
||||
|
||||
line.WriteString(strconv.Itoa(res.StatusCode))
|
||||
line.WriteRune(' ')
|
||||
line.WriteString(strconv.FormatInt(res.ContentLength, 10))
|
||||
}
|
||||
|
||||
func (f *CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
|
||||
f.CommonFormatter.Format(line, req, res)
|
||||
line.WriteString(" \"")
|
||||
line.WriteString(req.Referer())
|
||||
line.WriteString("\" \"")
|
||||
line.WriteString(req.UserAgent())
|
||||
line.WriteRune('"')
|
||||
}
|
||||
|
||||
func (f *JSONFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
|
||||
query := f.cfg.Query.ProcessQuery(req.URL.Query())
|
||||
headers := f.cfg.Headers.ProcessHeaders(req.Header)
|
||||
headers.Del("Cookie")
|
||||
cookies := f.cfg.Cookies.ProcessCookies(req.Cookies())
|
||||
|
||||
entry := JSONLogEntry{
|
||||
Time: f.GetTimeNow().Format(LogTimeFormat),
|
||||
IP: clientIP(req),
|
||||
Method: req.Method,
|
||||
Scheme: scheme(req),
|
||||
Host: req.Host,
|
||||
URI: requestURI(req.URL, query),
|
||||
Protocol: req.Proto,
|
||||
Status: res.StatusCode,
|
||||
ContentType: res.Header.Get("Content-Type"),
|
||||
Size: res.ContentLength,
|
||||
Referer: req.Referer(),
|
||||
UserAgent: req.UserAgent(),
|
||||
Query: query,
|
||||
Headers: headers,
|
||||
Cookies: cookies,
|
||||
}
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
entry.Error = res.Status
|
||||
}
|
||||
|
||||
if entry.ContentType == "" {
|
||||
// try to get content type from request
|
||||
entry.ContentType = req.Header.Get("Content-Type")
|
||||
}
|
||||
|
||||
marshaller := json.NewEncoder(line)
|
||||
err := marshaller.Encode(entry)
|
||||
if err != nil {
|
||||
logging.Err(err).Msg("failed to marshal json log")
|
||||
}
|
||||
}
|
||||
74
internal/net/gphttp/accesslog/mock_file.go
Normal file
74
internal/net/gphttp/accesslog/mock_file.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type MockFile struct {
|
||||
data []byte
|
||||
position int64
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (m *MockFile) Seek(offset int64, whence int) (int64, error) {
|
||||
switch whence {
|
||||
case io.SeekStart:
|
||||
m.position = offset
|
||||
case io.SeekCurrent:
|
||||
m.position += offset
|
||||
case io.SeekEnd:
|
||||
m.position = int64(len(m.data)) + offset
|
||||
}
|
||||
return m.position, nil
|
||||
}
|
||||
|
||||
func (m *MockFile) Write(p []byte) (n int, err error) {
|
||||
m.data = append(m.data, p...)
|
||||
n = len(p)
|
||||
m.position += int64(n)
|
||||
return
|
||||
}
|
||||
|
||||
func (m *MockFile) Name() string {
|
||||
return "mock"
|
||||
}
|
||||
|
||||
func (m *MockFile) Read(p []byte) (n int, err error) {
|
||||
if m.position >= int64(len(m.data)) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n = copy(p, m.data[m.position:])
|
||||
m.position += int64(n)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (m *MockFile) ReadAt(p []byte, off int64) (n int, err error) {
|
||||
if off >= int64(len(m.data)) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n = copy(p, m.data[off:])
|
||||
m.position += int64(n)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (m *MockFile) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockFile) Truncate(size int64) error {
|
||||
m.data = m.data[:size]
|
||||
m.position = size
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockFile) Count() int {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return bytes.Count(m.data[:m.position], []byte("\n"))
|
||||
}
|
||||
|
||||
func (m *MockFile) Len() int64 {
|
||||
return m.position
|
||||
}
|
||||
198
internal/net/gphttp/accesslog/retention.go
Normal file
198
internal/net/gphttp/accesslog/retention.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type Retention struct {
|
||||
Days uint64 `json:"days"`
|
||||
Last uint64 `json:"last"`
|
||||
}
|
||||
|
||||
const chunkSizeMax int64 = 128 * 1024 // 128KB
|
||||
|
||||
var (
|
||||
ErrInvalidSyntax = gperr.New("invalid syntax")
|
||||
ErrZeroValue = gperr.New("zero value")
|
||||
)
|
||||
|
||||
// Syntax:
|
||||
//
|
||||
// <N> days|weeks|months
|
||||
//
|
||||
// last <N>
|
||||
//
|
||||
// Parse implements strutils.Parser.
|
||||
func (r *Retention) Parse(v string) (err error) {
|
||||
split := strutils.SplitSpace(v)
|
||||
if len(split) != 2 {
|
||||
return ErrInvalidSyntax.Subject(v)
|
||||
}
|
||||
switch split[0] {
|
||||
case "last":
|
||||
r.Last, err = strconv.ParseUint(split[1], 10, 64)
|
||||
default: // <N> days|weeks|months
|
||||
r.Days, err = strconv.ParseUint(split[0], 10, 64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
switch split[1] {
|
||||
case "days":
|
||||
case "weeks":
|
||||
r.Days *= 7
|
||||
case "months":
|
||||
r.Days *= 30
|
||||
default:
|
||||
return ErrInvalidSyntax.Subject("unit " + split[1])
|
||||
}
|
||||
}
|
||||
if r.Days == 0 && r.Last == 0 {
|
||||
return ErrZeroValue
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *Retention) rotateLogFile(file AccessLogIO) (err error) {
|
||||
lastN := int(r.Last)
|
||||
days := int(r.Days)
|
||||
|
||||
// Seek to end to get file size
|
||||
size, err := file.Seek(0, io.SeekEnd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize ring buffer for last N lines
|
||||
lines := make([][]byte, 0, lastN|(days*1000))
|
||||
pos := size
|
||||
unprocessed := 0
|
||||
|
||||
var chunk [chunkSizeMax]byte
|
||||
var lastLine []byte
|
||||
|
||||
var shouldStop func() bool
|
||||
if days > 0 {
|
||||
cutoff := time.Now().AddDate(0, 0, -days)
|
||||
shouldStop = func() bool {
|
||||
return len(lastLine) > 0 && !parseLogTime(lastLine).After(cutoff)
|
||||
}
|
||||
} else {
|
||||
shouldStop = func() bool {
|
||||
return len(lines) == lastN
|
||||
}
|
||||
}
|
||||
|
||||
// Read backwards until we have enough lines or reach start of file
|
||||
for pos > 0 {
|
||||
if pos > chunkSizeMax {
|
||||
pos -= chunkSizeMax
|
||||
} else {
|
||||
pos = 0
|
||||
}
|
||||
|
||||
// Seek to the current chunk
|
||||
if _, err = file.Seek(pos, io.SeekStart); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var nRead int
|
||||
// Read the chunk
|
||||
if nRead, err = file.Read(chunk[unprocessed:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// last unprocessed bytes + read bytes
|
||||
curChunk := chunk[:unprocessed+nRead]
|
||||
unprocessed = len(curChunk)
|
||||
|
||||
// Split into lines
|
||||
scanner := bufio.NewScanner(bytes.NewReader(curChunk))
|
||||
for !shouldStop() && scanner.Scan() {
|
||||
lastLine = scanner.Bytes()
|
||||
lines = append(lines, lastLine)
|
||||
unprocessed -= len(lastLine)
|
||||
}
|
||||
if shouldStop() {
|
||||
break
|
||||
}
|
||||
|
||||
// move unprocessed bytes to the beginning for next iteration
|
||||
copy(chunk[:], curChunk[unprocessed:])
|
||||
}
|
||||
|
||||
if days > 0 {
|
||||
// truncate to the end of the log within last N days
|
||||
return file.Truncate(pos)
|
||||
}
|
||||
|
||||
// write lines to buffer in reverse order
|
||||
// since we read them backwards
|
||||
var buf bytes.Buffer
|
||||
for i := len(lines) - 1; i >= 0; i-- {
|
||||
buf.Write(lines[i])
|
||||
buf.WriteRune('\n')
|
||||
}
|
||||
|
||||
return writeTruncate(file, &buf)
|
||||
}
|
||||
|
||||
func writeTruncate(file AccessLogIO, buf *bytes.Buffer) (err error) {
|
||||
// Seek to beginning and truncate
|
||||
if _, err := file.Seek(0, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buffered := bufio.NewWriter(file)
|
||||
// Write buffer back to file
|
||||
nWritten, err := buffered.Write(buf.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = buffered.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Truncate file
|
||||
if err = file.Truncate(int64(nWritten)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check bytes written == buffer size
|
||||
if nWritten != buf.Len() {
|
||||
return io.ErrShortWrite
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseLogTime(line []byte) (t time.Time) {
|
||||
if len(line) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var start, end int
|
||||
const jsonStart = len(`{"time":"`)
|
||||
const jsonEnd = jsonStart + len(LogTimeFormat)
|
||||
|
||||
if len(line) == '{' { // possibly json log
|
||||
start = jsonStart
|
||||
end = jsonEnd
|
||||
} else { // possibly common or combined format
|
||||
// Format: <virtual host> <host ip> - - [02/Jan/2006:15:04:05 -0700] ...
|
||||
start = bytes.IndexRune(line, '[')
|
||||
end = bytes.IndexRune(line[start+1:], ']')
|
||||
if start == -1 || end == -1 || start >= end {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
timeStr := line[start+1 : end]
|
||||
t, _ = time.Parse(LogTimeFormat, string(timeStr)) // ignore error
|
||||
return
|
||||
}
|
||||
81
internal/net/gphttp/accesslog/retention_test.go
Normal file
81
internal/net/gphttp/accesslog/retention_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package accesslog_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestParseRetention(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected *Retention
|
||||
shouldErr bool
|
||||
}{
|
||||
{"30 days", &Retention{Days: 30}, false},
|
||||
{"2 weeks", &Retention{Days: 14}, false},
|
||||
{"last 5", &Retention{Last: 5}, false},
|
||||
{"invalid input", &Retention{}, true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
r := &Retention{}
|
||||
err := r.Parse(test.input)
|
||||
if !test.shouldErr {
|
||||
ExpectNoError(t, err)
|
||||
} else {
|
||||
ExpectDeepEqual(t, r, test.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetentionCommonFormat(t *testing.T) {
|
||||
var file MockFile
|
||||
logger := NewAccessLogger(task.RootTask("test", false), &file, &Config{
|
||||
Format: FormatCommon,
|
||||
BufferSize: 1024,
|
||||
})
|
||||
for range 10 {
|
||||
logger.Log(req, resp)
|
||||
}
|
||||
logger.Flush(true)
|
||||
// test.Finish(nil)
|
||||
|
||||
ExpectEqual(t, logger.Config().Retention, nil)
|
||||
ExpectTrue(t, file.Len() > 0)
|
||||
ExpectEqual(t, file.Count(), 10)
|
||||
|
||||
t.Run("keep last", func(t *testing.T) {
|
||||
logger.Config().Retention = strutils.MustParse[*Retention]("last 5")
|
||||
ExpectEqual(t, logger.Config().Retention.Days, 0)
|
||||
ExpectEqual(t, logger.Config().Retention.Last, 5)
|
||||
ExpectNoError(t, logger.Rotate())
|
||||
ExpectEqual(t, file.Count(), 5)
|
||||
})
|
||||
|
||||
_ = file.Truncate(0)
|
||||
|
||||
timeNow := time.Now()
|
||||
for i := range 10 {
|
||||
logger.Formatter.(*CommonFormatter).GetTimeNow = func() time.Time {
|
||||
return timeNow.AddDate(0, 0, -i)
|
||||
}
|
||||
logger.Log(req, resp)
|
||||
}
|
||||
logger.Flush(true)
|
||||
|
||||
// FIXME: keep days does not work
|
||||
t.Run("keep days", func(t *testing.T) {
|
||||
logger.Config().Retention = strutils.MustParse[*Retention]("3 days")
|
||||
ExpectEqual(t, logger.Config().Retention.Days, 3)
|
||||
ExpectEqual(t, logger.Config().Retention.Last, 0)
|
||||
ExpectNoError(t, logger.Rotate())
|
||||
ExpectEqual(t, file.Count(), 3)
|
||||
})
|
||||
}
|
||||
52
internal/net/gphttp/accesslog/status_code_range.go
Normal file
52
internal/net/gphttp/accesslog/status_code_range.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type StatusCodeRange struct {
|
||||
Start int
|
||||
End int
|
||||
}
|
||||
|
||||
var ErrInvalidStatusCodeRange = gperr.New("invalid status code range")
|
||||
|
||||
func (r *StatusCodeRange) Includes(code int) bool {
|
||||
return r.Start <= code && code <= r.End
|
||||
}
|
||||
|
||||
// Parse implements strutils.Parser.
|
||||
func (r *StatusCodeRange) Parse(v string) error {
|
||||
split := strutils.SplitRune(v, '-')
|
||||
switch len(split) {
|
||||
case 1:
|
||||
start, err := strconv.Atoi(split[0])
|
||||
if err != nil {
|
||||
return gperr.Wrap(err)
|
||||
}
|
||||
r.Start = start
|
||||
r.End = start
|
||||
return nil
|
||||
case 2:
|
||||
start, errStart := strconv.Atoi(split[0])
|
||||
end, errEnd := strconv.Atoi(split[1])
|
||||
if err := gperr.Join(errStart, errEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
r.Start = start
|
||||
r.End = end
|
||||
return nil
|
||||
default:
|
||||
return ErrInvalidStatusCodeRange.Subject(v)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *StatusCodeRange) String() string {
|
||||
if r.Start == r.End {
|
||||
return strconv.Itoa(r.Start)
|
||||
}
|
||||
return strconv.Itoa(r.Start) + "-" + strconv.Itoa(r.End)
|
||||
}
|
||||
46
internal/net/gphttp/body.go
Normal file
46
internal/net/gphttp/body.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package gphttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
)
|
||||
|
||||
func WriteBody(w http.ResponseWriter, body []byte) {
|
||||
if _, err := w.Write(body); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, http.ErrHandlerTimeout),
|
||||
errors.Is(err, context.DeadlineExceeded):
|
||||
logging.Err(err).Msg("timeout writing body")
|
||||
default:
|
||||
logging.Err(err).Msg("failed to write body")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) (canProceed bool) {
|
||||
if len(code) > 0 {
|
||||
w.WriteHeader(code[0])
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var err error
|
||||
|
||||
switch data := data.(type) {
|
||||
case string:
|
||||
_, err = w.Write([]byte(fmt.Sprintf("%q", data)))
|
||||
case []byte:
|
||||
panic("use WriteBody instead")
|
||||
default:
|
||||
err = json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
LogError(r).Err(err).Msg("failed to encode json")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
78
internal/net/gphttp/content_type.go
Normal file
78
internal/net/gphttp/content_type.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package gphttp
|
||||
|
||||
import (
|
||||
"mime"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type (
|
||||
ContentType string
|
||||
AcceptContentType []ContentType
|
||||
)
|
||||
|
||||
func GetContentType(h http.Header) ContentType {
|
||||
ct := h.Get("Content-Type")
|
||||
if ct == "" {
|
||||
return ""
|
||||
}
|
||||
ct, _, err := mime.ParseMediaType(ct)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return ContentType(ct)
|
||||
}
|
||||
|
||||
func GetAccept(h http.Header) AcceptContentType {
|
||||
var accepts []ContentType
|
||||
for _, v := range h["Accept"] {
|
||||
ct, _, err := mime.ParseMediaType(v)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
accepts = append(accepts, ContentType(ct))
|
||||
}
|
||||
return accepts
|
||||
}
|
||||
|
||||
func (ct ContentType) IsHTML() bool {
|
||||
return ct == "text/html" || ct == "application/xhtml+xml"
|
||||
}
|
||||
|
||||
func (ct ContentType) IsJSON() bool {
|
||||
return ct == "application/json"
|
||||
}
|
||||
|
||||
func (ct ContentType) IsPlainText() bool {
|
||||
return ct == "text/plain"
|
||||
}
|
||||
|
||||
func (act AcceptContentType) IsEmpty() bool {
|
||||
return len(act) == 0
|
||||
}
|
||||
|
||||
func (act AcceptContentType) AcceptHTML() bool {
|
||||
for _, v := range act {
|
||||
if v.IsHTML() || v == "text/*" || v == "*/*" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (act AcceptContentType) AcceptJSON() bool {
|
||||
for _, v := range act {
|
||||
if v.IsJSON() || v == "*/*" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (act AcceptContentType) AcceptPlainText() bool {
|
||||
for _, v := range act {
|
||||
if v.IsPlainText() || v == "text/*" || v == "*/*" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
41
internal/net/gphttp/content_type_test.go
Normal file
41
internal/net/gphttp/content_type_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package gphttp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestContentTypes(t *testing.T) {
|
||||
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsHTML())
|
||||
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/html; charset=utf-8"}}).IsHTML())
|
||||
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/xhtml+xml"}}).IsHTML())
|
||||
ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/plain"}}).IsHTML())
|
||||
|
||||
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/json"}}).IsJSON())
|
||||
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/json; charset=utf-8"}}).IsJSON())
|
||||
ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsJSON())
|
||||
|
||||
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/plain"}}).IsPlainText())
|
||||
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/plain; charset=utf-8"}}).IsPlainText())
|
||||
ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsPlainText())
|
||||
}
|
||||
|
||||
func TestAcceptContentTypes(t *testing.T) {
|
||||
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain"}}).AcceptPlainText())
|
||||
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain; charset=utf-8"}}).AcceptPlainText())
|
||||
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain"}}).AcceptHTML())
|
||||
ExpectTrue(t, GetAccept(http.Header{"Accept": {"application/json"}}).AcceptJSON())
|
||||
ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptPlainText())
|
||||
ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptHTML())
|
||||
ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptJSON())
|
||||
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptPlainText())
|
||||
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptHTML())
|
||||
|
||||
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/plain"}}).AcceptHTML())
|
||||
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/plain; charset=utf-8"}}).AcceptHTML())
|
||||
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/html"}}).AcceptPlainText())
|
||||
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/html"}}).AcceptJSON())
|
||||
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptJSON())
|
||||
}
|
||||
27
internal/net/gphttp/default_client.go
Normal file
27
internal/net/gphttp/default_client.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package gphttp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
httpClient = &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
ForceAttemptHTTP2: false,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 3 * time.Second,
|
||||
KeepAlive: 60 * time.Second, // this is different from DisableKeepAlives
|
||||
}).DialContext,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
|
||||
Get = httpClient.Get
|
||||
Post = httpClient.Post
|
||||
Head = httpClient.Head
|
||||
)
|
||||
95
internal/net/gphttp/error.go
Normal file
95
internal/net/gphttp/error.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package gphttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"syscall"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
)
|
||||
|
||||
// ServerError is for handling server errors.
|
||||
//
|
||||
// It logs the error and returns http.StatusInternalServerError to the client.
|
||||
// Status code can be specified as an argument.
|
||||
func ServerError(w http.ResponseWriter, r *http.Request, err error, code ...int) {
|
||||
switch {
|
||||
case err == nil,
|
||||
errors.Is(err, context.Canceled),
|
||||
errors.Is(err, syscall.EPIPE),
|
||||
errors.Is(err, syscall.ECONNRESET):
|
||||
return
|
||||
}
|
||||
LogError(r).Msg(err.Error())
|
||||
if httpheaders.IsWebsocket(r.Header) {
|
||||
return
|
||||
}
|
||||
if len(code) == 0 {
|
||||
code = []int{http.StatusInternalServerError}
|
||||
}
|
||||
http.Error(w, http.StatusText(code[0]), code[0])
|
||||
}
|
||||
|
||||
// ClientError is for responding to client errors.
|
||||
//
|
||||
// It returns http.StatusBadRequest with reason to the client.
|
||||
// Status code can be specified as an argument.
|
||||
//
|
||||
// For JSON marshallable errors (e.g. gperr.Error), it returns the error details as JSON.
|
||||
// Otherwise, it returns the error details as plain text.
|
||||
func ClientError(w http.ResponseWriter, err error, code ...int) {
|
||||
if len(code) == 0 {
|
||||
code = []int{http.StatusBadRequest}
|
||||
}
|
||||
if gperr.IsJSONMarshallable(err) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(err)
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
}
|
||||
http.Error(w, err.Error(), code[0])
|
||||
}
|
||||
|
||||
// JSONError returns a JSON response of gperr.Error with the given status code.
|
||||
func JSONError(w http.ResponseWriter, err gperr.Error, code int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(err)
|
||||
http.Error(w, err.Error(), code)
|
||||
}
|
||||
|
||||
// BadRequest returns a Bad Request response with the given error message.
|
||||
func BadRequest(w http.ResponseWriter, err string, code ...int) {
|
||||
if len(code) == 0 {
|
||||
code = []int{http.StatusBadRequest}
|
||||
}
|
||||
http.Error(w, err, code[0])
|
||||
}
|
||||
|
||||
// Unauthorized returns an Unauthorized response with the given error message.
|
||||
func Unauthorized(w http.ResponseWriter, err string) {
|
||||
BadRequest(w, err, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// NotFound returns a Not Found response with the given error message.
|
||||
func NotFound(w http.ResponseWriter, err string) {
|
||||
BadRequest(w, err, http.StatusNotFound)
|
||||
}
|
||||
|
||||
func ErrMissingKey(k string) error {
|
||||
return gperr.New(k + " is required")
|
||||
}
|
||||
|
||||
func ErrInvalidKey(k string) error {
|
||||
return gperr.New(k + " is invalid")
|
||||
}
|
||||
|
||||
func ErrAlreadyExists(k, v string) error {
|
||||
return gperr.Errorf("%s %q already exists", k, v)
|
||||
}
|
||||
|
||||
func ErrNotFound(k, v string) error {
|
||||
return gperr.Errorf("%s %q not found", k, v)
|
||||
}
|
||||
86
internal/net/gphttp/gpwebsocket/utils.go
Normal file
86
internal/net/gphttp/gpwebsocket/utils.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package gpwebsocket
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
)
|
||||
|
||||
func warnNoMatchDomains() {
|
||||
logging.Warn().Msg("no match domains configured, accepting websocket API request from all origins")
|
||||
}
|
||||
|
||||
var warnNoMatchDomainOnce sync.Once
|
||||
|
||||
func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
|
||||
var originPats []string
|
||||
|
||||
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
|
||||
|
||||
allowedDomains := httpheaders.WebsocketAllowedDomains(r.Header)
|
||||
if len(allowedDomains) == 0 || common.IsDebug {
|
||||
warnNoMatchDomainOnce.Do(warnNoMatchDomains)
|
||||
originPats = []string{"*"}
|
||||
} else {
|
||||
originPats = make([]string, len(allowedDomains))
|
||||
for i, domain := range allowedDomains {
|
||||
if domain[0] != '.' {
|
||||
originPats[i] = "*." + domain
|
||||
} else {
|
||||
originPats[i] = "*" + domain
|
||||
}
|
||||
}
|
||||
originPats = append(originPats, localAddresses...)
|
||||
}
|
||||
return websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||
OriginPatterns: originPats,
|
||||
})
|
||||
}
|
||||
|
||||
func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) {
|
||||
conn, err := Initiate(w, r)
|
||||
if err != nil {
|
||||
gphttp.ServerError(w, r, err)
|
||||
return
|
||||
}
|
||||
//nolint:errcheck
|
||||
defer conn.CloseNow()
|
||||
|
||||
if err := do(conn); err != nil {
|
||||
gphttp.ServerError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := do(conn); err != nil {
|
||||
gphttp.ServerError(w, r, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WriteText writes a text message to the websocket connection.
|
||||
// It returns true if the message was written successfully, false otherwise.
|
||||
// It logs an error if the message is not written successfully.
|
||||
func WriteText(r *http.Request, conn *websocket.Conn, msg string) bool {
|
||||
if err := conn.Write(r.Context(), websocket.MessageText, []byte(msg)); err != nil {
|
||||
gperr.LogError("failed to write text message", err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
7
internal/net/gphttp/httpheaders/sse.go
Normal file
7
internal/net/gphttp/httpheaders/sse.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package httpheaders
|
||||
|
||||
import "net/http"
|
||||
|
||||
func IsSSE(h http.Header) bool {
|
||||
return h.Get("Content-Type") == "text/event-stream"
|
||||
}
|
||||
119
internal/net/gphttp/httpheaders/utils.go
Normal file
119
internal/net/gphttp/httpheaders/utils.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package httpheaders
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
const (
|
||||
HeaderXForwardedMethod = "X-Forwarded-Method"
|
||||
HeaderXForwardedFor = "X-Forwarded-For"
|
||||
HeaderXForwardedProto = "X-Forwarded-Proto"
|
||||
HeaderXForwardedHost = "X-Forwarded-Host"
|
||||
HeaderXForwardedPort = "X-Forwarded-Port"
|
||||
HeaderXForwardedURI = "X-Forwarded-Uri"
|
||||
HeaderXRealIP = "X-Real-IP"
|
||||
|
||||
HeaderContentType = "Content-Type"
|
||||
HeaderContentLength = "Content-Length"
|
||||
|
||||
HeaderUpstreamName = "X-Godoxy-Upstream-Name"
|
||||
HeaderUpstreamScheme = "X-Godoxy-Upstream-Scheme"
|
||||
HeaderUpstreamHost = "X-Godoxy-Upstream-Host"
|
||||
HeaderUpstreamPort = "X-Godoxy-Upstream-Port"
|
||||
|
||||
HeaderGoDoxyCheckRedirect = "X-Godoxy-Check-Redirect"
|
||||
)
|
||||
|
||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||
// As of RFC 7230, hop-by-hop headers are required to appear in the
|
||||
// Connection header field. These are the headers defined by the
|
||||
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
|
||||
// compatibility.
|
||||
var hopHeaders = []string{
|
||||
"Connection",
|
||||
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
|
||||
func UpgradeType(h http.Header) string {
|
||||
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
|
||||
return ""
|
||||
}
|
||||
return h.Get("Upgrade")
|
||||
}
|
||||
|
||||
// RemoveHopByHopHeaders removes hop-by-hop headers.
|
||||
func RemoveHopByHopHeaders(h http.Header) {
|
||||
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
|
||||
for _, f := range h["Connection"] {
|
||||
for _, sf := range strutils.SplitComma(f) {
|
||||
if sf = textproto.TrimString(sf); sf != "" {
|
||||
h.Del(sf)
|
||||
}
|
||||
}
|
||||
}
|
||||
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
|
||||
// This behavior is superseded by the RFC 7230 Connection header, but
|
||||
// preserve it for backwards compatibility.
|
||||
for _, f := range hopHeaders {
|
||||
h.Del(f)
|
||||
}
|
||||
}
|
||||
|
||||
func RemoveHop(h http.Header) {
|
||||
reqUpType := UpgradeType(h)
|
||||
RemoveHopByHopHeaders(h)
|
||||
|
||||
if reqUpType != "" {
|
||||
h.Set("Connection", "Upgrade")
|
||||
h.Set("Upgrade", reqUpType)
|
||||
} else {
|
||||
h.Del("Connection")
|
||||
}
|
||||
}
|
||||
|
||||
func CopyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func FilterHeaders(h http.Header, allowed []string) http.Header {
|
||||
if len(allowed) == 0 {
|
||||
return h
|
||||
}
|
||||
|
||||
filtered := make(http.Header)
|
||||
|
||||
for i, header := range allowed {
|
||||
values := h.Values(header)
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
filtered[http.CanonicalHeaderKey(allowed[i])] = append([]string(nil), values...)
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
func HeaderToMap(h http.Header) map[string]string {
|
||||
result := make(map[string]string)
|
||||
for k, v := range h {
|
||||
if len(v) > 0 {
|
||||
result[k] = v[0] // Take the first value
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
21
internal/net/gphttp/httpheaders/websocket.go
Normal file
21
internal/net/gphttp/httpheaders/websocket.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package httpheaders
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
HeaderXGoDoxyWebsocketAllowedDomains = "X-GoDoxy-Websocket-Allowed-Domains"
|
||||
)
|
||||
|
||||
func WebsocketAllowedDomains(h http.Header) []string {
|
||||
return h[HeaderXGoDoxyWebsocketAllowedDomains]
|
||||
}
|
||||
|
||||
func SetWebsocketAllowedDomains(h http.Header, domains []string) {
|
||||
h[HeaderXGoDoxyWebsocketAllowedDomains] = domains
|
||||
}
|
||||
|
||||
func IsWebsocket(h http.Header) bool {
|
||||
return UpgradeType(h) == "websocket"
|
||||
}
|
||||
91
internal/net/gphttp/loadbalancer/ip_hash.go
Normal file
91
internal/net/gphttp/loadbalancer/ip_hash.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/middleware"
|
||||
)
|
||||
|
||||
type ipHash struct {
|
||||
*LoadBalancer
|
||||
|
||||
realIP *middleware.Middleware
|
||||
pool Servers
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) newIPHash() impl {
|
||||
impl := &ipHash{LoadBalancer: lb}
|
||||
if len(lb.Options) == 0 {
|
||||
return impl
|
||||
}
|
||||
var err gperr.Error
|
||||
impl.realIP, err = middleware.RealIP.New(lb.Options)
|
||||
if err != nil {
|
||||
gperr.LogError("invalid real_ip options, ignoring", err, &impl.l)
|
||||
}
|
||||
return impl
|
||||
}
|
||||
|
||||
func (impl *ipHash) OnAddServer(srv Server) {
|
||||
impl.mu.Lock()
|
||||
defer impl.mu.Unlock()
|
||||
|
||||
for i, s := range impl.pool {
|
||||
if s == srv {
|
||||
return
|
||||
}
|
||||
if s == nil {
|
||||
impl.pool[i] = srv
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
impl.pool = append(impl.pool, srv)
|
||||
}
|
||||
|
||||
func (impl *ipHash) OnRemoveServer(srv Server) {
|
||||
impl.mu.Lock()
|
||||
defer impl.mu.Unlock()
|
||||
|
||||
for i, s := range impl.pool {
|
||||
if s == srv {
|
||||
impl.pool[i] = nil
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (impl *ipHash) ServeHTTP(_ Servers, rw http.ResponseWriter, r *http.Request) {
|
||||
if impl.realIP != nil {
|
||||
impl.realIP.ModifyRequest(impl.serveHTTP, rw, r)
|
||||
} else {
|
||||
impl.serveHTTP(rw, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (impl *ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
impl.l.Err(err).Msg("invalid remote address " + r.RemoteAddr)
|
||||
return
|
||||
}
|
||||
idx := hashIP(ip) % uint32(len(impl.pool))
|
||||
|
||||
srv := impl.pool[idx]
|
||||
if srv == nil || srv.Status().Bad() {
|
||||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||
}
|
||||
srv.ServeHTTP(rw, r)
|
||||
}
|
||||
|
||||
func hashIP(ip string) uint32 {
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(ip))
|
||||
return h.Sum32()
|
||||
}
|
||||
53
internal/net/gphttp/loadbalancer/least_conn.go
Normal file
53
internal/net/gphttp/loadbalancer/least_conn.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
type leastConn struct {
|
||||
*LoadBalancer
|
||||
nConn F.Map[Server, *atomic.Int64]
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) newLeastConn() impl {
|
||||
return &leastConn{
|
||||
LoadBalancer: lb,
|
||||
nConn: F.NewMapOf[Server, *atomic.Int64](),
|
||||
}
|
||||
}
|
||||
|
||||
func (impl *leastConn) OnAddServer(srv Server) {
|
||||
impl.nConn.Store(srv, new(atomic.Int64))
|
||||
}
|
||||
|
||||
func (impl *leastConn) OnRemoveServer(srv Server) {
|
||||
impl.nConn.Delete(srv)
|
||||
}
|
||||
|
||||
func (impl *leastConn) ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request) {
|
||||
srv := srvs[0]
|
||||
minConn, ok := impl.nConn.Load(srv)
|
||||
if !ok {
|
||||
impl.l.Error().Msgf("[BUG] server %s not found", srv.Name())
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
for i := 1; i < len(srvs); i++ {
|
||||
nConn, ok := impl.nConn.Load(srvs[i])
|
||||
if !ok {
|
||||
impl.l.Error().Msgf("[BUG] server %s not found", srv.Name())
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
}
|
||||
if nConn.Load() < minConn.Load() {
|
||||
minConn = nConn
|
||||
srv = srvs[i]
|
||||
}
|
||||
}
|
||||
|
||||
minConn.Add(1)
|
||||
srv.ServeHTTP(rw, r)
|
||||
minConn.Add(-1)
|
||||
}
|
||||
314
internal/net/gphttp/loadbalancer/loadbalancer.go
Normal file
314
internal/net/gphttp/loadbalancer/loadbalancer.go
Normal file
@@ -0,0 +1,314 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
|
||||
"github.com/yusing/go-proxy/internal/route/routes"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
|
||||
)
|
||||
|
||||
// TODO: stats of each server.
|
||||
// TODO: support weighted mode.
|
||||
type (
|
||||
impl interface {
|
||||
ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request)
|
||||
OnAddServer(srv Server)
|
||||
OnRemoveServer(srv Server)
|
||||
}
|
||||
|
||||
LoadBalancer struct {
|
||||
impl
|
||||
*Config
|
||||
|
||||
task *task.Task
|
||||
|
||||
pool Pool
|
||||
poolMu sync.Mutex
|
||||
|
||||
sumWeight Weight
|
||||
startTime time.Time
|
||||
|
||||
l zerolog.Logger
|
||||
}
|
||||
)
|
||||
|
||||
const maxWeight Weight = 100
|
||||
|
||||
func New(cfg *Config) *LoadBalancer {
|
||||
lb := &LoadBalancer{
|
||||
Config: new(Config),
|
||||
pool: types.NewServerPool(),
|
||||
l: logging.With().Str("name", cfg.Link).Logger(),
|
||||
}
|
||||
lb.UpdateConfigIfNeeded(cfg)
|
||||
return lb
|
||||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (lb *LoadBalancer) Start(parent task.Parent) gperr.Error {
|
||||
lb.startTime = time.Now()
|
||||
lb.task = parent.Subtask("loadbalancer."+lb.Link, false)
|
||||
parent.OnCancel("lb_remove_route", func() {
|
||||
routes.DeleteHTTPRoute(lb.Link)
|
||||
})
|
||||
lb.task.OnFinished("cleanup", func() {
|
||||
if lb.impl != nil {
|
||||
lb.pool.RangeAll(func(k string, v Server) {
|
||||
lb.impl.OnRemoveServer(v)
|
||||
})
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Task implements task.TaskStarter.
|
||||
func (lb *LoadBalancer) Task() *task.Task {
|
||||
return lb.task
|
||||
}
|
||||
|
||||
// Finish implements task.TaskFinisher.
|
||||
func (lb *LoadBalancer) Finish(reason any) {
|
||||
lb.task.Finish(reason)
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) updateImpl() {
|
||||
switch lb.Mode {
|
||||
case types.ModeUnset, types.ModeRoundRobin:
|
||||
lb.impl = lb.newRoundRobin()
|
||||
case types.ModeLeastConn:
|
||||
lb.impl = lb.newLeastConn()
|
||||
case types.ModeIPHash:
|
||||
lb.impl = lb.newIPHash()
|
||||
default: // should happen in test only
|
||||
lb.impl = lb.newRoundRobin()
|
||||
}
|
||||
lb.pool.RangeAll(func(_ string, srv Server) {
|
||||
lb.impl.OnAddServer(srv)
|
||||
})
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
|
||||
if cfg != nil {
|
||||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
|
||||
lb.Link = cfg.Link
|
||||
|
||||
if lb.Mode == types.ModeUnset && cfg.Mode != types.ModeUnset {
|
||||
lb.Mode = cfg.Mode
|
||||
if !lb.Mode.ValidateUpdate() {
|
||||
lb.l.Error().Msgf("invalid mode %q, fallback to %q", cfg.Mode, lb.Mode)
|
||||
}
|
||||
lb.updateImpl()
|
||||
}
|
||||
|
||||
if len(lb.Options) == 0 && len(cfg.Options) > 0 {
|
||||
lb.Options = cfg.Options
|
||||
}
|
||||
}
|
||||
|
||||
if lb.impl == nil {
|
||||
lb.updateImpl()
|
||||
}
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) AddServer(srv Server) {
|
||||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
|
||||
if lb.pool.Has(srv.Name()) {
|
||||
old, _ := lb.pool.Load(srv.Name())
|
||||
lb.sumWeight -= old.Weight()
|
||||
lb.impl.OnRemoveServer(old)
|
||||
}
|
||||
lb.pool.Store(srv.Name(), srv)
|
||||
lb.sumWeight += srv.Weight()
|
||||
|
||||
lb.rebalance()
|
||||
lb.impl.OnAddServer(srv)
|
||||
|
||||
lb.l.Debug().
|
||||
Str("action", "add").
|
||||
Str("server", srv.Name()).
|
||||
Msgf("%d servers available", lb.pool.Size())
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) RemoveServer(srv Server) {
|
||||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
|
||||
if !lb.pool.Has(srv.Name()) {
|
||||
return
|
||||
}
|
||||
|
||||
lb.pool.Delete(srv.Name())
|
||||
|
||||
lb.sumWeight -= srv.Weight()
|
||||
lb.rebalance()
|
||||
lb.impl.OnRemoveServer(srv)
|
||||
|
||||
lb.l.Debug().
|
||||
Str("action", "remove").
|
||||
Str("server", srv.Name()).
|
||||
Msgf("%d servers left", lb.pool.Size())
|
||||
|
||||
if lb.pool.Size() == 0 {
|
||||
lb.task.Finish("no server left")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) rebalance() {
|
||||
if lb.sumWeight == maxWeight {
|
||||
return
|
||||
}
|
||||
|
||||
poolSize := lb.pool.Size()
|
||||
if poolSize == 0 {
|
||||
return
|
||||
}
|
||||
if lb.sumWeight == 0 { // distribute evenly
|
||||
weightEach := maxWeight / Weight(poolSize)
|
||||
remainder := maxWeight % Weight(poolSize)
|
||||
lb.pool.RangeAll(func(_ string, s Server) {
|
||||
w := weightEach
|
||||
lb.sumWeight += weightEach
|
||||
if remainder > 0 {
|
||||
w++
|
||||
remainder--
|
||||
}
|
||||
s.SetWeight(w)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// scale evenly
|
||||
scaleFactor := float64(maxWeight) / float64(lb.sumWeight)
|
||||
lb.sumWeight = 0
|
||||
|
||||
lb.pool.RangeAll(func(_ string, s Server) {
|
||||
s.SetWeight(Weight(float64(s.Weight()) * scaleFactor))
|
||||
lb.sumWeight += s.Weight()
|
||||
})
|
||||
|
||||
delta := maxWeight - lb.sumWeight
|
||||
if delta == 0 {
|
||||
return
|
||||
}
|
||||
lb.pool.Range(func(_ string, s Server) bool {
|
||||
if delta == 0 {
|
||||
return false
|
||||
}
|
||||
if delta > 0 {
|
||||
s.SetWeight(s.Weight() + 1)
|
||||
lb.sumWeight++
|
||||
delta--
|
||||
} else {
|
||||
s.SetWeight(s.Weight() - 1)
|
||||
lb.sumWeight--
|
||||
delta++
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
srvs := lb.availServers()
|
||||
if len(srvs) == 0 {
|
||||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
if r.Header.Get(httpheaders.HeaderGoDoxyCheckRedirect) != "" {
|
||||
// wake all servers
|
||||
for _, srv := range srvs {
|
||||
if err := srv.TryWake(); err != nil {
|
||||
lb.l.Warn().Err(err).
|
||||
Str("server", srv.Name()).
|
||||
Msg("failed to wake server")
|
||||
}
|
||||
}
|
||||
}
|
||||
lb.impl.ServeHTTP(srvs, rw, r)
|
||||
}
|
||||
|
||||
// MarshalJSON implements health.HealthMonitor.
|
||||
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
|
||||
extra := make(map[string]any)
|
||||
lb.pool.RangeAll(func(k string, v Server) {
|
||||
extra[v.Name()] = v
|
||||
})
|
||||
|
||||
return (&monitor.JSONRepresentation{
|
||||
Name: lb.Name(),
|
||||
Status: lb.Status(),
|
||||
Started: lb.startTime,
|
||||
Uptime: lb.Uptime(),
|
||||
Extra: map[string]any{
|
||||
"config": lb.Config,
|
||||
"pool": extra,
|
||||
},
|
||||
}).MarshalJSON()
|
||||
}
|
||||
|
||||
// Name implements health.HealthMonitor.
|
||||
func (lb *LoadBalancer) Name() string {
|
||||
return lb.Link
|
||||
}
|
||||
|
||||
// Status implements health.HealthMonitor.
|
||||
func (lb *LoadBalancer) Status() health.Status {
|
||||
if lb.pool.Size() == 0 {
|
||||
return health.StatusUnknown
|
||||
}
|
||||
|
||||
isHealthy := true
|
||||
lb.pool.Range(func(_ string, srv Server) bool {
|
||||
if srv.Status().Bad() {
|
||||
isHealthy = false
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
if !isHealthy {
|
||||
return health.StatusUnhealthy
|
||||
}
|
||||
return health.StatusHealthy
|
||||
}
|
||||
|
||||
// Uptime implements health.HealthMonitor.
|
||||
func (lb *LoadBalancer) Uptime() time.Duration {
|
||||
return time.Since(lb.startTime)
|
||||
}
|
||||
|
||||
// Latency implements health.HealthMonitor.
|
||||
func (lb *LoadBalancer) Latency() time.Duration {
|
||||
var sum time.Duration
|
||||
lb.pool.RangeAll(func(_ string, srv Server) {
|
||||
sum += srv.Latency()
|
||||
})
|
||||
return sum
|
||||
}
|
||||
|
||||
// String implements health.HealthMonitor.
|
||||
func (lb *LoadBalancer) String() string {
|
||||
return lb.Name()
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) availServers() []Server {
|
||||
avail := make([]Server, 0, lb.pool.Size())
|
||||
lb.pool.RangeAll(func(_ string, srv Server) {
|
||||
if srv.Status().Good() {
|
||||
avail = append(avail, srv)
|
||||
}
|
||||
})
|
||||
return avail
|
||||
}
|
||||
44
internal/net/gphttp/loadbalancer/loadbalancer_test.go
Normal file
44
internal/net/gphttp/loadbalancer/loadbalancer_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestRebalance(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("zero", func(t *testing.T) {
|
||||
lb := New(new(types.Config))
|
||||
for range 10 {
|
||||
lb.AddServer(types.TestNewServer(0))
|
||||
}
|
||||
lb.rebalance()
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
t.Run("less", func(t *testing.T) {
|
||||
lb := New(new(types.Config))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .3))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
|
||||
lb.rebalance()
|
||||
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
t.Run("more", func(t *testing.T) {
|
||||
lb := New(new(types.Config))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .3))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .4))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .3))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
|
||||
lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
|
||||
lb.rebalance()
|
||||
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
}
|
||||
22
internal/net/gphttp/loadbalancer/round_robin.go
Normal file
22
internal/net/gphttp/loadbalancer/round_robin.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type roundRobin struct {
|
||||
index atomic.Uint32
|
||||
}
|
||||
|
||||
func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} }
|
||||
func (lb *roundRobin) OnAddServer(srv Server) {}
|
||||
func (lb *roundRobin) OnRemoveServer(srv Server) {}
|
||||
|
||||
func (lb *roundRobin) ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request) {
|
||||
index := lb.index.Add(1) % uint32(len(srvs))
|
||||
srvs[index].ServeHTTP(rw, r)
|
||||
if lb.index.Load() >= 2*uint32(len(srvs)) {
|
||||
lb.index.Store(0)
|
||||
}
|
||||
}
|
||||
14
internal/net/gphttp/loadbalancer/types.go
Normal file
14
internal/net/gphttp/loadbalancer/types.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
|
||||
)
|
||||
|
||||
type (
|
||||
Server = types.Server
|
||||
Servers = []types.Server
|
||||
Pool = types.Pool
|
||||
Weight = types.Weight
|
||||
Config = types.Config
|
||||
Mode = types.Mode
|
||||
)
|
||||
8
internal/net/gphttp/loadbalancer/types/config.go
Normal file
8
internal/net/gphttp/loadbalancer/types/config.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package types
|
||||
|
||||
type Config struct {
|
||||
Link string `json:"link"`
|
||||
Mode Mode `json:"mode"`
|
||||
Weight Weight `json:"weight"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
}
|
||||
32
internal/net/gphttp/loadbalancer/types/mode.go
Normal file
32
internal/net/gphttp/loadbalancer/types/mode.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type Mode string
|
||||
|
||||
const (
|
||||
ModeUnset Mode = ""
|
||||
ModeRoundRobin Mode = "roundrobin"
|
||||
ModeLeastConn Mode = "leastconn"
|
||||
ModeIPHash Mode = "iphash"
|
||||
)
|
||||
|
||||
func (mode *Mode) ValidateUpdate() bool {
|
||||
switch strutils.ToLowerNoSnake(string(*mode)) {
|
||||
case "":
|
||||
return true
|
||||
case string(ModeRoundRobin):
|
||||
*mode = ModeRoundRobin
|
||||
return true
|
||||
case string(ModeLeastConn):
|
||||
*mode = ModeLeastConn
|
||||
return true
|
||||
case string(ModeIPHash):
|
||||
*mode = ModeIPHash
|
||||
return true
|
||||
}
|
||||
*mode = ModeRoundRobin
|
||||
return false
|
||||
}
|
||||
86
internal/net/gphttp/loadbalancer/types/server.go
Normal file
86
internal/net/gphttp/loadbalancer/types/server.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
type (
|
||||
server struct {
|
||||
_ U.NoCopy
|
||||
|
||||
name string
|
||||
url *net.URL
|
||||
weight Weight
|
||||
|
||||
http.Handler `json:"-"`
|
||||
health.HealthMonitor
|
||||
}
|
||||
|
||||
Server interface {
|
||||
http.Handler
|
||||
health.HealthMonitor
|
||||
Name() string
|
||||
URL() *net.URL
|
||||
Weight() Weight
|
||||
SetWeight(weight Weight)
|
||||
TryWake() error
|
||||
}
|
||||
|
||||
Pool = F.Map[string, Server]
|
||||
)
|
||||
|
||||
var NewServerPool = F.NewMap[Pool]
|
||||
|
||||
func NewServer(name string, url *net.URL, weight Weight, handler http.Handler, healthMon health.HealthMonitor) Server {
|
||||
srv := &server{
|
||||
name: name,
|
||||
url: url,
|
||||
weight: weight,
|
||||
Handler: handler,
|
||||
HealthMonitor: healthMon,
|
||||
}
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestNewServer[T ~int | ~float32 | ~float64](weight T) Server {
|
||||
srv := &server{
|
||||
weight: Weight(weight),
|
||||
}
|
||||
return srv
|
||||
}
|
||||
|
||||
func (srv *server) Name() string {
|
||||
return srv.name
|
||||
}
|
||||
|
||||
func (srv *server) URL() *net.URL {
|
||||
return srv.url
|
||||
}
|
||||
|
||||
func (srv *server) Weight() Weight {
|
||||
return srv.weight
|
||||
}
|
||||
|
||||
func (srv *server) SetWeight(weight Weight) {
|
||||
srv.weight = weight
|
||||
}
|
||||
|
||||
func (srv *server) String() string {
|
||||
return srv.name
|
||||
}
|
||||
|
||||
func (srv *server) TryWake() error {
|
||||
waker, ok := srv.Handler.(idlewatcher.Waker)
|
||||
if ok {
|
||||
if err := waker.Wake(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
3
internal/net/gphttp/loadbalancer/types/weight.go
Normal file
3
internal/net/gphttp/loadbalancer/types/weight.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package types
|
||||
|
||||
type Weight uint16
|
||||
20
internal/net/gphttp/logging.go
Normal file
20
internal/net/gphttp/logging.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package gphttp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
)
|
||||
|
||||
func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
|
||||
return logging.WithLevel(level).
|
||||
Str("remote", r.RemoteAddr).
|
||||
Str("host", r.Host).
|
||||
Str("uri", r.Method+" "+r.RequestURI)
|
||||
}
|
||||
|
||||
func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) }
|
||||
func LogWarn(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.WarnLevel) }
|
||||
func LogInfo(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.InfoLevel) }
|
||||
func LogDebug(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.DebugLevel) }
|
||||
20
internal/net/gphttp/methods.go
Normal file
20
internal/net/gphttp/methods.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package gphttp
|
||||
|
||||
import "net/http"
|
||||
|
||||
func IsMethodValid(method string) bool {
|
||||
switch method {
|
||||
case http.MethodGet,
|
||||
http.MethodHead,
|
||||
http.MethodPost,
|
||||
http.MethodPut,
|
||||
http.MethodPatch,
|
||||
http.MethodDelete,
|
||||
http.MethodConnect,
|
||||
http.MethodOptions,
|
||||
http.MethodTrace:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
82
internal/net/gphttp/middleware/cidr_whitelist.go
Normal file
82
internal/net/gphttp/middleware/cidr_whitelist.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
type (
|
||||
cidrWhitelist struct {
|
||||
CIDRWhitelistOpts
|
||||
Tracer
|
||||
cachedAddr F.Map[string, bool] // cache for trusted IPs
|
||||
}
|
||||
CIDRWhitelistOpts struct {
|
||||
Allow []*types.CIDR `validate:"min=1"`
|
||||
StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,status_code"`
|
||||
Message string
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
CIDRWhiteList = NewMiddleware[cidrWhitelist]()
|
||||
cidrWhitelistDefaults = CIDRWhitelistOpts{
|
||||
Allow: []*types.CIDR{},
|
||||
StatusCode: http.StatusForbidden,
|
||||
Message: "IP not allowed",
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
utils.MustRegisterValidation("status_code", func(fl validator.FieldLevel) bool {
|
||||
statusCode := fl.Field().Int()
|
||||
return gphttp.IsStatusCodeValid(int(statusCode))
|
||||
})
|
||||
}
|
||||
|
||||
// setup implements MiddlewareWithSetup.
|
||||
func (wl *cidrWhitelist) setup() {
|
||||
wl.CIDRWhitelistOpts = cidrWhitelistDefaults
|
||||
wl.cachedAddr = F.NewMapOf[string, bool]()
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (wl *cidrWhitelist) before(w http.ResponseWriter, r *http.Request) bool {
|
||||
return wl.checkIP(w, r)
|
||||
}
|
||||
|
||||
// checkIP checks if the IP address is allowed.
|
||||
func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
|
||||
var allow, ok bool
|
||||
if allow, ok = wl.cachedAddr.Load(r.RemoteAddr); !ok {
|
||||
ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
ipStr = r.RemoteAddr
|
||||
}
|
||||
ip := net.ParseIP(ipStr)
|
||||
for _, cidr := range wl.CIDRWhitelistOpts.Allow {
|
||||
if cidr.Contains(ip) {
|
||||
wl.cachedAddr.Store(r.RemoteAddr, true)
|
||||
allow = true
|
||||
wl.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allow {
|
||||
wl.cachedAddr.Store(r.RemoteAddr, false)
|
||||
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.CIDRWhitelistOpts.Allow)
|
||||
}
|
||||
}
|
||||
if !allow {
|
||||
http.Error(w, wl.Message, wl.StatusCode)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
91
internal/net/gphttp/middleware/cidr_whitelist_test.go
Normal file
91
internal/net/gphttp/middleware/cidr_whitelist_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
//go:embed test_data/cidr_whitelist_test.yml
|
||||
var testCIDRWhitelistCompose []byte
|
||||
var deny, accept *Middleware
|
||||
|
||||
func TestCIDRWhitelistValidation(t *testing.T) {
|
||||
const testMessage = "test-message"
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||
"allow": []string{"192.168.2.100/32"},
|
||||
"message": testMessage,
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
_, err = CIDRWhiteList.New(OptionsRaw{
|
||||
"allow": []string{"192.168.2.100/32"},
|
||||
"message": testMessage,
|
||||
"status": 403,
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
_, err = CIDRWhiteList.New(OptionsRaw{
|
||||
"allow": []string{"192.168.2.100/32"},
|
||||
"message": testMessage,
|
||||
"status_code": 403,
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
})
|
||||
t.Run("missing allow", func(t *testing.T) {
|
||||
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||
"message": testMessage,
|
||||
})
|
||||
ExpectError(t, utils.ErrValidationError, err)
|
||||
})
|
||||
t.Run("invalid cidr", func(t *testing.T) {
|
||||
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||
"allow": []string{"192.168.2.100/123"},
|
||||
"message": testMessage,
|
||||
})
|
||||
ExpectErrorT[*net.ParseError](t, err)
|
||||
})
|
||||
t.Run("invalid status code", func(t *testing.T) {
|
||||
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||
"allow": []string{"192.168.2.100/32"},
|
||||
"status_code": 600,
|
||||
"message": testMessage,
|
||||
})
|
||||
ExpectError(t, utils.ErrValidationError, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCIDRWhitelist(t *testing.T) {
|
||||
errs := gperr.NewBuilder("")
|
||||
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
|
||||
ExpectNoError(t, errs.Error())
|
||||
deny = mids["deny@file"]
|
||||
accept = mids["accept@file"]
|
||||
if deny == nil || accept == nil {
|
||||
panic("bug occurred")
|
||||
}
|
||||
|
||||
t.Run("deny", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
for range 10 {
|
||||
result, err := newMiddlewareTest(deny, nil)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults.StatusCode)
|
||||
ExpectEqual(t, strings.TrimSpace(string(result.Data)), cidrWhitelistDefaults.Message)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("accept", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
for range 10 {
|
||||
result, err := newMiddlewareTest(accept, nil)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
127
internal/net/gphttp/middleware/cloudflare_real_ip.go
Normal file
127
internal/net/gphttp/middleware/cloudflare_real_ip.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type cloudflareRealIP struct {
|
||||
realIP realIP
|
||||
Recursive bool
|
||||
}
|
||||
|
||||
const (
|
||||
cfIPv4CIDRsEndpoint = "https://www.cloudflare.com/ips-v4"
|
||||
cfIPv6CIDRsEndpoint = "https://www.cloudflare.com/ips-v6"
|
||||
cfCIDRsUpdateInterval = time.Hour
|
||||
cfCIDRsUpdateRetryInterval = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
cfCIDRsLastUpdate time.Time
|
||||
cfCIDRsMu sync.Mutex
|
||||
)
|
||||
|
||||
var CloudflareRealIP = NewMiddleware[cloudflareRealIP]()
|
||||
|
||||
// setup implements MiddlewareWithSetup.
|
||||
func (cri *cloudflareRealIP) setup() {
|
||||
cri.realIP.RealIPOpts = RealIPOpts{
|
||||
Header: "Cf-Connecting-Ip",
|
||||
Recursive: cri.Recursive,
|
||||
}
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (cri *cloudflareRealIP) before(w http.ResponseWriter, r *http.Request) bool {
|
||||
cidrs := tryFetchCFCIDR()
|
||||
if cidrs != nil {
|
||||
cri.realIP.From = cidrs
|
||||
}
|
||||
return cri.realIP.before(w, r)
|
||||
}
|
||||
|
||||
func (cri *cloudflareRealIP) enableTrace() {
|
||||
cri.realIP.enableTrace()
|
||||
}
|
||||
|
||||
func (cri *cloudflareRealIP) getTracer() *Tracer {
|
||||
return cri.realIP.getTracer()
|
||||
}
|
||||
|
||||
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
||||
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
|
||||
return
|
||||
}
|
||||
|
||||
cfCIDRsMu.Lock()
|
||||
defer cfCIDRsMu.Unlock()
|
||||
|
||||
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
|
||||
return
|
||||
}
|
||||
|
||||
if common.IsTest {
|
||||
cfCIDRs = []*types.CIDR{
|
||||
{IP: net.IPv4(127, 0, 0, 1), Mask: net.IPv4Mask(255, 0, 0, 0)},
|
||||
{IP: net.IPv4(10, 0, 0, 0), Mask: net.IPv4Mask(255, 0, 0, 0)},
|
||||
{IP: net.IPv4(172, 16, 0, 0), Mask: net.IPv4Mask(255, 255, 0, 0)},
|
||||
{IP: net.IPv4(192, 168, 0, 0), Mask: net.IPv4Mask(255, 255, 255, 0)},
|
||||
}
|
||||
} else {
|
||||
cfCIDRs = make([]*types.CIDR, 0, 30)
|
||||
err := errors.Join(
|
||||
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, &cfCIDRs),
|
||||
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs),
|
||||
)
|
||||
if err != nil {
|
||||
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)
|
||||
logging.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
|
||||
return nil
|
||||
}
|
||||
if len(cfCIDRs) == 0 {
|
||||
logging.Warn().Msg("cloudflare CIDR range is empty")
|
||||
}
|
||||
}
|
||||
|
||||
cfCIDRsLastUpdate = time.Now()
|
||||
logging.Info().Msg("cloudflare CIDR range updated")
|
||||
return
|
||||
}
|
||||
|
||||
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
|
||||
resp, err := http.Get(endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, line := range strutils.SplitLine(string(body)) {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
_, cidr, err := net.ParseCIDR(line)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
|
||||
}
|
||||
|
||||
*cfCIDRs = append(*cfCIDRs, (*types.CIDR)(cidr))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
80
internal/net/gphttp/middleware/custom_error_page.go
Normal file
80
internal/net/gphttp/middleware/custom_error_page.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/middleware/errorpage"
|
||||
)
|
||||
|
||||
type customErrorPage struct{}
|
||||
|
||||
var CustomErrorPage = NewMiddleware[customErrorPage]()
|
||||
|
||||
const StaticFilePathPrefix = "/$gperrorpage/"
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
return !ServeStaticErrorPageFile(w, r)
|
||||
}
|
||||
|
||||
// modifyResponse implements ResponseModifier.
|
||||
func (customErrorPage) modifyResponse(resp *http.Response) error {
|
||||
// only handles non-success status code and html/plain content type
|
||||
contentType := gphttp.GetContentType(resp.Header)
|
||||
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
|
||||
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
|
||||
if ok {
|
||||
logging.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
|
||||
_, _ = io.Copy(io.Discard, resp.Body) // drain the original body
|
||||
resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
|
||||
resp.ContentLength = int64(len(errorPage))
|
||||
resp.Header.Set(httpheaders.HeaderContentLength, strconv.Itoa(len(errorPage)))
|
||||
resp.Header.Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
|
||||
} else {
|
||||
logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bool) {
|
||||
path := r.URL.Path
|
||||
if path != "" && path[0] != '/' {
|
||||
path = "/" + path
|
||||
}
|
||||
if strings.HasPrefix(path, StaticFilePathPrefix) {
|
||||
filename := path[len(StaticFilePathPrefix):]
|
||||
file, ok := errorpage.GetStaticFile(filename)
|
||||
if !ok {
|
||||
logging.Error().Msg("unable to load resource " + filename)
|
||||
return false
|
||||
}
|
||||
ext := filepath.Ext(filename)
|
||||
switch ext {
|
||||
case ".html":
|
||||
w.Header().Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
|
||||
case ".js":
|
||||
w.Header().Set(httpheaders.HeaderContentType, "application/javascript; charset=utf-8")
|
||||
case ".css":
|
||||
w.Header().Set(httpheaders.HeaderContentType, "text/css; charset=utf-8")
|
||||
default:
|
||||
logging.Error().Msgf("unexpected file type %q for %s", ext, filename)
|
||||
}
|
||||
if _, err := w.Write(file); err != nil {
|
||||
logging.Err(err).Msg("unable to write resource " + filename)
|
||||
http.Error(w, "Error page failure", http.StatusInternalServerError)
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
96
internal/net/gphttp/middleware/errorpage/error_page.go
Normal file
96
internal/net/gphttp/middleware/errorpage/error_page.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package errorpage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"sync"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
W "github.com/yusing/go-proxy/internal/watcher"
|
||||
"github.com/yusing/go-proxy/internal/watcher/events"
|
||||
)
|
||||
|
||||
const errPagesBasePath = common.ErrorPagesBasePath
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
dirWatcher W.Watcher
|
||||
fileContentMap = F.NewMapOf[string, []byte]()
|
||||
)
|
||||
|
||||
func setup() {
|
||||
t := task.RootTask("error_page", false)
|
||||
dirWatcher = W.NewDirectoryWatcher(t, errPagesBasePath)
|
||||
loadContent()
|
||||
go watchDir()
|
||||
}
|
||||
|
||||
func GetStaticFile(filename string) ([]byte, bool) {
|
||||
setupOnce.Do(setup)
|
||||
return fileContentMap.Load(filename)
|
||||
}
|
||||
|
||||
// try <statusCode>.html -> 404.html -> not ok.
|
||||
func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
|
||||
content, ok = GetStaticFile(fmt.Sprintf("%d.html", statusCode))
|
||||
if !ok && statusCode != 404 {
|
||||
return fileContentMap.Load("404.html")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func loadContent() {
|
||||
files, err := U.ListFiles(errPagesBasePath, 0)
|
||||
if err != nil {
|
||||
logging.Err(err).Msg("failed to list error page resources")
|
||||
return
|
||||
}
|
||||
for _, file := range files {
|
||||
if fileContentMap.Has(file) {
|
||||
continue
|
||||
}
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
logging.Warn().Err(err).Msgf("failed to read error page resource %s", file)
|
||||
continue
|
||||
}
|
||||
file = path.Base(file)
|
||||
logging.Info().Msgf("error page resource %s loaded", file)
|
||||
fileContentMap.Store(file, content)
|
||||
}
|
||||
}
|
||||
|
||||
func watchDir() {
|
||||
eventCh, errCh := dirWatcher.Events(task.RootContext())
|
||||
for {
|
||||
select {
|
||||
case <-task.RootContextCanceled():
|
||||
return
|
||||
case event, ok := <-eventCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
filename := event.ActorName
|
||||
switch event.Action {
|
||||
case events.ActionFileWritten:
|
||||
fileContentMap.Delete(filename)
|
||||
loadContent()
|
||||
case events.ActionFileDeleted:
|
||||
fileContentMap.Delete(filename)
|
||||
logging.Warn().Msgf("error page resource %s deleted", filename)
|
||||
case events.ActionFileRenamed:
|
||||
logging.Warn().Msgf("error page resource %s deleted", filename)
|
||||
fileContentMap.Delete(filename)
|
||||
loadContent()
|
||||
}
|
||||
case err := <-errCh:
|
||||
gperr.LogError("error watching error page directory", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package metricslogger
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/metrics"
|
||||
)
|
||||
|
||||
type MetricsLogger struct {
|
||||
ServiceName string `json:"service_name"`
|
||||
}
|
||||
|
||||
func NewMetricsLogger(serviceName string) *MetricsLogger {
|
||||
return &MetricsLogger{serviceName}
|
||||
}
|
||||
|
||||
func (m *MetricsLogger) GetHandler(next http.Handler) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
m.ServeHTTP(rw, req, next.ServeHTTP)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MetricsLogger) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
|
||||
visitorIP, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err != nil {
|
||||
visitorIP = req.RemoteAddr
|
||||
}
|
||||
|
||||
// req.RemoteAddr had been modified by middleware (if any)
|
||||
lbls := &metrics.HTTPRouteMetricLabels{
|
||||
Service: m.ServiceName,
|
||||
Method: req.Method,
|
||||
Host: req.Host,
|
||||
Visitor: visitorIP,
|
||||
Path: req.URL.Path,
|
||||
}
|
||||
|
||||
next.ServeHTTP(newHTTPMetricLogger(rw, lbls), req)
|
||||
}
|
||||
|
||||
func (m *MetricsLogger) ResetMetrics() {
|
||||
metrics.GetRouteMetrics().UnregisterService(m.ServiceName)
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package metricslogger
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/metrics"
|
||||
)
|
||||
|
||||
type httpMetricLogger struct {
|
||||
http.ResponseWriter
|
||||
timestamp time.Time
|
||||
labels *metrics.HTTPRouteMetricLabels
|
||||
}
|
||||
|
||||
// WriteHeader implements http.ResponseWriter.
|
||||
func (l *httpMetricLogger) WriteHeader(status int) {
|
||||
l.ResponseWriter.WriteHeader(status)
|
||||
duration := time.Since(l.timestamp)
|
||||
go func() {
|
||||
m := metrics.GetRouteMetrics()
|
||||
m.HTTPReqTotal.Inc()
|
||||
m.HTTPReqElapsed.With(l.labels).Set(float64(duration.Milliseconds()))
|
||||
|
||||
// ignore 1xx
|
||||
switch {
|
||||
case status >= 500:
|
||||
m.HTTP5xx.With(l.labels).Inc()
|
||||
case status >= 400:
|
||||
m.HTTP4xx.With(l.labels).Inc()
|
||||
case status >= 200:
|
||||
m.HTTP2xx3xx.With(l.labels).Inc()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (l *httpMetricLogger) Unwrap() http.ResponseWriter {
|
||||
return l.ResponseWriter
|
||||
}
|
||||
|
||||
func newHTTPMetricLogger(w http.ResponseWriter, labels *metrics.HTTPRouteMetricLabels) *httpMetricLogger {
|
||||
return &httpMetricLogger{
|
||||
ResponseWriter: w,
|
||||
timestamp: time.Now(),
|
||||
labels: labels,
|
||||
}
|
||||
}
|
||||
237
internal/net/gphttp/middleware/middleware.go
Normal file
237
internal/net/gphttp/middleware/middleware.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type (
|
||||
Error = gperr.Error
|
||||
|
||||
ReverseProxy = reverseproxy.ReverseProxy
|
||||
ProxyRequest = reverseproxy.ProxyRequest
|
||||
|
||||
ImplNewFunc = func() any
|
||||
OptionsRaw = map[string]any
|
||||
|
||||
Middleware struct {
|
||||
name string
|
||||
construct ImplNewFunc
|
||||
impl any
|
||||
// priority is only applied for ReverseProxy.
|
||||
//
|
||||
// Middleware compose follows the order of the slice
|
||||
//
|
||||
// Default is 10, 0 is the highest
|
||||
priority int
|
||||
}
|
||||
ByPriority []*Middleware
|
||||
|
||||
RequestModifier interface {
|
||||
before(w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||
}
|
||||
ResponseModifier interface{ modifyResponse(r *http.Response) error }
|
||||
MiddlewareWithSetup interface{ setup() }
|
||||
MiddlewareFinalizer interface{ finalize() }
|
||||
MiddlewareFinalizerWithError interface {
|
||||
finalize() error
|
||||
}
|
||||
MiddlewareWithTracer interface {
|
||||
enableTrace()
|
||||
getTracer() *Tracer
|
||||
}
|
||||
)
|
||||
|
||||
const DefaultPriority = 10
|
||||
|
||||
func (m ByPriority) Len() int { return len(m) }
|
||||
func (m ByPriority) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
|
||||
func (m ByPriority) Less(i, j int) bool { return m[i].priority < m[j].priority }
|
||||
|
||||
func NewMiddleware[ImplType any]() *Middleware {
|
||||
// type check
|
||||
t := any(new(ImplType))
|
||||
switch t.(type) {
|
||||
case RequestModifier:
|
||||
case ResponseModifier:
|
||||
default:
|
||||
panic("must implement RequestModifier or ResponseModifier")
|
||||
}
|
||||
_, hasFinializer := t.(MiddlewareFinalizer)
|
||||
_, hasFinializerWithError := t.(MiddlewareFinalizerWithError)
|
||||
if hasFinializer && hasFinializerWithError {
|
||||
panic("MiddlewareFinalizer and MiddlewareFinalizerWithError are mutually exclusive")
|
||||
}
|
||||
return &Middleware{
|
||||
name: strings.ToLower(reflect.TypeFor[ImplType]().Name()),
|
||||
construct: func() any { return new(ImplType) },
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) enableTrace() {
|
||||
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
||||
tracer.enableTrace()
|
||||
logging.Debug().Msgf("middleware %s enabled trace", m.name)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) getTracer() *Tracer {
|
||||
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
||||
return tracer.getTracer()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Middleware) setParent(parent *Middleware) {
|
||||
if tracer := m.getTracer(); tracer != nil {
|
||||
tracer.SetParent(parent.getTracer())
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) setup() {
|
||||
if setup, ok := m.impl.(MiddlewareWithSetup); ok {
|
||||
setup.setup()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error {
|
||||
if len(optsRaw) == 0 {
|
||||
return nil
|
||||
}
|
||||
priority, ok := optsRaw["priority"].(int)
|
||||
if ok {
|
||||
m.priority = priority
|
||||
// remove priority for deserialization, restore later
|
||||
delete(optsRaw, "priority")
|
||||
defer func() {
|
||||
optsRaw["priority"] = priority
|
||||
}()
|
||||
} else {
|
||||
m.priority = DefaultPriority
|
||||
}
|
||||
return utils.Deserialize(optsRaw, m.impl)
|
||||
}
|
||||
|
||||
func (m *Middleware) finalize() error {
|
||||
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
|
||||
finalizer.finalize()
|
||||
return nil
|
||||
}
|
||||
if finalizer, ok := m.impl.(MiddlewareFinalizerWithError); ok {
|
||||
return finalizer.finalize()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, gperr.Error) {
|
||||
if m.construct == nil { // likely a middleware from compose
|
||||
if len(optsRaw) != 0 {
|
||||
return nil, gperr.New("additional options not allowed for middleware ").Subject(m.name)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
mid := &Middleware{name: m.name, impl: m.construct()}
|
||||
mid.setup()
|
||||
if err := mid.apply(optsRaw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := mid.finalize(); err != nil {
|
||||
return nil, gperr.Wrap(err)
|
||||
}
|
||||
return mid, nil
|
||||
}
|
||||
|
||||
func (m *Middleware) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *Middleware) String() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *Middleware) MarshalJSON() ([]byte, error) {
|
||||
return json.MarshalIndent(map[string]any{
|
||||
"name": m.name,
|
||||
"options": m.impl,
|
||||
"priority": m.priority,
|
||||
}, "", " ")
|
||||
}
|
||||
|
||||
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
||||
if exec, ok := m.impl.(RequestModifier); ok {
|
||||
if proceed := exec.before(w, r); !proceed {
|
||||
return
|
||||
}
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
|
||||
func (m *Middleware) ModifyResponse(resp *http.Response) error {
|
||||
if exec, ok := m.impl.(ResponseModifier); ok {
|
||||
return exec.modifyResponse(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
||||
if exec, ok := m.impl.(ResponseModifier); ok {
|
||||
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
|
||||
return exec.modifyResponse(resp)
|
||||
})
|
||||
}
|
||||
if exec, ok := m.impl.(RequestModifier); ok {
|
||||
if proceed := exec.before(w, r); !proceed {
|
||||
return
|
||||
}
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
|
||||
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err gperr.Error) {
|
||||
var middlewares []*Middleware
|
||||
middlewares, err = compileMiddlewares(middlewaresMap)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
patchReverseProxy(rp, middlewares)
|
||||
return
|
||||
}
|
||||
|
||||
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
||||
sort.Sort(ByPriority(middlewares))
|
||||
middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
|
||||
|
||||
mid := NewMiddlewareChain(rp.TargetName, middlewares)
|
||||
|
||||
if before, ok := mid.impl.(RequestModifier); ok {
|
||||
next := rp.HandlerFunc
|
||||
rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
|
||||
if proceed := before.before(w, r); proceed {
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if mr, ok := mid.impl.(ResponseModifier); ok {
|
||||
if rp.ModifyResponse != nil {
|
||||
ori := rp.ModifyResponse
|
||||
rp.ModifyResponse = func(res *http.Response) error {
|
||||
if err := mr.modifyResponse(res); err != nil {
|
||||
return err
|
||||
}
|
||||
return ori(res)
|
||||
}
|
||||
} else {
|
||||
rp.ModifyResponse = mr.modifyResponse
|
||||
}
|
||||
}
|
||||
}
|
||||
107
internal/net/gphttp/middleware/middleware_builder.go
Normal file
107
internal/net/gphttp/middleware/middleware_builder.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"sort"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var ErrMissingMiddlewareUse = gperr.New("missing middleware 'use' field")
|
||||
|
||||
func BuildMiddlewaresFromComposeFile(filePath string, eb *gperr.Builder) map[string]*Middleware {
|
||||
fileContent, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
eb.Add(err)
|
||||
return nil
|
||||
}
|
||||
return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
|
||||
}
|
||||
|
||||
func BuildMiddlewaresFromYAML(source string, data []byte, eb *gperr.Builder) map[string]*Middleware {
|
||||
var rawMap map[string][]map[string]any
|
||||
err := yaml.Unmarshal(data, &rawMap)
|
||||
if err != nil {
|
||||
eb.Add(err)
|
||||
return nil
|
||||
}
|
||||
middlewares := make(map[string]*Middleware)
|
||||
for name, defs := range rawMap {
|
||||
chain, err := BuildMiddlewareFromChainRaw(name, defs)
|
||||
if err != nil {
|
||||
eb.Add(err.Subject(source))
|
||||
} else {
|
||||
middlewares[name+"@file"] = chain
|
||||
}
|
||||
}
|
||||
return middlewares
|
||||
}
|
||||
|
||||
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, gperr.Error) {
|
||||
middlewares := make([]*Middleware, 0, len(middlewaresMap))
|
||||
|
||||
errs := gperr.NewBuilder("middlewares compile error")
|
||||
invalidOpts := gperr.NewBuilder("options compile error")
|
||||
|
||||
for name, opts := range middlewaresMap {
|
||||
m, err := Get(name)
|
||||
if err != nil {
|
||||
errs.Add(err)
|
||||
continue
|
||||
}
|
||||
|
||||
m, err = m.New(opts)
|
||||
if err != nil {
|
||||
invalidOpts.Add(err.Subject(name))
|
||||
continue
|
||||
}
|
||||
middlewares = append(middlewares, m)
|
||||
}
|
||||
|
||||
if invalidOpts.HasError() {
|
||||
errs.Add(invalidOpts.Error())
|
||||
}
|
||||
sort.Sort(ByPriority(middlewares))
|
||||
return middlewares, errs.Error()
|
||||
}
|
||||
|
||||
func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (*Middleware, gperr.Error) {
|
||||
compiled, err := compileMiddlewares(middlewaresMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewMiddlewareChain(name, compiled), nil
|
||||
}
|
||||
|
||||
// TODO: check conflict or duplicates.
|
||||
func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middleware, gperr.Error) {
|
||||
chainErr := gperr.NewBuilder("")
|
||||
chain := make([]*Middleware, 0, len(defs))
|
||||
for i, def := range defs {
|
||||
if def["use"] == nil || def["use"] == "" {
|
||||
chainErr.Add(ErrMissingMiddlewareUse.Subjectf("%s[%d]", name, i))
|
||||
continue
|
||||
}
|
||||
baseName := def["use"].(string)
|
||||
base, err := Get(baseName)
|
||||
if err != nil {
|
||||
chainErr.Add(err.Subjectf("%s[%d]", name, i))
|
||||
continue
|
||||
}
|
||||
delete(def, "use")
|
||||
m, err := base.New(def)
|
||||
if err != nil {
|
||||
chainErr.Add(err.Subjectf("%s[%d]", name, i))
|
||||
continue
|
||||
}
|
||||
m.name = fmt.Sprintf("%s[%d]", name, i)
|
||||
chain = append(chain, m)
|
||||
}
|
||||
if chainErr.HasError() {
|
||||
return nil, chainErr.Error()
|
||||
}
|
||||
return NewMiddlewareChain(name, chain), nil
|
||||
}
|
||||
22
internal/net/gphttp/middleware/middleware_builder_test.go
Normal file
22
internal/net/gphttp/middleware/middleware_builder_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
//go:embed test_data/middleware_compose.yml
|
||||
var testMiddlewareCompose []byte
|
||||
|
||||
func TestBuild(t *testing.T) {
|
||||
errs := gperr.NewBuilder("")
|
||||
middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs)
|
||||
ExpectNoError(t, errs.Error())
|
||||
Must(json.MarshalIndent(middlewares, "", " "))
|
||||
// t.Log(string(data))
|
||||
// TODO: test
|
||||
}
|
||||
61
internal/net/gphttp/middleware/middleware_chain.go
Normal file
61
internal/net/gphttp/middleware/middleware_chain.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
)
|
||||
|
||||
type middlewareChain struct {
|
||||
befores []RequestModifier
|
||||
modResps []ResponseModifier
|
||||
}
|
||||
|
||||
// TODO: check conflict or duplicates.
|
||||
func NewMiddlewareChain(name string, chain []*Middleware) *Middleware {
|
||||
chainMid := &middlewareChain{befores: []RequestModifier{}, modResps: []ResponseModifier{}}
|
||||
m := &Middleware{name: name, impl: chainMid}
|
||||
|
||||
for _, comp := range chain {
|
||||
if before, ok := comp.impl.(RequestModifier); ok {
|
||||
chainMid.befores = append(chainMid.befores, before)
|
||||
}
|
||||
if mr, ok := comp.impl.(ResponseModifier); ok {
|
||||
chainMid.modResps = append(chainMid.modResps, mr)
|
||||
}
|
||||
comp.setParent(m)
|
||||
}
|
||||
|
||||
if common.IsDebug {
|
||||
for _, child := range chain {
|
||||
child.enableTrace()
|
||||
}
|
||||
m.enableTrace()
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (m *middlewareChain) before(w http.ResponseWriter, r *http.Request) (proceedNext bool) {
|
||||
for _, b := range m.befores {
|
||||
if proceedNext = b.before(w, r); !proceedNext {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// modifyResponse implements ResponseModifier.
|
||||
func (m *middlewareChain) modifyResponse(resp *http.Response) error {
|
||||
if len(m.modResps) == 0 {
|
||||
return nil
|
||||
}
|
||||
errs := gperr.NewBuilder("modify response errors")
|
||||
for i, mr := range m.modResps {
|
||||
if err := mr.modifyResponse(resp); err != nil {
|
||||
errs.Add(gperr.Wrap(err).Subjectf("%d", i))
|
||||
}
|
||||
}
|
||||
return errs.Error()
|
||||
}
|
||||
37
internal/net/gphttp/middleware/middleware_test.go
Normal file
37
internal/net/gphttp/middleware/middleware_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
type testPriority struct {
|
||||
Value int `json:"value"`
|
||||
}
|
||||
|
||||
var test = NewMiddleware[testPriority]()
|
||||
|
||||
func (t testPriority) before(w http.ResponseWriter, r *http.Request) bool {
|
||||
w.Header().Add("Test-Value", strconv.Itoa(t.Value))
|
||||
return true
|
||||
}
|
||||
|
||||
func TestMiddlewarePriority(t *testing.T) {
|
||||
priorities := []int{4, 7, 9, 0}
|
||||
chain := make([]*Middleware, len(priorities))
|
||||
for i, p := range priorities {
|
||||
mid, err := test.New(OptionsRaw{
|
||||
"priority": p,
|
||||
"value": i,
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
chain[i] = mid
|
||||
}
|
||||
res, err := newMiddlewaresTest(chain, nil)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, strings.Join(res.ResponseHeaders["Test-Value"], ","), "3,0,1,2")
|
||||
}
|
||||
104
internal/net/gphttp/middleware/middlewares.go
Normal file
104
internal/net/gphttp/middleware/middlewares.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"path"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
// snakes and cases will be stripped on `Get`
|
||||
// so keys are lowercase without snake.
|
||||
var allMiddlewares = map[string]*Middleware{
|
||||
"redirecthttp": RedirectHTTP,
|
||||
|
||||
"oidc": OIDC,
|
||||
|
||||
"request": ModifyRequest,
|
||||
"modifyrequest": ModifyRequest,
|
||||
"response": ModifyResponse,
|
||||
"modifyresponse": ModifyResponse,
|
||||
"setxforwarded": SetXForwarded,
|
||||
"hidexforwarded": HideXForwarded,
|
||||
|
||||
"errorpage": CustomErrorPage,
|
||||
"customerrorpage": CustomErrorPage,
|
||||
|
||||
"realip": RealIP,
|
||||
"cloudflarerealip": CloudflareRealIP,
|
||||
|
||||
"cidrwhitelist": CIDRWhiteList,
|
||||
"ratelimit": RateLimiter,
|
||||
}
|
||||
|
||||
var (
|
||||
ErrUnknownMiddleware = gperr.New("unknown middleware")
|
||||
ErrDuplicatedMiddleware = gperr.New("duplicated middleware")
|
||||
)
|
||||
|
||||
func Get(name string) (*Middleware, Error) {
|
||||
middleware, ok := allMiddlewares[strutils.ToLowerNoSnake(name)]
|
||||
if !ok {
|
||||
return nil, ErrUnknownMiddleware.
|
||||
Subject(name).
|
||||
Withf(strutils.DoYouMean(utils.NearestField(name, allMiddlewares)))
|
||||
}
|
||||
return middleware, nil
|
||||
}
|
||||
|
||||
func All() map[string]*Middleware {
|
||||
return allMiddlewares
|
||||
}
|
||||
|
||||
func LoadComposeFiles() {
|
||||
errs := gperr.NewBuilder("middleware compile errors")
|
||||
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
|
||||
if err != nil {
|
||||
logging.Err(err).Msg("failed to list middleware definitions")
|
||||
return
|
||||
}
|
||||
for _, defFile := range middlewareDefs {
|
||||
voidErrs := gperr.NewBuilder("") // ignore these errors, will be added in next step
|
||||
mws := BuildMiddlewaresFromComposeFile(defFile, voidErrs)
|
||||
if len(mws) == 0 {
|
||||
continue
|
||||
}
|
||||
for name, m := range mws {
|
||||
name = strutils.ToLowerNoSnake(name)
|
||||
if _, ok := allMiddlewares[name]; ok {
|
||||
errs.Add(ErrDuplicatedMiddleware.Subject(name))
|
||||
continue
|
||||
}
|
||||
allMiddlewares[name] = m
|
||||
logging.Info().
|
||||
Str("src", path.Base(defFile)).
|
||||
Str("name", name).
|
||||
Msg("middleware loaded")
|
||||
}
|
||||
}
|
||||
// build again to resolve cross references
|
||||
for _, defFile := range middlewareDefs {
|
||||
mws := BuildMiddlewaresFromComposeFile(defFile, errs)
|
||||
if len(mws) == 0 {
|
||||
continue
|
||||
}
|
||||
for name, m := range mws {
|
||||
name = strutils.ToLowerNoSnake(name)
|
||||
if _, ok := allMiddlewares[name]; ok {
|
||||
// already loaded above
|
||||
continue
|
||||
}
|
||||
allMiddlewares[name] = m
|
||||
logging.Info().
|
||||
Str("src", path.Base(defFile)).
|
||||
Str("name", name).
|
||||
Msg("middleware loaded")
|
||||
}
|
||||
}
|
||||
if errs.HasError() {
|
||||
gperr.LogError(errs.About(), errs.Error())
|
||||
}
|
||||
}
|
||||
91
internal/net/gphttp/middleware/modify_request.go
Normal file
91
internal/net/gphttp/middleware/modify_request.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type (
|
||||
modifyRequest struct {
|
||||
ModifyRequestOpts
|
||||
Tracer
|
||||
}
|
||||
// order: add_prefix -> set_headers -> add_headers -> hide_headers
|
||||
ModifyRequestOpts struct {
|
||||
SetHeaders map[string]string
|
||||
AddHeaders map[string]string
|
||||
HideHeaders []string
|
||||
AddPrefix string
|
||||
|
||||
needVarSubstitution bool
|
||||
}
|
||||
)
|
||||
|
||||
var ModifyRequest = NewMiddleware[modifyRequest]()
|
||||
|
||||
// finalize implements MiddlewareFinalizer.
|
||||
func (mr *ModifyRequestOpts) finalize() {
|
||||
mr.checkVarSubstitution()
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (mr *modifyRequest) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
mr.AddTraceRequest("before modify request", r)
|
||||
|
||||
mr.addPrefix(r, nil, r.URL.Path)
|
||||
mr.modifyHeaders(r, nil, r.Header)
|
||||
mr.AddTraceRequest("after modify request", r)
|
||||
return true
|
||||
}
|
||||
|
||||
func (mr *ModifyRequestOpts) checkVarSubstitution() {
|
||||
for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} {
|
||||
for _, v := range m {
|
||||
if strings.ContainsRune(v, '$') {
|
||||
mr.needVarSubstitution = true
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mr *ModifyRequestOpts) modifyHeaders(req *http.Request, resp *http.Response, headers http.Header) {
|
||||
if !mr.needVarSubstitution {
|
||||
for k, v := range mr.SetHeaders {
|
||||
if req != nil && strings.EqualFold(k, "host") {
|
||||
defer func() {
|
||||
req.Host = v
|
||||
}()
|
||||
}
|
||||
headers[k] = []string{v}
|
||||
}
|
||||
for k, v := range mr.AddHeaders {
|
||||
headers[k] = append(headers[k], v)
|
||||
}
|
||||
} else {
|
||||
for k, v := range mr.SetHeaders {
|
||||
if req != nil && strings.EqualFold(k, "host") {
|
||||
defer func() {
|
||||
req.Host = varReplace(req, resp, v)
|
||||
}()
|
||||
}
|
||||
headers[k] = []string{varReplace(req, resp, v)}
|
||||
}
|
||||
for k, v := range mr.AddHeaders {
|
||||
headers[k] = append(headers[k], varReplace(req, resp, v))
|
||||
}
|
||||
}
|
||||
|
||||
for _, k := range mr.HideHeaders {
|
||||
delete(headers, k)
|
||||
}
|
||||
}
|
||||
|
||||
func (mr *modifyRequest) addPrefix(r *http.Request, _ *http.Response, path string) {
|
||||
if len(mr.AddPrefix) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
r.URL.Path = filepath.Join(mr.AddPrefix, path)
|
||||
}
|
||||
145
internal/net/gphttp/middleware/modify_request_test.go
Normal file
145
internal/net/gphttp/middleware/modify_request_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"net/http"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestModifyRequest(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"set_headers": map[string]string{
|
||||
"User-Agent": "go-proxy/v0.5.0",
|
||||
"Host": VarUpstreamAddr,
|
||||
"X-Test-Req-Method": VarRequestMethod,
|
||||
"X-Test-Req-Scheme": VarRequestScheme,
|
||||
"X-Test-Req-Host": VarRequestHost,
|
||||
"X-Test-Req-Port": VarRequestPort,
|
||||
"X-Test-Req-Addr": VarRequestAddr,
|
||||
"X-Test-Req-Path": VarRequestPath,
|
||||
"X-Test-Req-Query": VarRequestQuery,
|
||||
"X-Test-Req-Url": VarRequestURL,
|
||||
"X-Test-Req-Uri": VarRequestURI,
|
||||
"X-Test-Req-Content-Type": VarRequestContentType,
|
||||
"X-Test-Req-Content-Length": VarRequestContentLen,
|
||||
"X-Test-Remote-Host": VarRemoteHost,
|
||||
"X-Test-Remote-Port": VarRemotePort,
|
||||
"X-Test-Remote-Addr": VarRemoteAddr,
|
||||
"X-Test-Upstream-Scheme": VarUpstreamScheme,
|
||||
"X-Test-Upstream-Host": VarUpstreamHost,
|
||||
"X-Test-Upstream-Port": VarUpstreamPort,
|
||||
"X-Test-Upstream-Addr": VarUpstreamAddr,
|
||||
"X-Test-Upstream-Url": VarUpstreamURL,
|
||||
"X-Test-Header-Content-Type": "$header(Content-Type)",
|
||||
"X-Test-Arg-Arg_1": "$arg(arg_1)",
|
||||
},
|
||||
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
|
||||
"hide_headers": []string{"Accept"},
|
||||
}
|
||||
|
||||
t.Run("set_options", func(t *testing.T) {
|
||||
mr, err := ModifyRequest.New(opts)
|
||||
ExpectNoError(t, err)
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string))
|
||||
})
|
||||
|
||||
t.Run("request_headers", func(t *testing.T) {
|
||||
reqURL := types.MustParseURL("https://my.app/?arg_1=b")
|
||||
upstreamURL := types.MustParseURL("http://test.example.com")
|
||||
result, err := newMiddlewareTest(ModifyRequest, &testArgs{
|
||||
middlewareOpt: opts,
|
||||
reqURL: reqURL,
|
||||
upstreamURL: upstreamURL,
|
||||
body: bytes.Repeat([]byte("a"), 100),
|
||||
headers: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
|
||||
ExpectEqual(t, result.RequestHeaders.Get("Host"), "test.example.com")
|
||||
ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
|
||||
ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
|
||||
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Method"), "GET")
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Port"), reqURL.Port())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Path"), reqURL.Path)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Url"), reqURL.String())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Type"), "application/json")
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Length"), "100")
|
||||
|
||||
remoteHost, remotePort, _ := net.SplitHostPort(result.RemoteAddr)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Host"), remoteHost)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Port"), remotePort)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)
|
||||
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
|
||||
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Header-Content-Type"), "application/json")
|
||||
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Arg-Arg_1"), "b")
|
||||
})
|
||||
|
||||
t.Run("add_prefix", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expectedPath string
|
||||
upstreamURL string
|
||||
addPrefix string
|
||||
}{
|
||||
{
|
||||
name: "no prefix",
|
||||
path: "/foo",
|
||||
expectedPath: "/foo",
|
||||
upstreamURL: "http://test.example.com",
|
||||
},
|
||||
{
|
||||
name: "slash only",
|
||||
path: "/",
|
||||
expectedPath: "/",
|
||||
upstreamURL: "http://test.example.com",
|
||||
addPrefix: "/", // should not change anything
|
||||
},
|
||||
{
|
||||
name: "some prefix",
|
||||
path: "/test",
|
||||
expectedPath: "/foo/test",
|
||||
upstreamURL: "http://test.example.com",
|
||||
addPrefix: "/foo",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reqURL := types.MustParseURL("https://my.app" + tt.path)
|
||||
upstreamURL := types.MustParseURL(tt.upstreamURL)
|
||||
|
||||
opts["add_prefix"] = tt.addPrefix
|
||||
result, err := newMiddlewareTest(ModifyRequest, &testArgs{
|
||||
middlewareOpt: opts,
|
||||
reqURL: reqURL,
|
||||
upstreamURL: upstreamURL,
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Path"), tt.expectedPath)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
20
internal/net/gphttp/middleware/modify_response.go
Normal file
20
internal/net/gphttp/middleware/modify_response.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type modifyResponse struct {
|
||||
ModifyRequestOpts
|
||||
Tracer
|
||||
}
|
||||
|
||||
var ModifyResponse = NewMiddleware[modifyResponse]()
|
||||
|
||||
// modifyResponse implements ResponseModifier.
|
||||
func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
|
||||
mr.AddTraceResponse("before modify response", resp)
|
||||
mr.modifyHeaders(resp.Request, resp, resp.Header)
|
||||
mr.AddTraceResponse("after modify response", resp)
|
||||
return nil
|
||||
}
|
||||
108
internal/net/gphttp/middleware/modify_response_test.go
Normal file
108
internal/net/gphttp/middleware/modify_response_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"net/http"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestModifyResponse(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"set_headers": map[string]string{
|
||||
"X-Test-Resp-Status": VarRespStatusCode,
|
||||
"X-Test-Resp-Content-Type": VarRespContentType,
|
||||
"X-Test-Resp-Content-Length": VarRespContentLen,
|
||||
"X-Test-Resp-Header-Content-Type": "$resp_header(Content-Type)",
|
||||
|
||||
"X-Test-Req-Method": VarRequestMethod,
|
||||
"X-Test-Req-Scheme": VarRequestScheme,
|
||||
"X-Test-Req-Host": VarRequestHost,
|
||||
"X-Test-Req-Port": VarRequestPort,
|
||||
"X-Test-Req-Addr": VarRequestAddr,
|
||||
"X-Test-Req-Path": VarRequestPath,
|
||||
"X-Test-Req-Query": VarRequestQuery,
|
||||
"X-Test-Req-Url": VarRequestURL,
|
||||
"X-Test-Req-Uri": VarRequestURI,
|
||||
"X-Test-Req-Content-Type": VarRequestContentType,
|
||||
"X-Test-Req-Content-Length": VarRequestContentLen,
|
||||
"X-Test-Remote-Host": VarRemoteHost,
|
||||
"X-Test-Remote-Port": VarRemotePort,
|
||||
"X-Test-Remote-Addr": VarRemoteAddr,
|
||||
"X-Test-Upstream-Scheme": VarUpstreamScheme,
|
||||
"X-Test-Upstream-Host": VarUpstreamHost,
|
||||
"X-Test-Upstream-Port": VarUpstreamPort,
|
||||
"X-Test-Upstream-Addr": VarUpstreamAddr,
|
||||
"X-Test-Upstream-Url": VarUpstreamURL,
|
||||
"X-Test-Header-Content-Type": "$header(Content-Type)",
|
||||
"X-Test-Arg-Arg_1": "$arg(arg_1)",
|
||||
},
|
||||
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
|
||||
"hide_headers": []string{"Accept"},
|
||||
}
|
||||
|
||||
t.Run("set_options", func(t *testing.T) {
|
||||
mr, err := ModifyResponse.New(opts)
|
||||
ExpectNoError(t, err)
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
|
||||
})
|
||||
|
||||
t.Run("response_headers", func(t *testing.T) {
|
||||
reqURL := types.MustParseURL("https://my.app/?arg_1=b")
|
||||
upstreamURL := types.MustParseURL("http://test.example.com")
|
||||
result, err := newMiddlewareTest(ModifyResponse, &testArgs{
|
||||
middlewareOpt: opts,
|
||||
reqURL: reqURL,
|
||||
upstreamURL: upstreamURL,
|
||||
body: bytes.Repeat([]byte("a"), 100),
|
||||
headers: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
respHeaders: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
respBody: bytes.Repeat([]byte("a"), 50),
|
||||
respStatus: http.StatusOK,
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "")
|
||||
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Status"), "200")
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Type"), "application/json")
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Length"), "50")
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Header-Content-Type"), "application/json")
|
||||
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Method"), http.MethodGet)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Port"), reqURL.Port())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Path"), reqURL.Path)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Url"), reqURL.String())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Content-Type"), "application/json")
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Content-Length"), "100")
|
||||
|
||||
remoteHost, remotePort, _ := net.SplitHostPort(result.RemoteAddr)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Host"), remoteHost)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Port"), remotePort)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)
|
||||
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
|
||||
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Header-Content-Type"), "application/json")
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Arg-Arg_1"), "b")
|
||||
})
|
||||
}
|
||||
91
internal/net/gphttp/middleware/oidc.go
Normal file
91
internal/net/gphttp/middleware/oidc.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/api/v1/auth"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
)
|
||||
|
||||
type oidcMiddleware struct {
|
||||
AllowedUsers []string `json:"allowed_users"`
|
||||
AllowedGroups []string `json:"allowed_groups"`
|
||||
|
||||
auth auth.Provider
|
||||
authMux *http.ServeMux
|
||||
|
||||
isInitialized int32
|
||||
initMu sync.Mutex
|
||||
}
|
||||
|
||||
var OIDC = NewMiddleware[oidcMiddleware]()
|
||||
|
||||
func (amw *oidcMiddleware) finalize() error {
|
||||
if !auth.IsOIDCEnabled() {
|
||||
return gperr.New("OIDC not enabled but OIDC middleware is used")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (amw *oidcMiddleware) init() error {
|
||||
if atomic.LoadInt32(&amw.isInitialized) == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return amw.initSlow()
|
||||
}
|
||||
|
||||
func (amw *oidcMiddleware) initSlow() error {
|
||||
amw.initMu.Lock()
|
||||
if amw.isInitialized == 1 {
|
||||
amw.initMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
defer func() {
|
||||
amw.isInitialized = 1
|
||||
amw.initMu.Unlock()
|
||||
}()
|
||||
|
||||
authProvider, err := auth.NewOIDCProviderFromEnv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
authProvider.SetIsMiddleware(true)
|
||||
if len(amw.AllowedUsers) > 0 {
|
||||
authProvider.SetAllowedUsers(amw.AllowedUsers)
|
||||
}
|
||||
if len(amw.AllowedGroups) > 0 {
|
||||
authProvider.SetAllowedGroups(amw.AllowedGroups)
|
||||
}
|
||||
|
||||
amw.authMux = http.NewServeMux()
|
||||
amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
|
||||
amw.authMux.HandleFunc(auth.OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
})
|
||||
amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage)
|
||||
amw.auth = authProvider
|
||||
return nil
|
||||
}
|
||||
|
||||
func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
if err := amw.init(); err != nil {
|
||||
// no need to log here, main OIDC may already failed and logged
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := amw.auth.CheckToken(r); err != nil {
|
||||
amw.authMux.ServeHTTP(w, r)
|
||||
return false
|
||||
}
|
||||
if r.URL.Path == auth.OIDCLogoutPath {
|
||||
amw.auth.LogoutCallbackHandler(w, r)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
73
internal/net/gphttp/middleware/rate_limit.go
Normal file
73
internal/net/gphttp/middleware/rate_limit.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type (
|
||||
requestMap = map[string]*rate.Limiter
|
||||
rateLimiter struct {
|
||||
RateLimiterOpts
|
||||
Tracer
|
||||
|
||||
requestMap requestMap
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
RateLimiterOpts struct {
|
||||
Average int `validate:"min=1,required"`
|
||||
Burst int `validate:"min=1,required"`
|
||||
Period time.Duration `validate:"min=1s"`
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
RateLimiter = NewMiddleware[rateLimiter]()
|
||||
rateLimiterOptsDefault = RateLimiterOpts{
|
||||
Period: time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
// setup implements MiddlewareWithSetup.
|
||||
func (rl *rateLimiter) setup() {
|
||||
rl.RateLimiterOpts = rateLimiterOptsDefault
|
||||
rl.requestMap = make(requestMap, 0)
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (rl *rateLimiter) before(w http.ResponseWriter, r *http.Request) bool {
|
||||
return rl.limit(w, r)
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) newLimiter() *rate.Limiter {
|
||||
return rate.NewLimiter(rate.Limit(rl.Average)*rate.Every(rl.Period), rl.Burst)
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) limit(w http.ResponseWriter, r *http.Request) bool {
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
rl.AddTracef("unable to parse remote address %s", r.RemoteAddr)
|
||||
http.Error(w, "Internal error", http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
|
||||
rl.mu.Lock()
|
||||
limiter, ok := rl.requestMap[host]
|
||||
if !ok {
|
||||
limiter = rl.newLimiter()
|
||||
rl.requestMap[host] = limiter
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
|
||||
if limiter.Allow() {
|
||||
return true
|
||||
}
|
||||
|
||||
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
|
||||
return false
|
||||
}
|
||||
27
internal/net/gphttp/middleware/rate_limit_test.go
Normal file
27
internal/net/gphttp/middleware/rate_limit_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestRateLimit(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"average": "10",
|
||||
"burst": "10",
|
||||
"period": "1s",
|
||||
}
|
||||
|
||||
rl, err := RateLimiter.New(opts)
|
||||
ExpectNoError(t, err)
|
||||
for range 10 {
|
||||
result, err := newMiddlewareTest(rl, nil)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
}
|
||||
result, err := newMiddlewareTest(rl, nil)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusTooManyRequests)
|
||||
}
|
||||
116
internal/net/gphttp/middleware/real_ip.go
Normal file
116
internal/net/gphttp/middleware/real_ip.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
)
|
||||
|
||||
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
|
||||
|
||||
type (
|
||||
realIP struct {
|
||||
RealIPOpts
|
||||
Tracer
|
||||
}
|
||||
RealIPOpts struct {
|
||||
// Header is the name of the header to use for the real client IP
|
||||
Header string `validate:"required"`
|
||||
// From is a list of Address / CIDRs to trust
|
||||
From []*types.CIDR `validate:"required,min=1"`
|
||||
/*
|
||||
If recursive search is disabled,
|
||||
the original client address that matches one of the trusted addresses is replaced by
|
||||
the last address sent in the request header field defined by the Header field.
|
||||
If recursive search is enabled,
|
||||
the original client address that matches one of the trusted addresses is replaced by
|
||||
the last non-trusted address sent in the request header field.
|
||||
*/
|
||||
Recursive bool
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
RealIP = NewMiddleware[realIP]()
|
||||
realIPOptsDefault = RealIPOpts{
|
||||
Header: "X-Real-IP",
|
||||
From: []*types.CIDR{},
|
||||
}
|
||||
)
|
||||
|
||||
// setup implements MiddlewareWithSetup.
|
||||
func (ri *realIP) setup() {
|
||||
ri.RealIPOpts = realIPOptsDefault
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (ri *realIP) before(w http.ResponseWriter, r *http.Request) bool {
|
||||
ri.setRealIP(r)
|
||||
return true
|
||||
}
|
||||
|
||||
func (ri *realIP) isInCIDRList(ip net.IP) bool {
|
||||
for _, CIDR := range ri.From {
|
||||
if CIDR.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// not in any CIDR
|
||||
return false
|
||||
}
|
||||
|
||||
func (ri *realIP) setRealIP(req *http.Request) {
|
||||
clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err != nil {
|
||||
clientIPStr = req.RemoteAddr
|
||||
}
|
||||
|
||||
clientIP := net.ParseIP(clientIPStr)
|
||||
isTrusted := false
|
||||
|
||||
for _, CIDR := range ri.From {
|
||||
if CIDR.Contains(clientIP) {
|
||||
isTrusted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isTrusted {
|
||||
ri.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
|
||||
return
|
||||
}
|
||||
|
||||
realIPs := req.Header.Values(ri.Header)
|
||||
lastNonTrustedIP := ""
|
||||
|
||||
if len(realIPs) == 0 {
|
||||
// try non-canonical key
|
||||
realIPs = req.Header[ri.Header]
|
||||
}
|
||||
|
||||
if len(realIPs) == 0 {
|
||||
ri.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
|
||||
return
|
||||
}
|
||||
|
||||
if !ri.Recursive {
|
||||
lastNonTrustedIP = realIPs[len(realIPs)-1]
|
||||
} else {
|
||||
for _, r := range realIPs {
|
||||
if !ri.isInCIDRList(net.ParseIP(r)) {
|
||||
lastNonTrustedIP = r
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lastNonTrustedIP == "" {
|
||||
ri.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
|
||||
return
|
||||
}
|
||||
|
||||
req.RemoteAddr = lastNonTrustedIP
|
||||
req.Header.Set(ri.Header, lastNonTrustedIP)
|
||||
req.Header.Set(httpheaders.HeaderXRealIP, lastNonTrustedIP)
|
||||
ri.AddTracef("set real ip %s", lastNonTrustedIP)
|
||||
}
|
||||
77
internal/net/gphttp/middleware/real_ip_test.go
Normal file
77
internal/net/gphttp/middleware/real_ip_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestSetRealIPOpts(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"header": httpheaders.HeaderXRealIP,
|
||||
"from": []string{
|
||||
"127.0.0.0/8",
|
||||
"192.168.0.0/16",
|
||||
"172.16.0.0/12",
|
||||
},
|
||||
"recursive": true,
|
||||
}
|
||||
optExpected := &RealIPOpts{
|
||||
Header: httpheaders.HeaderXRealIP,
|
||||
From: []*types.CIDR{
|
||||
{
|
||||
IP: net.ParseIP("127.0.0.0"),
|
||||
Mask: net.IPv4Mask(255, 0, 0, 0),
|
||||
},
|
||||
{
|
||||
IP: net.ParseIP("192.168.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 0, 0),
|
||||
},
|
||||
{
|
||||
IP: net.ParseIP("172.16.0.0"),
|
||||
Mask: net.IPv4Mask(255, 240, 0, 0),
|
||||
},
|
||||
},
|
||||
Recursive: true,
|
||||
}
|
||||
|
||||
ri, err := RealIP.New(opts)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
|
||||
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
|
||||
for i, CIDR := range ri.impl.(*realIP).From {
|
||||
ExpectEqual(t, CIDR.String(), optExpected.From[i].String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetRealIP(t *testing.T) {
|
||||
const (
|
||||
testHeader = httpheaders.HeaderXRealIP
|
||||
testRealIP = "192.168.1.1"
|
||||
)
|
||||
opts := OptionsRaw{
|
||||
"header": testHeader,
|
||||
"from": []string{"0.0.0.0/0"},
|
||||
}
|
||||
optsMr := OptionsRaw{
|
||||
"set_headers": map[string]string{testHeader: testRealIP},
|
||||
}
|
||||
realip, err := RealIP.New(opts)
|
||||
ExpectNoError(t, err)
|
||||
|
||||
mr, err := ModifyRequest.New(optsMr)
|
||||
ExpectNoError(t, err)
|
||||
|
||||
mid := NewMiddlewareChain("test", []*Middleware{mr, realip})
|
||||
|
||||
result, err := newMiddlewareTest(mid, nil)
|
||||
ExpectNoError(t, err)
|
||||
t.Log(traces)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
|
||||
}
|
||||
29
internal/net/gphttp/middleware/redirect_http.go
Normal file
29
internal/net/gphttp/middleware/redirect_http.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
)
|
||||
|
||||
type redirectHTTP struct{}
|
||||
|
||||
var RedirectHTTP = NewMiddleware[redirectHTTP]()
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (redirectHTTP) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
if r.TLS != nil {
|
||||
return true
|
||||
}
|
||||
r.URL.Scheme = "https"
|
||||
host := r.Host
|
||||
if i := strings.Index(host, ":"); i != -1 {
|
||||
host = host[:i] // strip port number if present
|
||||
}
|
||||
r.URL.Host = host + ":" + common.ProxyHTTPSPort
|
||||
logging.Debug().Str("url", r.URL.String()).Msg("redirect to https")
|
||||
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
|
||||
return true
|
||||
}
|
||||
27
internal/net/gphttp/middleware/redirect_http_test.go
Normal file
27
internal/net/gphttp/middleware/redirect_http_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestRedirectToHTTPs(t *testing.T) {
|
||||
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
|
||||
reqURL: types.MustParseURL("http://example.com"),
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://example.com:"+common.ProxyHTTPSPort)
|
||||
}
|
||||
|
||||
func TestNoRedirect(t *testing.T) {
|
||||
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
|
||||
reqURL: types.MustParseURL("https://example.com"),
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
}
|
||||
37
internal/net/gphttp/middleware/set_upstream_headers.go
Normal file
37
internal/net/gphttp/middleware/set_upstream_headers.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
)
|
||||
|
||||
// internal use only.
|
||||
type setUpstreamHeaders struct {
|
||||
Name, Scheme, Host, Port string
|
||||
}
|
||||
|
||||
var suh = NewMiddleware[setUpstreamHeaders]()
|
||||
|
||||
func newSetUpstreamHeaders(rp *reverseproxy.ReverseProxy) *Middleware {
|
||||
m, err := suh.New(OptionsRaw{
|
||||
"name": rp.TargetName,
|
||||
"scheme": rp.TargetURL.Scheme,
|
||||
"host": rp.TargetURL.Hostname(),
|
||||
"port": rp.TargetURL.Port(),
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
r.Header.Set(httpheaders.HeaderUpstreamName, s.Name)
|
||||
r.Header.Set(httpheaders.HeaderUpstreamScheme, s.Scheme)
|
||||
r.Header.Set(httpheaders.HeaderUpstreamHost, s.Host)
|
||||
r.Header.Set(httpheaders.HeaderUpstreamPort, s.Port)
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
deny:
|
||||
- use: ModifyRequest
|
||||
setHeaders:
|
||||
X-Real-IP: 192.168.1.1:1234
|
||||
- use: RealIP
|
||||
header: X-Real-IP
|
||||
from:
|
||||
- 0.0.0.0/0
|
||||
- use: CIDRWhitelist
|
||||
allow:
|
||||
- 192.168.0.0/24
|
||||
accept:
|
||||
- use: ModifyRequest
|
||||
setHeaders:
|
||||
X-Real-IP: 192.168.0.1:1234
|
||||
- use: RealIP
|
||||
header: X-Real-IP
|
||||
from:
|
||||
- 0.0.0.0/0
|
||||
- use: CIDRWhitelist
|
||||
allow:
|
||||
- 192.168.0.0/24
|
||||
- 127.0.0.1
|
||||
@@ -0,0 +1,41 @@
|
||||
theGreatPretender:
|
||||
- use: HideXForwarded
|
||||
- use: ModifyRequest
|
||||
setHeaders:
|
||||
X-Real-IP: 6.6.6.6
|
||||
- use: ModifyResponse
|
||||
hideHeaders:
|
||||
- X-Test3
|
||||
- X-Test4
|
||||
|
||||
notAuthenticAuthentik:
|
||||
- use: RedirectHTTP
|
||||
- use: ForwardAuth
|
||||
address: https://authentik.company
|
||||
trustForwardHeader: true
|
||||
addAuthCookiesToResponse:
|
||||
- session_id
|
||||
- user_id
|
||||
authResponseHeaders:
|
||||
- X-Auth-SessionID
|
||||
- X-Auth-UserID
|
||||
- use: CustomErrorPage
|
||||
|
||||
realIPAuthentik:
|
||||
- use: RedirectHTTP
|
||||
- use: RealIP
|
||||
header: X-Real-IP
|
||||
from:
|
||||
- "127.0.0.0/8"
|
||||
- "192.168.0.0/16"
|
||||
- "172.16.0.0/12"
|
||||
recursive: true
|
||||
- use: ForwardAuth
|
||||
address: https://authentik.company
|
||||
trustForwardHeader: true
|
||||
|
||||
testFakeRealIP:
|
||||
- use: ModifyRequest
|
||||
setHeaders:
|
||||
CF-Connecting-IP: 127.0.0.1
|
||||
- use: CloudflareRealIP
|
||||
17
internal/net/gphttp/middleware/test_data/sample_headers.json
Normal file
17
internal/net/gphttp/middleware/test_data/sample_headers.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
|
||||
"Accept-Encoding": "gzip, deflate, br, zstd",
|
||||
"Accept-Language": "en,zh-HK;q=0.9,zh-TW;q=0.8,zh-CN;q=0.7,zh;q=0.6",
|
||||
"Dnt": "1",
|
||||
"Host": "localhost",
|
||||
"Priority": "u=0, i",
|
||||
"Sec-Ch-Ua": "\"Chromium\";v=\"129\", \"Not=A?Brand\";v=\"8\"",
|
||||
"Sec-Ch-Ua-Mobile": "?0",
|
||||
"Sec-Ch-Ua-Platform": "\"Windows\"",
|
||||
"Sec-Fetch-Dest": "document",
|
||||
"Sec-Fetch-Mode": "navigate",
|
||||
"Sec-Fetch-Site": "none",
|
||||
"Sec-Fetch-User": "?1",
|
||||
"Upgrade-Insecure-Requests": "1",
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36"
|
||||
}
|
||||
176
internal/net/gphttp/middleware/test_utils.go
Normal file
176
internal/net/gphttp/middleware/test_utils.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
//go:embed test_data/sample_headers.json
|
||||
var testHeadersRaw []byte
|
||||
var testHeaders http.Header
|
||||
|
||||
func init() {
|
||||
if !common.IsTest {
|
||||
return
|
||||
}
|
||||
tmp := map[string]string{}
|
||||
err := json.Unmarshal(testHeadersRaw, &tmp)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
testHeaders = http.Header{}
|
||||
for k, v := range tmp {
|
||||
testHeaders.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
type requestRecorder struct {
|
||||
args *testArgs
|
||||
|
||||
parent http.RoundTripper
|
||||
headers http.Header
|
||||
remoteAddr string
|
||||
}
|
||||
|
||||
func newRequestRecorder(args *testArgs) *requestRecorder {
|
||||
return &requestRecorder{args: args}
|
||||
}
|
||||
|
||||
func (rt *requestRecorder) RoundTrip(req *http.Request) (resp *http.Response, err error) {
|
||||
rt.headers = req.Header
|
||||
rt.remoteAddr = req.RemoteAddr
|
||||
if rt.parent != nil {
|
||||
resp, err = rt.parent.RoundTrip(req)
|
||||
} else {
|
||||
resp = &http.Response{
|
||||
Status: http.StatusText(rt.args.respStatus),
|
||||
StatusCode: rt.args.respStatus,
|
||||
Header: testHeaders,
|
||||
Body: io.NopCloser(bytes.NewReader(rt.args.respBody)),
|
||||
ContentLength: int64(len(rt.args.respBody)),
|
||||
Request: req,
|
||||
TLS: req.TLS,
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
for k, v := range rt.args.respHeaders {
|
||||
resp.Header[k] = v
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
type TestResult struct {
|
||||
RequestHeaders http.Header
|
||||
ResponseHeaders http.Header
|
||||
ResponseStatus int
|
||||
RemoteAddr string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
type testArgs struct {
|
||||
middlewareOpt OptionsRaw
|
||||
upstreamURL *types.URL
|
||||
|
||||
realRoundTrip bool
|
||||
|
||||
reqURL *types.URL
|
||||
reqMethod string
|
||||
headers http.Header
|
||||
body []byte
|
||||
|
||||
respHeaders http.Header
|
||||
respBody []byte
|
||||
respStatus int
|
||||
}
|
||||
|
||||
func (args *testArgs) setDefaults() {
|
||||
if args.reqURL == nil {
|
||||
args.reqURL = Must(types.ParseURL("https://example.com"))
|
||||
}
|
||||
if args.reqMethod == "" {
|
||||
args.reqMethod = http.MethodGet
|
||||
}
|
||||
if args.upstreamURL == nil {
|
||||
args.upstreamURL = Must(types.ParseURL("https://10.0.0.1:8443")) // dummy url, no actual effect
|
||||
}
|
||||
if args.respHeaders == nil {
|
||||
args.respHeaders = http.Header{}
|
||||
}
|
||||
if args.respBody == nil {
|
||||
args.respBody = []byte("OK")
|
||||
}
|
||||
if args.respStatus == 0 {
|
||||
args.respStatus = http.StatusOK
|
||||
}
|
||||
}
|
||||
|
||||
func (args *testArgs) bodyReader() io.Reader {
|
||||
if args.body != nil {
|
||||
return bytes.NewReader(args.body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, gperr.Error) {
|
||||
if args == nil {
|
||||
args = new(testArgs)
|
||||
}
|
||||
args.setDefaults()
|
||||
|
||||
mid, setOptErr := middleware.New(args.middlewareOpt)
|
||||
if setOptErr != nil {
|
||||
return nil, setOptErr
|
||||
}
|
||||
|
||||
return newMiddlewaresTest([]*Middleware{mid}, args)
|
||||
}
|
||||
|
||||
func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, gperr.Error) {
|
||||
if args == nil {
|
||||
args = new(testArgs)
|
||||
}
|
||||
args.setDefaults()
|
||||
|
||||
req := httptest.NewRequest(args.reqMethod, args.reqURL.String(), args.bodyReader())
|
||||
for k, v := range args.headers {
|
||||
req.Header[k] = v
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rr := newRequestRecorder(args)
|
||||
if args.realRoundTrip {
|
||||
rr.parent = http.DefaultTransport
|
||||
}
|
||||
|
||||
rp := reverseproxy.NewReverseProxy("test", args.upstreamURL, rr)
|
||||
patchReverseProxy(rp, middlewares)
|
||||
rp.ServeHTTP(w, req)
|
||||
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, gperr.Wrap(err)
|
||||
}
|
||||
|
||||
return &TestResult{
|
||||
RequestHeaders: rr.headers,
|
||||
ResponseHeaders: resp.Header,
|
||||
ResponseStatus: resp.StatusCode,
|
||||
RemoteAddr: rr.remoteAddr,
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
87
internal/net/gphttp/middleware/trace.go
Normal file
87
internal/net/gphttp/middleware/trace.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
)
|
||||
|
||||
type (
|
||||
Trace struct {
|
||||
Time string `json:"time,omitempty"`
|
||||
Caller string `json:"caller,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Message string `json:"msg"`
|
||||
ReqHeaders map[string]string `json:"req_headers,omitempty"`
|
||||
RespHeaders map[string]string `json:"resp_headers,omitempty"`
|
||||
RespStatus int `json:"resp_status,omitempty"`
|
||||
Additional map[string]any `json:"additional,omitempty"`
|
||||
}
|
||||
Traces []*Trace
|
||||
)
|
||||
|
||||
var (
|
||||
traces = make(Traces, 0)
|
||||
tracesMu sync.Mutex
|
||||
)
|
||||
|
||||
const MaxTraceNum = 100
|
||||
|
||||
func GetAllTrace() []*Trace {
|
||||
return traces
|
||||
}
|
||||
|
||||
func (tr *Trace) WithRequest(req *http.Request) *Trace {
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
tr.URL = req.RequestURI
|
||||
tr.ReqHeaders = httpheaders.HeaderToMap(req.Header)
|
||||
return tr
|
||||
}
|
||||
|
||||
func (tr *Trace) WithResponse(resp *http.Response) *Trace {
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
tr.URL = resp.Request.RequestURI
|
||||
tr.ReqHeaders = httpheaders.HeaderToMap(resp.Request.Header)
|
||||
tr.RespHeaders = httpheaders.HeaderToMap(resp.Header)
|
||||
tr.RespStatus = resp.StatusCode
|
||||
return tr
|
||||
}
|
||||
|
||||
func (tr *Trace) With(what string, additional any) *Trace {
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if tr.Additional == nil {
|
||||
tr.Additional = map[string]any{}
|
||||
}
|
||||
tr.Additional[what] = additional
|
||||
return tr
|
||||
}
|
||||
|
||||
func (tr *Trace) WithError(err error) *Trace {
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if tr.Additional == nil {
|
||||
tr.Additional = map[string]any{}
|
||||
}
|
||||
tr.Additional["error"] = err.Error()
|
||||
return tr
|
||||
}
|
||||
|
||||
func addTrace(t *Trace) *Trace {
|
||||
tracesMu.Lock()
|
||||
defer tracesMu.Unlock()
|
||||
if len(traces) > MaxTraceNum {
|
||||
traces = traces[1:]
|
||||
}
|
||||
traces = append(traces, t)
|
||||
return t
|
||||
}
|
||||
62
internal/net/gphttp/middleware/tracer.go
Normal file
62
internal/net/gphttp/middleware/tracer.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type Tracer struct {
|
||||
name string
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func _() {
|
||||
var _ MiddlewareWithTracer = &Tracer{}
|
||||
}
|
||||
|
||||
func (t *Tracer) enableTrace() {
|
||||
t.enabled = true
|
||||
}
|
||||
|
||||
func (t *Tracer) getTracer() *Tracer {
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Tracer) SetParent(parent *Tracer) {
|
||||
if parent == nil {
|
||||
return
|
||||
}
|
||||
t.name = parent.name + "." + t.name
|
||||
}
|
||||
|
||||
func (t *Tracer) addTrace(msg string) *Trace {
|
||||
return addTrace(&Trace{
|
||||
Time: strutils.FormatTime(time.Now()),
|
||||
Caller: t.name,
|
||||
Message: msg,
|
||||
})
|
||||
}
|
||||
|
||||
func (t *Tracer) AddTracef(msg string, args ...any) *Trace {
|
||||
if !t.enabled {
|
||||
return nil
|
||||
}
|
||||
return t.addTrace(fmt.Sprintf(msg, args...))
|
||||
}
|
||||
|
||||
func (t *Tracer) AddTraceRequest(msg string, req *http.Request) *Trace {
|
||||
if !t.enabled {
|
||||
return nil
|
||||
}
|
||||
return t.addTrace(msg).WithRequest(req)
|
||||
}
|
||||
|
||||
func (t *Tracer) AddTraceResponse(msg string, resp *http.Response) *Trace {
|
||||
if !t.enabled {
|
||||
return nil
|
||||
}
|
||||
return t.addTrace(msg).WithResponse(resp)
|
||||
}
|
||||
175
internal/net/gphttp/middleware/vars.go
Normal file
175
internal/net/gphttp/middleware/vars.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
)
|
||||
|
||||
type (
|
||||
reqVarGetter func(*http.Request) string
|
||||
respVarGetter func(*http.Response) string
|
||||
)
|
||||
|
||||
var (
|
||||
reArg = regexp.MustCompile(`\$arg\([\w-_]+\)`)
|
||||
reReqHeader = regexp.MustCompile(`\$header\([\w-]+\)`)
|
||||
reRespHeader = regexp.MustCompile(`\$resp_header\([\w-]+\)`)
|
||||
reStatic = regexp.MustCompile(`\$[\w_]+`)
|
||||
)
|
||||
|
||||
const (
|
||||
VarRequestMethod = "$req_method"
|
||||
VarRequestScheme = "$req_scheme"
|
||||
VarRequestHost = "$req_host"
|
||||
VarRequestPort = "$req_port"
|
||||
VarRequestPath = "$req_path"
|
||||
VarRequestAddr = "$req_addr"
|
||||
VarRequestQuery = "$req_query"
|
||||
VarRequestURL = "$req_url"
|
||||
VarRequestURI = "$req_uri"
|
||||
VarRequestContentType = "$req_content_type"
|
||||
VarRequestContentLen = "$req_content_length"
|
||||
VarRemoteHost = "$remote_host"
|
||||
VarRemotePort = "$remote_port"
|
||||
VarRemoteAddr = "$remote_addr"
|
||||
|
||||
VarUpstreamName = "$upstream_name"
|
||||
VarUpstreamScheme = "$upstream_scheme"
|
||||
VarUpstreamHost = "$upstream_host"
|
||||
VarUpstreamPort = "$upstream_port"
|
||||
VarUpstreamAddr = "$upstream_addr"
|
||||
VarUpstreamURL = "$upstream_url"
|
||||
|
||||
VarRespContentType = "$resp_content_type"
|
||||
VarRespContentLen = "$resp_content_length"
|
||||
VarRespStatusCode = "$status_code"
|
||||
)
|
||||
|
||||
var staticReqVarSubsMap = map[string]reqVarGetter{
|
||||
VarRequestMethod: func(req *http.Request) string { return req.Method },
|
||||
VarRequestScheme: func(req *http.Request) string {
|
||||
if req.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
return "http"
|
||||
},
|
||||
VarRequestHost: func(req *http.Request) string {
|
||||
reqHost, _, err := net.SplitHostPort(req.Host)
|
||||
if err != nil {
|
||||
return req.Host
|
||||
}
|
||||
return reqHost
|
||||
},
|
||||
VarRequestPort: func(req *http.Request) string {
|
||||
_, reqPort, _ := net.SplitHostPort(req.Host)
|
||||
return reqPort
|
||||
},
|
||||
VarRequestAddr: func(req *http.Request) string { return req.Host },
|
||||
VarRequestPath: func(req *http.Request) string { return req.URL.Path },
|
||||
VarRequestQuery: func(req *http.Request) string { return req.URL.RawQuery },
|
||||
VarRequestURL: func(req *http.Request) string { return req.URL.String() },
|
||||
VarRequestURI: func(req *http.Request) string { return req.URL.RequestURI() },
|
||||
VarRequestContentType: func(req *http.Request) string { return req.Header.Get("Content-Type") },
|
||||
VarRequestContentLen: func(req *http.Request) string { return strconv.FormatInt(req.ContentLength, 10) },
|
||||
VarRemoteHost: func(req *http.Request) string {
|
||||
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err == nil {
|
||||
return clientIP
|
||||
}
|
||||
return ""
|
||||
},
|
||||
VarRemotePort: func(req *http.Request) string {
|
||||
_, clientPort, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err == nil {
|
||||
return clientPort
|
||||
}
|
||||
return ""
|
||||
},
|
||||
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
|
||||
VarUpstreamName: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamName) },
|
||||
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamScheme) },
|
||||
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamHost) },
|
||||
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamPort) },
|
||||
VarUpstreamAddr: func(req *http.Request) string {
|
||||
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
|
||||
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
|
||||
if upPort != "" {
|
||||
return upHost + ":" + upPort
|
||||
}
|
||||
return upHost
|
||||
},
|
||||
VarUpstreamURL: func(req *http.Request) string {
|
||||
upScheme := req.Header.Get(httpheaders.HeaderUpstreamScheme)
|
||||
if upScheme == "" {
|
||||
return ""
|
||||
}
|
||||
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
|
||||
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
|
||||
upAddr := upHost
|
||||
if upPort != "" {
|
||||
upAddr += ":" + upPort
|
||||
}
|
||||
return upScheme + "://" + upAddr
|
||||
},
|
||||
}
|
||||
|
||||
var staticRespVarSubsMap = map[string]respVarGetter{
|
||||
VarRespContentType: func(resp *http.Response) string { return resp.Header.Get("Content-Type") },
|
||||
VarRespContentLen: func(resp *http.Response) string { return strconv.FormatInt(resp.ContentLength, 10) },
|
||||
VarRespStatusCode: func(resp *http.Response) string { return strconv.Itoa(resp.StatusCode) },
|
||||
}
|
||||
|
||||
func varReplace(req *http.Request, resp *http.Response, s string) string {
|
||||
if req != nil {
|
||||
// Replace query parameters
|
||||
s = reArg.ReplaceAllStringFunc(s, func(match string) string {
|
||||
name := match[5 : len(match)-1]
|
||||
for k, v := range req.URL.Query() {
|
||||
if strings.EqualFold(k, name) {
|
||||
return v[0]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
})
|
||||
|
||||
// Replace request headers
|
||||
s = reReqHeader.ReplaceAllStringFunc(s, func(match string) string {
|
||||
header := http.CanonicalHeaderKey(match[8 : len(match)-1])
|
||||
return req.Header.Get(header)
|
||||
})
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
// Replace response headers
|
||||
s = reRespHeader.ReplaceAllStringFunc(s, func(match string) string {
|
||||
header := http.CanonicalHeaderKey(match[13 : len(match)-1])
|
||||
return resp.Header.Get(header)
|
||||
})
|
||||
}
|
||||
|
||||
// Replace static variables
|
||||
if req != nil {
|
||||
s = reStatic.ReplaceAllStringFunc(s, func(match string) string {
|
||||
if fn, ok := staticReqVarSubsMap[match]; ok {
|
||||
return fn(req)
|
||||
}
|
||||
return match
|
||||
})
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
s = reStatic.ReplaceAllStringFunc(s, func(match string) string {
|
||||
if fn, ok := staticRespVarSubsMap[match]; ok {
|
||||
return fn(resp)
|
||||
}
|
||||
return match
|
||||
})
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
45
internal/net/gphttp/middleware/x_forwarded.go
Normal file
45
internal/net/gphttp/middleware/x_forwarded.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
)
|
||||
|
||||
type (
|
||||
setXForwarded struct{}
|
||||
hideXForwarded struct{}
|
||||
)
|
||||
|
||||
var (
|
||||
SetXForwarded = NewMiddleware[setXForwarded]()
|
||||
HideXForwarded = NewMiddleware[hideXForwarded]()
|
||||
)
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
r.Header.Del(httpheaders.HeaderXForwardedFor)
|
||||
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err == nil {
|
||||
r.Header.Set(httpheaders.HeaderXForwardedFor, clientIP)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (hideXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
toDelete := make([]string, 0, len(r.Header))
|
||||
for k := range r.Header {
|
||||
if strings.HasPrefix(k, "X-Forwarded-") {
|
||||
toDelete = append(toDelete, k)
|
||||
}
|
||||
}
|
||||
|
||||
for _, k := range toDelete {
|
||||
r.Header.Del(k)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
116
internal/net/gphttp/modify_response_writer.go
Normal file
116
internal/net/gphttp/modify_response_writer.go
Normal file
@@ -0,0 +1,116 @@
|
||||
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/response_modifier.go)
|
||||
// Copyright (c) 2020-2024 Traefik Labs
|
||||
|
||||
package gphttp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type (
|
||||
ModifyResponseFunc func(*http.Response) error
|
||||
ModifyResponseWriter struct {
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
|
||||
headerSent bool
|
||||
code int
|
||||
size int
|
||||
|
||||
modifier ModifyResponseFunc
|
||||
modified bool
|
||||
modifierErr error
|
||||
}
|
||||
)
|
||||
|
||||
func NewModifyResponseWriter(w http.ResponseWriter, r *http.Request, f ModifyResponseFunc) *ModifyResponseWriter {
|
||||
return &ModifyResponseWriter{
|
||||
w: w,
|
||||
r: r,
|
||||
modifier: f,
|
||||
code: http.StatusOK,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *ModifyResponseWriter) Unwrap() http.ResponseWriter {
|
||||
return w.w
|
||||
}
|
||||
|
||||
func (w *ModifyResponseWriter) StatusCode() int {
|
||||
return w.code
|
||||
}
|
||||
|
||||
func (w *ModifyResponseWriter) Size() int {
|
||||
return w.size
|
||||
}
|
||||
|
||||
func (w *ModifyResponseWriter) WriteHeader(code int) {
|
||||
if w.headerSent {
|
||||
return
|
||||
}
|
||||
|
||||
if code >= http.StatusContinue && code < http.StatusOK {
|
||||
w.w.WriteHeader(code)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
w.headerSent = true
|
||||
w.code = code
|
||||
}()
|
||||
|
||||
if w.modifier == nil || w.modified {
|
||||
w.w.WriteHeader(code)
|
||||
return
|
||||
}
|
||||
|
||||
resp := http.Response{
|
||||
StatusCode: code,
|
||||
Header: w.w.Header(),
|
||||
Request: w.r,
|
||||
ContentLength: int64(w.size),
|
||||
}
|
||||
|
||||
if err := w.modifier(&resp); err != nil {
|
||||
w.modifierErr = fmt.Errorf("response modifier error: %w", err)
|
||||
resp.Status = w.modifierErr.Error()
|
||||
w.w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.modified = true
|
||||
w.w.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (w *ModifyResponseWriter) Header() http.Header {
|
||||
return w.w.Header()
|
||||
}
|
||||
|
||||
func (w *ModifyResponseWriter) Write(b []byte) (int, error) {
|
||||
w.WriteHeader(w.code)
|
||||
if w.modifierErr != nil {
|
||||
return 0, w.modifierErr
|
||||
}
|
||||
|
||||
n, err := w.w.Write(b)
|
||||
w.size += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Hijack hijacks the connection.
|
||||
func (w *ModifyResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if h, ok := w.w.(http.Hijacker); ok {
|
||||
return h.Hijack()
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("not a hijacker: %T", w.w)
|
||||
}
|
||||
|
||||
// Flush sends any buffered data to the client.
|
||||
func (w *ModifyResponseWriter) Flush() {
|
||||
if flusher, ok := w.w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
530
internal/net/gphttp/reverseproxy/reverse_proxy_mod.go
Normal file
530
internal/net/gphttp/reverseproxy/reverse_proxy_mod.go
Normal file
@@ -0,0 +1,530 @@
|
||||
// Copyright 2011 The Go Authors.
|
||||
// Modified from the Go project under the a BSD-style License (https://cs.opensource.google/go/go/+/refs/tags/go1.23.1:src/net/http/httputil/reverseproxy.go)
|
||||
// https://cs.opensource.google/go/go/+/master:LICENSE
|
||||
|
||||
package reverseproxy
|
||||
|
||||
// This is a small mod on net/http/httputil/reverseproxy.go
|
||||
// that boosts performance in some cases
|
||||
// and compatible to other modules of this project
|
||||
// Copyright (c) 2024 yusing
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
// A ProxyRequest contains a request to be rewritten by a [ReverseProxy].
|
||||
type ProxyRequest struct {
|
||||
// In is the request received by the proxy.
|
||||
// The Rewrite function must not modify In.
|
||||
In *http.Request
|
||||
|
||||
// Out is the request which will be sent by the proxy.
|
||||
// The Rewrite function may modify or replace this request.
|
||||
// Hop-by-hop headers are removed from this request
|
||||
// before Rewrite is called.
|
||||
Out *http.Request
|
||||
}
|
||||
|
||||
// SetXForwarded sets the X-Forwarded-For, X-Forwarded-Host, and
|
||||
// X-Forwarded-Proto headers of the outbound request.
|
||||
//
|
||||
// - The X-Forwarded-For header is set to the client IP address.
|
||||
// - The X-Forwarded-Host header is set to the host name requested
|
||||
// by the client.
|
||||
// - The X-Forwarded-Proto header is set to "http" or "https", depending
|
||||
// on whether the inbound request was made on a TLS-enabled connection.
|
||||
//
|
||||
// If the outbound request contains an existing X-Forwarded-For header,
|
||||
// SetXForwarded appends the client IP address to it. To append to the
|
||||
// inbound request's X-Forwarded-For header (the default behavior of
|
||||
// [ReverseProxy] when using a Director function), copy the header
|
||||
// from the inbound request before calling SetXForwarded:
|
||||
//
|
||||
// rewriteFunc := func(r *httputil.ProxyRequest) {
|
||||
// r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
|
||||
// r.SetXForwarded()
|
||||
// }
|
||||
|
||||
// ReverseProxy is an HTTP Handler that takes an incoming request and
|
||||
// sends it to another server, proxying the response back to the
|
||||
// client.
|
||||
//
|
||||
// 1xx responses are forwarded to the client if the underlying
|
||||
// transport supports ClientTrace.Got1xxResponse.
|
||||
type ReverseProxy struct {
|
||||
zerolog.Logger
|
||||
|
||||
// The transport used to perform proxy requests.
|
||||
Transport http.RoundTripper
|
||||
|
||||
// ModifyResponse is an optional function that modifies the
|
||||
// Response from the backend. It is called if the backend
|
||||
// returns a response at all, with any HTTP status code.
|
||||
// If the backend is unreachable, the optional ErrorHandler is
|
||||
// called before ModifyResponse.
|
||||
//
|
||||
// If ModifyResponse returns an error, ErrorHandler is called
|
||||
// with its error value. If ErrorHandler is nil, its default
|
||||
// implementation is used.
|
||||
ModifyResponse func(*http.Response) error
|
||||
AccessLogger *accesslog.AccessLogger
|
||||
|
||||
HandlerFunc http.HandlerFunc
|
||||
|
||||
TargetName string
|
||||
TargetURL *types.URL
|
||||
}
|
||||
|
||||
func singleJoiningSlash(a, b string) string {
|
||||
aslash := strings.HasSuffix(a, "/")
|
||||
bslash := strings.HasPrefix(b, "/")
|
||||
switch {
|
||||
case aslash && bslash:
|
||||
return a + b[1:]
|
||||
case !aslash && !bslash:
|
||||
return a + "/" + b
|
||||
}
|
||||
return a + b
|
||||
}
|
||||
|
||||
func joinURLPath(a, b *url.URL) (path, rawpath string) {
|
||||
if a.RawPath == "" && b.RawPath == "" {
|
||||
return singleJoiningSlash(a.Path, b.Path), ""
|
||||
}
|
||||
// Same as singleJoiningSlash, but uses EscapedPath to determine
|
||||
// whether a slash should be added
|
||||
apath := a.EscapedPath()
|
||||
bpath := b.EscapedPath()
|
||||
|
||||
aslash := strings.HasSuffix(apath, "/")
|
||||
bslash := strings.HasPrefix(bpath, "/")
|
||||
|
||||
switch {
|
||||
case aslash && bslash:
|
||||
return a.Path + b.Path[1:], apath + bpath[1:]
|
||||
case !aslash && !bslash:
|
||||
return a.Path + "/" + b.Path, apath + "/" + bpath
|
||||
}
|
||||
return a.Path + b.Path, apath + bpath
|
||||
}
|
||||
|
||||
// NewReverseProxy returns a new [ReverseProxy] that routes
|
||||
// URLs to the scheme, host, and base path provided in target. If the
|
||||
// target's path is "/base" and the incoming request was for "/dir",
|
||||
// the target request will be for /base/dir.
|
||||
func NewReverseProxy(name string, target *types.URL, transport http.RoundTripper) *ReverseProxy {
|
||||
if transport == nil {
|
||||
panic("nil transport")
|
||||
}
|
||||
rp := &ReverseProxy{
|
||||
Logger: logging.With().Str("name", name).Logger(),
|
||||
Transport: transport,
|
||||
TargetName: name,
|
||||
TargetURL: target,
|
||||
}
|
||||
rp.HandlerFunc = rp.handler
|
||||
return rp
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) rewriteRequestURL(req *http.Request) {
|
||||
targetQuery := p.TargetURL.RawQuery
|
||||
req.URL.Scheme = p.TargetURL.Scheme
|
||||
req.URL.Host = p.TargetURL.Host
|
||||
req.URL.Path, req.URL.RawPath = joinURLPath(&p.TargetURL.URL, req.URL)
|
||||
if targetQuery == "" || req.URL.RawQuery == "" {
|
||||
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
||||
} else {
|
||||
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
|
||||
}
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err error, writeHeader bool) {
|
||||
reqURL := r.Host + r.URL.Path
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled),
|
||||
errors.Is(err, io.EOF),
|
||||
errors.Is(err, context.DeadlineExceeded):
|
||||
logging.Debug().Err(err).Str("url", reqURL).Msg("http proxy error")
|
||||
default:
|
||||
var recordErr tls.RecordHeaderError
|
||||
if errors.As(err, &recordErr) {
|
||||
logging.Error().
|
||||
Str("url", reqURL).
|
||||
Msgf(`scheme was likely misconfigured as https,
|
||||
try setting "proxy.%s.scheme" back to "http"`, p.TargetName)
|
||||
logging.Err(err).Msg("underlying error")
|
||||
} else {
|
||||
logging.Err(err).Str("url", reqURL).Msg("http proxy error")
|
||||
}
|
||||
}
|
||||
|
||||
if writeHeader {
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
if p.AccessLogger != nil {
|
||||
p.AccessLogger.LogError(r, err)
|
||||
}
|
||||
}
|
||||
|
||||
// modifyResponse conditionally runs the optional ModifyResponse hook
|
||||
// and reports whether the request should proceed.
|
||||
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, origReq, req *http.Request) bool {
|
||||
if p.ModifyResponse == nil {
|
||||
return true
|
||||
}
|
||||
res.Request = origReq
|
||||
err := p.ModifyResponse(res)
|
||||
res.Request = req
|
||||
if err != nil {
|
||||
res.Body.Close()
|
||||
p.errorHandler(rw, req, err, true)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
p.HandlerFunc(rw, req)
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
|
||||
transport := p.Transport
|
||||
|
||||
ctx := req.Context()
|
||||
/* trunk-ignore(golangci-lint/revive) */
|
||||
if ctx.Done() != nil {
|
||||
// CloseNotifier predates context.Context, and has been
|
||||
// entirely superseded by it. If the request contains
|
||||
// a Context that carries a cancellation signal, don't
|
||||
// bother spinning up a goroutine to watch the CloseNotify
|
||||
// channel (if any).
|
||||
//
|
||||
// If the request Context has a nil Done channel (which
|
||||
// means it is either context.Background, or a custom
|
||||
// Context implementation with no cancellation signal),
|
||||
// then consult the CloseNotifier if available.
|
||||
} else if cn, ok := rw.(http.CloseNotifier); ok {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
notifyChan := cn.CloseNotify()
|
||||
go func() {
|
||||
select {
|
||||
case <-notifyChan:
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
outreq := req.Clone(ctx)
|
||||
if req.ContentLength == 0 {
|
||||
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
|
||||
}
|
||||
if outreq.Body != nil {
|
||||
// Reading from the request body after returning from a handler is not
|
||||
// allowed, and the RoundTrip goroutine that reads the Body can outlive
|
||||
// this handler. This can lead to a crash if the handler panics (see
|
||||
// Issue 46866). Although calling Close doesn't guarantee there isn't
|
||||
// any Read in flight after the handle returns, in practice it's safe to
|
||||
// read after closing it.
|
||||
defer outreq.Body.Close()
|
||||
}
|
||||
if outreq.Header == nil {
|
||||
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
|
||||
}
|
||||
|
||||
p.rewriteRequestURL(outreq)
|
||||
outreq.Close = false
|
||||
|
||||
reqUpType := httpheaders.UpgradeType(outreq.Header)
|
||||
if !IsPrint(reqUpType) {
|
||||
p.errorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), true)
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Del("Forwarded")
|
||||
httpheaders.RemoveHopByHopHeaders(outreq.Header)
|
||||
|
||||
// Issue 21096: tell backend applications that care about trailer support
|
||||
// that we support trailers. (We do, but we don't go out of our way to
|
||||
// advertise that unless the incoming client request thought it was worth
|
||||
// mentioning.) Note that we look at req.Header, not outreq.Header, since
|
||||
// the latter has passed through removeHopByHopHeaders.
|
||||
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
|
||||
outreq.Header.Set("Te", "trailers")
|
||||
}
|
||||
|
||||
// After stripping all the hop-by-hop connection headers above, add back any
|
||||
// necessary for protocol upgrades, such as for websockets.
|
||||
if reqUpType != "" {
|
||||
outreq.Header.Set("Connection", "Upgrade")
|
||||
outreq.Header.Set("Upgrade", reqUpType)
|
||||
|
||||
if strings.EqualFold(reqUpType, "websocket") {
|
||||
cleanWebsocketHeaders(outreq)
|
||||
}
|
||||
}
|
||||
|
||||
// If we aren't the first proxy retain prior
|
||||
// X-Forwarded-For information as a comma+space
|
||||
// separated list and fold multiple headers into one.
|
||||
prior, ok := outreq.Header[httpheaders.HeaderXForwardedFor]
|
||||
omit := ok && prior == nil // Issue 38079: nil now means don't populate the header
|
||||
|
||||
xff, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err != nil {
|
||||
xff = req.RemoteAddr
|
||||
}
|
||||
if len(prior) > 0 {
|
||||
xff = strings.Join(prior, ", ") + ", " + xff
|
||||
}
|
||||
if !omit {
|
||||
outreq.Header.Set(httpheaders.HeaderXForwardedFor, xff)
|
||||
}
|
||||
|
||||
var reqScheme string
|
||||
if req.TLS != nil {
|
||||
reqScheme = "https"
|
||||
} else {
|
||||
reqScheme = "http"
|
||||
}
|
||||
|
||||
outreq.Header.Set(httpheaders.HeaderXForwardedMethod, req.Method)
|
||||
outreq.Header.Set(httpheaders.HeaderXForwardedProto, reqScheme)
|
||||
outreq.Header.Set(httpheaders.HeaderXForwardedHost, req.Host)
|
||||
outreq.Header.Set(httpheaders.HeaderXForwardedURI, req.RequestURI)
|
||||
|
||||
if _, ok := outreq.Header["User-Agent"]; !ok {
|
||||
// If the outbound request doesn't have a User-Agent header set,
|
||||
// don't send the default Go HTTP client User-Agent.
|
||||
outreq.Header.Set("User-Agent", "")
|
||||
}
|
||||
|
||||
var (
|
||||
roundTripMutex sync.Mutex
|
||||
roundTripDone bool
|
||||
)
|
||||
trace := &httptrace.ClientTrace{
|
||||
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
|
||||
roundTripMutex.Lock()
|
||||
defer roundTripMutex.Unlock()
|
||||
if roundTripDone {
|
||||
// If RoundTrip has returned, don't try to further modify
|
||||
// the ResponseWriter's header map.
|
||||
return nil
|
||||
}
|
||||
h := rw.Header()
|
||||
copyHeader(h, http.Header(header))
|
||||
rw.WriteHeader(code)
|
||||
|
||||
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
|
||||
clear(h)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
|
||||
|
||||
res, err := transport.RoundTrip(outreq)
|
||||
|
||||
roundTripMutex.Lock()
|
||||
roundTripDone = true
|
||||
roundTripMutex.Unlock()
|
||||
if err != nil {
|
||||
p.errorHandler(rw, outreq, err, false)
|
||||
res = &http.Response{
|
||||
Status: http.StatusText(http.StatusBadGateway),
|
||||
StatusCode: http.StatusBadGateway,
|
||||
Proto: req.Proto,
|
||||
ProtoMajor: req.ProtoMajor,
|
||||
ProtoMinor: req.ProtoMinor,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))),
|
||||
Request: req,
|
||||
TLS: req.TLS,
|
||||
}
|
||||
}
|
||||
|
||||
if p.AccessLogger != nil {
|
||||
defer func() {
|
||||
p.AccessLogger.Log(req, res)
|
||||
}()
|
||||
}
|
||||
|
||||
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
|
||||
if res.StatusCode == http.StatusSwitchingProtocols {
|
||||
if !p.modifyResponse(rw, res, req, outreq) {
|
||||
return
|
||||
}
|
||||
p.handleUpgradeResponse(rw, outreq, res)
|
||||
return
|
||||
}
|
||||
|
||||
httpheaders.RemoveHopByHopHeaders(res.Header)
|
||||
|
||||
if !p.modifyResponse(rw, res, req, outreq) {
|
||||
return
|
||||
}
|
||||
|
||||
copyHeader(rw.Header(), res.Header)
|
||||
|
||||
// The "Trailer" header isn't included in the Transport's response,
|
||||
// at least for *http.Transport. Build it up from Trailer.
|
||||
announcedTrailers := len(res.Trailer)
|
||||
if announcedTrailers > 0 {
|
||||
trailerKeys := make([]string, 0, len(res.Trailer))
|
||||
for k := range res.Trailer {
|
||||
trailerKeys = append(trailerKeys, k)
|
||||
}
|
||||
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
|
||||
}
|
||||
|
||||
rw.WriteHeader(res.StatusCode)
|
||||
|
||||
err = U.CopyClose(U.NewContextWriter(ctx, rw), U.NewContextReader(ctx, res.Body)) // close now, instead of defer, to populate res.Trailer
|
||||
if err != nil {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
p.errorHandler(rw, req, err, false)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(res.Trailer) > 0 {
|
||||
// Force chunking if we saw a response trailer.
|
||||
// This prevents net/http from calculating the length for short
|
||||
// bodies and adding a Content-Length.
|
||||
http.NewResponseController(rw).Flush()
|
||||
}
|
||||
|
||||
if len(res.Trailer) == announcedTrailers {
|
||||
copyHeader(rw.Header(), res.Trailer)
|
||||
return
|
||||
}
|
||||
|
||||
for k, vv := range res.Trailer {
|
||||
k = http.TrailerPrefix + k
|
||||
for _, v := range vv {
|
||||
rw.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reference: https://github.com/traefik/traefik/blob/master/pkg/proxy/httputil/proxy.go
|
||||
// https://tools.ietf.org/html/rfc6455#page-20
|
||||
func cleanWebsocketHeaders(req *http.Request) {
|
||||
req.Header["Sec-WebSocket-Key"] = req.Header["Sec-Websocket-Key"]
|
||||
delete(req.Header, "Sec-Websocket-Key")
|
||||
|
||||
req.Header["Sec-WebSocket-Extensions"] = req.Header["Sec-Websocket-Extensions"]
|
||||
delete(req.Header, "Sec-Websocket-Extensions")
|
||||
|
||||
req.Header["Sec-WebSocket-Accept"] = req.Header["Sec-Websocket-Accept"]
|
||||
delete(req.Header, "Sec-Websocket-Accept")
|
||||
|
||||
req.Header["Sec-WebSocket-Protocol"] = req.Header["Sec-Websocket-Protocol"]
|
||||
delete(req.Header, "Sec-Websocket-Protocol")
|
||||
|
||||
req.Header["Sec-WebSocket-Version"] = req.Header["Sec-Websocket-Version"]
|
||||
delete(req.Header, "Sec-Websocket-Version")
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
|
||||
reqUpType := httpheaders.UpgradeType(req.Header)
|
||||
resUpType := httpheaders.UpgradeType(res.Header)
|
||||
if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
|
||||
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), true)
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(reqUpType, resUpType) {
|
||||
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType), true)
|
||||
return
|
||||
}
|
||||
|
||||
backConn, ok := res.Body.(io.ReadWriteCloser)
|
||||
if !ok {
|
||||
p.errorHandler(rw, req, errors.New("internal error: 101 switching protocols response with non-writable body"), true)
|
||||
return
|
||||
}
|
||||
|
||||
rc := http.NewResponseController(rw)
|
||||
conn, brw, hijackErr := rc.Hijack()
|
||||
if errors.Is(hijackErr, http.ErrNotSupported) {
|
||||
p.errorHandler(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw), true)
|
||||
return
|
||||
}
|
||||
|
||||
backConnCloseCh := make(chan bool)
|
||||
go func() {
|
||||
// Ensure that the cancellation of a request closes the backend.
|
||||
// See issue https://golang.org/issue/35559.
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
case <-backConnCloseCh:
|
||||
}
|
||||
backConn.Close()
|
||||
}()
|
||||
defer close(backConnCloseCh)
|
||||
|
||||
if hijackErr != nil {
|
||||
p.errorHandler(rw, req, fmt.Errorf("hijack failed on protocol switch: %w", hijackErr), true)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
copyHeader(rw.Header(), res.Header)
|
||||
|
||||
res.Header = rw.Header()
|
||||
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
|
||||
if err := res.Write(brw); err != nil {
|
||||
/* trunk-ignore(golangci-lint/errorlint) */
|
||||
p.errorHandler(rw, req, fmt.Errorf("response write: %s", err), true)
|
||||
return
|
||||
}
|
||||
if err := brw.Flush(); err != nil {
|
||||
/* trunk-ignore(golangci-lint/errorlint) */
|
||||
p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true)
|
||||
return
|
||||
}
|
||||
|
||||
bdp := U.NewBidirectionalPipe(req.Context(), conn, backConn)
|
||||
/* trunk-ignore(golangci-lint/errcheck) */
|
||||
bdp.Start()
|
||||
}
|
||||
|
||||
func IsPrint(s string) bool {
|
||||
for _, r := range s {
|
||||
if r < ' ' || r > '~' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
31
internal/net/gphttp/serve_mux.go
Normal file
31
internal/net/gphttp/serve_mux.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package gphttp
|
||||
|
||||
import "net/http"
|
||||
|
||||
type ServeMux struct {
|
||||
*http.ServeMux
|
||||
}
|
||||
|
||||
func NewServeMux() ServeMux {
|
||||
return ServeMux{http.NewServeMux()}
|
||||
}
|
||||
|
||||
func (mux ServeMux) Handle(pattern string, handler http.Handler) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = r.(error)
|
||||
}
|
||||
}()
|
||||
mux.ServeMux.Handle(pattern, handler)
|
||||
return
|
||||
}
|
||||
|
||||
func (mux ServeMux) HandleFunc(pattern string, handler http.HandlerFunc) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = r.(error)
|
||||
}
|
||||
}()
|
||||
mux.ServeMux.HandleFunc(pattern, handler)
|
||||
return
|
||||
}
|
||||
18
internal/net/gphttp/server/error.go
Normal file
18
internal/net/gphttp/server/error.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func HandleError(logger *zerolog.Logger, err error, msg string) {
|
||||
switch {
|
||||
case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled):
|
||||
return
|
||||
default:
|
||||
logger.Fatal().Err(err).Msg(msg)
|
||||
}
|
||||
}
|
||||
163
internal/net/gphttp/server/server.go
Normal file
163
internal/net/gphttp/server/server.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/autocert"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
Name string
|
||||
CertProvider *autocert.Provider
|
||||
http *http.Server
|
||||
https *http.Server
|
||||
httpStarted bool
|
||||
httpsStarted bool
|
||||
startTime time.Time
|
||||
|
||||
l zerolog.Logger
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Name string
|
||||
HTTPAddr string
|
||||
HTTPSAddr string
|
||||
CertProvider *autocert.Provider
|
||||
Handler http.Handler
|
||||
}
|
||||
|
||||
func StartServer(parent task.Parent, opt Options) (s *Server) {
|
||||
s = NewServer(opt)
|
||||
s.Start(parent)
|
||||
return s
|
||||
}
|
||||
|
||||
func NewServer(opt Options) (s *Server) {
|
||||
var httpSer, httpsSer *http.Server
|
||||
|
||||
logger := logging.With().Str("server", opt.Name).Logger()
|
||||
|
||||
certAvailable := false
|
||||
if opt.CertProvider != nil {
|
||||
_, err := opt.CertProvider.GetCert(nil)
|
||||
certAvailable = err == nil
|
||||
}
|
||||
|
||||
out := io.Discard
|
||||
if common.IsDebug {
|
||||
out = logger
|
||||
}
|
||||
|
||||
if opt.HTTPAddr != "" {
|
||||
httpSer = &http.Server{
|
||||
Addr: opt.HTTPAddr,
|
||||
Handler: opt.Handler,
|
||||
ErrorLog: log.New(out, "", 0), // most are tls related
|
||||
}
|
||||
}
|
||||
if certAvailable && opt.HTTPSAddr != "" {
|
||||
httpsSer = &http.Server{
|
||||
Addr: opt.HTTPSAddr,
|
||||
Handler: opt.Handler,
|
||||
ErrorLog: log.New(out, "", 0), // most are tls related
|
||||
TLSConfig: &tls.Config{
|
||||
GetCertificate: opt.CertProvider.GetCert,
|
||||
},
|
||||
}
|
||||
}
|
||||
return &Server{
|
||||
Name: opt.Name,
|
||||
CertProvider: opt.CertProvider,
|
||||
http: httpSer,
|
||||
https: httpsSer,
|
||||
l: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Start will start the http and https servers.
|
||||
//
|
||||
// If both are not set, this does nothing.
|
||||
//
|
||||
// Start() is non-blocking.
|
||||
func (s *Server) Start(parent task.Parent) {
|
||||
if s.http == nil && s.https == nil {
|
||||
return
|
||||
}
|
||||
|
||||
task := parent.Subtask("server."+s.Name, false)
|
||||
|
||||
s.startTime = time.Now()
|
||||
if s.http != nil {
|
||||
go func() {
|
||||
err := s.http.ListenAndServe()
|
||||
if err != nil {
|
||||
s.handleErr(err, "failed to serve http server")
|
||||
}
|
||||
}()
|
||||
s.httpStarted = true
|
||||
s.l.Info().Str("addr", s.http.Addr).Msg("server started")
|
||||
}
|
||||
|
||||
if s.https != nil {
|
||||
go func() {
|
||||
l, err := net.Listen("tcp", s.https.Addr)
|
||||
if err != nil {
|
||||
s.handleErr(err, "failed to listen on port")
|
||||
return
|
||||
}
|
||||
defer l.Close()
|
||||
s.handleErr(s.https.Serve(tls.NewListener(l, s.https.TLSConfig)), "failed to serve https server")
|
||||
}()
|
||||
s.httpsStarted = true
|
||||
s.l.Info().Str("addr", s.https.Addr).Msgf("server started")
|
||||
}
|
||||
|
||||
task.OnCancel("stop", s.stop)
|
||||
}
|
||||
|
||||
func (s *Server) stop() {
|
||||
if s.http == nil && s.https == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(task.RootContext(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if s.http != nil && s.httpStarted {
|
||||
err := s.http.Shutdown(ctx)
|
||||
if err != nil {
|
||||
s.handleErr(err, "failed to shutdown http server")
|
||||
} else {
|
||||
s.httpStarted = false
|
||||
s.l.Info().Str("addr", s.http.Addr).Msgf("server stopped")
|
||||
}
|
||||
}
|
||||
|
||||
if s.https != nil && s.httpsStarted {
|
||||
err := s.https.Shutdown(ctx)
|
||||
if err != nil {
|
||||
s.handleErr(err, "failed to shutdown https server")
|
||||
} else {
|
||||
s.httpsStarted = false
|
||||
s.l.Info().Str("addr", s.https.Addr).Msgf("server stopped")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Uptime() time.Duration {
|
||||
return time.Since(s.startTime)
|
||||
}
|
||||
|
||||
func (s *Server) handleErr(err error, msg string) {
|
||||
HandleError(&s.l, err, msg)
|
||||
}
|
||||
11
internal/net/gphttp/status_code.go
Normal file
11
internal/net/gphttp/status_code.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package gphttp
|
||||
|
||||
import "net/http"
|
||||
|
||||
func IsSuccess(status int) bool {
|
||||
return status >= http.StatusOK && status < http.StatusMultipleChoices
|
||||
}
|
||||
|
||||
func IsStatusCodeValid(status int) bool {
|
||||
return http.StatusText(status) != ""
|
||||
}
|
||||
34
internal/net/gphttp/transport.go
Normal file
34
internal/net/gphttp/transport.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package gphttp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
var DefaultDialer = net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
func NewTransport() *http.Transport {
|
||||
return &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: DefaultDialer.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConnsPerHost: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
// DisableCompression: true, // Prevent double compression
|
||||
ResponseHeaderTimeout: 60 * time.Second,
|
||||
WriteBufferSize: 16 * 1024, // 16KB
|
||||
ReadBufferSize: 16 * 1024, // 16KB
|
||||
}
|
||||
}
|
||||
|
||||
func NewTransportWithTLSConfig(tlsConfig *tls.Config) *http.Transport {
|
||||
tr := NewTransport()
|
||||
tr.TLSClientConfig = tlsConfig
|
||||
return tr
|
||||
}
|
||||
Reference in New Issue
Block a user