From b427ff1f881b027ddaa1fb1b19b264d26d277fdd Mon Sep 17 00:00:00 2001 From: yusing Date: Fri, 25 Apr 2025 10:47:52 +0800 Subject: [PATCH] feat(acl): connection level ip/geo blocking - fixed access log logic - implement acl at connection level - acl logging - ip/cidr blocking - geoblocking with MaxMind database --- agent/pkg/server/server.go | 2 +- go.mod | 4 +- go.sum | 4 + internal/acl/city_cache.go | 39 +++ internal/acl/config.go | 215 ++++++++++++++ internal/acl/matcher.go | 99 ++++++ internal/acl/maxmind.go | 281 ++++++++++++++++++ internal/acl/maxmind_test.go | 213 +++++++++++++ internal/acl/tcp_listener.go | 46 +++ internal/acl/types/city_info.go | 10 + internal/acl/types/ip_info.go | 9 + internal/acl/udp_listener.go | 79 +++++ internal/config/config.go | 9 + internal/config/types/config.go | 13 +- internal/config/types/homepage_config.go | 5 - internal/entrypoint/entrypoint.go | 2 +- internal/logging/accesslog/access_logger.go | 219 ++++++++------ .../logging/accesslog/access_logger_test.go | 18 +- internal/logging/accesslog/back_scanner.go | 11 +- .../logging/accesslog/back_scanner_test.go | 4 +- internal/logging/accesslog/config.go | 95 ++++-- internal/logging/accesslog/config_test.go | 2 +- internal/logging/accesslog/fields_test.go | 14 +- internal/logging/accesslog/file_logger.go | 8 +- .../logging/accesslog/file_logger_test.go | 6 +- internal/logging/accesslog/formatter.go | 32 +- internal/logging/accesslog/multi_writer.go | 33 +- internal/logging/accesslog/rotate.go | 12 +- internal/logging/accesslog/rotate_test.go | 34 ++- internal/logging/accesslog/stdout_logger.go | 2 - internal/net/gphttp/server/server.go | 18 +- internal/route/route.go | 14 +- 32 files changed, 1359 insertions(+), 193 deletions(-) create mode 100644 internal/acl/city_cache.go create mode 100644 internal/acl/config.go create mode 100644 internal/acl/matcher.go create mode 100644 internal/acl/maxmind.go create mode 100644 internal/acl/maxmind_test.go create mode 100644 internal/acl/tcp_listener.go create mode 100644 internal/acl/types/city_info.go create mode 100644 internal/acl/types/ip_info.go create mode 100644 internal/acl/udp_listener.go delete mode 100644 internal/config/types/homepage_config.go diff --git a/agent/pkg/server/server.go b/agent/pkg/server/server.go index be0ac1dd..9be4631d 100644 --- a/agent/pkg/server/server.go +++ b/agent/pkg/server/server.go @@ -40,5 +40,5 @@ func StartAgentServer(parent task.Parent, opt Options) { TLSConfig: tlsConfig, } - server.Start(parent, agentServer, logger) + server.Start(parent, agentServer, nil, logger) } diff --git a/go.mod b/go.mod index 729ba4be..cc5029e2 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( golang.org/x/oauth2 v0.29.0 // oauth2 authentication golang.org/x/text v0.24.0 // string utilities golang.org/x/time v0.11.0 // time utilities - gopkg.in/yaml.v3 v3.0.1 // yaml parsing for different config files + gopkg.in/yaml.v3 v3.0.1 // indirect; yaml parsing for different config files ) replace github.com/coreos/go-oidc/v3 => github.com/godoxy-app/go-oidc/v3 v3.14.2 @@ -30,8 +30,10 @@ replace github.com/coreos/go-oidc/v3 => github.com/godoxy-app/go-oidc/v3 v3.14.2 require ( github.com/bytedance/sonic v1.13.2 github.com/docker/cli v28.1.1+incompatible + github.com/goccy/go-yaml v1.17.1 github.com/golang-jwt/jwt/v5 v5.2.2 github.com/luthermonson/go-proxmox v0.2.2 + github.com/oschwald/maxminddb-golang v1.13.1 github.com/quic-go/quic-go v0.51.0 github.com/samber/slog-zerolog/v2 v2.7.3 github.com/spf13/afero v1.14.0 diff --git a/go.sum b/go.sum index 19743af1..1dcde465 100644 --- a/go.sum +++ b/go.sum @@ -77,6 +77,8 @@ github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/goccy/go-yaml v1.17.1 h1:LI34wktB2xEE3ONG/2Ar54+/HJVBriAGJ55PHls4YuY= +github.com/goccy/go-yaml v1.17.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godoxy-app/docker v0.0.0-20250418000134-7af8fd7b079e h1:LEbMtJ6loEubxetD+Aw8+1x0rShor5iMoy9WuFQ8hN8= github.com/godoxy-app/docker v0.0.0-20250418000134-7af8fd7b079e/go.mod h1:3tMTnTkH7IN5smn7PX83XdmRnNj4Nw2/Pt8GgReqnKM= @@ -163,6 +165,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8 github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= +github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE= +github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8= github.com/ovh/go-ovh v1.7.0 h1:V14nF7FwDjQrZt9g7jzcvAAQ3HN6DNShRFRMC3jLoPw= github.com/ovh/go-ovh v1.7.0/go.mod h1:cTVDnl94z4tl8pP1uZ/8jlVxntjSIf09bNcQ5TJSC7c= github.com/pierrec/lz4/v4 v4.1.17 h1:kV4Ip+/hUBC+8T6+2EgburRtkE9ef4nbY3f4dFhGjMc= diff --git a/internal/acl/city_cache.go b/internal/acl/city_cache.go new file mode 100644 index 00000000..1c496d5e --- /dev/null +++ b/internal/acl/city_cache.go @@ -0,0 +1,39 @@ +package acl + +import ( + "github.com/puzpuzpuz/xsync/v3" + acl "github.com/yusing/go-proxy/internal/acl/types" + "go.uber.org/atomic" +) + +var cityCache = xsync.NewMapOf[string, *acl.City]() +var numCachedLookup atomic.Uint64 + +func (cfg *MaxMindConfig) lookupCity(ip *acl.IPInfo) (*acl.City, bool) { + if ip.City != nil { + return ip.City, true + } + + if cfg.db.Reader == nil { + return nil, false + } + + city, ok := cityCache.Load(ip.Str) + if ok { + numCachedLookup.Inc() + return city, true + } + + cfg.db.RLock() + defer cfg.db.RUnlock() + + city = new(acl.City) + err := cfg.db.Lookup(ip.IP, city) + if err != nil { + return nil, false + } + + cityCache.Store(ip.Str, city) + ip.City = city + return city, true +} diff --git a/internal/acl/config.go b/internal/acl/config.go new file mode 100644 index 00000000..e56cbc02 --- /dev/null +++ b/internal/acl/config.go @@ -0,0 +1,215 @@ +package acl + +import ( + "net" + "sync" + "time" + + "github.com/oschwald/maxminddb-golang" + "github.com/puzpuzpuz/xsync/v3" + "github.com/rs/zerolog" + acl "github.com/yusing/go-proxy/internal/acl/types" + "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/logging/accesslog" + "github.com/yusing/go-proxy/internal/task" + "github.com/yusing/go-proxy/internal/utils" +) + +type Config struct { + Default string `json:"default" validate:"omitempty,oneof=allow deny"` // default: allow + AllowLocal *bool `json:"allow_local"` // default: true + Allow []string `json:"allow"` + Deny []string `json:"deny"` + Log *accesslog.ACLLoggerConfig `json:"log"` + + MaxMind *MaxMindConfig `json:"maxmind" validate:"omitempty"` + + config +} + +type ( + MaxMindDatabaseType string + MaxMindConfig struct { + AccountID string `json:"account_id" validate:"required"` + LicenseKey string `json:"license_key" validate:"required"` + Database MaxMindDatabaseType `json:"database" validate:"required,oneof=geolite geoip2"` + + logger zerolog.Logger + lastUpdate time.Time + db struct { + *maxminddb.Reader + sync.RWMutex + } + } +) + +type config struct { + defaultAllow bool + allowLocal bool + allow []matcher + deny []matcher + ipCache *xsync.MapOf[string, *checkCache] + logAllowed bool + logger *accesslog.AccessLogger +} + +type checkCache struct { + *acl.IPInfo + allow bool + created time.Time +} + +const cacheTTL = 1 * time.Minute + +func (c *checkCache) Expired() bool { + return c.created.Add(cacheTTL).After(utils.TimeNow()) +} + +//TODO: add stats + +const ( + ACLAllow = "allow" + ACLDeny = "deny" +) + +const ( + MaxMindGeoLite MaxMindDatabaseType = "geolite" + MaxMindGeoIP2 MaxMindDatabaseType = "geoip2" +) + +func (c *Config) Validate() gperr.Error { + switch c.Default { + case "", ACLAllow: + c.defaultAllow = true + case ACLDeny: + c.defaultAllow = false + default: + return gperr.New("invalid default value").Subject(c.Default) + } + + if c.AllowLocal != nil { + c.allowLocal = *c.AllowLocal + } else { + c.allowLocal = true + } + + if c.MaxMind != nil { + c.MaxMind.logger = logging.With().Str("type", string(c.MaxMind.Database)).Logger() + } + + if c.Log != nil { + c.logAllowed = c.Log.LogAllowed + } + + errs := gperr.NewBuilder("syntax error") + c.allow = make([]matcher, 0, len(c.Allow)) + c.deny = make([]matcher, 0, len(c.Deny)) + + for _, s := range c.Allow { + m, err := c.parseMatcher(s) + if err != nil { + errs.Add(err.Subject(s)) + continue + } + c.allow = append(c.allow, m) + } + for _, s := range c.Deny { + m, err := c.parseMatcher(s) + if err != nil { + errs.Add(err.Subject(s)) + continue + } + c.deny = append(c.deny, m) + } + + if errs.HasError() { + c.allow = nil + c.deny = nil + return errMatcherFormat.With(errs.Error()) + } + + c.ipCache = xsync.NewMapOf[string, *checkCache]() + return nil +} + +func (c *Config) Valid() bool { + return c != nil && (len(c.allow) > 0 || len(c.deny) > 0 || c.allowLocal) +} + +func (c *Config) Start(parent *task.Task) gperr.Error { + if c.MaxMind != nil { + if err := c.MaxMind.LoadMaxMindDB(parent); err != nil { + return err + } + } + if c.Log != nil { + logger, err := accesslog.NewAccessLogger(parent, c.Log) + if err != nil { + return gperr.New("failed to start access logger").With(err) + } + c.logger = logger + } + return nil +} + +func (c *config) cacheRecord(info *acl.IPInfo, allow bool) { + c.ipCache.Store(info.Str, &checkCache{ + IPInfo: info, + allow: allow, + created: utils.TimeNow(), + }) +} + +func (c *config) log(info *acl.IPInfo, allowed bool) { + if c.logger == nil { + return + } + if !allowed || c.logAllowed { + c.logger.LogACL(info, !allowed) + } +} + +func (c *Config) IPAllowed(ip net.IP) bool { + if ip == nil { + return false + } + + // always allow private and loopback + // loopback is not logged + if ip.IsLoopback() { + return true + } + + if c.allowLocal && ip.IsPrivate() { + c.log(&acl.IPInfo{IP: ip, Str: ip.String()}, true) + return true + } + + ipStr := ip.String() + record, ok := c.ipCache.Load(ipStr) + if ok && !record.Expired() { + c.log(record.IPInfo, record.allow) + return record.allow + } + + ipAndStr := &acl.IPInfo{IP: ip, Str: ipStr} + for _, m := range c.allow { + if m(ipAndStr) { + c.log(ipAndStr, true) + c.cacheRecord(ipAndStr, true) + return true + } + } + for _, m := range c.deny { + if m(ipAndStr) { + c.log(ipAndStr, false) + c.cacheRecord(ipAndStr, false) + return false + } + } + + c.log(ipAndStr, c.defaultAllow) + c.cacheRecord(ipAndStr, c.defaultAllow) + return c.defaultAllow +} diff --git a/internal/acl/matcher.go b/internal/acl/matcher.go new file mode 100644 index 00000000..73660c2d --- /dev/null +++ b/internal/acl/matcher.go @@ -0,0 +1,99 @@ +package acl + +import ( + "net" + "strings" + + acl "github.com/yusing/go-proxy/internal/acl/types" + "github.com/yusing/go-proxy/internal/gperr" +) + +type matcher func(*acl.IPInfo) bool + +const ( + MatcherTypeIP = "ip" + MatcherTypeCIDR = "cidr" + MatcherTypeTimeZone = "tz" + MatcherTypeISO = "iso" +) + +var errMatcherFormat = gperr.Multiline().AddLines( + "invalid matcher format, expect {type}:{value}", + "Available types: ip|cidr|tz|iso", + "ip:127.0.0.1", + "cidr:127.0.0.0/8", + "tz:Asia/Shanghai", + "iso:GB", +) +var ( + errSyntax = gperr.New("syntax error") + errInvalidIP = gperr.New("invalid IP") + errInvalidCIDR = gperr.New("invalid CIDR") + errMaxMindNotConfigured = gperr.New("MaxMind not configured") +) + +func (cfg *Config) parseMatcher(s string) (matcher, gperr.Error) { + parts := strings.Split(s, ":") + if len(parts) != 2 { + return nil, errSyntax + } + + switch parts[0] { + case MatcherTypeIP: + ip := net.ParseIP(parts[1]) + if ip == nil { + return nil, errInvalidIP + } + return matchIP(ip), nil + case MatcherTypeCIDR: + _, net, err := net.ParseCIDR(parts[1]) + if err != nil { + return nil, errInvalidCIDR + } + return matchCIDR(net), nil + case MatcherTypeTimeZone: + if cfg.MaxMind == nil { + return nil, errMaxMindNotConfigured + } + return cfg.MaxMind.matchTimeZone(parts[1]), nil + case MatcherTypeISO: + if cfg.MaxMind == nil { + return nil, errMaxMindNotConfigured + } + return cfg.MaxMind.matchISO(parts[1]), nil + default: + return nil, errSyntax + } +} + +func matchIP(ip net.IP) matcher { + return func(ip2 *acl.IPInfo) bool { + return ip.Equal(ip2.IP) + } +} + +func matchCIDR(n *net.IPNet) matcher { + return func(ip *acl.IPInfo) bool { + return n.Contains(ip.IP) + } +} + +func (cfg *MaxMindConfig) matchTimeZone(tz string) matcher { + return func(ip *acl.IPInfo) bool { + city, ok := cfg.lookupCity(ip) + if !ok { + return false + } + return city.Location.TimeZone == tz + } +} + +func (cfg *MaxMindConfig) matchISO(iso string) matcher { + return func(ip *acl.IPInfo) bool { + city, ok := cfg.lookupCity(ip) + if !ok { + return false + } + return city.Country.IsoCode == iso + } +} diff --git a/internal/acl/maxmind.go b/internal/acl/maxmind.go new file mode 100644 index 00000000..20684470 --- /dev/null +++ b/internal/acl/maxmind.go @@ -0,0 +1,281 @@ +package acl + +import ( + "archive/tar" + "compress/gzip" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/oschwald/maxminddb-golang" + "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/task" +) + +var ( + updateInterval = 24 * time.Hour + httpClient = &http.Client{ + Timeout: 10 * time.Second, + } + ErrResponseNotOK = gperr.New("response not OK") + ErrDownloadFailure = gperr.New("download failure") +) + +func dbPathImpl(dbType MaxMindDatabaseType) string { + if dbType == MaxMindGeoLite { + return filepath.Join(dataDir, "GeoLite2-City.mmdb") + } + return filepath.Join(dataDir, "GeoIP2-City.mmdb") +} + +func dbURLimpl(dbType MaxMindDatabaseType) string { + if dbType == MaxMindGeoLite { + return "https://download.maxmind.com/geoip/databases/GeoLite2-City/download?suffix=tar.gz" + } + return "https://download.maxmind.com/geoip/databases/GeoIP2-City/download?suffix=tar.gz" +} + +func dbFilename(dbType MaxMindDatabaseType) string { + if dbType == MaxMindGeoLite { + return "GeoLite2-City.mmdb" + } + return "GeoIP2-City.mmdb" +} + +func (cfg *MaxMindConfig) LoadMaxMindDB(parent task.Parent) gperr.Error { + if cfg.Database == "" { + return nil + } + + path := dbPath(cfg.Database) + reader, err := maxmindDBOpen(path) + exists := true + if err != nil { + switch { + case errors.Is(err, os.ErrNotExist): + default: + // ignore invalid error, just download it again + var invalidErr maxminddb.InvalidDatabaseError + if !errors.As(err, &invalidErr) { + return gperr.Wrap(err) + } + } + exists = false + } + + if !exists { + cfg.logger.Info().Msg("MaxMind DB not found/invalid, downloading...") + reader, err = cfg.download() + if err != nil { + return ErrDownloadFailure.With(err) + } + } + cfg.logger.Info().Msg("MaxMind DB loaded") + + cfg.db.Reader = reader + go cfg.scheduleUpdate(parent) + return nil +} + +func (cfg *MaxMindConfig) loadLastUpdate() { + f, err := os.Stat(dbPath(cfg.Database)) + if err != nil { + return + } + cfg.lastUpdate = f.ModTime() +} + +func (cfg *MaxMindConfig) setLastUpdate(t time.Time) { + cfg.lastUpdate = t + _ = os.Chtimes(dbPath(cfg.Database), t, t) +} + +func (cfg *MaxMindConfig) scheduleUpdate(parent task.Parent) { + task := parent.Subtask("schedule_update", true) + ticker := time.NewTicker(updateInterval) + + cfg.loadLastUpdate() + cfg.update() + + defer func() { + ticker.Stop() + if cfg.db.Reader != nil { + cfg.db.Reader.Close() + } + task.Finish(nil) + }() + + for { + select { + case <-task.Context().Done(): + return + case <-ticker.C: + cfg.update() + } + } +} + +func (cfg *MaxMindConfig) update() { + // check for update + cfg.logger.Info().Msg("checking for MaxMind DB update...") + remoteLastModified, err := cfg.checkLastest() + if err != nil { + cfg.logger.Err(err).Msg("failed to check MaxMind DB update") + return + } + if remoteLastModified.Equal(cfg.lastUpdate) { + cfg.logger.Info().Msg("MaxMind DB is up to date") + return + } + + cfg.logger.Info(). + Time("latest", remoteLastModified.Local()). + Time("current", cfg.lastUpdate). + Msg("MaxMind DB update available") + reader, err := cfg.download() + if err != nil { + cfg.logger.Err(err).Msg("failed to update MaxMind DB") + return + } + cfg.db.Lock() + cfg.db.Close() + cfg.db.Reader = reader + cfg.setLastUpdate(*remoteLastModified) + cfg.db.Unlock() + + cfg.logger.Info().Msg("MaxMind DB updated") +} + +func (cfg *MaxMindConfig) newReq(method string) (*http.Response, error) { + req, err := http.NewRequest(method, dbURL(cfg.Database), nil) + if err != nil { + return nil, err + } + req.SetBasicAuth(cfg.AccountID, cfg.LicenseKey) + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + return resp, nil +} + +func (cfg *MaxMindConfig) checkLastest() (lastModifiedT *time.Time, err error) { + resp, err := newReq(cfg, http.MethodHead) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%w: %d", ErrResponseNotOK, resp.StatusCode) + } + + lastModified := resp.Header.Get("Last-Modified") + if lastModified == "" { + cfg.logger.Warn().Msg("MaxMind responded no last modified time, update skipped") + return nil, nil + } + + lastModifiedTime, err := time.Parse(http.TimeFormat, lastModified) + if err != nil { + cfg.logger.Warn().Err(err).Msg("MaxMind responded invalid last modified time, update skipped") + return nil, err + } + + return &lastModifiedTime, nil +} + +func (cfg *MaxMindConfig) download() (*maxminddb.Reader, error) { + resp, err := newReq(cfg, http.MethodGet) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%w: %d", ErrResponseNotOK, resp.StatusCode) + } + + path := dbPath(cfg.Database) + tmpPath := path + "-tmp.tar.gz" + file, err := os.OpenFile(tmpPath, os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return nil, err + } + + cfg.logger.Info().Msg("MaxMind DB downloading...") + + _, err = io.Copy(file, resp.Body) + if err != nil { + file.Close() + return nil, err + } + + file.Close() + + // extract .tar.gz and move only the dbFilename to path + err = extractFileFromTarGz(tmpPath, dbFilename(cfg.Database), path) + if err != nil { + return nil, gperr.New("failed to extract database from archive").With(err) + } + // cleanup the tar.gz file + _ = os.Remove(tmpPath) + + db, err := maxmindDBOpen(path) + if err != nil { + return nil, err + } + return db, nil +} + +func extractFileFromTarGz(tarGzPath, targetFilename, destPath string) error { + f, err := os.Open(tarGzPath) + if err != nil { + return err + } + defer f.Close() + + gzr, err := gzip.NewReader(f) + if err != nil { + return err + } + defer gzr.Close() + + tr := tar.NewReader(gzr) + for { + hdr, err := tr.Next() + if err == io.EOF { + break // End of archive + } + if err != nil { + return err + } + // Only extract the file that matches targetFilename (basename match) + if filepath.Base(hdr.Name) == targetFilename { + outFile, err := os.OpenFile(destPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, hdr.FileInfo().Mode()) + if err != nil { + return err + } + defer outFile.Close() + _, err = io.Copy(outFile, tr) + if err != nil { + return err + } + return nil // Done + } + } + return fmt.Errorf("file %s not found in archive", targetFilename) +} + +var ( + dataDir = common.DataDir + dbURL = dbURLimpl + dbPath = dbPathImpl + maxmindDBOpen = maxminddb.Open + newReq = (*MaxMindConfig).newReq +) diff --git a/internal/acl/maxmind_test.go b/internal/acl/maxmind_test.go new file mode 100644 index 00000000..62f2cca8 --- /dev/null +++ b/internal/acl/maxmind_test.go @@ -0,0 +1,213 @@ +package acl + +import ( + "io" + "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/oschwald/maxminddb-golang" + "github.com/rs/zerolog" + "github.com/yusing/go-proxy/internal/task" +) + +func Test_dbPath(t *testing.T) { + tmpDataDir := "/tmp/testdata" + oldDataDir := dataDir + dataDir = tmpDataDir + defer func() { dataDir = oldDataDir }() + + tests := []struct { + name string + dbType MaxMindDatabaseType + want string + }{ + {"GeoLite", MaxMindGeoLite, filepath.Join(tmpDataDir, "GeoLite2-City.mmdb")}, + {"GeoIP2", MaxMindGeoIP2, filepath.Join(tmpDataDir, "GeoIP2-City.mmdb")}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := dbPath(tt.dbType); got != tt.want { + t.Errorf("dbPath() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_dbURL(t *testing.T) { + tests := []struct { + name string + dbType MaxMindDatabaseType + want string + }{ + {"GeoLite", MaxMindGeoLite, "https://download.maxmind.com/geoip/databases/GeoLite2-City/download?suffix=tar.gz"}, + {"GeoIP2", MaxMindGeoIP2, "https://download.maxmind.com/geoip/databases/GeoIP2-City/download?suffix=tar.gz"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := dbURL(tt.dbType); got != tt.want { + t.Errorf("dbURL() = %v, want %v", got, tt.want) + } + }) + } +} + +// --- Helper for MaxMindConfig --- +type testLogger struct{ zerolog.Logger } + +func (testLogger) Info() *zerolog.Event { return &zerolog.Event{} } +func (testLogger) Warn() *zerolog.Event { return &zerolog.Event{} } +func (testLogger) Err(_ error) *zerolog.Event { return &zerolog.Event{} } + +func Test_MaxMindConfig_newReq(t *testing.T) { + cfg := &MaxMindConfig{ + AccountID: "testid", + LicenseKey: "testkey", + Database: MaxMindGeoLite, + logger: zerolog.Nop(), + } + + // Patch httpClient to use httptest + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if u, p, ok := r.BasicAuth(); !ok || u != "testid" || p != "testkey" { + t.Errorf("basic auth not set correctly") + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + oldURL := dbURL + dbURL = func(MaxMindDatabaseType) string { return server.URL } + defer func() { dbURL = oldURL }() + + resp, err := cfg.newReq(http.MethodGet) + if err != nil { + t.Fatalf("newReq() error = %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("unexpected status: %v", resp.StatusCode) + } +} + +func Test_MaxMindConfig_checkUpdate(t *testing.T) { + cfg := &MaxMindConfig{ + AccountID: "id", + LicenseKey: "key", + Database: MaxMindGeoLite, + logger: zerolog.Nop(), + } + lastMod := time.Now().UTC().Format(http.TimeFormat) + buildTime := time.Now().Add(-time.Hour) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Last-Modified", lastMod) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + oldURL := dbURL + dbURL = func(MaxMindDatabaseType) string { return server.URL } + defer func() { dbURL = oldURL }() + + latest, err := cfg.checkLastest() + if err != nil { + t.Fatalf("checkUpdate() error = %v", err) + } + if latest.Equal(buildTime) { + t.Errorf("expected update needed") + } +} + +type fakeReadCloser struct { + firstRead bool + closed bool +} + +func (c *fakeReadCloser) Read(p []byte) (int, error) { + if !c.firstRead { + c.firstRead = true + return strings.NewReader("FAKEMMDB").Read(p) + } + return 0, io.EOF +} + +func (c *fakeReadCloser) Close() error { + c.closed = true + return nil +} + +func Test_MaxMindConfig_download(t *testing.T) { + cfg := &MaxMindConfig{ + AccountID: "id", + LicenseKey: "key", + Database: MaxMindGeoLite, + logger: zerolog.Nop(), + } + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.Copy(w, strings.NewReader("FAKEMMDB")) + })) + defer server.Close() + oldURL := dbURL + dbURL = func(MaxMindDatabaseType) string { return server.URL } + defer func() { dbURL = oldURL }() + + tmpDir := t.TempDir() + oldDataDir := dataDir + dataDir = tmpDir + defer func() { dataDir = oldDataDir }() + + // Patch maxminddb.Open to always succeed + origOpen := maxmindDBOpen + maxmindDBOpen = func(path string) (*maxminddb.Reader, error) { + return &maxminddb.Reader{}, nil + } + defer func() { maxmindDBOpen = origOpen }() + + rw := &fakeReadCloser{} + oldNewReq := newReq + newReq = func(cfg *MaxMindConfig, method string) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: rw, + }, nil + } + defer func() { newReq = oldNewReq }() + + db, err := cfg.download() + if err != nil { + t.Fatalf("download() error = %v", err) + } + if db == nil { + t.Error("expected db instance") + } + if !rw.closed { + t.Error("expected rw to be closed") + } +} + +func Test_MaxMindConfig_loadMaxMindDB(t *testing.T) { + // This test should cover both the path where DB exists and where it does not + // For brevity, only the non-existing path is tested here + cfg := &MaxMindConfig{ + AccountID: "id", + LicenseKey: "key", + Database: MaxMindGeoLite, + logger: zerolog.Nop(), + } + oldOpen := maxmindDBOpen + maxmindDBOpen = func(path string) (*maxminddb.Reader, error) { + return &maxminddb.Reader{}, nil + } + defer func() { maxmindDBOpen = oldOpen }() + + oldDBPath := dbPath + dbPath = func(MaxMindDatabaseType) string { return filepath.Join(t.TempDir(), "maxmind.mmdb") } + defer func() { dbPath = oldDBPath }() + + task := task.RootTask("test") + defer task.Finish(nil) + err := cfg.LoadMaxMindDB(task) + if err != nil { + t.Errorf("loadMaxMindDB() error = %v", err) + } +} diff --git a/internal/acl/tcp_listener.go b/internal/acl/tcp_listener.go new file mode 100644 index 00000000..c9361d1a --- /dev/null +++ b/internal/acl/tcp_listener.go @@ -0,0 +1,46 @@ +package acl + +import ( + "net" +) + +type TCPListener struct { + acl *Config + lis net.Listener +} + +func (cfg *Config) WrapTCP(lis net.Listener) net.Listener { + if cfg == nil { + return lis + } + return &TCPListener{ + acl: cfg, + lis: lis, + } +} + +func (s *TCPListener) Addr() net.Addr { + return s.lis.Addr() +} + +func (s *TCPListener) Accept() (net.Conn, error) { + c, err := s.lis.Accept() + if err != nil { + return nil, err + } + addr, ok := c.RemoteAddr().(*net.TCPAddr) + if !ok { + // Not a TCPAddr, drop + c.Close() + return nil, nil + } + if !s.acl.IPAllowed(addr.IP) { + c.Close() + return nil, nil + } + return c, nil +} + +func (s *TCPListener) Close() error { + return s.lis.Close() +} diff --git a/internal/acl/types/city_info.go b/internal/acl/types/city_info.go new file mode 100644 index 00000000..05b43152 --- /dev/null +++ b/internal/acl/types/city_info.go @@ -0,0 +1,10 @@ +package acl + +type City struct { + Location struct { + TimeZone string `maxminddb:"time_zone"` + } `maxminddb:"location"` + Country struct { + IsoCode string `maxminddb:"iso_code"` + } `maxminddb:"country"` +} diff --git a/internal/acl/types/ip_info.go b/internal/acl/types/ip_info.go new file mode 100644 index 00000000..13dec8ba --- /dev/null +++ b/internal/acl/types/ip_info.go @@ -0,0 +1,9 @@ +package acl + +import "net" + +type IPInfo struct { + IP net.IP + Str string + City *City +} diff --git a/internal/acl/udp_listener.go b/internal/acl/udp_listener.go new file mode 100644 index 00000000..ac51ee4e --- /dev/null +++ b/internal/acl/udp_listener.go @@ -0,0 +1,79 @@ +package acl + +import ( + "net" + "time" +) + +type UDPListener struct { + acl *Config + lis net.PacketConn +} + +func (cfg *Config) WrapUDP(lis net.PacketConn) net.PacketConn { + if cfg == nil { + return lis + } + return &UDPListener{ + acl: cfg, + lis: lis, + } +} + +func (s *UDPListener) LocalAddr() net.Addr { + return s.lis.LocalAddr() +} + +func (s *UDPListener) ReadFrom(p []byte) (int, net.Addr, error) { + for { + n, addr, err := s.lis.ReadFrom(p) + if err != nil { + return n, addr, err + } + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + // Not a UDPAddr, drop + continue + } + if !s.acl.IPAllowed(udpAddr.IP) { + // Drop packet from disallowed IP + continue + } + return n, addr, nil + } +} + +func (s *UDPListener) WriteTo(p []byte, addr net.Addr) (int, error) { + for { + n, err := s.lis.WriteTo(p, addr) + if err != nil { + return n, err + } + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + // Not a UDPAddr, drop + continue + } + if !s.acl.IPAllowed(udpAddr.IP) { + // Drop packet to disallowed IP + continue + } + return n, nil + } +} + +func (s *UDPListener) SetDeadline(t time.Time) error { + return s.lis.SetDeadline(t) +} + +func (s *UDPListener) SetReadDeadline(t time.Time) error { + return s.lis.SetReadDeadline(t) +} + +func (s *UDPListener) SetWriteDeadline(t time.Time) error { + return s.lis.SetWriteDeadline(t) +} + +func (s *UDPListener) Close() error { + return s.lis.Close() +} diff --git a/internal/config/config.go b/internal/config/config.go index 663c4d62..73da4583 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -197,6 +197,7 @@ func (cfg *Config) StartServers(opts ...*StartServersOptions) { HTTPAddr: common.ProxyHTTPAddr, HTTPSAddr: common.ProxyHTTPSAddr, Handler: cfg.entrypoint, + ACL: cfg.value.ACL, }) } if opt.API { @@ -237,6 +238,14 @@ func (cfg *Config) load() gperr.Error { } } cfg.entrypoint.SetFindRouteDomains(model.MatchDomains) + if model.ACL.Valid() { + err := model.ACL.Start(cfg.task) + if err != nil { + errs.Add(err) + } else { + logging.Info().Msg("ACL started") + } + } return errs.Error() } diff --git a/internal/config/types/config.go b/internal/config/types/config.go index f8e693b7..101e5717 100644 --- a/internal/config/types/config.go +++ b/internal/config/types/config.go @@ -1,4 +1,4 @@ -package types +package config import ( "context" @@ -7,15 +7,17 @@ import ( "github.com/go-playground/validator/v10" "github.com/yusing/go-proxy/agent/pkg/agent" + "github.com/yusing/go-proxy/internal/acl" "github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" + "github.com/yusing/go-proxy/internal/logging/accesslog" "github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/utils" ) type ( Config struct { + ACL *acl.Config `json:"acl"` AutoCert *autocert.AutocertConfig `json:"autocert"` Entrypoint Entrypoint `json:"entrypoint"` Providers Providers `json:"providers"` @@ -30,8 +32,11 @@ type ( Notification []notif.NotificationConfig `json:"notification" yaml:"notification,omitempty"` } Entrypoint struct { - Middlewares []map[string]any `json:"middlewares"` - AccessLog *accesslog.Config `json:"access_log" validate:"omitempty"` + Middlewares []map[string]any `json:"middlewares"` + AccessLog *accesslog.RequestLoggerConfig `json:"access_log" validate:"omitempty"` + } + HomepageConfig struct { + UseDefaultCategories bool `json:"use_default_categories"` } ConfigInstance interface { diff --git a/internal/config/types/homepage_config.go b/internal/config/types/homepage_config.go deleted file mode 100644 index 14301c8d..00000000 --- a/internal/config/types/homepage_config.go +++ /dev/null @@ -1,5 +0,0 @@ -package types - -type HomepageConfig struct { - UseDefaultCategories bool `json:"use_default_categories"` -} diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index 643588a4..0f6d2022 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -54,7 +54,7 @@ func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error { return nil } -func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) { +func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.RequestLoggerConfig) (err error) { if cfg == nil { ep.accessLogger = nil return diff --git a/internal/logging/accesslog/access_logger.go b/internal/logging/accesslog/access_logger.go index bf45847b..20eea7c5 100644 --- a/internal/logging/accesslog/access_logger.go +++ b/internal/logging/accesslog/access_logger.go @@ -8,6 +8,7 @@ import ( "time" "github.com/rs/zerolog" + acl "github.com/yusing/go-proxy/internal/acl/types" "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" @@ -17,11 +18,14 @@ import ( type ( AccessLogger struct { - task *task.Task - cfg *Config - io AccessLogIO - buffered *bufio.Writer - supportRotate bool + task *task.Task + cfg *Config + + closer []io.Closer + supportRotate []supportRotate + writer *bufio.Writer + writeLock sync.Mutex + closed bool lineBufPool *synk.BytesPool // buffer pool for formatting a single log line @@ -29,85 +33,104 @@ type ( logger zerolog.Logger - Formatter + RequestFormatter + ACLFormatter } - AccessLogIO interface { + WriterWithName interface { io.Writer - sync.Locker Name() string // file name or path } - Formatter interface { - // AppendLog appends a log line to line with or without a trailing newline - AppendLog(line []byte, req *http.Request, res *http.Response) []byte + RequestFormatter interface { + // AppendRequestLog appends a log line to line with or without a trailing newline + AppendRequestLog(line []byte, req *http.Request, res *http.Response) []byte + } + ACLFormatter interface { + // AppendACLLog appends a log line to line with or without a trailing newline + AppendACLLog(line []byte, info *acl.IPInfo, blocked bool) []byte } ) -const MinBufferSize = 4 * kilobyte +const ( + MinBufferSize = 4 * kilobyte + MaxBufferSize = 1 * megabyte +) const ( flushInterval = 30 * time.Second rotateInterval = time.Hour ) -func NewAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) { - var ios []AccessLogIO +const ( + errRateLimit = 200 * time.Millisecond + errBurst = 5 +) - if cfg.Stdout { - ios = append(ios, stdoutIO) +func NewAccessLogger(parent task.Parent, cfg AnyConfig) (*AccessLogger, error) { + io, err := cfg.IO() + if err != nil { + return nil, err } - - if cfg.Path != "" { - io, err := newFileIO(cfg.Path) - if err != nil { - return nil, err - } - ios = append(ios, io) - } - - if len(ios) == 0 { - return nil, nil - } - - return NewAccessLoggerWithIO(parent, NewMultiWriter(ios...), cfg), nil + return NewAccessLoggerWithIO(parent, io, cfg), nil } -func NewMockAccessLogger(parent task.Parent, cfg *Config) *AccessLogger { +func NewMockAccessLogger(parent task.Parent, cfg *RequestLoggerConfig) *AccessLogger { return NewAccessLoggerWithIO(parent, NewMockFile(), cfg) } -func NewAccessLoggerWithIO(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger { +func NewAccessLoggerWithIO(parent task.Parent, writer WriterWithName, anyCfg AnyConfig) *AccessLogger { + cfg := anyCfg.ToConfig() if cfg.BufferSize == 0 { cfg.BufferSize = DefaultBufferSize } if cfg.BufferSize < MinBufferSize { cfg.BufferSize = MinBufferSize } + if cfg.BufferSize > MaxBufferSize { + cfg.BufferSize = MaxBufferSize + } l := &AccessLogger{ - task: parent.Subtask("accesslog."+io.Name(), true), + task: parent.Subtask("accesslog."+writer.Name(), true), cfg: cfg, - io: io, - buffered: bufio.NewWriterSize(io, cfg.BufferSize), - lineBufPool: synk.NewBytesPool(1024, synk.DefaultMaxBytes), - errRateLimiter: rate.NewLimiter(rate.Every(time.Second), 1), - logger: logging.With().Str("file", io.Name()).Logger(), + writer: bufio.NewWriterSize(writer, cfg.BufferSize), + lineBufPool: synk.NewBytesPool(512, 8192), + errRateLimiter: rate.NewLimiter(rate.Every(errRateLimit), errBurst), + logger: logging.With().Str("file", writer.Name()).Logger(), } - fmt := CommonFormatter{cfg: &l.cfg.Fields} - switch l.cfg.Format { - case FormatCommon: - l.Formatter = &fmt - case FormatCombined: - l.Formatter = &CombinedFormatter{fmt} - case FormatJSON: - l.Formatter = &JSONFormatter{fmt} - default: // should not happen, validation has done by validate tags - panic("invalid access log format") + if unwrapped, ok := writer.(MultiWriterInterface); ok { + for _, w := range unwrapped.Unwrap() { + if sr, ok := w.(supportRotate); ok { + l.supportRotate = append(l.supportRotate, sr) + } + if closer, ok := w.(io.Closer); ok { + l.closer = append(l.closer, closer) + } + } + } else { + if sr, ok := writer.(supportRotate); ok { + l.supportRotate = append(l.supportRotate, sr) + } + if closer, ok := writer.(io.Closer); ok { + l.closer = append(l.closer, closer) + } } - if _, ok := l.io.(supportRotate); ok { - l.supportRotate = true + if cfg.req != nil { + fmt := CommonFormatter{cfg: &cfg.req.Fields} + switch cfg.req.Format { + case FormatCommon: + l.RequestFormatter = &fmt + case FormatCombined: + l.RequestFormatter = &CombinedFormatter{fmt} + case FormatJSON: + l.RequestFormatter = &JSONFormatter{fmt} + default: // should not happen, validation has done by validate tags + panic("invalid access log format") + } + } else { + l.ACLFormatter = ACLLogFormatter{} } go l.start() @@ -119,10 +142,10 @@ func (l *AccessLogger) Config() *Config { } func (l *AccessLogger) shouldLog(req *http.Request, res *http.Response) bool { - if !l.cfg.Filters.StatusCodes.CheckKeep(req, res) || - !l.cfg.Filters.Method.CheckKeep(req, res) || - !l.cfg.Filters.Headers.CheckKeep(req, res) || - !l.cfg.Filters.CIDR.CheckKeep(req, res) { + if !l.cfg.req.Filters.StatusCodes.CheckKeep(req, res) || + !l.cfg.req.Filters.Method.CheckKeep(req, res) || + !l.cfg.req.Filters.Headers.CheckKeep(req, res) || + !l.cfg.req.Filters.CIDR.CheckKeep(req, res) { return false } return true @@ -135,19 +158,29 @@ func (l *AccessLogger) Log(req *http.Request, res *http.Response) { line := l.lineBufPool.Get() defer l.lineBufPool.Put(line) - line = l.Formatter.AppendLog(line, req, res) + line = l.AppendRequestLog(line, req, res) if line[len(line)-1] != '\n' { line = append(line, '\n') } - l.lockWrite(line) + l.write(line) } func (l *AccessLogger) LogError(req *http.Request, err error) { l.Log(req, &http.Response{StatusCode: http.StatusInternalServerError, Status: err.Error()}) } +func (l *AccessLogger) LogACL(info *acl.IPInfo, blocked bool) { + line := l.lineBufPool.Get() + defer l.lineBufPool.Put(line) + line = l.ACLFormatter.AppendACLLog(line, info, blocked) + if line[len(line)-1] != '\n' { + line = append(line, '\n') + } + l.write(line) +} + func (l *AccessLogger) ShouldRotate() bool { - return l.cfg.Retention.IsValid() && l.supportRotate + return l.supportRotate != nil && l.cfg.Retention.IsValid() } func (l *AccessLogger) Rotate() (result *RotateResult, err error) { @@ -155,10 +188,21 @@ func (l *AccessLogger) Rotate() (result *RotateResult, err error) { return nil, nil } - l.io.Lock() - defer l.io.Unlock() + l.writer.Flush() + l.writeLock.Lock() + defer l.writeLock.Unlock() - return rotateLogFile(l.io.(supportRotate), l.cfg.Retention) + result = new(RotateResult) + for _, sr := range l.supportRotate { + r, err := rotateLogFile(sr, l.cfg.Retention) + if err != nil { + return nil, err + } + if r != nil { + result.Add(r) + } + } + return result, nil } func (l *AccessLogger) handleErr(err error) { @@ -172,11 +216,9 @@ func (l *AccessLogger) handleErr(err error) { func (l *AccessLogger) start() { defer func() { - defer l.task.Finish(nil) - defer l.close() - if err := l.Flush(); err != nil { - l.handleErr(err) - } + l.Flush() + l.Close() + l.task.Finish(nil) }() // flushes the buffer every 30 seconds @@ -191,9 +233,7 @@ func (l *AccessLogger) start() { case <-l.task.Context().Done(): return case <-flushTicker.C: - if err := l.Flush(); err != nil { - l.handleErr(err) - } + l.Flush() case <-rotateTicker.C: if !l.ShouldRotate() { continue @@ -210,27 +250,40 @@ func (l *AccessLogger) start() { } } -func (l *AccessLogger) Flush() error { - l.io.Lock() - defer l.io.Unlock() - return l.buffered.Flush() +func (l *AccessLogger) Close() error { + l.writeLock.Lock() + defer l.writeLock.Unlock() + if l.closed { + return nil + } + if l.closer != nil { + for _, c := range l.closer { + c.Close() + } + } + l.closed = true + return nil } -func (l *AccessLogger) close() { - if r, ok := l.io.(io.Closer); ok { - l.io.Lock() - defer l.io.Unlock() - r.Close() +func (l *AccessLogger) Flush() { + l.writeLock.Lock() + defer l.writeLock.Unlock() + if l.closed { + return + } + if err := l.writer.Flush(); err != nil { + l.handleErr(err) } } -func (l *AccessLogger) lockWrite(data []byte) { - l.io.Lock() // prevent concurrent write, i.e. log rotation, other access loggers - _, err := l.buffered.Write(data) - l.io.Unlock() +func (l *AccessLogger) write(data []byte) { + l.writeLock.Lock() + defer l.writeLock.Unlock() + if l.closed { + return + } + _, err := l.writer.Write(data) if err != nil { l.handleErr(err) - } else { - logging.Trace().Msg("access log flushed to " + l.io.Name()) } } diff --git a/internal/logging/accesslog/access_logger_test.go b/internal/logging/accesslog/access_logger_test.go index 9d01ca25..c25c1205 100644 --- a/internal/logging/accesslog/access_logger_test.go +++ b/internal/logging/accesslog/access_logger_test.go @@ -52,18 +52,18 @@ var ( } ) -func fmtLog(cfg *Config) (ts string, line string) { +func fmtLog(cfg *RequestLoggerConfig) (ts string, line string) { buf := make([]byte, 0, 1024) t := time.Now() logger := NewMockAccessLogger(testTask, cfg) utils.MockTimeNow(t) - buf = logger.AppendLog(buf, req, resp) + buf = logger.AppendRequestLog(buf, req, resp) return t.Format(LogTimeFormat), string(buf) } func TestAccessLoggerCommon(t *testing.T) { - config := DefaultConfig() + config := DefaultRequestLoggerConfig() config.Format = FormatCommon ts, log := fmtLog(config) expect.Equal(t, log, @@ -74,7 +74,7 @@ func TestAccessLoggerCommon(t *testing.T) { } func TestAccessLoggerCombined(t *testing.T) { - config := DefaultConfig() + config := DefaultRequestLoggerConfig() config.Format = FormatCombined ts, log := fmtLog(config) expect.Equal(t, log, @@ -85,7 +85,7 @@ func TestAccessLoggerCombined(t *testing.T) { } func TestAccessLoggerRedactQuery(t *testing.T) { - config := DefaultConfig() + config := DefaultRequestLoggerConfig() config.Format = FormatCommon config.Fields.Query.Default = FieldModeRedact ts, log := fmtLog(config) @@ -115,7 +115,7 @@ type JSONLogEntry struct { Cookies map[string]string `json:"cookies,omitempty"` } -func getJSONEntry(t *testing.T, config *Config) JSONLogEntry { +func getJSONEntry(t *testing.T, config *RequestLoggerConfig) JSONLogEntry { t.Helper() config.Format = FormatJSON var entry JSONLogEntry @@ -126,7 +126,7 @@ func getJSONEntry(t *testing.T, config *Config) JSONLogEntry { } func TestAccessLoggerJSON(t *testing.T) { - config := DefaultConfig() + config := DefaultRequestLoggerConfig() entry := getJSONEntry(t, config) expect.Equal(t, entry.IP, remote) expect.Equal(t, entry.Method, method) @@ -147,7 +147,7 @@ func TestAccessLoggerJSON(t *testing.T) { } func BenchmarkAccessLoggerJSON(b *testing.B) { - config := DefaultConfig() + config := DefaultRequestLoggerConfig() config.Format = FormatJSON logger := NewMockAccessLogger(testTask, config) b.ResetTimer() @@ -157,7 +157,7 @@ func BenchmarkAccessLoggerJSON(b *testing.B) { } func BenchmarkAccessLoggerCombined(b *testing.B) { - config := DefaultConfig() + config := DefaultRequestLoggerConfig() config.Format = FormatCombined logger := NewMockAccessLogger(testTask, config) b.ResetTimer() diff --git a/internal/logging/accesslog/back_scanner.go b/internal/logging/accesslog/back_scanner.go index bf17a1c7..2e93ac26 100644 --- a/internal/logging/accesslog/back_scanner.go +++ b/internal/logging/accesslog/back_scanner.go @@ -6,9 +6,14 @@ import ( "io" ) +type ReaderAtSeeker interface { + io.ReaderAt + io.Seeker +} + // BackScanner provides an interface to read a file backward line by line. type BackScanner struct { - file supportRotate + file ReaderAtSeeker size int64 chunkSize int chunkBuf []byte @@ -21,7 +26,7 @@ type BackScanner struct { // NewBackScanner creates a new Scanner to read the file backward. // chunkSize determines the size of each read chunk from the end of the file. -func NewBackScanner(file supportRotate, chunkSize int) *BackScanner { +func NewBackScanner(file ReaderAtSeeker, chunkSize int) *BackScanner { size, err := file.Seek(0, io.SeekEnd) if err != nil { return &BackScanner{err: err} @@ -29,7 +34,7 @@ func NewBackScanner(file supportRotate, chunkSize int) *BackScanner { return newBackScanner(file, size, make([]byte, chunkSize)) } -func newBackScanner(file supportRotate, fileSize int64, buf []byte) *BackScanner { +func newBackScanner(file ReaderAtSeeker, fileSize int64, buf []byte) *BackScanner { return &BackScanner{ file: file, size: fileSize, diff --git a/internal/logging/accesslog/back_scanner_test.go b/internal/logging/accesslog/back_scanner_test.go index 59e4ca67..02f00000 100644 --- a/internal/logging/accesslog/back_scanner_test.go +++ b/internal/logging/accesslog/back_scanner_test.go @@ -135,7 +135,7 @@ func TestBackScannerWithVaryingChunkSizes(t *testing.T) { } func logEntry() []byte { - accesslog := NewMockAccessLogger(task.RootTask("test", false), &Config{ + accesslog := NewMockAccessLogger(task.RootTask("test", false), &RequestLoggerConfig{ Format: FormatJSON, }) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -148,7 +148,7 @@ func logEntry() []byte { res := httptest.NewRecorder() // server the request srv.Config.Handler.ServeHTTP(res, req) - b := accesslog.AppendLog(nil, req, res.Result()) + b := accesslog.AppendRequestLog(nil, req, res.Result()) if b[len(b)-1] != '\n' { b = append(b, '\n') } diff --git a/internal/logging/accesslog/config.go b/internal/logging/accesslog/config.go index 91455c0d..73952131 100644 --- a/internal/logging/accesslog/config.go +++ b/internal/logging/accesslog/config.go @@ -6,6 +6,32 @@ import ( ) type ( + ConfigBase struct { + BufferSize int `json:"buffer_size"` + Path string `json:"path"` + Stdout bool `json:"stdout"` + Retention *Retention `json:"retention" aliases:"keep"` + } + ACLLoggerConfig struct { + ConfigBase + LogAllowed bool `json:"log_allowed"` + } + RequestLoggerConfig struct { + ConfigBase + Format Format `json:"format" validate:"oneof=common combined json"` + Filters Filters `json:"filters"` + Fields Fields `json:"fields"` + } + Config struct { + *ConfigBase + acl *ACLLoggerConfig + req *RequestLoggerConfig + } + AnyConfig interface { + ToConfig() *Config + IO() (WriterWithName, error) + } + Format string Filters struct { StatusCodes LogFilter[*StatusCodeRange] `json:"status_codes"` @@ -19,15 +45,6 @@ type ( Query FieldConfig `json:"query"` Cookies FieldConfig `json:"cookies"` } - Config struct { - BufferSize int `json:"buffer_size"` - Format Format `json:"format" validate:"oneof=common combined json"` - Path string `json:"path"` - Stdout bool `json:"stdout"` - Filters Filters `json:"filters"` - Fields Fields `json:"fields"` - Retention *Retention `json:"retention"` - } ) var ( @@ -35,23 +52,57 @@ var ( FormatCombined Format = "combined" FormatJSON Format = "json" - AvailableFormats = []Format{FormatCommon, FormatCombined, FormatJSON} + ReqLoggerFormats = []Format{FormatCommon, FormatCombined, FormatJSON} ) const DefaultBufferSize = 64 * kilobyte // 64KB -func (cfg *Config) Validate() gperr.Error { +func (cfg *ConfigBase) Validate() gperr.Error { if cfg.Path == "" && !cfg.Stdout { return gperr.New("path or stdout is required") } return nil } -func DefaultConfig() *Config { +func (cfg *ConfigBase) IO() (WriterWithName, error) { + ios := make([]WriterWithName, 0, 2) + if cfg.Stdout { + ios = append(ios, stdoutIO) + } + if cfg.Path != "" { + io, err := newFileIO(cfg.Path) + if err != nil { + return nil, err + } + ios = append(ios, io) + } + if len(ios) == 0 { + return nil, nil + } + return NewMultiWriter(ios...), nil +} + +func (cfg *ACLLoggerConfig) ToConfig() *Config { return &Config{ - BufferSize: DefaultBufferSize, - Format: FormatCombined, - Retention: &Retention{Days: 30}, + ConfigBase: &cfg.ConfigBase, + acl: cfg, + } +} + +func (cfg *RequestLoggerConfig) ToConfig() *Config { + return &Config{ + ConfigBase: &cfg.ConfigBase, + req: cfg, + } +} + +func DefaultRequestLoggerConfig() *RequestLoggerConfig { + return &RequestLoggerConfig{ + ConfigBase: ConfigBase{ + BufferSize: DefaultBufferSize, + Retention: &Retention{Days: 30}, + }, + Format: FormatCombined, Fields: Fields{ Headers: FieldConfig{ Default: FieldModeDrop, @@ -66,6 +117,16 @@ func DefaultConfig() *Config { } } -func init() { - utils.RegisterDefaultValueFactory(DefaultConfig) +func DefaultACLLoggerConfig() *ACLLoggerConfig { + return &ACLLoggerConfig{ + ConfigBase: ConfigBase{ + BufferSize: DefaultBufferSize, + Retention: &Retention{Days: 30}, + }, + } +} + +func init() { + utils.RegisterDefaultValueFactory(DefaultRequestLoggerConfig) + utils.RegisterDefaultValueFactory(DefaultACLLoggerConfig) } diff --git a/internal/logging/accesslog/config_test.go b/internal/logging/accesslog/config_test.go index 23e76c7e..55adde48 100644 --- a/internal/logging/accesslog/config_test.go +++ b/internal/logging/accesslog/config_test.go @@ -29,7 +29,7 @@ func TestNewConfig(t *testing.T) { parsed, err := docker.ParseLabels(labels) expect.NoError(t, err) - var config Config + var config RequestLoggerConfig err = utils.Deserialize(parsed, &config) expect.NoError(t, err) diff --git a/internal/logging/accesslog/fields_test.go b/internal/logging/accesslog/fields_test.go index 9ab62bd2..662e924b 100644 --- a/internal/logging/accesslog/fields_test.go +++ b/internal/logging/accesslog/fields_test.go @@ -10,7 +10,7 @@ import ( // Cookie header should be removed, // stored in JSONLogEntry.Cookies instead. func TestAccessLoggerJSONKeepHeaders(t *testing.T) { - config := DefaultConfig() + config := DefaultRequestLoggerConfig() config.Fields.Headers.Default = FieldModeKeep entry := getJSONEntry(t, config) for k, v := range req.Header { @@ -29,7 +29,7 @@ func TestAccessLoggerJSONKeepHeaders(t *testing.T) { } func TestAccessLoggerJSONDropHeaders(t *testing.T) { - config := DefaultConfig() + config := DefaultRequestLoggerConfig() config.Fields.Headers.Default = FieldModeDrop entry := getJSONEntry(t, config) for k := range req.Header { @@ -46,7 +46,7 @@ func TestAccessLoggerJSONDropHeaders(t *testing.T) { } func TestAccessLoggerJSONRedactHeaders(t *testing.T) { - config := DefaultConfig() + config := DefaultRequestLoggerConfig() config.Fields.Headers.Default = FieldModeRedact entry := getJSONEntry(t, config) for k := range req.Header { @@ -57,7 +57,7 @@ func TestAccessLoggerJSONRedactHeaders(t *testing.T) { } func TestAccessLoggerJSONKeepCookies(t *testing.T) { - config := DefaultConfig() + config := DefaultRequestLoggerConfig() config.Fields.Headers.Default = FieldModeKeep config.Fields.Cookies.Default = FieldModeKeep entry := getJSONEntry(t, config) @@ -67,7 +67,7 @@ func TestAccessLoggerJSONKeepCookies(t *testing.T) { } func TestAccessLoggerJSONRedactCookies(t *testing.T) { - config := DefaultConfig() + config := DefaultRequestLoggerConfig() config.Fields.Headers.Default = FieldModeKeep config.Fields.Cookies.Default = FieldModeRedact entry := getJSONEntry(t, config) @@ -77,7 +77,7 @@ func TestAccessLoggerJSONRedactCookies(t *testing.T) { } func TestAccessLoggerJSONDropQuery(t *testing.T) { - config := DefaultConfig() + config := DefaultRequestLoggerConfig() config.Fields.Query.Default = FieldModeDrop entry := getJSONEntry(t, config) expect.Equal(t, entry.Query["foo"], nil) @@ -85,7 +85,7 @@ func TestAccessLoggerJSONDropQuery(t *testing.T) { } func TestAccessLoggerJSONRedactQuery(t *testing.T) { - config := DefaultConfig() + config := DefaultRequestLoggerConfig() config.Fields.Query.Default = FieldModeRedact entry := getJSONEntry(t, config) expect.Equal(t, entry.Query["foo"], []string{RedactedValue}) diff --git a/internal/logging/accesslog/file_logger.go b/internal/logging/accesslog/file_logger.go index a3679ac5..b1691a33 100644 --- a/internal/logging/accesslog/file_logger.go +++ b/internal/logging/accesslog/file_logger.go @@ -12,7 +12,6 @@ import ( type File struct { *os.File - sync.Mutex // os.File.Name() may not equal to key of `openedFiles`. // Store it for later delete from `openedFiles`. @@ -26,18 +25,18 @@ var ( openedFilesMu sync.Mutex ) -func newFileIO(path string) (AccessLogIO, error) { +func newFileIO(path string) (WriterWithName, error) { openedFilesMu.Lock() + defer openedFilesMu.Unlock() var file *File path = pathPkg.Clean(path) if opened, ok := openedFiles[path]; ok { opened.refCount.Add() - file = opened + return opened, nil } else { f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o644) if err != nil { - openedFilesMu.Unlock() return nil, fmt.Errorf("access log open error: %w", err) } file = &File{File: f, path: path, refCount: utils.NewRefCounter()} @@ -45,7 +44,6 @@ func newFileIO(path string) (AccessLogIO, error) { go file.closeOnZero() } - openedFilesMu.Unlock() return file, nil } diff --git a/internal/logging/accesslog/file_logger_test.go b/internal/logging/accesslog/file_logger_test.go index b9961f63..884214b9 100644 --- a/internal/logging/accesslog/file_logger_test.go +++ b/internal/logging/accesslog/file_logger_test.go @@ -14,11 +14,11 @@ import ( func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) { var wg sync.WaitGroup - cfg := DefaultConfig() + cfg := DefaultRequestLoggerConfig() cfg.Path = "test.log" loggerCount := 10 - accessLogIOs := make([]AccessLogIO, loggerCount) + accessLogIOs := make([]WriterWithName, loggerCount) // make test log file file, err := os.Create(cfg.Path) @@ -49,7 +49,7 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) { func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) { file := NewMockFile() - cfg := DefaultConfig() + cfg := DefaultRequestLoggerConfig() cfg.BufferSize = 1024 parent := task.RootTask("test", false) diff --git a/internal/logging/accesslog/formatter.go b/internal/logging/accesslog/formatter.go index 4d95b91a..6b553ffa 100644 --- a/internal/logging/accesslog/formatter.go +++ b/internal/logging/accesslog/formatter.go @@ -8,6 +8,7 @@ import ( "strconv" "github.com/rs/zerolog" + acl "github.com/yusing/go-proxy/internal/acl/types" "github.com/yusing/go-proxy/internal/utils" ) @@ -17,6 +18,7 @@ type ( } CombinedFormatter struct{ CommonFormatter } JSONFormatter struct{ CommonFormatter } + ACLLogFormatter struct{} ) const LogTimeFormat = "02/Jan/2006:15:04:05 -0700" @@ -56,7 +58,7 @@ func clientIP(req *http.Request) string { return req.RemoteAddr } -func (f *CommonFormatter) AppendLog(line []byte, req *http.Request, res *http.Response) []byte { +func (f *CommonFormatter) AppendRequestLog(line []byte, req *http.Request, res *http.Response) []byte { query := f.cfg.Query.IterQuery(req.URL.Query()) line = append(line, req.Host...) @@ -82,8 +84,8 @@ func (f *CommonFormatter) AppendLog(line []byte, req *http.Request, res *http.Re return line } -func (f *CombinedFormatter) AppendLog(line []byte, req *http.Request, res *http.Response) []byte { - line = f.CommonFormatter.AppendLog(line, req, res) +func (f *CombinedFormatter) AppendRequestLog(line []byte, req *http.Request, res *http.Response) []byte { + line = f.CommonFormatter.AppendRequestLog(line, req, res) line = append(line, " \""...) line = append(line, req.Referer()...) line = append(line, "\" \""...) @@ -118,14 +120,14 @@ func (z *zeroLogStringStringSliceMapMarshaler) MarshalZerologObject(e *zerolog.E } } -func (f *JSONFormatter) AppendLog(line []byte, req *http.Request, res *http.Response) []byte { +func (f *JSONFormatter) AppendRequestLog(line []byte, req *http.Request, res *http.Response) []byte { query := f.cfg.Query.ZerologQuery(req.URL.Query()) headers := f.cfg.Headers.ZerologHeaders(req.Header) cookies := f.cfg.Cookies.ZerologCookies(req.Cookies()) contentType := res.Header.Get("Content-Type") writer := bytes.NewBuffer(line) - logger := zerolog.New(writer).With().Logger() + logger := zerolog.New(writer) event := logger.Info(). Str("time", utils.TimeNow().Format(LogTimeFormat)). Str("ip", clientIP(req)). @@ -155,3 +157,23 @@ func (f *JSONFormatter) AppendLog(line []byte, req *http.Request, res *http.Resp event.Send() return writer.Bytes() } + +func (f ACLLogFormatter) AppendACLLog(line []byte, info *acl.IPInfo, blocked bool) []byte { + writer := bytes.NewBuffer(line) + logger := zerolog.New(writer) + event := logger.Info(). + Str("time", utils.TimeNow().Format(LogTimeFormat)). + Str("ip", info.Str) + if blocked { + event.Str("action", "block") + } else { + event.Str("action", "allow") + } + if info.City != nil { + event.Str("iso_code", info.City.Country.IsoCode) + event.Str("time_zone", info.City.Location.TimeZone) + } + // NOTE: zerolog will append a newline to the buffer + event.Send() + return writer.Bytes() +} diff --git a/internal/logging/accesslog/multi_writer.go b/internal/logging/accesslog/multi_writer.go index 3577bc48..830e04ba 100644 --- a/internal/logging/accesslog/multi_writer.go +++ b/internal/logging/accesslog/multi_writer.go @@ -1,12 +1,19 @@ package accesslog -import "strings" +import ( + "io" + "strings" +) type MultiWriter struct { - writers []AccessLogIO + writers []WriterWithName } -func NewMultiWriter(writers ...AccessLogIO) AccessLogIO { +type MultiWriterInterface interface { + Unwrap() []io.Writer +} + +func NewMultiWriter(writers ...WriterWithName) WriterWithName { if len(writers) == 0 { return nil } @@ -18,6 +25,14 @@ func NewMultiWriter(writers ...AccessLogIO) AccessLogIO { } } +func (w *MultiWriter) Unwrap() []io.Writer { + writers := make([]io.Writer, len(w.writers)) + for i, writer := range w.writers { + writers[i] = writer + } + return writers +} + func (w *MultiWriter) Write(p []byte) (n int, err error) { for _, writer := range w.writers { writer.Write(p) @@ -25,18 +40,6 @@ func (w *MultiWriter) Write(p []byte) (n int, err error) { return len(p), nil } -func (w *MultiWriter) Lock() { - for _, writer := range w.writers { - writer.Lock() - } -} - -func (w *MultiWriter) Unlock() { - for _, writer := range w.writers { - writer.Unlock() - } -} - func (w *MultiWriter) Name() string { names := make([]string, len(w.writers)) for i, writer := range w.writers { diff --git a/internal/logging/accesslog/rotate.go b/internal/logging/accesslog/rotate.go index 469038f4..f21348c6 100644 --- a/internal/logging/accesslog/rotate.go +++ b/internal/logging/accesslog/rotate.go @@ -12,9 +12,7 @@ import ( ) type supportRotate interface { - io.Reader - io.Writer - io.Seeker + io.ReadSeeker io.ReaderAt io.WriterAt Truncate(size int64) error @@ -41,6 +39,14 @@ func (r *RotateResult) Print(logger *zerolog.Logger) { Msg("log rotate result") } +func (r *RotateResult) Add(other *RotateResult) { + r.NumBytesRead += other.NumBytesRead + r.NumBytesKeep += other.NumBytesKeep + r.NumLinesRead += other.NumLinesRead + r.NumLinesKeep += other.NumLinesKeep + r.NumLinesInvalid += other.NumLinesInvalid +} + type lineInfo struct { Pos int64 // Position from the start of the file Size int64 // Size of this line diff --git a/internal/logging/accesslog/rotate_test.go b/internal/logging/accesslog/rotate_test.go index 8897d8be..6da053e5 100644 --- a/internal/logging/accesslog/rotate_test.go +++ b/internal/logging/accesslog/rotate_test.go @@ -53,11 +53,11 @@ func TestParseLogTime(t *testing.T) { } func TestRotateKeepLast(t *testing.T) { - for _, format := range AvailableFormats { + for _, format := range ReqLoggerFormats { t.Run(string(format)+" keep last", func(t *testing.T) { file := NewMockFile() utils.MockTimeNow(testTime) - logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{ + logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{ Format: format, }) expect.Nil(t, logger.Config().Retention) @@ -65,7 +65,7 @@ func TestRotateKeepLast(t *testing.T) { for range 10 { logger.Log(req, resp) } - expect.NoError(t, logger.Flush()) + logger.Flush() expect.Greater(t, file.Len(), int64(0)) expect.Equal(t, file.NumLines(), 10) @@ -85,7 +85,7 @@ func TestRotateKeepLast(t *testing.T) { t.Run(string(format)+" keep days", func(t *testing.T) { file := NewMockFile() - logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{ + logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{ Format: format, }) expect.Nil(t, logger.Config().Retention) @@ -127,10 +127,10 @@ func TestRotateKeepLast(t *testing.T) { } func TestRotateKeepFileSize(t *testing.T) { - for _, format := range AvailableFormats { + for _, format := range ReqLoggerFormats { t.Run(string(format)+" keep size no rotation", func(t *testing.T) { file := NewMockFile() - logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{ + logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{ Format: format, }) expect.Nil(t, logger.Config().Retention) @@ -160,7 +160,7 @@ func TestRotateKeepFileSize(t *testing.T) { t.Run("keep size with rotation", func(t *testing.T) { file := NewMockFile() - logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{ + logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{ Format: FormatJSON, }) expect.Nil(t, logger.Config().Retention) @@ -189,10 +189,10 @@ func TestRotateKeepFileSize(t *testing.T) { // skipping invalid lines is not supported for keep file_size func TestRotateSkipInvalidTime(t *testing.T) { - for _, format := range AvailableFormats { + for _, format := range ReqLoggerFormats { t.Run(string(format), func(t *testing.T) { file := NewMockFile() - logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{ + logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{ Format: format, }) expect.Nil(t, logger.Config().Retention) @@ -232,9 +232,11 @@ func BenchmarkRotate(b *testing.B) { for _, retention := range tests { b.Run(fmt.Sprintf("retention_%s", retention), func(b *testing.B) { file := NewMockFile() - logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{ - Format: FormatJSON, - Retention: retention, + logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{ + ConfigBase: ConfigBase{ + Retention: retention, + }, + Format: FormatJSON, }) for i := range 100 { utils.MockTimeNow(testTime.AddDate(0, 0, -100+i+1)) @@ -263,9 +265,11 @@ func BenchmarkRotateWithInvalidTime(b *testing.B) { for _, retention := range tests { b.Run(fmt.Sprintf("retention_%s", retention), func(b *testing.B) { file := NewMockFile() - logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{ - Format: FormatJSON, - Retention: retention, + logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{ + ConfigBase: ConfigBase{ + Retention: retention, + }, + Format: FormatJSON, }) for i := range 10000 { utils.MockTimeNow(testTime.AddDate(0, 0, -10000+i+1)) diff --git a/internal/logging/accesslog/stdout_logger.go b/internal/logging/accesslog/stdout_logger.go index 2e1f2456..30d7ca7a 100644 --- a/internal/logging/accesslog/stdout_logger.go +++ b/internal/logging/accesslog/stdout_logger.go @@ -11,8 +11,6 @@ type StdoutLogger struct { var stdoutIO = &StdoutLogger{os.Stdout} -func (l *StdoutLogger) Lock() {} -func (l *StdoutLogger) Unlock() {} func (l *StdoutLogger) Name() string { return "stdout" } diff --git a/internal/net/gphttp/server/server.go b/internal/net/gphttp/server/server.go index 92408af9..8a675d06 100644 --- a/internal/net/gphttp/server/server.go +++ b/internal/net/gphttp/server/server.go @@ -9,6 +9,7 @@ import ( "github.com/quic-go/quic-go/http3" "github.com/rs/zerolog" + "github.com/yusing/go-proxy/internal/acl" "github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/logging" @@ -21,6 +22,7 @@ type Server struct { http *http.Server https *http.Server startTime time.Time + acl *acl.Config l zerolog.Logger } @@ -31,6 +33,7 @@ type Options struct { HTTPSAddr string CertProvider *autocert.Provider Handler http.Handler + ACL *acl.Config } type httpServer interface { @@ -76,6 +79,7 @@ func NewServer(opt Options) (s *Server) { http: httpSer, https: httpsSer, l: logger, + acl: opt.ACL, } } @@ -95,16 +99,16 @@ func (s *Server) Start(parent task.Parent) { Handler: s.https.Handler, TLSConfig: http3.ConfigureTLSConfig(s.https.TLSConfig), } - Start(subtask, h3, &s.l) + Start(subtask, h3, s.acl, &s.l) s.http.Handler = advertiseHTTP3(s.http.Handler, h3) s.https.Handler = advertiseHTTP3(s.https.Handler, h3) } - Start(subtask, s.http, &s.l) - Start(subtask, s.https, &s.l) + Start(subtask, s.http, s.acl, &s.l) + Start(subtask, s.https, s.acl, &s.l) } -func Start[Server httpServer](parent task.Parent, srv Server, logger *zerolog.Logger) { +func Start[Server httpServer](parent task.Parent, srv Server, acl *acl.Config, logger *zerolog.Logger) { if srv == nil { return } @@ -130,6 +134,9 @@ func Start[Server httpServer](parent task.Parent, srv Server, logger *zerolog.Lo if srv.TLSConfig != nil { l = tls.NewListener(l, srv.TLSConfig) } + if acl != nil { + l = acl.WrapTCP(l) + } serveFunc = getServeFunc(l, srv.Serve) case *http3.Server: l, err := lc.ListenPacket(task.Context(), "udp", srv.Addr) @@ -137,6 +144,9 @@ func Start[Server httpServer](parent task.Parent, srv Server, logger *zerolog.Lo HandleError(logger, err, "failed to listen on port") return } + if acl != nil { + l = acl.WrapUDP(l) + } serveFunc = getServeFunc(l, srv.Serve) } task.OnCancel("stop", func() { diff --git a/internal/route/route.go b/internal/route/route.go index fb866fa4..36e1af20 100644 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -42,13 +42,13 @@ type ( Root string `json:"root,omitempty"` route.HTTPConfig - PathPatterns []string `json:"path_patterns,omitempty"` - Rules rules.Rules `json:"rules,omitempty" validate:"omitempty,unique=Name"` - HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"` - LoadBalance *loadbalance.Config `json:"load_balance,omitempty"` - Middlewares map[string]docker.LabelMap `json:"middlewares,omitempty"` - Homepage *homepage.ItemConfig `json:"homepage,omitempty"` - AccessLog *accesslog.Config `json:"access_log,omitempty"` + PathPatterns []string `json:"path_patterns,omitempty"` + Rules rules.Rules `json:"rules,omitempty" validate:"omitempty,unique=Name"` + HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"` + LoadBalance *loadbalance.Config `json:"load_balance,omitempty"` + Middlewares map[string]docker.LabelMap `json:"middlewares,omitempty"` + Homepage *homepage.ItemConfig `json:"homepage,omitempty"` + AccessLog *accesslog.RequestLoggerConfig `json:"access_log,omitempty"` Idlewatcher *idlewatcher.Config `json:"idlewatcher,omitempty"`