feat: proxmox idlewatcher (#88)

* feat: idle sleep for proxmox LXCs

* refactor: replace deprecated docker api types

* chore(api): remove debug task list endpoint

* refactor: move servemux to gphttp/servemux; favicon.go to v1/favicon

* refactor: introduce Pool interface, move agent_pool to agent module

* refactor: simplify api code

* feat: introduce debug api

* refactor: remove net.URL and net.CIDR types, improved unmarshal handling

* chore: update Makefile for debug build tag, update README

* chore: add gperr.Unwrap method

* feat: relative time and duration formatting

* chore: add ROOT_DIR environment variable, refactor

* migration: move homepage override and icon cache to $BASE_DIR/data, add migration code

* fix: nil dereference on marshalling service health

* fix: wait for route deletion

* chore: enhance tasks debuggability

* feat: stdout access logger and MultiWriter

* fix(agent): remove agent properly on verify error

* fix(metrics): disk exclusion logic and added corresponding tests

* chore: update schema and prettify, fix package.json and Makefile

* fix: I/O buffer not being shrunk before putting back to pool

* feat: enhanced error handling module

* chore: deps upgrade

* feat: better value formatting and handling

---------

Co-authored-by: yusing <yusing@6uo.me>
This commit is contained in:
Yuzerion
2025-04-16 14:52:33 +08:00
committed by GitHub
parent 88f3a95b61
commit 57292f0fe8
173 changed files with 4131 additions and 2096 deletions

View File

@@ -25,11 +25,15 @@ type (
}
AccessLogIO interface {
io.Writer
sync.Locker
Name() string // file name or path
}
supportRotate interface {
io.ReadWriteCloser
io.ReadWriteSeeker
io.ReaderAt
sync.Locker
Name() string // file name or path
Truncate(size int64) error
}
@@ -40,7 +44,33 @@ type (
}
)
func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger {
func NewAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) {
var ios []AccessLogIO
if cfg.Stdout {
ios = append(ios, stdoutIO)
}
if cfg.Path != "" {
io, err := newFileIO(cfg.Path)
if err != nil {
return nil, err
}
ios = append(ios, io)
}
if len(ios) == 0 {
return nil, nil
}
return NewAccessLoggerWithIO(parent, NewMultiWriter(ios...), cfg), nil
}
func NewMockAccessLogger(parent task.Parent, cfg *Config) *AccessLogger {
return NewAccessLoggerWithIO(parent, &MockFile{}, cfg)
}
func NewAccessLoggerWithIO(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger {
if cfg.BufferSize == 0 {
cfg.BufferSize = DefaultBufferSize
}
@@ -152,7 +182,9 @@ func (l *AccessLogger) Flush() error {
func (l *AccessLogger) close() {
l.io.Lock()
defer l.io.Unlock()
l.io.Close()
if r, ok := l.io.(io.Closer); ok {
r.Close()
}
}
func (l *AccessLogger) write(data []byte) {

View File

@@ -56,7 +56,7 @@ func fmtLog(cfg *Config) (ts string, line string) {
var buf bytes.Buffer
t := time.Now()
logger := NewAccessLogger(testTask, nil, cfg)
logger := NewMockAccessLogger(testTask, cfg)
logger.Formatter.SetGetTimeNow(func() time.Time {
return t
})

View File

@@ -7,7 +7,7 @@ import (
// BackScanner provides an interface to read a file backward line by line.
type BackScanner struct {
file AccessLogIO
file supportRotate
chunkSize int
offset int64
buffer []byte
@@ -18,7 +18,7 @@ type BackScanner struct {
// NewBackScanner creates a new Scanner to read the file backward.
// chunkSize determines the size of each read chunk from the end of the file.
func NewBackScanner(file AccessLogIO, chunkSize int) *BackScanner {
func NewBackScanner(file supportRotate, chunkSize int) *BackScanner {
size, err := file.Seek(0, io.SeekEnd)
if err != nil {
return &BackScanner{err: err}

View File

@@ -1,6 +1,10 @@
package accesslog
import "github.com/yusing/go-proxy/internal/utils"
import (
"errors"
"github.com/yusing/go-proxy/internal/utils"
)
type (
Format string
@@ -19,7 +23,8 @@ type (
Config struct {
BufferSize int `json:"buffer_size"`
Format Format `json:"format" validate:"oneof=common combined json"`
Path string `json:"path" validate:"required"`
Path string `json:"path"`
Stdout bool `json:"stdout"`
Filters Filters `json:"filters"`
Fields Fields `json:"fields"`
Retention *Retention `json:"retention"`
@@ -34,6 +39,13 @@ var (
const DefaultBufferSize = 64 * 1024 // 64KB
func (cfg *Config) Validate() error {
if cfg.Path == "" && !cfg.Stdout {
return errors.New("path or stdout is required")
}
return nil
}
func DefaultConfig() *Config {
return &Config{
BufferSize: DefaultBufferSize,

View File

@@ -3,11 +3,10 @@ package accesslog
import (
"fmt"
"os"
"path"
pathPkg "path"
"sync"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
)
@@ -27,16 +26,16 @@ var (
openedFilesMu sync.Mutex
)
func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) {
func newFileIO(path string) (AccessLogIO, error) {
openedFilesMu.Lock()
var file *File
path := path.Clean(cfg.Path)
path = pathPkg.Clean(path)
if opened, ok := openedFiles[path]; ok {
opened.refCount.Add()
file = opened
} else {
f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o644)
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o644)
if err != nil {
openedFilesMu.Unlock()
return nil, fmt.Errorf("access log open error: %w", err)
@@ -47,7 +46,7 @@ func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error)
}
openedFilesMu.Unlock()
return NewAccessLogger(parent, file, cfg), nil
return file, nil
}
func (f *File) Close() error {

View File

@@ -16,7 +16,6 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) {
cfg := DefaultConfig()
cfg.Path = "test.log"
parent := task.RootTask("test", false)
loggerCount := 10
accessLogIOs := make([]AccessLogIO, loggerCount)
@@ -33,9 +32,9 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) {
wg.Add(1)
go func(index int) {
defer wg.Done()
logger, err := NewFileAccessLogger(parent, cfg)
file, err := newFileIO(cfg.Path)
ExpectNoError(t, err)
accessLogIOs[index] = logger.io
accessLogIOs[index] = file
}(i)
}
@@ -59,7 +58,7 @@ func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) {
loggers := make([]*AccessLogger, loggerCount)
for i := range loggerCount {
loggers[i] = NewAccessLogger(parent, &file, cfg)
loggers[i] = NewAccessLoggerWithIO(parent, &file, cfg)
}
var wg sync.WaitGroup

View File

@@ -6,7 +6,6 @@ import (
"strings"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@@ -24,7 +23,7 @@ type (
Key, Value string
}
Host string
CIDR struct{ types.CIDR }
CIDR net.IPNet
)
var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter")
@@ -86,7 +85,7 @@ func (h Host) Fulfill(req *http.Request, res *http.Response) bool {
return req.Host == string(h)
}
func (cidr CIDR) Fulfill(req *http.Request, res *http.Response) bool {
func (cidr *CIDR) Fulfill(req *http.Request, res *http.Response) bool {
ip, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
ip = req.RemoteAddr
@@ -95,5 +94,9 @@ func (cidr CIDR) Fulfill(req *http.Request, res *http.Response) bool {
if netIP == nil {
return false
}
return cidr.Contains(netIP)
return (*net.IPNet)(cidr).Contains(netIP)
}
func (cidr *CIDR) String() string {
return (*net.IPNet)(cidr).String()
}

View File

@@ -1,6 +1,7 @@
package accesslog_test
import (
"net"
"net/http"
"testing"
@@ -155,9 +156,10 @@ func TestHeaderFilter(t *testing.T) {
}
func TestCIDRFilter(t *testing.T) {
cidr := []*CIDR{
strutils.MustParse[*CIDR]("192.168.10.0/24"),
}
cidr := []*CIDR{{
IP: net.ParseIP("192.168.10.0"),
Mask: net.CIDRMask(24, 32),
}}
ExpectEqual(t, cidr[0].String(), "192.168.10.0/24")
inCIDR := &http.Request{
RemoteAddr: "192.168.10.1",

View File

@@ -0,0 +1,46 @@
package accesslog
import "strings"
type MultiWriter struct {
writers []AccessLogIO
}
func NewMultiWriter(writers ...AccessLogIO) AccessLogIO {
if len(writers) == 0 {
return nil
}
if len(writers) == 1 {
return writers[0]
}
return &MultiWriter{
writers: writers,
}
}
func (w *MultiWriter) Write(p []byte) (n int, err error) {
for _, writer := range w.writers {
writer.Write(p)
}
return len(p), nil
}
func (w *MultiWriter) Lock() {
for _, writer := range w.writers {
writer.Lock()
}
}
func (w *MultiWriter) Unlock() {
for _, writer := range w.writers {
writer.Unlock()
}
}
func (w *MultiWriter) Name() string {
names := make([]string, len(w.writers))
for i, writer := range w.writers {
names[i] = writer.Name()
}
return strings.Join(names, ", ")
}

View File

@@ -2,11 +2,15 @@ package accesslog
import (
"bytes"
"io"
ioPkg "io"
"time"
)
func (l *AccessLogger) rotate() (err error) {
io, ok := l.io.(supportRotate)
if !ok {
return nil
}
// Get retention configuration
config := l.Config().Retention
var shouldKeep func(t time.Time, lineCount int) bool
@@ -24,7 +28,7 @@ func (l *AccessLogger) rotate() (err error) {
return nil // No retention policy set
}
s := NewBackScanner(l.io, defaultChunkSize)
s := NewBackScanner(io, defaultChunkSize)
nRead := 0
nLines := 0
for s.Scan() {
@@ -40,11 +44,11 @@ func (l *AccessLogger) rotate() (err error) {
}
beg := int64(nRead)
if _, err := l.io.Seek(-beg, io.SeekEnd); err != nil {
if _, err := io.Seek(-beg, ioPkg.SeekEnd); err != nil {
return err
}
buf := make([]byte, nRead)
if _, err := l.io.Read(buf); err != nil {
if _, err := io.Read(buf); err != nil {
return err
}
@@ -55,8 +59,13 @@ func (l *AccessLogger) rotate() (err error) {
}
func (l *AccessLogger) writeTruncate(buf []byte) (err error) {
io, ok := l.io.(supportRotate)
if !ok {
return nil
}
// Seek to beginning and truncate
if _, err := l.io.Seek(0, 0); err != nil {
if _, err := io.Seek(0, 0); err != nil {
return err
}
@@ -70,13 +79,13 @@ func (l *AccessLogger) writeTruncate(buf []byte) (err error) {
}
// Truncate file
if err = l.io.Truncate(int64(nWritten)); err != nil {
if err = io.Truncate(int64(nWritten)); err != nil {
return err
}
// check bytes written == buffer size
if nWritten != len(buf) {
return io.ErrShortWrite
return ioPkg.ErrShortWrite
}
return
}

View File

@@ -33,7 +33,7 @@ func TestParseLogTime(t *testing.T) {
func TestRetentionCommonFormat(t *testing.T) {
var file MockFile
logger := NewAccessLogger(task.RootTask("test", false), &file, &Config{
logger := NewAccessLoggerWithIO(task.RootTask("test", false), &file, &Config{
Format: FormatCommon,
BufferSize: 1024,
})

View File

@@ -0,0 +1,18 @@
package accesslog
import (
"io"
"os"
)
type StdoutLogger struct {
io.Writer
}
var stdoutIO = &StdoutLogger{os.Stdout}
func (l *StdoutLogger) Lock() {}
func (l *StdoutLogger) Unlock() {}
func (l *StdoutLogger) Name() string {
return "stdout"
}

View File

@@ -6,6 +6,7 @@ import (
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
@@ -84,3 +85,18 @@ func WriteText(r *http.Request, conn *websocket.Conn, msg string) bool {
}
return true
}
// DynamicJSONHandler serves a JSON response depending on the request type.
//
// If the request is a websocket, it serves the data for the given interval.
//
// Otherwise, it serves the data once.
func DynamicJSONHandler[ResultType any](w http.ResponseWriter, r *http.Request, getter func() ResultType, interval time.Duration) {
if httpheaders.IsWebsocket(r.Header) {
Periodic(w, r, interval, func(conn *websocket.Conn) error {
return wsjson.Write(r.Context(), conn, getter())
})
} else {
gphttp.RespondJSON(w, r, getter())
}
}

View File

@@ -13,7 +13,6 @@ import (
"github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/health"
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
)
// TODO: stats of each server.
@@ -240,14 +239,14 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
lb.impl.ServeHTTP(srvs, rw, r)
}
// MarshalJSON implements health.HealthMonitor.
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
// MarshalMap implements health.HealthMonitor.
func (lb *LoadBalancer) MarshalMap() map[string]any {
extra := make(map[string]any)
lb.pool.RangeAll(func(k string, v Server) {
extra[v.Key()] = v
})
return (&monitor.JSONRepresentation{
return (&health.JSONRepresentation{
Name: lb.Name(),
Status: lb.Status(),
Started: lb.startTime,
@@ -256,7 +255,7 @@ func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
"config": lb.Config,
"pool": extra,
},
}).MarshalJSON()
}).MarshalMap()
}
// Name implements health.HealthMonitor.

View File

@@ -2,9 +2,9 @@ package types
import (
"net/http"
"net/url"
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
net "github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher/health"
@@ -15,7 +15,7 @@ type (
_ U.NoCopy
name string
url *net.URL
url *url.URL
weight Weight
http.Handler `json:"-"`
@@ -27,7 +27,7 @@ type (
health.HealthMonitor
Name() string
Key() string
URL() *net.URL
URL() *url.URL
Weight() Weight
SetWeight(weight Weight)
TryWake() error
@@ -38,7 +38,7 @@ type (
var NewServerPool = F.NewMap[Pool]
func NewServer(name string, url *net.URL, weight Weight, handler http.Handler, healthMon health.HealthMonitor) Server {
func NewServer(name string, url *url.URL, weight Weight, handler http.Handler, healthMon health.HealthMonitor) Server {
srv := &server{
name: name,
url: url,
@@ -52,7 +52,7 @@ func NewServer(name string, url *net.URL, weight Weight, handler http.Handler, h
func TestNewServer[T ~int | ~float32 | ~float64](weight T) Server {
srv := &server{
weight: Weight(weight),
url: net.MustParseURL("http://localhost"),
url: &url.URL{Scheme: "http", Host: "localhost"},
}
return srv
}
@@ -61,7 +61,7 @@ func (srv *server) Name() string {
return srv.name
}
func (srv *server) URL() *net.URL {
func (srv *server) URL() *url.URL {
return srv.url
}

View File

@@ -6,7 +6,6 @@ import (
"github.com/go-playground/validator/v10"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
@@ -18,8 +17,8 @@ type (
cachedAddr F.Map[string, bool] // cache for trusted IPs
}
CIDRWhitelistOpts struct {
Allow []*types.CIDR `validate:"min=1"`
StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,status_code"`
Allow []*net.IPNet `validate:"min=1"`
StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,status_code"`
Message string
}
)
@@ -27,7 +26,7 @@ type (
var (
CIDRWhiteList = NewMiddleware[cidrWhitelist]()
cidrWhitelistDefaults = CIDRWhitelistOpts{
Allow: []*types.CIDR{},
Allow: []*net.IPNet{},
StatusCode: http.StatusForbidden,
Message: "IP not allowed",
}

View File

@@ -11,7 +11,6 @@ import (
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/atomic"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@@ -33,7 +32,7 @@ var (
cfCIDRsMu sync.Mutex
// RFC 1918.
localCIDRs = []*types.CIDR{
localCIDRs = []*net.IPNet{
{IP: net.IPv4(127, 0, 0, 1), Mask: net.IPv4Mask(255, 255, 255, 255)}, // 127.0.0.1/32
{IP: net.IPv4(10, 0, 0, 0), Mask: net.IPv4Mask(255, 0, 0, 0)}, // 10.0.0.0/8
{IP: net.IPv4(172, 16, 0, 0), Mask: net.IPv4Mask(255, 240, 0, 0)}, // 172.16.0.0/12
@@ -68,7 +67,7 @@ func (cri *cloudflareRealIP) getTracer() *Tracer {
return cri.realIP.getTracer()
}
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) {
if time.Since(cfCIDRsLastUpdate.Load()) < cfCIDRsUpdateInterval {
return
}
@@ -83,7 +82,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
if common.IsTest {
cfCIDRs = localCIDRs
} else {
cfCIDRs = make([]*types.CIDR, 0, 30)
cfCIDRs = make([]*net.IPNet, 0, 30)
err := errors.Join(
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, &cfCIDRs),
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs),
@@ -103,7 +102,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
return
}
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*net.IPNet) error {
resp, err := http.Get(endpoint)
if err != nil {
return err
@@ -124,7 +123,7 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
}
*cfCIDRs = append(*cfCIDRs, (*types.CIDR)(cidr))
*cfCIDRs = append(*cfCIDRs, (*net.IPNet)(cidr))
}
*cfCIDRs = append(*cfCIDRs, localCIDRs...)
return nil

View File

@@ -16,8 +16,6 @@ import (
"github.com/yusing/go-proxy/internal/watcher/events"
)
const errPagesBasePath = common.ErrorPagesBasePath
var (
setupOnce sync.Once
dirWatcher W.Watcher
@@ -26,7 +24,7 @@ var (
func setup() {
t := task.RootTask("error_page", false)
dirWatcher = W.NewDirectoryWatcher(t, errPagesBasePath)
dirWatcher = W.NewDirectoryWatcher(t, common.ErrorPagesDir)
loadContent()
go watchDir()
}
@@ -46,7 +44,7 @@ func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
}
func loadContent() {
files, err := U.ListFiles(errPagesBasePath, 0)
files, err := U.ListFiles(common.ErrorPagesDir, 0)
if err != nil {
logging.Err(err).Msg("failed to list error page resources")
return

View File

@@ -55,7 +55,7 @@ func All() map[string]*Middleware {
func LoadComposeFiles() {
errs := gperr.NewBuilder("middleware compile errors")
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeDir, 0)
if err != nil {
logging.Err(err).Msg("failed to list middleware definitions")
return

View File

@@ -4,10 +4,10 @@ import (
"bytes"
"net"
"net/http"
"net/url"
"slices"
"testing"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@@ -51,8 +51,8 @@ func TestModifyRequest(t *testing.T) {
})
t.Run("request_headers", func(t *testing.T) {
reqURL := types.MustParseURL("https://my.app/?arg_1=b")
upstreamURL := types.MustParseURL("http://test.example.com")
reqURL := Must(url.Parse("https://my.app/?arg_1=b"))
upstreamURL := Must(url.Parse("http://test.example.com"))
result, err := newMiddlewareTest(ModifyRequest, &testArgs{
middlewareOpt: opts,
reqURL: reqURL,
@@ -128,8 +128,8 @@ func TestModifyRequest(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reqURL := types.MustParseURL("https://my.app" + tt.path)
upstreamURL := types.MustParseURL(tt.upstreamURL)
reqURL := Must(url.Parse("https://my.app" + tt.path))
upstreamURL := Must(url.Parse(tt.upstreamURL))
opts["add_prefix"] = tt.addPrefix
result, err := newMiddlewareTest(ModifyRequest, &testArgs{

View File

@@ -4,10 +4,10 @@ import (
"bytes"
"net"
"net/http"
"net/url"
"slices"
"testing"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@@ -54,8 +54,8 @@ func TestModifyResponse(t *testing.T) {
})
t.Run("response_headers", func(t *testing.T) {
reqURL := types.MustParseURL("https://my.app/?arg_1=b")
upstreamURL := types.MustParseURL("http://test.example.com")
reqURL := Must(url.Parse("https://my.app/?arg_1=b"))
upstreamURL := Must(url.Parse("http://test.example.com"))
result, err := newMiddlewareTest(ModifyResponse, &testArgs{
middlewareOpt: opts,
reqURL: reqURL,

View File

@@ -5,7 +5,6 @@ import (
"net/http"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/types"
)
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
@@ -19,7 +18,7 @@ type (
// Header is the name of the header to use for the real client IP
Header string `validate:"required"`
// From is a list of Address / CIDRs to trust
From []*types.CIDR `validate:"required,min=1"`
From []*net.IPNet `validate:"required,min=1"`
/*
If recursive search is disabled,
the original client address that matches one of the trusted addresses is replaced by
@@ -36,7 +35,7 @@ var (
RealIP = NewMiddleware[realIP]()
realIPOptsDefault = RealIPOpts{
Header: "X-Real-IP",
From: []*types.CIDR{},
From: []*net.IPNet{},
}
)

View File

@@ -7,7 +7,6 @@ import (
"testing"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@@ -23,7 +22,7 @@ func TestSetRealIPOpts(t *testing.T) {
}
optExpected := &RealIPOpts{
Header: httpheaders.HeaderXRealIP,
From: []*types.CIDR{
From: []*net.IPNet{
{
IP: net.ParseIP("127.0.0.0"),
Mask: net.IPv4Mask(255, 0, 0, 0),

View File

@@ -2,15 +2,15 @@ package middleware
import (
"net/http"
"net/url"
"testing"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRedirectToHTTPs(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
reqURL: types.MustParseURL("http://example.com"),
reqURL: Must(url.Parse("http://example.com")),
})
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusPermanentRedirect)
@@ -19,7 +19,7 @@ func TestRedirectToHTTPs(t *testing.T) {
func TestNoRedirect(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
reqURL: types.MustParseURL("https://example.com"),
reqURL: Must(url.Parse("https://example.com")),
})
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)

View File

@@ -7,11 +7,11 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@@ -80,11 +80,11 @@ type TestResult struct {
type testArgs struct {
middlewareOpt OptionsRaw
upstreamURL *types.URL
upstreamURL *url.URL
realRoundTrip bool
reqURL *types.URL
reqURL *url.URL
reqMethod string
headers http.Header
body []byte
@@ -96,13 +96,13 @@ type testArgs struct {
func (args *testArgs) setDefaults() {
if args.reqURL == nil {
args.reqURL = Must(types.ParseURL("https://example.com"))
args.reqURL = Must(url.Parse("https://example.com"))
}
if args.reqMethod == "" {
args.reqMethod = http.MethodGet
}
if args.upstreamURL == nil {
args.upstreamURL = Must(types.ParseURL("https://10.0.0.1:8443")) // dummy url, no actual effect
args.upstreamURL = Must(url.Parse("https://10.0.0.1:8443")) // dummy url, no actual effect
}
if args.respHeaders == nil {
args.respHeaders = http.Header{}

View File

@@ -28,7 +28,6 @@ import (
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils"
"golang.org/x/net/http/httpguts"
)
@@ -93,7 +92,7 @@ type ReverseProxy struct {
HandlerFunc http.HandlerFunc
TargetName string
TargetURL *types.URL
TargetURL *url.URL
}
func singleJoiningSlash(a, b string) string {
@@ -133,7 +132,7 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
// URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir.
func NewReverseProxy(name string, target *types.URL, transport http.RoundTripper) *ReverseProxy {
func NewReverseProxy(name string, target *url.URL, transport http.RoundTripper) *ReverseProxy {
if transport == nil {
panic("nil transport")
}
@@ -151,7 +150,7 @@ func (p *ReverseProxy) rewriteRequestURL(req *http.Request) {
targetQuery := p.TargetURL.RawQuery
req.URL.Scheme = p.TargetURL.Scheme
req.URL.Host = p.TargetURL.Host
req.URL.Path, req.URL.RawPath = joinURLPath(&p.TargetURL.URL, req.URL)
req.URL.Path, req.URL.RawPath = joinURLPath(p.TargetURL, req.URL)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {

View File

@@ -0,0 +1,61 @@
package servemux
import (
"fmt"
"net/http"
"github.com/yusing/go-proxy/internal/api/v1/auth"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type (
ServeMux struct {
*http.ServeMux
cfg config.ConfigInstance
}
WithCfgHandler = func(config.ConfigInstance, http.ResponseWriter, *http.Request)
)
func NewServeMux(cfg config.ConfigInstance) ServeMux {
return ServeMux{http.NewServeMux(), cfg}
}
func (mux ServeMux) HandleFunc(methods, endpoint string, h any, requireAuth ...bool) {
var handler http.HandlerFunc
switch h := h.(type) {
case func(http.ResponseWriter, *http.Request):
handler = h
case http.Handler:
handler = h.ServeHTTP
case WithCfgHandler:
handler = func(w http.ResponseWriter, r *http.Request) {
h(mux.cfg, w, r)
}
default:
panic(fmt.Errorf("unsupported handler type: %T", h))
}
matchDomains := mux.cfg.Value().MatchDomains
if len(matchDomains) > 0 {
origHandler := handler
handler = func(w http.ResponseWriter, r *http.Request) {
if httpheaders.IsWebsocket(r.Header) {
httpheaders.SetWebsocketAllowedDomains(r.Header, matchDomains)
}
origHandler(w, r)
}
}
if len(requireAuth) > 0 && requireAuth[0] {
handler = auth.RequireAuth(handler)
}
if methods == "" {
mux.ServeMux.HandleFunc(endpoint, handler)
} else {
for _, m := range strutils.CommaSeperatedList(methods) {
mux.ServeMux.HandleFunc(m+" "+endpoint, handler)
}
}
}