refactor: remove forward auth, move module net/http to net/gphttp

This commit is contained in:
yusing
2025-03-28 07:03:35 +08:00
parent c0c6e21a16
commit 5d2df3550b
69 changed files with 321 additions and 745 deletions

View File

@@ -0,0 +1,167 @@
package accesslog
import (
"bufio"
"bytes"
"io"
"net/http"
"sync"
"time"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
)
type (
AccessLogger struct {
task *task.Task
cfg *Config
io AccessLogIO
buffered *bufio.Writer
lineBufPool sync.Pool // buffer pool for formatting a single log line
Formatter
}
AccessLogIO interface {
io.ReadWriteCloser
io.ReadWriteSeeker
io.ReaderAt
sync.Locker
Name() string // file name or path
Truncate(size int64) error
}
Formatter interface {
// Format writes a log line to line without a trailing newline
Format(line *bytes.Buffer, req *http.Request, res *http.Response)
SetGetTimeNow(getTimeNow func() time.Time)
}
)
func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger {
if cfg.BufferSize == 0 {
cfg.BufferSize = DefaultBufferSize
}
if cfg.BufferSize < 4096 {
cfg.BufferSize = 4096
}
l := &AccessLogger{
task: parent.Subtask("accesslog"),
cfg: cfg,
io: io,
buffered: bufio.NewWriterSize(io, cfg.BufferSize),
}
fmt := CommonFormatter{cfg: &l.cfg.Fields, GetTimeNow: time.Now}
switch l.cfg.Format {
case FormatCommon:
l.Formatter = &fmt
case FormatCombined:
l.Formatter = &CombinedFormatter{fmt}
case FormatJSON:
l.Formatter = &JSONFormatter{fmt}
default: // should not happen, validation has done by validate tags
panic("invalid access log format")
}
l.lineBufPool.New = func() any {
return bytes.NewBuffer(make([]byte, 0, 1024))
}
go l.start()
return l
}
func (l *AccessLogger) checkKeep(req *http.Request, res *http.Response) bool {
if !l.cfg.Filters.StatusCodes.CheckKeep(req, res) ||
!l.cfg.Filters.Method.CheckKeep(req, res) ||
!l.cfg.Filters.Headers.CheckKeep(req, res) ||
!l.cfg.Filters.CIDR.CheckKeep(req, res) {
return false
}
return true
}
func (l *AccessLogger) Log(req *http.Request, res *http.Response) {
if !l.checkKeep(req, res) {
return
}
line := l.lineBufPool.Get().(*bytes.Buffer)
line.Reset()
defer l.lineBufPool.Put(line)
l.Formatter.Format(line, req, res)
line.WriteRune('\n')
l.write(line.Bytes())
}
func (l *AccessLogger) LogError(req *http.Request, err error) {
l.Log(req, &http.Response{StatusCode: http.StatusInternalServerError, Status: err.Error()})
}
func (l *AccessLogger) Config() *Config {
return l.cfg
}
func (l *AccessLogger) Rotate() error {
if l.cfg.Retention == nil {
return nil
}
l.io.Lock()
defer l.io.Unlock()
return l.rotate()
}
func (l *AccessLogger) handleErr(err error) {
gperr.LogError("failed to write access log", err)
}
func (l *AccessLogger) start() {
defer func() {
if err := l.Flush(); err != nil {
l.handleErr(err)
}
l.close()
l.task.Finish(nil)
}()
// flushes the buffer every 30 seconds
flushTicker := time.NewTicker(30 * time.Second)
defer flushTicker.Stop()
for {
select {
case <-l.task.Context().Done():
return
case <-flushTicker.C:
if err := l.Flush(); err != nil {
l.handleErr(err)
}
}
}
}
func (l *AccessLogger) Flush() error {
l.io.Lock()
defer l.io.Unlock()
return l.buffered.Flush()
}
func (l *AccessLogger) close() {
l.io.Lock()
defer l.io.Unlock()
l.io.Close()
}
func (l *AccessLogger) write(data []byte) {
l.io.Lock() // prevent concurrent write, i.e. log rotation, other access loggers
_, err := l.buffered.Write(data)
l.io.Unlock()
if err != nil {
l.handleErr(err)
} else {
logging.Debug().Msg("access log flushed to " + l.io.Name())
}
}

View File

@@ -0,0 +1,127 @@
package accesslog_test
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/url"
"testing"
"time"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/task"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
const (
remote = "192.168.1.1"
host = "example.com"
uri = "/?bar=baz&foo=bar"
uriRedacted = "/?bar=" + RedactedValue + "&foo=" + RedactedValue
referer = "https://www.google.com/"
proto = "HTTP/1.1"
ua = "Go-http-client/1.1"
status = http.StatusOK
contentLength = 100
method = http.MethodGet
)
var (
testTask = task.RootTask("test", false)
testURL = Must(url.Parse("http://" + host + uri))
req = &http.Request{
RemoteAddr: remote,
Method: method,
Proto: proto,
Host: testURL.Host,
URL: testURL,
Header: http.Header{
"User-Agent": []string{ua},
"Referer": []string{referer},
"Cookie": []string{
"foo=bar",
"bar=baz",
},
},
}
resp = &http.Response{
StatusCode: status,
ContentLength: contentLength,
Header: http.Header{"Content-Type": []string{"text/plain"}},
}
)
func fmtLog(cfg *Config) (ts string, line string) {
var buf bytes.Buffer
t := time.Now()
logger := NewAccessLogger(testTask, nil, cfg)
logger.Formatter.SetGetTimeNow(func() time.Time {
return t
})
logger.Format(&buf, req, resp)
return t.Format(LogTimeFormat), buf.String()
}
func TestAccessLoggerCommon(t *testing.T) {
config := DefaultConfig()
config.Format = FormatCommon
ts, log := fmtLog(config)
ExpectEqual(t, log,
fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d",
host, remote, ts, method, uri, proto, status, contentLength,
),
)
}
func TestAccessLoggerCombined(t *testing.T) {
config := DefaultConfig()
config.Format = FormatCombined
ts, log := fmtLog(config)
ExpectEqual(t, log,
fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d \"%s\" \"%s\"",
host, remote, ts, method, uri, proto, status, contentLength, referer, ua,
),
)
}
func TestAccessLoggerRedactQuery(t *testing.T) {
config := DefaultConfig()
config.Format = FormatCommon
config.Fields.Query.Default = FieldModeRedact
ts, log := fmtLog(config)
ExpectEqual(t, log,
fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d",
host, remote, ts, method, uriRedacted, proto, status, contentLength,
),
)
}
func getJSONEntry(t *testing.T, config *Config) JSONLogEntry {
t.Helper()
config.Format = FormatJSON
var entry JSONLogEntry
_, log := fmtLog(config)
err := json.Unmarshal([]byte(log), &entry)
ExpectNoError(t, err)
return entry
}
func TestAccessLoggerJSON(t *testing.T) {
config := DefaultConfig()
entry := getJSONEntry(t, config)
ExpectEqual(t, entry.IP, remote)
ExpectEqual(t, entry.Method, method)
ExpectEqual(t, entry.Scheme, "http")
ExpectEqual(t, entry.Host, testURL.Host)
ExpectEqual(t, entry.URI, testURL.RequestURI())
ExpectEqual(t, entry.Protocol, proto)
ExpectEqual(t, entry.Status, status)
ExpectEqual(t, entry.ContentType, "text/plain")
ExpectEqual(t, entry.Size, contentLength)
ExpectEqual(t, entry.Referer, referer)
ExpectEqual(t, entry.UserAgent, ua)
ExpectEqual(t, len(entry.Headers), 0)
ExpectEqual(t, len(entry.Cookies), 0)
}

View File

@@ -0,0 +1,57 @@
package accesslog
import "github.com/yusing/go-proxy/internal/utils"
type (
Format string
Filters struct {
StatusCodes LogFilter[*StatusCodeRange] `json:"status_codes"`
Method LogFilter[HTTPMethod] `json:"method"`
Host LogFilter[Host] `json:"host"`
Headers LogFilter[*HTTPHeader] `json:"headers"` // header exists or header == value
CIDR LogFilter[*CIDR] `json:"cidr"`
}
Fields struct {
Headers FieldConfig `json:"headers"`
Query FieldConfig `json:"query"`
Cookies FieldConfig `json:"cookies"`
}
Config struct {
BufferSize int `json:"buffer_size"`
Format Format `json:"format" validate:"oneof=common combined json"`
Path string `json:"path" validate:"required"`
Filters Filters `json:"filters"`
Fields Fields `json:"fields"`
Retention *Retention `json:"retention"`
}
)
var (
FormatCommon Format = "common"
FormatCombined Format = "combined"
FormatJSON Format = "json"
)
const DefaultBufferSize = 64 * 1024 // 64KB
func DefaultConfig() *Config {
return &Config{
BufferSize: DefaultBufferSize,
Format: FormatCombined,
Fields: Fields{
Headers: FieldConfig{
Default: FieldModeDrop,
},
Query: FieldConfig{
Default: FieldModeKeep,
},
Cookies: FieldConfig{
Default: FieldModeDrop,
},
},
}
}
func init() {
utils.RegisterDefaultValueFactory(DefaultConfig)
}

View File

@@ -0,0 +1,53 @@
package accesslog_test
import (
"testing"
"github.com/yusing/go-proxy/internal/docker"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestNewConfig(t *testing.T) {
labels := map[string]string{
"proxy.buffer_size": "10",
"proxy.format": "combined",
"proxy.path": "/tmp/access.log",
"proxy.filters.status_codes.values": "200-299",
"proxy.filters.method.values": "GET, POST",
"proxy.filters.headers.values": "foo=bar, baz",
"proxy.filters.headers.negative": "true",
"proxy.filters.cidr.values": "192.168.10.0/24",
"proxy.fields.headers.default": "keep",
"proxy.fields.headers.config.foo": "redact",
"proxy.fields.query.default": "drop",
"proxy.fields.query.config.foo": "keep",
"proxy.fields.cookies.default": "redact",
"proxy.fields.cookies.config.foo": "keep",
}
parsed, err := docker.ParseLabels(labels)
ExpectNoError(t, err)
var config Config
err = utils.Deserialize(parsed, &config)
ExpectNoError(t, err)
ExpectEqual(t, config.BufferSize, 10)
ExpectEqual(t, config.Format, FormatCombined)
ExpectEqual(t, config.Path, "/tmp/access.log")
ExpectDeepEqual(t, config.Filters.StatusCodes.Values, []*StatusCodeRange{{Start: 200, End: 299}})
ExpectEqual(t, len(config.Filters.Method.Values), 2)
ExpectDeepEqual(t, config.Filters.Method.Values, []HTTPMethod{"GET", "POST"})
ExpectEqual(t, len(config.Filters.Headers.Values), 2)
ExpectDeepEqual(t, config.Filters.Headers.Values, []*HTTPHeader{{Key: "foo", Value: "bar"}, {Key: "baz", Value: ""}})
ExpectTrue(t, config.Filters.Headers.Negative)
ExpectEqual(t, len(config.Filters.CIDR.Values), 1)
ExpectEqual(t, config.Filters.CIDR.Values[0].String(), "192.168.10.0/24")
ExpectEqual(t, config.Fields.Headers.Default, FieldModeKeep)
ExpectEqual(t, config.Fields.Headers.Config["foo"], FieldModeRedact)
ExpectEqual(t, config.Fields.Query.Default, FieldModeDrop)
ExpectEqual(t, config.Fields.Query.Config["foo"], FieldModeKeep)
ExpectEqual(t, config.Fields.Cookies.Default, FieldModeRedact)
ExpectEqual(t, config.Fields.Cookies.Config["foo"], FieldModeKeep)
}

View File

@@ -0,0 +1,103 @@
package accesslog
import (
"net/http"
"net/url"
)
type (
FieldConfig struct {
Default FieldMode `json:"default" validate:"oneof=keep drop redact"`
Config map[string]FieldMode `json:"config" validate:"dive,oneof=keep drop redact"`
}
FieldMode string
)
const (
FieldModeKeep FieldMode = "keep"
FieldModeDrop FieldMode = "drop"
FieldModeRedact FieldMode = "redact"
RedactedValue = "REDACTED"
)
func processMap[V any](cfg *FieldConfig, m map[string]V, redactedV V) map[string]V {
if len(cfg.Config) == 0 {
switch cfg.Default {
case FieldModeKeep:
return m
case FieldModeDrop:
return nil
case FieldModeRedact:
redacted := make(map[string]V)
for k := range m {
redacted[k] = redactedV
}
return redacted
}
}
if len(m) == 0 {
return m
}
newMap := make(map[string]V, len(m))
for k := range m {
var mode FieldMode
var ok bool
if mode, ok = cfg.Config[k]; !ok {
mode = cfg.Default
}
switch mode {
case FieldModeKeep:
newMap[k] = m[k]
case FieldModeRedact:
newMap[k] = redactedV
}
}
return newMap
}
func processSlice[V any, VReturn any](cfg *FieldConfig, s []V, getKey func(V) string, convert func(V) VReturn, redact func(V) VReturn) map[string]VReturn {
if len(s) == 0 ||
len(cfg.Config) == 0 && cfg.Default == FieldModeDrop {
return nil
}
newMap := make(map[string]VReturn, len(s))
for _, v := range s {
var mode FieldMode
var ok bool
k := getKey(v)
if mode, ok = cfg.Config[k]; !ok {
mode = cfg.Default
}
switch mode {
case FieldModeKeep:
newMap[k] = convert(v)
case FieldModeRedact:
newMap[k] = redact(v)
}
}
return newMap
}
func (cfg *FieldConfig) ProcessHeaders(headers http.Header) http.Header {
return processMap(cfg, headers, []string{RedactedValue})
}
func (cfg *FieldConfig) ProcessQuery(q url.Values) url.Values {
return processMap(cfg, q, []string{RedactedValue})
}
func (cfg *FieldConfig) ProcessCookies(cookies []*http.Cookie) map[string]string {
return processSlice(cfg, cookies,
func(c *http.Cookie) string {
return c.Name
},
func(c *http.Cookie) string {
return c.Value
},
func(c *http.Cookie) string {
return RedactedValue
})
}

View File

@@ -0,0 +1,96 @@
package accesslog_test
import (
"testing"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
// Cookie header should be removed,
// stored in JSONLogEntry.Cookies instead.
func TestAccessLoggerJSONKeepHeaders(t *testing.T) {
config := DefaultConfig()
config.Fields.Headers.Default = FieldModeKeep
entry := getJSONEntry(t, config)
for k, v := range req.Header {
if k != "Cookie" {
ExpectDeepEqual(t, entry.Headers[k], v)
}
}
config.Fields.Headers.Config = map[string]FieldMode{
"Referer": FieldModeRedact,
"User-Agent": FieldModeDrop,
}
entry = getJSONEntry(t, config)
ExpectDeepEqual(t, entry.Headers["Referer"], []string{RedactedValue})
ExpectDeepEqual(t, entry.Headers["User-Agent"], nil)
}
func TestAccessLoggerJSONDropHeaders(t *testing.T) {
config := DefaultConfig()
config.Fields.Headers.Default = FieldModeDrop
entry := getJSONEntry(t, config)
for k := range req.Header {
ExpectDeepEqual(t, entry.Headers[k], nil)
}
config.Fields.Headers.Config = map[string]FieldMode{
"Referer": FieldModeKeep,
"User-Agent": FieldModeRedact,
}
entry = getJSONEntry(t, config)
ExpectDeepEqual(t, entry.Headers["Referer"], []string{req.Header.Get("Referer")})
ExpectDeepEqual(t, entry.Headers["User-Agent"], []string{RedactedValue})
}
func TestAccessLoggerJSONRedactHeaders(t *testing.T) {
config := DefaultConfig()
config.Fields.Headers.Default = FieldModeRedact
entry := getJSONEntry(t, config)
ExpectEqual(t, len(entry.Headers["Cookie"]), 0)
for k := range req.Header {
if k != "Cookie" {
ExpectDeepEqual(t, entry.Headers[k], []string{RedactedValue})
}
}
}
func TestAccessLoggerJSONKeepCookies(t *testing.T) {
config := DefaultConfig()
config.Fields.Headers.Default = FieldModeKeep
config.Fields.Cookies.Default = FieldModeKeep
entry := getJSONEntry(t, config)
ExpectEqual(t, len(entry.Headers["Cookie"]), 0)
for _, cookie := range req.Cookies() {
ExpectEqual(t, entry.Cookies[cookie.Name], cookie.Value)
}
}
func TestAccessLoggerJSONRedactCookies(t *testing.T) {
config := DefaultConfig()
config.Fields.Headers.Default = FieldModeKeep
config.Fields.Cookies.Default = FieldModeRedact
entry := getJSONEntry(t, config)
ExpectEqual(t, len(entry.Headers["Cookie"]), 0)
for _, cookie := range req.Cookies() {
ExpectEqual(t, entry.Cookies[cookie.Name], RedactedValue)
}
}
func TestAccessLoggerJSONDropQuery(t *testing.T) {
config := DefaultConfig()
config.Fields.Query.Default = FieldModeDrop
entry := getJSONEntry(t, config)
ExpectDeepEqual(t, entry.Query["foo"], nil)
ExpectDeepEqual(t, entry.Query["bar"], nil)
}
func TestAccessLoggerJSONRedactQuery(t *testing.T) {
config := DefaultConfig()
config.Fields.Query.Default = FieldModeRedact
entry := getJSONEntry(t, config)
ExpectDeepEqual(t, entry.Query["foo"], []string{RedactedValue})
ExpectDeepEqual(t, entry.Query["bar"], []string{RedactedValue})
}

View File

@@ -0,0 +1,69 @@
package accesslog
import (
"fmt"
"os"
"path"
"sync"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
)
type File struct {
*os.File
sync.Mutex
// os.File.Name() may not equal to key of `openedFiles`.
// Store it for later delete from `openedFiles`.
path string
refCount *utils.RefCount
}
var (
openedFiles = make(map[string]*File)
openedFilesMu sync.Mutex
)
func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) {
openedFilesMu.Lock()
var file *File
path := path.Clean(cfg.Path)
if opened, ok := openedFiles[path]; ok {
opened.refCount.Add()
file = opened
} else {
f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o644)
if err != nil {
openedFilesMu.Unlock()
return nil, fmt.Errorf("access log open error: %w", err)
}
file = &File{File: f, path: path, refCount: utils.NewRefCounter()}
openedFiles[path] = file
go file.closeOnZero()
}
openedFilesMu.Unlock()
return NewAccessLogger(parent, file, cfg), nil
}
func (f *File) Close() error {
f.refCount.Sub()
return nil
}
func (f *File) closeOnZero() {
defer logging.Debug().
Str("path", f.path).
Msg("access log closed")
<-f.refCount.Zero()
openedFilesMu.Lock()
delete(openedFiles, f.path)
openedFilesMu.Unlock()
f.File.Close()
}

View File

@@ -0,0 +1,95 @@
package accesslog
import (
"net/http"
"os"
"sync"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
"github.com/yusing/go-proxy/internal/task"
)
func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) {
var wg sync.WaitGroup
cfg := DefaultConfig()
cfg.Path = "test.log"
parent := task.RootTask("test", false)
loggerCount := 10
accessLogIOs := make([]AccessLogIO, loggerCount)
// make test log file
file, err := os.Create(cfg.Path)
ExpectNoError(t, err)
file.Close()
t.Cleanup(func() {
ExpectNoError(t, os.Remove(cfg.Path))
})
for i := range loggerCount {
wg.Add(1)
go func(index int) {
defer wg.Done()
logger, err := NewFileAccessLogger(parent, cfg)
ExpectNoError(t, err)
accessLogIOs[index] = logger.io
}(i)
}
wg.Wait()
firstIO := accessLogIOs[0]
for _, io := range accessLogIOs {
ExpectEqual(t, io, firstIO)
}
}
func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) {
var file MockFile
cfg := DefaultConfig()
cfg.BufferSize = 1024
parent := task.RootTask("test", false)
loggerCount := 5
logCountPerLogger := 10
loggers := make([]*AccessLogger, loggerCount)
for i := range loggerCount {
loggers[i] = NewAccessLogger(parent, &file, cfg)
}
var wg sync.WaitGroup
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
resp := &http.Response{StatusCode: http.StatusOK}
for _, logger := range loggers {
wg.Add(1)
go func(l *AccessLogger) {
defer wg.Done()
parallelLog(l, req, resp, logCountPerLogger)
l.Flush()
}(logger)
}
wg.Wait()
expected := loggerCount * logCountPerLogger
actual := file.LineCount()
ExpectEqual(t, actual, expected)
}
func parallelLog(logger *AccessLogger, req *http.Request, resp *http.Response, n int) {
var wg sync.WaitGroup
wg.Add(n)
for range n {
go func() {
defer wg.Done()
logger.Log(req, resp)
}()
}
wg.Wait()
}

View File

@@ -0,0 +1,99 @@
package accesslog
import (
"net"
"net/http"
"strings"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type (
LogFilter[T Filterable] struct {
Negative bool
Values []T
}
Filterable interface {
comparable
Fulfill(req *http.Request, res *http.Response) bool
}
HTTPMethod string
HTTPHeader struct {
Key, Value string
}
Host string
CIDR struct{ types.CIDR }
)
var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter")
func (f *LogFilter[T]) CheckKeep(req *http.Request, res *http.Response) bool {
if len(f.Values) == 0 {
return !f.Negative
}
for _, check := range f.Values {
if check.Fulfill(req, res) {
return !f.Negative
}
}
return f.Negative
}
func (r *StatusCodeRange) Fulfill(req *http.Request, res *http.Response) bool {
return r.Includes(res.StatusCode)
}
func (method HTTPMethod) Fulfill(req *http.Request, res *http.Response) bool {
return req.Method == string(method)
}
// Parse implements strutils.Parser.
func (k *HTTPHeader) Parse(v string) error {
split := strutils.SplitRune(v, '=')
switch len(split) {
case 1:
split = append(split, "")
case 2:
default:
return ErrInvalidHTTPHeaderFilter.Subject(v)
}
k.Key = split[0]
k.Value = split[1]
return nil
}
func (k *HTTPHeader) Fulfill(req *http.Request, res *http.Response) bool {
wanted := k.Value
// non canonical key matching
got, ok := req.Header[k.Key]
if wanted == "" {
return ok
}
if !ok {
return false
}
for _, v := range got {
if strings.EqualFold(v, wanted) {
return true
}
}
return false
}
func (h Host) Fulfill(req *http.Request, res *http.Response) bool {
return req.Host == string(h)
}
func (cidr CIDR) Fulfill(req *http.Request, res *http.Response) bool {
ip, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
ip = req.RemoteAddr
}
netIP := net.ParseIP(ip)
if netIP == nil {
return false
}
return cidr.Contains(netIP)
}

View File

@@ -0,0 +1,188 @@
package accesslog_test
import (
"net/http"
"testing"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/utils/strutils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestStatusCodeFilter(t *testing.T) {
values := []*StatusCodeRange{
strutils.MustParse[*StatusCodeRange]("200-308"),
}
t.Run("positive", func(t *testing.T) {
filter := &LogFilter[*StatusCodeRange]{}
ExpectTrue(t, filter.CheckKeep(nil, nil))
// keep any 2xx 3xx (inclusive)
filter.Values = values
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusForbidden,
}))
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusOK,
}))
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusMultipleChoices,
}))
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusPermanentRedirect,
}))
})
t.Run("negative", func(t *testing.T) {
filter := &LogFilter[*StatusCodeRange]{
Negative: true,
}
ExpectFalse(t, filter.CheckKeep(nil, nil))
// drop any 2xx 3xx (inclusive)
filter.Values = values
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusForbidden,
}))
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusOK,
}))
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusMultipleChoices,
}))
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusPermanentRedirect,
}))
})
}
func TestMethodFilter(t *testing.T) {
t.Run("positive", func(t *testing.T) {
filter := &LogFilter[HTTPMethod]{}
ExpectTrue(t, filter.CheckKeep(&http.Request{
Method: http.MethodGet,
}, nil))
ExpectTrue(t, filter.CheckKeep(&http.Request{
Method: http.MethodPost,
}, nil))
// keep get only
filter.Values = []HTTPMethod{http.MethodGet}
ExpectTrue(t, filter.CheckKeep(&http.Request{
Method: http.MethodGet,
}, nil))
ExpectFalse(t, filter.CheckKeep(&http.Request{
Method: http.MethodPost,
}, nil))
})
t.Run("negative", func(t *testing.T) {
filter := &LogFilter[HTTPMethod]{
Negative: true,
}
ExpectFalse(t, filter.CheckKeep(&http.Request{
Method: http.MethodGet,
}, nil))
ExpectFalse(t, filter.CheckKeep(&http.Request{
Method: http.MethodPost,
}, nil))
// drop post only
filter.Values = []HTTPMethod{http.MethodPost}
ExpectFalse(t, filter.CheckKeep(&http.Request{
Method: http.MethodPost,
}, nil))
ExpectTrue(t, filter.CheckKeep(&http.Request{
Method: http.MethodGet,
}, nil))
})
}
func TestHeaderFilter(t *testing.T) {
fooBar := &http.Request{
Header: http.Header{
"Foo": []string{"bar"},
},
}
fooBaz := &http.Request{
Header: http.Header{
"Foo": []string{"baz"},
},
}
headerFoo := []*HTTPHeader{
strutils.MustParse[*HTTPHeader]("Foo"),
}
ExpectEqual(t, headerFoo[0].Key, "Foo")
ExpectEqual(t, headerFoo[0].Value, "")
headerFooBar := []*HTTPHeader{
strutils.MustParse[*HTTPHeader]("Foo=bar"),
}
ExpectEqual(t, headerFooBar[0].Key, "Foo")
ExpectEqual(t, headerFooBar[0].Value, "bar")
t.Run("positive", func(t *testing.T) {
filter := &LogFilter[*HTTPHeader]{}
ExpectTrue(t, filter.CheckKeep(fooBar, nil))
ExpectTrue(t, filter.CheckKeep(fooBaz, nil))
// keep any foo
filter.Values = headerFoo
ExpectTrue(t, filter.CheckKeep(fooBar, nil))
ExpectTrue(t, filter.CheckKeep(fooBaz, nil))
// keep foo == bar
filter.Values = headerFooBar
ExpectTrue(t, filter.CheckKeep(fooBar, nil))
ExpectFalse(t, filter.CheckKeep(fooBaz, nil))
})
t.Run("negative", func(t *testing.T) {
filter := &LogFilter[*HTTPHeader]{
Negative: true,
}
ExpectFalse(t, filter.CheckKeep(fooBar, nil))
ExpectFalse(t, filter.CheckKeep(fooBaz, nil))
// drop any foo
filter.Values = headerFoo
ExpectFalse(t, filter.CheckKeep(fooBar, nil))
ExpectFalse(t, filter.CheckKeep(fooBaz, nil))
// drop foo == bar
filter.Values = headerFooBar
ExpectFalse(t, filter.CheckKeep(fooBar, nil))
ExpectTrue(t, filter.CheckKeep(fooBaz, nil))
})
}
func TestCIDRFilter(t *testing.T) {
cidr := []*CIDR{
strutils.MustParse[*CIDR]("192.168.10.0/24"),
}
ExpectEqual(t, cidr[0].String(), "192.168.10.0/24")
inCIDR := &http.Request{
RemoteAddr: "192.168.10.1",
}
notInCIDR := &http.Request{
RemoteAddr: "192.168.11.1",
}
t.Run("positive", func(t *testing.T) {
filter := &LogFilter[*CIDR]{}
ExpectTrue(t, filter.CheckKeep(inCIDR, nil))
ExpectTrue(t, filter.CheckKeep(notInCIDR, nil))
filter.Values = cidr
ExpectTrue(t, filter.CheckKeep(inCIDR, nil))
ExpectFalse(t, filter.CheckKeep(notInCIDR, nil))
})
t.Run("negative", func(t *testing.T) {
filter := &LogFilter[*CIDR]{Negative: true}
ExpectFalse(t, filter.CheckKeep(inCIDR, nil))
ExpectFalse(t, filter.CheckKeep(notInCIDR, nil))
filter.Values = cidr
ExpectFalse(t, filter.CheckKeep(inCIDR, nil))
ExpectTrue(t, filter.CheckKeep(notInCIDR, nil))
})
}

View File

@@ -0,0 +1,144 @@
package accesslog
import (
"bytes"
"encoding/json"
"net"
"net/http"
"net/url"
"strconv"
"time"
"github.com/yusing/go-proxy/internal/logging"
)
type (
CommonFormatter struct {
cfg *Fields
GetTimeNow func() time.Time // for testing purposes only
}
CombinedFormatter struct{ CommonFormatter }
JSONFormatter struct{ CommonFormatter }
JSONLogEntry struct {
Time string `json:"time"`
IP string `json:"ip"`
Method string `json:"method"`
Scheme string `json:"scheme"`
Host string `json:"host"`
URI string `json:"uri"`
Protocol string `json:"protocol"`
Status int `json:"status"`
Error string `json:"error,omitempty"`
ContentType string `json:"type"`
Size int64 `json:"size"`
Referer string `json:"referer"`
UserAgent string `json:"useragent"`
Query map[string][]string `json:"query,omitempty"`
Headers map[string][]string `json:"headers,omitempty"`
Cookies map[string]string `json:"cookies,omitempty"`
}
)
const LogTimeFormat = "02/Jan/2006:15:04:05 -0700"
func scheme(req *http.Request) string {
if req.TLS != nil {
return "https"
}
return "http"
}
func requestURI(u *url.URL, query url.Values) string {
uri := u.EscapedPath()
if len(query) > 0 {
uri += "?" + query.Encode()
}
return uri
}
func clientIP(req *http.Request) string {
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
return clientIP
}
return req.RemoteAddr
}
// debug only.
func (f *CommonFormatter) SetGetTimeNow(getTimeNow func() time.Time) {
f.GetTimeNow = getTimeNow
}
func (f *CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
query := f.cfg.Query.ProcessQuery(req.URL.Query())
line.WriteString(req.Host)
line.WriteRune(' ')
line.WriteString(clientIP(req))
line.WriteString(" - - [")
line.WriteString(f.GetTimeNow().Format(LogTimeFormat))
line.WriteString("] \"")
line.WriteString(req.Method)
line.WriteRune(' ')
line.WriteString(requestURI(req.URL, query))
line.WriteRune(' ')
line.WriteString(req.Proto)
line.WriteString("\" ")
line.WriteString(strconv.Itoa(res.StatusCode))
line.WriteRune(' ')
line.WriteString(strconv.FormatInt(res.ContentLength, 10))
}
func (f *CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
f.CommonFormatter.Format(line, req, res)
line.WriteString(" \"")
line.WriteString(req.Referer())
line.WriteString("\" \"")
line.WriteString(req.UserAgent())
line.WriteRune('"')
}
func (f *JSONFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
query := f.cfg.Query.ProcessQuery(req.URL.Query())
headers := f.cfg.Headers.ProcessHeaders(req.Header)
headers.Del("Cookie")
cookies := f.cfg.Cookies.ProcessCookies(req.Cookies())
entry := JSONLogEntry{
Time: f.GetTimeNow().Format(LogTimeFormat),
IP: clientIP(req),
Method: req.Method,
Scheme: scheme(req),
Host: req.Host,
URI: requestURI(req.URL, query),
Protocol: req.Proto,
Status: res.StatusCode,
ContentType: res.Header.Get("Content-Type"),
Size: res.ContentLength,
Referer: req.Referer(),
UserAgent: req.UserAgent(),
Query: query,
Headers: headers,
Cookies: cookies,
}
if res.StatusCode >= 400 {
entry.Error = res.Status
}
if entry.ContentType == "" {
// try to get content type from request
entry.ContentType = req.Header.Get("Content-Type")
}
marshaller := json.NewEncoder(line)
err := marshaller.Encode(entry)
if err != nil {
logging.Err(err).Msg("failed to marshal json log")
}
}

View File

@@ -0,0 +1,77 @@
package accesslog
import (
"bytes"
"io"
"sync"
)
type MockFile struct {
data []byte
position int64
sync.Mutex
}
func (m *MockFile) Seek(offset int64, whence int) (int64, error) {
switch whence {
case io.SeekStart:
m.position = offset
case io.SeekCurrent:
m.position += offset
case io.SeekEnd:
m.position = int64(len(m.data)) + offset
}
return m.position, nil
}
func (m *MockFile) Write(p []byte) (n int, err error) {
m.data = append(m.data, p...)
n = len(p)
m.position += int64(n)
return
}
func (m *MockFile) Name() string {
return "mock"
}
func (m *MockFile) Read(p []byte) (n int, err error) {
if m.position >= int64(len(m.data)) {
return 0, io.EOF
}
n = copy(p, m.data[m.position:])
m.position += int64(n)
return n, nil
}
func (m *MockFile) ReadAt(p []byte, off int64) (n int, err error) {
if off >= int64(len(m.data)) {
return 0, io.EOF
}
n = copy(p, m.data[off:])
return n, nil
}
func (m *MockFile) Close() error {
return nil
}
func (m *MockFile) Truncate(size int64) error {
m.data = m.data[:size]
m.position = size
return nil
}
func (m *MockFile) LineCount() int {
m.Lock()
defer m.Unlock()
return bytes.Count(m.data[:m.position], []byte("\n"))
}
func (m *MockFile) Len() int64 {
return m.position
}
func (m *MockFile) Content() []byte {
return m.data[:m.position]
}

View File

@@ -0,0 +1,56 @@
package accesslog
import (
"strconv"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type Retention struct {
Days uint64 `json:"days"`
Last uint64 `json:"last"`
}
var (
ErrInvalidSyntax = gperr.New("invalid syntax")
ErrZeroValue = gperr.New("zero value")
)
var defaultChunkSize = 64 * 1024 // 64KB
// Syntax:
//
// <N> days|weeks|months
//
// last <N>
//
// Parse implements strutils.Parser.
func (r *Retention) Parse(v string) (err error) {
split := strutils.SplitSpace(v)
if len(split) != 2 {
return ErrInvalidSyntax.Subject(v)
}
switch split[0] {
case "last":
r.Last, err = strconv.ParseUint(split[1], 10, 64)
default: // <N> days|weeks|months
r.Days, err = strconv.ParseUint(split[0], 10, 64)
if err != nil {
return
}
switch split[1] {
case "days":
case "weeks":
r.Days *= 7
case "months":
r.Days *= 30
default:
return ErrInvalidSyntax.Subject("unit " + split[1])
}
}
if r.Days == 0 && r.Last == 0 {
return ErrZeroValue
}
return
}

View File

@@ -0,0 +1,33 @@
package accesslog_test
import (
"testing"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestParseRetention(t *testing.T) {
tests := []struct {
input string
expected *Retention
shouldErr bool
}{
{"30 days", &Retention{Days: 30}, false},
{"2 weeks", &Retention{Days: 14}, false},
{"last 5", &Retention{Last: 5}, false},
{"invalid input", &Retention{}, true},
}
for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
r := &Retention{}
err := r.Parse(test.input)
if !test.shouldErr {
ExpectNoError(t, err)
} else {
ExpectDeepEqual(t, r, test.expected)
}
})
}
}

View File

@@ -0,0 +1,52 @@
package accesslog
import (
"strconv"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type StatusCodeRange struct {
Start int
End int
}
var ErrInvalidStatusCodeRange = gperr.New("invalid status code range")
func (r *StatusCodeRange) Includes(code int) bool {
return r.Start <= code && code <= r.End
}
// Parse implements strutils.Parser.
func (r *StatusCodeRange) Parse(v string) error {
split := strutils.SplitRune(v, '-')
switch len(split) {
case 1:
start, err := strconv.Atoi(split[0])
if err != nil {
return gperr.Wrap(err)
}
r.Start = start
r.End = start
return nil
case 2:
start, errStart := strconv.Atoi(split[0])
end, errEnd := strconv.Atoi(split[1])
if err := gperr.Join(errStart, errEnd); err != nil {
return err
}
r.Start = start
r.End = end
return nil
default:
return ErrInvalidStatusCodeRange.Subject(v)
}
}
func (r *StatusCodeRange) String() string {
if r.Start == r.End {
return strconv.Itoa(r.Start)
}
return strconv.Itoa(r.Start) + "-" + strconv.Itoa(r.End)
}