diff --git a/internal/config/config.go b/internal/config/config.go index 6771fc67..924f3821 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -174,6 +174,7 @@ func (cfg *Config) load() E.Error { // errors are non fatal below errs := E.NewBuilder(errMsg) errs.Add(entrypoint.SetMiddlewares(model.Entrypoint.Middlewares)) + errs.Add(entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog)) errs.Add(cfg.initNotification(model.Providers.Notification)) errs.Add(cfg.initAutoCert(&model.AutoCert)) errs.Add(cfg.loadRouteProviders(&model.Providers)) diff --git a/internal/config/types/config.go b/internal/config/types/config.go index f619a102..39bd12c5 100644 --- a/internal/config/types/config.go +++ b/internal/config/types/config.go @@ -1,5 +1,7 @@ package types +import "github.com/yusing/go-proxy/internal/net/http/accesslog" + type ( Config struct { AutoCert AutoCertConfig `json:"autocert" yaml:",flow"` @@ -15,7 +17,8 @@ type ( Notification []NotificationConfig `json:"notification" yaml:"notification"` } Entrypoint struct { - Middlewares []map[string]any `json:"middlewares" yaml:"middlewares"` + Middlewares []map[string]any `json:"middlewares" yaml:"middlewares"` + AccessLog *accesslog.Config `json:"access_log" yaml:"access_log"` } NotificationConfig map[string]any ) diff --git a/internal/docker/idlewatcher/waker.go b/internal/docker/idlewatcher/waker.go index 0104d88e..1ae21ebd 100644 --- a/internal/docker/idlewatcher/waker.go +++ b/internal/docker/idlewatcher/waker.go @@ -39,7 +39,7 @@ const ( // TODO: support stream func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) { - hcCfg := entry.HealthCheckConfig() + hcCfg := entry.RawEntry().HealthCheck hcCfg.Timeout = idleWakerCheckTimeout waker := &waker{ diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index 5e66c97f..452bd651 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -7,10 +7,13 @@ import ( "strings" "sync" + gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/accesslog" "github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/net/http/middleware/errorpage" "github.com/yusing/go-proxy/internal/route/routes" route "github.com/yusing/go-proxy/internal/route/types" + "github.com/yusing/go-proxy/internal/task" ) var findRouteFunc = findRouteAnyDomain @@ -18,6 +21,9 @@ var findRouteFunc = findRouteAnyDomain var ( epMiddleware *middleware.Middleware epMiddlewareMu sync.Mutex + + epAccessLogger *accesslog.AccessLogger + epAccessLoggerMu sync.Mutex ) func SetFindRouteDomains(domains []string) { @@ -47,6 +53,23 @@ func SetMiddlewares(mws []map[string]any) error { return nil } +func SetAccessLogger(parent *task.Task, cfg *accesslog.Config) (err error) { + epAccessLoggerMu.Lock() + defer epAccessLoggerMu.Unlock() + + if cfg == nil { + epAccessLogger = nil + return + } + + epAccessLogger, err = accesslog.NewFileAccessLogger(parent, cfg) + if err != nil { + return + } + logger.Debug().Msg("entrypoint access logger created") + return +} + func Handler(w http.ResponseWriter, r *http.Request) { mux, err := findRouteFunc(r.Host) if err != nil { @@ -58,6 +81,16 @@ func Handler(w http.ResponseWriter, r *http.Request) { } } if err == nil { + if epAccessLogger != nil { + epMiddlewareMu.Lock() + if epAccessLogger != nil { + w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error { + epAccessLogger.Log(r, resp) + return nil + }) + } + epMiddlewareMu.Unlock() + } if epMiddleware != nil { epMiddlewareMu.Lock() if epMiddleware != nil { diff --git a/internal/net/http/accesslog/access_logger.go b/internal/net/http/accesslog/access_logger.go new file mode 100644 index 00000000..b9d35fa8 --- /dev/null +++ b/internal/net/http/accesslog/access_logger.go @@ -0,0 +1,133 @@ +package accesslog + +import ( + "bytes" + "io" + "net/http" + "os" + "time" + + "github.com/yusing/go-proxy/internal/common" + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/task" +) + +type ( + AccessLogger struct { + parent *task.Task + buf chan []byte + cfg *Config + w io.WriteCloser + Formatter + } + + Formatter interface { + // Format writes a log line to line without a trailing newline + Format(line *bytes.Buffer, req *http.Request, res *http.Response) + } +) + +var logger = logging.With().Str("module", "accesslog").Logger() + +var TestTimeNow = time.Now().Format(logTimeFormat) + +const logTimeFormat = "02/Jan/2006:15:04:05 -0700" + +func NewFileAccessLogger(parent *task.Task, cfg *Config) (*AccessLogger, error) { + f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return nil, err + } + return NewAccessLogger(parent, f, cfg), nil +} + +func NewAccessLogger(parent *task.Task, w io.WriteCloser, cfg *Config) *AccessLogger { + l := &AccessLogger{ + parent: parent, + cfg: cfg, + w: w, + } + fmt := CommonFormatter{cfg: &l.cfg.Fields} + switch l.cfg.Format { + case FormatCommon: + l.Formatter = fmt + case FormatCombined: + l.Formatter = CombinedFormatter{CommonFormatter: fmt} + case FormatJSON: + l.Formatter = JSONFormatter{CommonFormatter: fmt} + } + if cfg.BufferSize == 0 { + cfg.BufferSize = DefaultBufferSize + } + l.buf = make(chan []byte, cfg.BufferSize) + go l.start() + return l +} + +func timeNow() string { + if !common.IsTest { + return time.Now().Format(logTimeFormat) + } + return TestTimeNow +} + +func (l *AccessLogger) checkKeep(req *http.Request, res *http.Response) bool { + if !l.cfg.Filters.StatusCodes.CheckKeep(req, res) || + !l.cfg.Filters.Method.CheckKeep(req, res) || + !l.cfg.Filters.Headers.CheckKeep(req, res) || + !l.cfg.Filters.CIDR.CheckKeep(req, res) { + return false + } + return true +} + +func (l *AccessLogger) Log(req *http.Request, res *http.Response) { + if !l.checkKeep(req, res) { + return + } + + var line bytes.Buffer + l.Format(&line, req, res) + line.WriteRune('\n') + + select { + case <-l.parent.Context().Done(): + return + default: + l.buf <- line.Bytes() + } +} + +func (l *AccessLogger) LogError(req *http.Request, err error) { + l.Log(req, &http.Response{StatusCode: http.StatusInternalServerError, Status: err.Error()}) +} + +func (l *AccessLogger) close() { + close(l.buf) + l.w.Close() +} + +func (l *AccessLogger) handleErr(err error) { + E.LogError("failed to write access log", err, &logger) +} + +func (l *AccessLogger) start() { + task := l.parent.Subtask("access log flusher") + defer task.Finish("done") + defer l.close() + + for { + select { + case <-task.Context().Done(): + return + default: + for line := range l.buf { + _, err := l.w.Write(line) + if err != nil { + l.handleErr(err) + } + } + } + } +} diff --git a/internal/net/http/accesslog/access_logger_test.go b/internal/net/http/accesslog/access_logger_test.go new file mode 100644 index 00000000..f44b776b --- /dev/null +++ b/internal/net/http/accesslog/access_logger_test.go @@ -0,0 +1,132 @@ +package accesslog_test + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "testing" + + E "github.com/yusing/go-proxy/internal/error" + . "github.com/yusing/go-proxy/internal/net/http/accesslog" + taskPkg "github.com/yusing/go-proxy/internal/task" + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +type testWritter struct { + line string +} + +func (w *testWritter) Write(p []byte) (n int, err error) { + w.line = string(p) + return len(p), nil +} + +func (w *testWritter) Close() error { + return nil +} + +var tw testWritter + +const ( + remote = "192.168.1.1" + u = "http://example.com/?bar=baz&foo=bar" + uRedacted = "http://example.com/?bar=" + RedactedValue + "&foo=" + RedactedValue + referer = "https://www.google.com/" + proto = "HTTP/1.1" + ua = "Go-http-client/1.1" + status = http.StatusOK + contentLength = 100 + method = http.MethodGet +) + +var ( + testURL = E.Must(url.Parse(u)) + req = &http.Request{ + RemoteAddr: remote, + Method: method, + Proto: proto, + Host: testURL.Host, + URL: testURL, + Header: http.Header{ + "User-Agent": []string{ua}, + "Referer": []string{referer}, + "Cookie": []string{ + "foo=bar", + "bar=baz", + }, + }, + } + resp = &http.Response{ + StatusCode: status, + ContentLength: contentLength, + Header: http.Header{"Content-Type": []string{"text/plain"}}, + } + task = taskPkg.GlobalTask("test logger") +) + +func TestAccessLoggerCommon(t *testing.T) { + config := DefaultConfig + config.Format = FormatCommon + logger := NewAccessLogger(task, &tw, &config) + logger.Log(req, resp) + ExpectEqual(t, tw.line, + fmt.Sprintf("%s - - [%s] \"%s %s %s\" %d %d\n", + remote, TestTimeNow, method, u, proto, status, contentLength, + ), + ) +} + +func TestAccessLoggerCombined(t *testing.T) { + config := DefaultConfig + config.Format = FormatCombined + logger := NewAccessLogger(task, &tw, &config) + logger.Log(req, resp) + ExpectEqual(t, tw.line, + fmt.Sprintf("%s - - [%s] \"%s %s %s\" %d %d \"%s\" \"%s\"\n", + remote, TestTimeNow, method, u, proto, status, contentLength, referer, ua, + ), + ) +} + +func TestAccessLoggerRedactQuery(t *testing.T) { + config := DefaultConfig + config.Format = FormatCommon + config.Fields.Query.DefaultMode = FieldModeRedact + logger := NewAccessLogger(task, &tw, &config) + logger.Log(req, resp) + ExpectEqual(t, tw.line, + fmt.Sprintf("%s - - [%s] \"%s %s %s\" %d %d\n", + remote, TestTimeNow, method, uRedacted, proto, status, contentLength, + ), + ) +} + +func getJSONEntry(t *testing.T, config *Config) JSONLogEntry { + t.Helper() + config.Format = FormatJSON + logger := NewAccessLogger(task, &tw, config) + logger.Log(req, resp) + var entry JSONLogEntry + err := json.Unmarshal([]byte(tw.line), &entry) + ExpectNoError(t, err) + return entry +} + +func TestAccessLoggerJSON(t *testing.T) { + config := DefaultConfig + entry := getJSONEntry(t, &config) + ExpectEqual(t, entry.IP, remote) + ExpectEqual(t, entry.Method, method) + ExpectEqual(t, entry.Scheme, "http") + ExpectEqual(t, entry.Host, testURL.Host) + ExpectEqual(t, entry.URI, testURL.RequestURI()) + ExpectEqual(t, entry.Protocol, proto) + ExpectEqual(t, entry.Status, status) + ExpectEqual(t, entry.ContentType, "text/plain") + ExpectEqual(t, entry.Size, contentLength) + ExpectEqual(t, entry.Referer, referer) + ExpectEqual(t, entry.UserAgent, ua) + ExpectEqual(t, len(entry.Headers), 0) + ExpectEqual(t, len(entry.Cookies), 0) +} diff --git a/internal/net/http/accesslog/config.go b/internal/net/http/accesslog/config.go new file mode 100644 index 00000000..b6a57317 --- /dev/null +++ b/internal/net/http/accesslog/config.go @@ -0,0 +1,47 @@ +package accesslog + +type ( + Format string + Filters struct { + StatusCodes LogFilter[*StatusCodeRange] + Method LogFilter[HTTPMethod] + Headers LogFilter[*HTTPHeader] // header exists or header == value + CIDR LogFilter[*CIDR] + } + Fields struct { + Headers FieldConfig + Query FieldConfig + Cookies FieldConfig + } + Config struct { + BufferSize uint + Format Format `validate:"oneof=common combined json"` + Path string `validate:"required"` + Filters Filters + Fields Fields + } +) + +var ( + FormatCommon Format = "common" + FormatCombined Format = "combined" + FormatJSON Format = "json" +) + +const DefaultBufferSize = 100 + +var DefaultConfig = Config{ + BufferSize: DefaultBufferSize, + Format: FormatCombined, + Fields: Fields{ + Headers: FieldConfig{ + DefaultMode: FieldModeDrop, + }, + Query: FieldConfig{ + DefaultMode: FieldModeKeep, + }, + Cookies: FieldConfig{ + DefaultMode: FieldModeDrop, + }, + }, +} diff --git a/internal/net/http/accesslog/config_test.go b/internal/net/http/accesslog/config_test.go new file mode 100644 index 00000000..4706410e --- /dev/null +++ b/internal/net/http/accesslog/config_test.go @@ -0,0 +1,53 @@ +package accesslog_test + +import ( + "testing" + + "github.com/yusing/go-proxy/internal/docker" + . "github.com/yusing/go-proxy/internal/net/http/accesslog" + "github.com/yusing/go-proxy/internal/utils" + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestNewConfig(t *testing.T) { + labels := map[string]string{ + "proxy.buffer_size": "10", + "proxy.format": "combined", + "proxy.file_path": "/tmp/access.log", + "proxy.filters.status_codes.values": "200-299", + "proxy.filters.method.values": "GET, POST", + "proxy.filters.headers.values": "foo=bar, baz", + "proxy.filters.headers.negative": "true", + "proxy.filters.cidr.values": "192.168.10.0/24", + "proxy.fields.headers.default_mode": "keep", + "proxy.fields.headers.config.foo": "redact", + "proxy.fields.query.default_mode": "drop", + "proxy.fields.query.config.foo": "keep", + "proxy.fields.cookies.default_mode": "redact", + "proxy.fields.cookies.config.foo": "keep", + } + parsed, err := docker.ParseLabels(labels) + ExpectNoError(t, err) + + var config Config + err = utils.Deserialize(parsed, &config) + ExpectNoError(t, err) + + ExpectEqual(t, config.BufferSize, 10) + ExpectEqual(t, config.Format, FormatCombined) + ExpectEqual(t, config.Path, "/tmp/access.log") + ExpectDeepEqual(t, config.Filters.StatusCodes.Values, []*StatusCodeRange{{Start: 200, End: 299}}) + ExpectEqual(t, len(config.Filters.Method.Values), 2) + ExpectDeepEqual(t, config.Filters.Method.Values, []HTTPMethod{"GET", "POST"}) + ExpectEqual(t, len(config.Filters.Headers.Values), 2) + ExpectDeepEqual(t, config.Filters.Headers.Values, []*HTTPHeader{{Key: "foo", Value: "bar"}, {Key: "baz", Value: ""}}) + ExpectTrue(t, config.Filters.Headers.Negative) + ExpectEqual(t, len(config.Filters.CIDR.Values), 1) + ExpectEqual(t, config.Filters.CIDR.Values[0].String(), "192.168.10.0/24") + ExpectEqual(t, config.Fields.Headers.DefaultMode, FieldModeKeep) + ExpectEqual(t, config.Fields.Headers.Config["foo"], FieldModeRedact) + ExpectEqual(t, config.Fields.Query.DefaultMode, FieldModeDrop) + ExpectEqual(t, config.Fields.Query.Config["foo"], FieldModeKeep) + ExpectEqual(t, config.Fields.Cookies.DefaultMode, FieldModeRedact) + ExpectEqual(t, config.Fields.Cookies.Config["foo"], FieldModeKeep) +} diff --git a/internal/net/http/accesslog/fields.go b/internal/net/http/accesslog/fields.go new file mode 100644 index 00000000..4b3d4989 --- /dev/null +++ b/internal/net/http/accesslog/fields.go @@ -0,0 +1,103 @@ +package accesslog + +import ( + "net/http" + "net/url" +) + +type ( + FieldConfig struct { + DefaultMode FieldMode `validate:"oneof=keep drop redact"` + Config map[string]FieldMode `validate:"dive,oneof=keep drop redact"` + } + FieldMode string +) + +const ( + FieldModeKeep FieldMode = "keep" + FieldModeDrop FieldMode = "drop" + FieldModeRedact FieldMode = "redact" + + RedactedValue = "REDACTED" +) + +func processMap[V any](cfg *FieldConfig, m map[string]V, redactedV V) map[string]V { + if len(cfg.Config) == 0 { + switch cfg.DefaultMode { + case FieldModeKeep: + return m + case FieldModeDrop: + return nil + case FieldModeRedact: + redacted := make(map[string]V) + for k := range m { + redacted[k] = redactedV + } + return redacted + } + } + + if len(m) == 0 { + return m + } + + newMap := make(map[string]V) + for k := range m { + var mode FieldMode + var ok bool + if mode, ok = cfg.Config[k]; !ok { + mode = cfg.DefaultMode + } + switch mode { + case FieldModeKeep: + newMap[k] = m[k] + case FieldModeRedact: + newMap[k] = redactedV + } + } + return newMap +} + +func processSlice[V any, VReturn any](cfg *FieldConfig, s []V, getKey func(V) string, convert func(V) VReturn, redact func(V) VReturn) map[string]VReturn { + if len(s) == 0 || + len(cfg.Config) == 0 && cfg.DefaultMode == FieldModeDrop { + return nil + } + newMap := make(map[string]VReturn, len(s)) + for _, v := range s { + var mode FieldMode + var ok bool + k := getKey(v) + if mode, ok = cfg.Config[k]; !ok { + mode = cfg.DefaultMode + } + switch mode { + case FieldModeKeep: + newMap[k] = convert(v) + case FieldModeRedact: + newMap[k] = redact(v) + } + } + return newMap +} + +func (cfg *FieldConfig) ProcessHeaders(headers http.Header) http.Header { + return processMap(cfg, headers, []string{RedactedValue}) +} + +func (cfg *FieldConfig) ProcessQuery(q url.Values) url.Values { + return processMap(cfg, q, []string{RedactedValue}) +} + +func (cfg *FieldConfig) ProcessCookies(cookies []*http.Cookie) map[string]string { + return processSlice(cfg, cookies, + func(c *http.Cookie) string { + return c.Name + }, + func(c *http.Cookie) string { + return c.Value + }, + func(c *http.Cookie) string { + return RedactedValue + }) +} diff --git a/internal/net/http/accesslog/fields_test.go b/internal/net/http/accesslog/fields_test.go new file mode 100644 index 00000000..f8cd1c40 --- /dev/null +++ b/internal/net/http/accesslog/fields_test.go @@ -0,0 +1,72 @@ +package accesslog_test + +import ( + "testing" + + . "github.com/yusing/go-proxy/internal/net/http/accesslog" + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +// Cookie header should be removed, +// stored in JSONLogEntry.Cookies instead. +func TestAccessLoggerJSONKeepHeaders(t *testing.T) { + config := DefaultConfig + config.Fields.Headers.DefaultMode = FieldModeKeep + entry := getJSONEntry(t, &config) + ExpectDeepEqual(t, len(entry.Headers["Cookie"]), 0) + for k, v := range req.Header { + if k != "Cookie" { + ExpectDeepEqual(t, entry.Headers[k], v) + } + } +} + +func TestAccessLoggerJSONRedactHeaders(t *testing.T) { + config := DefaultConfig + config.Fields.Headers.DefaultMode = FieldModeRedact + entry := getJSONEntry(t, &config) + ExpectDeepEqual(t, len(entry.Headers["Cookie"]), 0) + for k := range req.Header { + if k != "Cookie" { + ExpectDeepEqual(t, entry.Headers[k], []string{RedactedValue}) + } + } +} + +func TestAccessLoggerJSONKeepCookies(t *testing.T) { + config := DefaultConfig + config.Fields.Headers.DefaultMode = FieldModeKeep + config.Fields.Cookies.DefaultMode = FieldModeKeep + entry := getJSONEntry(t, &config) + ExpectDeepEqual(t, len(entry.Headers["Cookie"]), 0) + for _, cookie := range req.Cookies() { + ExpectEqual(t, entry.Cookies[cookie.Name], cookie.Value) + } +} + +func TestAccessLoggerJSONRedactCookies(t *testing.T) { + config := DefaultConfig + config.Fields.Headers.DefaultMode = FieldModeKeep + config.Fields.Cookies.DefaultMode = FieldModeRedact + entry := getJSONEntry(t, &config) + ExpectDeepEqual(t, len(entry.Headers["Cookie"]), 0) + for _, cookie := range req.Cookies() { + ExpectEqual(t, entry.Cookies[cookie.Name], RedactedValue) + } +} + +func TestAccessLoggerJSONDropQuery(t *testing.T) { + config := DefaultConfig + config.Fields.Query.DefaultMode = FieldModeDrop + entry := getJSONEntry(t, &config) + ExpectDeepEqual(t, entry.Query["foo"], nil) + ExpectDeepEqual(t, entry.Query["bar"], nil) +} + +func TestAccessLoggerJSONRedactQuery(t *testing.T) { + config := DefaultConfig + config.Fields.Query.DefaultMode = FieldModeRedact + entry := getJSONEntry(t, &config) + ExpectDeepEqual(t, entry.Query["foo"], []string{RedactedValue}) + ExpectDeepEqual(t, entry.Query["bar"], []string{RedactedValue}) +} diff --git a/internal/net/http/accesslog/filter.go b/internal/net/http/accesslog/filter.go new file mode 100644 index 00000000..c92457a3 --- /dev/null +++ b/internal/net/http/accesslog/filter.go @@ -0,0 +1,102 @@ +package accesslog + +import ( + "net" + "net/http" + "strings" + + E "github.com/yusing/go-proxy/internal/error" +) + +type ( + LogFilter[T Filterable] struct { + Negative bool + Values []T + } + Filterable interface { + comparable + Fulfill(req *http.Request, res *http.Response) bool + } + HTTPMethod string + HTTPHeader struct { + Key, Value string + } + CIDR struct { + *net.IPNet + } +) + +var ErrInvalidHTTPHeaderFilter = E.New("invalid http header filter") + +func (f *LogFilter[T]) CheckKeep(req *http.Request, res *http.Response) bool { + if len(f.Values) == 0 { + return !f.Negative + } + for _, check := range f.Values { + if check.Fulfill(req, res) { + return !f.Negative + } + } + return f.Negative +} + +func (r *StatusCodeRange) Fulfill(req *http.Request, res *http.Response) bool { + return r.Includes(res.StatusCode) +} + +func (method HTTPMethod) Fulfill(req *http.Request, res *http.Response) bool { + return req.Method == string(method) +} + +func (k *HTTPHeader) Parse(v string) error { + split := strings.Split(v, "=") + switch len(split) { + case 1: + split = append(split, "") + case 2: + default: + return ErrInvalidHTTPHeaderFilter.Subject(v) + } + k.Key = split[0] + k.Value = split[1] + return nil +} + +func (k *HTTPHeader) Fulfill(req *http.Request, res *http.Response) bool { + wanted := k.Value + // non canonical key matching + got, ok := req.Header[k.Key] + if wanted == "" { + return ok + } + if !ok { + return false + } + for _, v := range got { + if strings.EqualFold(v, wanted) { + return true + } + } + return false +} + +func (cidr *CIDR) Parse(v string) error { + _, ipnet, err := net.ParseCIDR(v) + if err != nil { + return err + } + cidr.IPNet = ipnet + return nil +} + +func (cidr *CIDR) Fulfill(req *http.Request, res *http.Response) bool { + ip, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + ip = req.RemoteAddr + } + netIP := net.ParseIP(ip) + if netIP == nil { + return false + } + return cidr.Contains(netIP) +} diff --git a/internal/net/http/accesslog/filter_test.go b/internal/net/http/accesslog/filter_test.go new file mode 100644 index 00000000..7160dcee --- /dev/null +++ b/internal/net/http/accesslog/filter_test.go @@ -0,0 +1,188 @@ +package accesslog_test + +import ( + "net/http" + "testing" + + . "github.com/yusing/go-proxy/internal/net/http/accesslog" + "github.com/yusing/go-proxy/internal/utils/strutils" + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestStatusCodeFilter(t *testing.T) { + values := []*StatusCodeRange{ + strutils.MustParse[*StatusCodeRange]("200-308"), + } + t.Run("positive", func(t *testing.T) { + filter := &LogFilter[*StatusCodeRange]{} + ExpectTrue(t, filter.CheckKeep(nil, nil)) + + // keep any 2xx 3xx (inclusive) + filter.Values = values + ExpectFalse(t, filter.CheckKeep(nil, &http.Response{ + StatusCode: http.StatusForbidden, + })) + ExpectTrue(t, filter.CheckKeep(nil, &http.Response{ + StatusCode: http.StatusOK, + })) + ExpectTrue(t, filter.CheckKeep(nil, &http.Response{ + StatusCode: http.StatusMultipleChoices, + })) + ExpectTrue(t, filter.CheckKeep(nil, &http.Response{ + StatusCode: http.StatusPermanentRedirect, + })) + }) + + t.Run("negative", func(t *testing.T) { + filter := &LogFilter[*StatusCodeRange]{ + Negative: true, + } + ExpectFalse(t, filter.CheckKeep(nil, nil)) + + // drop any 2xx 3xx (inclusive) + filter.Values = values + ExpectTrue(t, filter.CheckKeep(nil, &http.Response{ + StatusCode: http.StatusForbidden, + })) + ExpectFalse(t, filter.CheckKeep(nil, &http.Response{ + StatusCode: http.StatusOK, + })) + ExpectFalse(t, filter.CheckKeep(nil, &http.Response{ + StatusCode: http.StatusMultipleChoices, + })) + ExpectFalse(t, filter.CheckKeep(nil, &http.Response{ + StatusCode: http.StatusPermanentRedirect, + })) + }) +} + +func TestMethodFilter(t *testing.T) { + t.Run("positive", func(t *testing.T) { + filter := &LogFilter[HTTPMethod]{} + ExpectTrue(t, filter.CheckKeep(&http.Request{ + Method: http.MethodGet, + }, nil)) + ExpectTrue(t, filter.CheckKeep(&http.Request{ + Method: http.MethodPost, + }, nil)) + + // keep get only + filter.Values = []HTTPMethod{http.MethodGet} + ExpectTrue(t, filter.CheckKeep(&http.Request{ + Method: http.MethodGet, + }, nil)) + ExpectFalse(t, filter.CheckKeep(&http.Request{ + Method: http.MethodPost, + }, nil)) + }) + + t.Run("negative", func(t *testing.T) { + filter := &LogFilter[HTTPMethod]{ + Negative: true, + } + ExpectFalse(t, filter.CheckKeep(&http.Request{ + Method: http.MethodGet, + }, nil)) + ExpectFalse(t, filter.CheckKeep(&http.Request{ + Method: http.MethodPost, + }, nil)) + + // drop post only + filter.Values = []HTTPMethod{http.MethodPost} + ExpectFalse(t, filter.CheckKeep(&http.Request{ + Method: http.MethodPost, + }, nil)) + ExpectTrue(t, filter.CheckKeep(&http.Request{ + Method: http.MethodGet, + }, nil)) + }) +} + +func TestHeaderFilter(t *testing.T) { + fooBar := &http.Request{ + Header: http.Header{ + "Foo": []string{"bar"}, + }, + } + fooBaz := &http.Request{ + Header: http.Header{ + "Foo": []string{"baz"}, + }, + } + headerFoo := []*HTTPHeader{ + strutils.MustParse[*HTTPHeader]("Foo"), + } + ExpectEqual(t, headerFoo[0].Key, "Foo") + ExpectEqual(t, headerFoo[0].Value, "") + headerFooBar := []*HTTPHeader{ + strutils.MustParse[*HTTPHeader]("Foo=bar"), + } + ExpectEqual(t, headerFooBar[0].Key, "Foo") + ExpectEqual(t, headerFooBar[0].Value, "bar") + + t.Run("positive", func(t *testing.T) { + filter := &LogFilter[*HTTPHeader]{} + ExpectTrue(t, filter.CheckKeep(fooBar, nil)) + ExpectTrue(t, filter.CheckKeep(fooBaz, nil)) + + // keep any foo + filter.Values = headerFoo + ExpectTrue(t, filter.CheckKeep(fooBar, nil)) + ExpectTrue(t, filter.CheckKeep(fooBaz, nil)) + + // keep foo == bar + filter.Values = headerFooBar + ExpectTrue(t, filter.CheckKeep(fooBar, nil)) + ExpectFalse(t, filter.CheckKeep(fooBaz, nil)) + }) + t.Run("negative", func(t *testing.T) { + filter := &LogFilter[*HTTPHeader]{ + Negative: true, + } + ExpectFalse(t, filter.CheckKeep(fooBar, nil)) + ExpectFalse(t, filter.CheckKeep(fooBaz, nil)) + + // drop any foo + filter.Values = headerFoo + ExpectFalse(t, filter.CheckKeep(fooBar, nil)) + ExpectFalse(t, filter.CheckKeep(fooBaz, nil)) + + // drop foo == bar + filter.Values = headerFooBar + ExpectFalse(t, filter.CheckKeep(fooBar, nil)) + ExpectTrue(t, filter.CheckKeep(fooBaz, nil)) + }) +} + +func TestCIDRFilter(t *testing.T) { + cidr := []*CIDR{ + strutils.MustParse[*CIDR]("192.168.10.0/24"), + } + ExpectEqual(t, cidr[0].String(), "192.168.10.0/24") + inCIDR := &http.Request{ + RemoteAddr: "192.168.10.1", + } + notInCIDR := &http.Request{ + RemoteAddr: "192.168.11.1", + } + + t.Run("positive", func(t *testing.T) { + filter := &LogFilter[*CIDR]{} + ExpectTrue(t, filter.CheckKeep(inCIDR, nil)) + ExpectTrue(t, filter.CheckKeep(notInCIDR, nil)) + + filter.Values = cidr + ExpectTrue(t, filter.CheckKeep(inCIDR, nil)) + ExpectFalse(t, filter.CheckKeep(notInCIDR, nil)) + }) + + t.Run("negative", func(t *testing.T) { + filter := &LogFilter[*CIDR]{Negative: true} + ExpectFalse(t, filter.CheckKeep(inCIDR, nil)) + ExpectFalse(t, filter.CheckKeep(notInCIDR, nil)) + + filter.Values = cidr + ExpectFalse(t, filter.CheckKeep(inCIDR, nil)) + ExpectTrue(t, filter.CheckKeep(notInCIDR, nil)) + }) +} diff --git a/internal/net/http/accesslog/formatter.go b/internal/net/http/accesslog/formatter.go new file mode 100644 index 00000000..29dabd58 --- /dev/null +++ b/internal/net/http/accesslog/formatter.go @@ -0,0 +1,129 @@ +package accesslog + +import ( + "bytes" + "encoding/json" + "net" + "net/http" + "net/url" + "strconv" +) + +type ( + CommonFormatter struct { + cfg *Fields + } + CombinedFormatter struct { + CommonFormatter + } + JSONFormatter struct { + CommonFormatter + } + JSONLogEntry struct { + IP string `json:"ip"` + Method string `json:"method"` + Scheme string `json:"scheme"` + Host string `json:"host"` + URI string `json:"uri"` + Protocol string `json:"protocol"` + Status int `json:"status"` + Error string `json:"error,omitempty"` + ContentType string `json:"type"` + Size int64 `json:"size"` + Referer string `json:"referer"` + UserAgent string `json:"useragent"` + Query map[string][]string `json:"query,omitempty"` + Headers map[string][]string `json:"headers,omitempty"` + Cookies map[string]string `json:"cookies,omitempty"` + } +) + +func scheme(req *http.Request) string { + if req.TLS != nil { + return "https" + } + return "http" +} + +func requestURI(u *url.URL, query url.Values) string { + uri := u.EscapedPath() + if len(query) > 0 { + uri += "?" + query.Encode() + } + return uri +} + +func clientIP(req *http.Request) string { + clientIP, _, err := net.SplitHostPort(req.RemoteAddr) + if err == nil { + return clientIP + } + return req.RemoteAddr +} + +func (f CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) { + query := f.cfg.Query.ProcessQuery(req.URL.Query()) + + line.WriteString(req.Host) + line.WriteRune(' ') + + line.WriteString(clientIP(req)) + line.WriteString(" - - [") + + line.WriteString(timeNow()) + line.WriteString("] \"") + + line.WriteString(req.Method) + line.WriteRune(' ') + line.WriteString(requestURI(req.URL, query)) + line.WriteRune(' ') + line.WriteString(req.Proto) + line.WriteString("\" ") + + line.WriteString(strconv.Itoa(res.StatusCode)) + line.WriteRune(' ') + line.WriteString(strconv.FormatInt(res.ContentLength, 10)) +} + +func (f CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) { + f.CommonFormatter.Format(line, req, res) + line.WriteString(" \"") + line.WriteString(req.Referer()) + line.WriteString("\" \"") + line.WriteString(req.UserAgent()) + line.WriteRune('"') +} + +func (f JSONFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) { + query := f.cfg.Query.ProcessQuery(req.URL.Query()) + headers := f.cfg.Headers.ProcessHeaders(req.Header) + headers.Del("Cookie") + cookies := f.cfg.Cookies.ProcessCookies(req.Cookies()) + + entry := JSONLogEntry{ + IP: clientIP(req), + Method: req.Method, + Scheme: scheme(req), + Host: req.Host, + URI: requestURI(req.URL, query), + Protocol: req.Proto, + Status: res.StatusCode, + ContentType: res.Header.Get("Content-Type"), + Size: res.ContentLength, + Referer: req.Referer(), + UserAgent: req.UserAgent(), + Query: query, + Headers: headers, + Cookies: cookies, + } + + if res.StatusCode >= 400 { + entry.Error = res.Status + } + + marshaller := json.NewEncoder(line) + err := marshaller.Encode(entry) + if err != nil { + logger.Err(err).Msg("failed to marshal json log") + } +} diff --git a/internal/net/http/accesslog/status_code_range.go b/internal/net/http/accesslog/status_code_range.go new file mode 100644 index 00000000..01868446 --- /dev/null +++ b/internal/net/http/accesslog/status_code_range.go @@ -0,0 +1,51 @@ +package accesslog + +import ( + "strconv" + "strings" + + E "github.com/yusing/go-proxy/internal/error" +) + +type StatusCodeRange struct { + Start int + End int +} + +var ErrInvalidStatusCodeRange = E.New("invalid status code range") + +func (r *StatusCodeRange) Includes(code int) bool { + return r.Start <= code && code <= r.End +} + +func (r *StatusCodeRange) Parse(v string) error { + split := strings.Split(v, "-") + switch len(split) { + case 1: + start, err := strconv.Atoi(split[0]) + if err != nil { + return E.From(err) + } + r.Start = start + r.End = start + return nil + case 2: + start, errStart := strconv.Atoi(split[0]) + end, errEnd := strconv.Atoi(split[1]) + if err := E.Join(errStart, errEnd); err != nil { + return err + } + r.Start = start + r.End = end + return nil + default: + return ErrInvalidStatusCodeRange.Subject(v) + } +} + +func (r *StatusCodeRange) String() string { + if r.Start == r.End { + return strconv.Itoa(r.Start) + } + return strconv.Itoa(r.Start) + "-" + strconv.Itoa(r.End) +} diff --git a/internal/net/http/modify_response_writer.go b/internal/net/http/modify_response_writer.go index 4da0d202..a8c7b89c 100644 --- a/internal/net/http/modify_response_writer.go +++ b/internal/net/http/modify_response_writer.go @@ -18,6 +18,7 @@ type ( headerSent bool code int + size int modifier ModifyResponseFunc modified bool @@ -38,6 +39,14 @@ func (w *ModifyResponseWriter) Unwrap() http.ResponseWriter { return w.w } +func (w *ModifyResponseWriter) StatusCode() int { + return w.code +} + +func (w *ModifyResponseWriter) Size() int { + return w.size +} + func (w *ModifyResponseWriter) WriteHeader(code int) { if w.headerSent { return @@ -58,12 +67,15 @@ func (w *ModifyResponseWriter) WriteHeader(code int) { } resp := http.Response{ - Header: w.w.Header(), - Request: w.r, + StatusCode: code, + Header: w.w.Header(), + Request: w.r, + ContentLength: int64(w.size), } if err := w.modifier(&resp); err != nil { w.modifierErr = fmt.Errorf("response modifier error: %w", err) + resp.Status = w.modifierErr.Error() w.w.WriteHeader(http.StatusInternalServerError) return } @@ -81,7 +93,10 @@ func (w *ModifyResponseWriter) Write(b []byte) (int, error) { if w.modifierErr != nil { return 0, w.modifierErr } - return w.w.Write(b) + + n, err := w.w.Write(b) + w.size += n + return n, err } // Hijack hijacks the connection. diff --git a/internal/net/http/reverse_proxy_mod.go b/internal/net/http/reverse_proxy_mod.go index 5cc62457..47b008d1 100644 --- a/internal/net/http/reverse_proxy_mod.go +++ b/internal/net/http/reverse_proxy_mod.go @@ -27,6 +27,7 @@ import ( "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/metrics" + "github.com/yusing/go-proxy/internal/net/http/accesslog" "github.com/yusing/go-proxy/internal/net/types" U "github.com/yusing/go-proxy/internal/utils" "golang.org/x/net/http/httpguts" @@ -88,6 +89,7 @@ type ReverseProxy struct { // with its error value. If ErrorHandler is nil, its default // implementation is used. ModifyResponse func(*http.Response) error + AccessLogger *accesslog.AccessLogger HandlerFunc http.HandlerFunc @@ -245,7 +247,10 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err logger.Err(err).Str("url", r.URL.String()).Msg("http proxy error") } if writeHeader { - rw.WriteHeader(http.StatusBadGateway) + rw.WriteHeader(http.StatusInternalServerError) + } + if p.AccessLogger != nil { + p.AccessLogger.LogError(r, err) } } @@ -271,37 +276,19 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { + visitorIP, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + visitorIP = req.RemoteAddr + } + if common.PrometheusEnabled { t := time.Now() - var visitor string - if realIPs := req.Header.Values(HeaderXRealIP); len(realIPs) > 0 { - if len(realIPs) == 1 { - visitor = realIPs[0] - } else { - p.Warn().Strs("real_ips", realIPs). - Str("remote_addr", req.RemoteAddr). - Str("request_url", req.URL.String()). - Msg("client sent multiple 'X-Real-IP' values, ignoring.") - } - } - if visitor == "" { - if fwdIPs := req.Header.Values(HeaderXForwardedFor); len(fwdIPs) > 0 { - // right-most IP is the visitor - visitor = fwdIPs[len(fwdIPs)-1] - } - } - if visitor == "" { - var err error - visitor, _, err = net.SplitHostPort(req.RemoteAddr) - if err != nil { - visitor = req.RemoteAddr - } - } + // req.RemoteAddr had been modified by middleware (if any) lbls := &metrics.HTTPRouteMetricLabels{ Service: p.TargetName, Method: req.Method, Host: req.Host, - Visitor: visitor, + Visitor: visitorIP, Path: req.URL.Path, } rw = &httpMetricLogger{ @@ -389,18 +376,17 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { } } - if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { - // If we aren't the first proxy retain prior - // X-Forwarded-For information as a comma+space - // separated list and fold multiple headers into one. - prior, ok := outreq.Header[HeaderXForwardedFor] - omit := ok && prior == nil // Issue 38079: nil now means don't populate the header - if len(prior) > 0 { - clientIP = strings.Join(prior, ", ") + ", " + clientIP - } - if !omit { - outreq.Header.Set(HeaderXForwardedFor, clientIP) - } + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + prior, ok := outreq.Header[HeaderXForwardedFor] + omit := ok && prior == nil // Issue 38079: nil now means don't populate the header + xff := visitorIP + if len(prior) > 0 { + xff = strings.Join(prior, ", ") + ", " + xff + } + if !omit { + outreq.Header.Set(HeaderXForwardedFor, xff) } var reqScheme string @@ -465,6 +451,12 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { } } + if p.AccessLogger != nil { + defer func() { + p.AccessLogger.Log(req, res) + }() + } + // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) if res.StatusCode == http.StatusSwitchingProtocols { if !p.modifyResponse(rw, res, req, outreq) { diff --git a/internal/route/entry/entry.go b/internal/route/entry/entry.go index 27c0c70f..67f38725 100644 --- a/internal/route/entry/entry.go +++ b/internal/route/entry/entry.go @@ -43,7 +43,7 @@ func ShouldNotServe(entry Entry) bool { } func UseLoadBalance(entry Entry) bool { - lb := entry.LoadBalanceConfig() + lb := entry.RawEntry().LoadBalance return lb != nil && lb.Link != "" } @@ -53,6 +53,10 @@ func UseIdleWatcher(entry Entry) bool { } func UseHealthCheck(entry Entry) bool { - hc := entry.HealthCheckConfig() + hc := entry.RawEntry().HealthCheck return hc != nil && !hc.Disable } + +func UseAccessLog(entry Entry) bool { + return entry.RawEntry().AccessLog != nil +} diff --git a/internal/route/entry/reverse_proxy.go b/internal/route/entry/reverse_proxy.go index 115db855..f76fb715 100644 --- a/internal/route/entry/reverse_proxy.go +++ b/internal/route/entry/reverse_proxy.go @@ -7,10 +7,8 @@ import ( "github.com/yusing/go-proxy/internal/docker" idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" E "github.com/yusing/go-proxy/internal/error" - loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" net "github.com/yusing/go-proxy/internal/net/types" route "github.com/yusing/go-proxy/internal/route/types" - "github.com/yusing/go-proxy/internal/watcher/health" ) type ReverseProxyEntry struct { // real model after validation @@ -33,14 +31,6 @@ func (rp *ReverseProxyEntry) RawEntry() *route.RawEntry { return rp.Raw } -func (rp *ReverseProxyEntry) LoadBalanceConfig() *loadbalance.Config { - return rp.Raw.LoadBalance -} - -func (rp *ReverseProxyEntry) HealthCheckConfig() *health.HealthCheckConfig { - return rp.Raw.HealthCheck -} - func (rp *ReverseProxyEntry) IdlewatcherConfig() *idlewatcher.Config { return rp.Idlewatcher } diff --git a/internal/route/entry/stream.go b/internal/route/entry/stream.go index 4c5962cc..313321d9 100644 --- a/internal/route/entry/stream.go +++ b/internal/route/entry/stream.go @@ -6,10 +6,8 @@ import ( "github.com/yusing/go-proxy/internal/docker" idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" E "github.com/yusing/go-proxy/internal/error" - loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" net "github.com/yusing/go-proxy/internal/net/types" route "github.com/yusing/go-proxy/internal/route/types" - "github.com/yusing/go-proxy/internal/watcher/health" ) type StreamEntry struct { @@ -36,15 +34,6 @@ func (s *StreamEntry) RawEntry() *route.RawEntry { return s.Raw } -func (s *StreamEntry) LoadBalanceConfig() *loadbalance.Config { - // TODO: support stream load balance - return nil -} - -func (s *StreamEntry) HealthCheckConfig() *health.HealthCheckConfig { - return s.Raw.HealthCheck -} - func (s *StreamEntry) IdlewatcherConfig() *idlewatcher.Config { return s.Idlewatcher } diff --git a/internal/route/http.go b/internal/route/http.go index 13c8335d..b294856f 100755 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -9,6 +9,7 @@ import ( "github.com/yusing/go-proxy/internal/docker/idlewatcher" E "github.com/yusing/go-proxy/internal/error" gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/accesslog" "github.com/yusing/go-proxy/internal/net/http/loadbalancer" loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" "github.com/yusing/go-proxy/internal/net/http/middleware" @@ -105,6 +106,15 @@ func (r *HTTPRoute) Start(providerSubtask *task.Task) E.Error { } } + if entry.UseAccessLog(r) { + var err error + r.rp.AccessLogger, err = accesslog.NewFileAccessLogger(r.task, r.Raw.AccessLog) + if err != nil { + r.task.Finish(err) + return E.From(err) + } + } + if r.handler == nil { pathPatterns := r.Raw.PathPatterns switch { diff --git a/internal/route/stream.go b/internal/route/stream.go index 1a946309..eb230b46 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -16,6 +16,7 @@ import ( "github.com/yusing/go-proxy/internal/watcher/health/monitor" ) +// TODO: support stream load balance type StreamRoute struct { *entry.StreamEntry diff --git a/internal/route/types/entry.go b/internal/route/types/entry.go index ae907ded..27cb6238 100644 --- a/internal/route/types/entry.go +++ b/internal/route/types/entry.go @@ -2,16 +2,12 @@ package types import ( idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" - loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" net "github.com/yusing/go-proxy/internal/net/types" - "github.com/yusing/go-proxy/internal/watcher/health" ) type Entry interface { TargetName() string TargetURL() net.URL RawEntry() *RawEntry - LoadBalanceConfig() *loadbalance.Config - HealthCheckConfig() *health.HealthCheckConfig IdlewatcherConfig() *idlewatcher.Config } diff --git a/internal/route/types/raw_entry.go b/internal/route/types/raw_entry.go index e46226be..5c0a0b2f 100644 --- a/internal/route/types/raw_entry.go +++ b/internal/route/types/raw_entry.go @@ -10,6 +10,7 @@ import ( "github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/homepage" "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/net/http/accesslog" loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" @@ -33,7 +34,7 @@ type ( LoadBalance *loadbalance.Config `json:"load_balance,omitempty" yaml:"load_balance"` Middlewares map[string]docker.LabelMap `json:"middlewares,omitempty" yaml:"middlewares"` Homepage *homepage.Item `json:"homepage,omitempty" yaml:"homepage"` - // AccessLog *accesslog.Config `json:"access_log,omitempty" yaml:"access_log"` + AccessLog *accesslog.Config `json:"access_log,omitempty" yaml:"access_log"` /* Docker only */ Container *docker.Container `json:"container,omitempty" yaml:"-"`