mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-24 09:18:31 +02:00
breaking: move maxmind config to config.providers
- moved maxmind to separate module - code refactored - simplified test
This commit is contained in:
@@ -1,37 +0,0 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
acl "github.com/yusing/go-proxy/internal/acl/types"
|
||||
)
|
||||
|
||||
var cityCache = xsync.NewMapOf[string, *acl.City]()
|
||||
|
||||
func (cfg *MaxMindConfig) lookupCity(ip *acl.IPInfo) (*acl.City, bool) {
|
||||
if ip.City != nil {
|
||||
return ip.City, true
|
||||
}
|
||||
|
||||
if cfg.db.Reader == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
city, ok := cityCache.Load(ip.Str)
|
||||
if ok {
|
||||
ip.City = city
|
||||
return city, true
|
||||
}
|
||||
|
||||
cfg.db.RLock()
|
||||
defer cfg.db.RUnlock()
|
||||
|
||||
city = new(acl.City)
|
||||
err := cfg.db.Lookup(ip.IP, city)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
cityCache.Store(ip.Str, city)
|
||||
ip.City = city
|
||||
return city, true
|
||||
}
|
||||
@@ -2,17 +2,13 @@ package acl
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
"github.com/rs/zerolog"
|
||||
acl "github.com/yusing/go-proxy/internal/acl/types"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/logging/accesslog"
|
||||
"github.com/yusing/go-proxy/internal/maxmind"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
@@ -20,43 +16,23 @@ import (
|
||||
type Config struct {
|
||||
Default string `json:"default" validate:"omitempty,oneof=allow deny"` // default: allow
|
||||
AllowLocal *bool `json:"allow_local"` // default: true
|
||||
Allow []string `json:"allow"`
|
||||
Deny []string `json:"deny"`
|
||||
Allow Matchers `json:"allow"`
|
||||
Deny Matchers `json:"deny"`
|
||||
Log *accesslog.ACLLoggerConfig `json:"log"`
|
||||
|
||||
MaxMind *MaxMindConfig `json:"maxmind" validate:"omitempty"`
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
type (
|
||||
MaxMindDatabaseType string
|
||||
MaxMindConfig struct {
|
||||
AccountID string `json:"account_id" validate:"required"`
|
||||
LicenseKey string `json:"license_key" validate:"required"`
|
||||
Database MaxMindDatabaseType `json:"database" validate:"required,oneof=geolite geoip2"`
|
||||
|
||||
logger zerolog.Logger
|
||||
lastUpdate time.Time
|
||||
db struct {
|
||||
*maxminddb.Reader
|
||||
sync.RWMutex
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
type config struct {
|
||||
defaultAllow bool
|
||||
allowLocal bool
|
||||
allow []matcher
|
||||
deny []matcher
|
||||
ipCache *xsync.MapOf[string, *checkCache]
|
||||
logAllowed bool
|
||||
logger *accesslog.AccessLogger
|
||||
}
|
||||
|
||||
type checkCache struct {
|
||||
*acl.IPInfo
|
||||
*maxmind.IPInfo
|
||||
allow bool
|
||||
created time.Time
|
||||
}
|
||||
@@ -74,11 +50,6 @@ const (
|
||||
ACLDeny = "deny"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxMindGeoLite MaxMindDatabaseType = "geolite"
|
||||
MaxMindGeoIP2 MaxMindDatabaseType = "geoip2"
|
||||
)
|
||||
|
||||
func (c *Config) Validate() gperr.Error {
|
||||
switch c.Default {
|
||||
case "", ACLAllow:
|
||||
@@ -95,55 +66,19 @@ func (c *Config) Validate() gperr.Error {
|
||||
c.allowLocal = true
|
||||
}
|
||||
|
||||
if c.MaxMind != nil {
|
||||
c.MaxMind.logger = logging.With().Str("type", string(c.MaxMind.Database)).Logger()
|
||||
}
|
||||
|
||||
if c.Log != nil {
|
||||
c.logAllowed = c.Log.LogAllowed
|
||||
}
|
||||
|
||||
errs := gperr.NewBuilder("syntax error")
|
||||
c.allow = make([]matcher, 0, len(c.Allow))
|
||||
c.deny = make([]matcher, 0, len(c.Deny))
|
||||
|
||||
for _, s := range c.Allow {
|
||||
m, err := c.parseMatcher(s)
|
||||
if err != nil {
|
||||
errs.Add(err.Subject(s))
|
||||
continue
|
||||
}
|
||||
c.allow = append(c.allow, m)
|
||||
}
|
||||
for _, s := range c.Deny {
|
||||
m, err := c.parseMatcher(s)
|
||||
if err != nil {
|
||||
errs.Add(err.Subject(s))
|
||||
continue
|
||||
}
|
||||
c.deny = append(c.deny, m)
|
||||
}
|
||||
|
||||
if errs.HasError() {
|
||||
c.allow = nil
|
||||
c.deny = nil
|
||||
return errMatcherFormat.With(errs.Error())
|
||||
}
|
||||
|
||||
c.ipCache = xsync.NewMapOf[string, *checkCache]()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) Valid() bool {
|
||||
return c != nil && (len(c.allow) > 0 || len(c.deny) > 0 || c.allowLocal)
|
||||
return c != nil && (len(c.Allow) > 0 || len(c.Deny) > 0 || c.allowLocal)
|
||||
}
|
||||
|
||||
func (c *Config) Start(parent *task.Task) gperr.Error {
|
||||
if c.MaxMind != nil {
|
||||
if err := c.MaxMind.LoadMaxMindDB(parent); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.Log != nil {
|
||||
logger, err := accesslog.NewAccessLogger(parent, c.Log)
|
||||
if err != nil {
|
||||
@@ -154,9 +89,9 @@ func (c *Config) Start(parent *task.Task) gperr.Error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) cacheRecord(info *acl.IPInfo, allow bool) {
|
||||
func (c *Config) cacheRecord(info *maxmind.IPInfo, allow bool) {
|
||||
if common.ForceResolveCountry && info.City == nil {
|
||||
c.MaxMind.lookupCity(info)
|
||||
maxmind.LookupCity(info)
|
||||
}
|
||||
c.ipCache.Store(info.Str, &checkCache{
|
||||
IPInfo: info,
|
||||
@@ -165,7 +100,7 @@ func (c *Config) cacheRecord(info *acl.IPInfo, allow bool) {
|
||||
})
|
||||
}
|
||||
|
||||
func (c *config) log(info *acl.IPInfo, allowed bool) {
|
||||
func (c *config) log(info *maxmind.IPInfo, allowed bool) {
|
||||
if c.logger == nil {
|
||||
return
|
||||
}
|
||||
@@ -186,7 +121,7 @@ func (c *Config) IPAllowed(ip net.IP) bool {
|
||||
}
|
||||
|
||||
if c.allowLocal && ip.IsPrivate() {
|
||||
c.log(&acl.IPInfo{IP: ip, Str: ip.String()}, true)
|
||||
c.log(&maxmind.IPInfo{IP: ip, Str: ip.String()}, true)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -197,15 +132,15 @@ func (c *Config) IPAllowed(ip net.IP) bool {
|
||||
return record.allow
|
||||
}
|
||||
|
||||
ipAndStr := &acl.IPInfo{IP: ip, Str: ipStr}
|
||||
for _, m := range c.allow {
|
||||
ipAndStr := &maxmind.IPInfo{IP: ip, Str: ipStr}
|
||||
for _, m := range c.Allow {
|
||||
if m(ipAndStr) {
|
||||
c.log(ipAndStr, true)
|
||||
c.cacheRecord(ipAndStr, true)
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, m := range c.deny {
|
||||
for _, m := range c.Deny {
|
||||
if m(ipAndStr) {
|
||||
c.log(ipAndStr, false)
|
||||
c.cacheRecord(ipAndStr, false)
|
||||
|
||||
@@ -4,11 +4,12 @@ import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
acl "github.com/yusing/go-proxy/internal/acl/types"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/maxmind"
|
||||
)
|
||||
|
||||
type matcher func(*acl.IPInfo) bool
|
||||
type Matcher func(*maxmind.IPInfo) bool
|
||||
type Matchers []Matcher
|
||||
|
||||
const (
|
||||
MatcherTypeIP = "ip"
|
||||
@@ -32,7 +33,7 @@ var (
|
||||
errMaxMindNotConfigured = gperr.New("MaxMind not configured")
|
||||
)
|
||||
|
||||
func (cfg *Config) parseMatcher(s string) (matcher, gperr.Error) {
|
||||
func ParseMatcher(s string) (Matcher, gperr.Error) {
|
||||
parts := strings.Split(s, ":")
|
||||
if len(parts) != 2 {
|
||||
return nil, errSyntax
|
||||
@@ -52,35 +53,44 @@ func (cfg *Config) parseMatcher(s string) (matcher, gperr.Error) {
|
||||
}
|
||||
return matchCIDR(net), nil
|
||||
case MatcherTypeTimeZone:
|
||||
if cfg.MaxMind == nil {
|
||||
if !maxmind.HasInstance() {
|
||||
return nil, errMaxMindNotConfigured
|
||||
}
|
||||
return cfg.MaxMind.matchTimeZone(parts[1]), nil
|
||||
return matchTimeZone(parts[1]), nil
|
||||
case MatcherTypeCountry:
|
||||
if cfg.MaxMind == nil {
|
||||
if !maxmind.HasInstance() {
|
||||
return nil, errMaxMindNotConfigured
|
||||
}
|
||||
return cfg.MaxMind.matchISOCode(parts[1]), nil
|
||||
return matchISOCode(parts[1]), nil
|
||||
default:
|
||||
return nil, errSyntax
|
||||
}
|
||||
}
|
||||
|
||||
func matchIP(ip net.IP) matcher {
|
||||
return func(ip2 *acl.IPInfo) bool {
|
||||
func (matchers Matchers) Match(ip *maxmind.IPInfo) bool {
|
||||
for _, m := range matchers {
|
||||
if m(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func matchIP(ip net.IP) Matcher {
|
||||
return func(ip2 *maxmind.IPInfo) bool {
|
||||
return ip.Equal(ip2.IP)
|
||||
}
|
||||
}
|
||||
|
||||
func matchCIDR(n *net.IPNet) matcher {
|
||||
return func(ip *acl.IPInfo) bool {
|
||||
func matchCIDR(n *net.IPNet) Matcher {
|
||||
return func(ip *maxmind.IPInfo) bool {
|
||||
return n.Contains(ip.IP)
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *MaxMindConfig) matchTimeZone(tz string) matcher {
|
||||
return func(ip *acl.IPInfo) bool {
|
||||
city, ok := cfg.lookupCity(ip)
|
||||
func matchTimeZone(tz string) Matcher {
|
||||
return func(ip *maxmind.IPInfo) bool {
|
||||
city, ok := maxmind.LookupCity(ip)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
@@ -88,9 +98,9 @@ func (cfg *MaxMindConfig) matchTimeZone(tz string) matcher {
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *MaxMindConfig) matchISOCode(iso string) matcher {
|
||||
return func(ip *acl.IPInfo) bool {
|
||||
city, ok := cfg.lookupCity(ip)
|
||||
func matchISOCode(iso string) Matcher {
|
||||
return func(ip *maxmind.IPInfo) bool {
|
||||
city, ok := maxmind.LookupCity(ip)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,303 +0,0 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
)
|
||||
|
||||
var (
|
||||
updateInterval = 24 * time.Hour
|
||||
httpClient = &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
ErrResponseNotOK = gperr.New("response not OK")
|
||||
ErrDownloadFailure = gperr.New("download failure")
|
||||
)
|
||||
|
||||
func dbPathImpl(dbType MaxMindDatabaseType) string {
|
||||
if dbType == MaxMindGeoLite {
|
||||
return filepath.Join(dataDir, "GeoLite2-City.mmdb")
|
||||
}
|
||||
return filepath.Join(dataDir, "GeoIP2-City.mmdb")
|
||||
}
|
||||
|
||||
func dbURLimpl(dbType MaxMindDatabaseType) string {
|
||||
if dbType == MaxMindGeoLite {
|
||||
return "https://download.maxmind.com/geoip/databases/GeoLite2-City/download?suffix=tar.gz"
|
||||
}
|
||||
return "https://download.maxmind.com/geoip/databases/GeoIP2-City/download?suffix=tar.gz"
|
||||
}
|
||||
|
||||
func dbFilename(dbType MaxMindDatabaseType) string {
|
||||
if dbType == MaxMindGeoLite {
|
||||
return "GeoLite2-City.mmdb"
|
||||
}
|
||||
return "GeoIP2-City.mmdb"
|
||||
}
|
||||
|
||||
func (cfg *MaxMindConfig) LoadMaxMindDB(parent task.Parent) gperr.Error {
|
||||
if cfg.Database == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
path := dbPath(cfg.Database)
|
||||
reader, err := maxmindDBOpen(path)
|
||||
valid := true
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
default:
|
||||
// ignore invalid error, just download it again
|
||||
var invalidErr maxminddb.InvalidDatabaseError
|
||||
if !errors.As(err, &invalidErr) {
|
||||
return gperr.Wrap(err)
|
||||
}
|
||||
}
|
||||
valid = false
|
||||
}
|
||||
|
||||
if !valid {
|
||||
cfg.logger.Info().Msg("MaxMind DB not found/invalid, downloading...")
|
||||
if err = cfg.download(); err != nil {
|
||||
return ErrDownloadFailure.With(err)
|
||||
}
|
||||
} else {
|
||||
cfg.logger.Info().Msg("MaxMind DB loaded")
|
||||
cfg.db.Reader = reader
|
||||
go cfg.scheduleUpdate(parent)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *MaxMindConfig) loadLastUpdate() {
|
||||
f, err := os.Stat(dbPath(cfg.Database))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cfg.lastUpdate = f.ModTime()
|
||||
}
|
||||
|
||||
func (cfg *MaxMindConfig) setLastUpdate(t time.Time) {
|
||||
cfg.lastUpdate = t
|
||||
_ = os.Chtimes(dbPath(cfg.Database), t, t)
|
||||
}
|
||||
|
||||
func (cfg *MaxMindConfig) scheduleUpdate(parent task.Parent) {
|
||||
task := parent.Subtask("schedule_update", true)
|
||||
ticker := time.NewTicker(updateInterval)
|
||||
|
||||
cfg.loadLastUpdate()
|
||||
cfg.update()
|
||||
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
if cfg.db.Reader != nil {
|
||||
cfg.db.Reader.Close()
|
||||
}
|
||||
task.Finish(nil)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-task.Context().Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
cfg.update()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *MaxMindConfig) update() {
|
||||
// check for update
|
||||
cfg.logger.Info().Msg("checking for MaxMind DB update...")
|
||||
remoteLastModified, err := cfg.checkLastest()
|
||||
if err != nil {
|
||||
cfg.logger.Err(err).Msg("failed to check MaxMind DB update")
|
||||
return
|
||||
}
|
||||
if remoteLastModified.Equal(cfg.lastUpdate) {
|
||||
cfg.logger.Info().Msg("MaxMind DB is up to date")
|
||||
return
|
||||
}
|
||||
|
||||
cfg.logger.Info().
|
||||
Time("latest", remoteLastModified.Local()).
|
||||
Time("current", cfg.lastUpdate).
|
||||
Msg("MaxMind DB update available")
|
||||
if err = cfg.download(); err != nil {
|
||||
cfg.logger.Err(err).Msg("failed to update MaxMind DB")
|
||||
return
|
||||
}
|
||||
cfg.logger.Info().Msg("MaxMind DB updated")
|
||||
}
|
||||
|
||||
func (cfg *MaxMindConfig) newReq(method string) (*http.Response, error) {
|
||||
req, err := http.NewRequest(method, dbURL(cfg.Database), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.SetBasicAuth(cfg.AccountID, cfg.LicenseKey)
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (cfg *MaxMindConfig) checkLastest() (lastModifiedT *time.Time, err error) {
|
||||
resp, err := newReq(cfg, http.MethodHead)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("%w: %d", ErrResponseNotOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
lastModified := resp.Header.Get("Last-Modified")
|
||||
if lastModified == "" {
|
||||
cfg.logger.Warn().Msg("MaxMind responded no last modified time, update skipped")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lastModifiedTime, err := time.Parse(http.TimeFormat, lastModified)
|
||||
if err != nil {
|
||||
cfg.logger.Warn().Err(err).Msg("MaxMind responded invalid last modified time, update skipped")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &lastModifiedTime, nil
|
||||
}
|
||||
|
||||
func (cfg *MaxMindConfig) download() error {
|
||||
resp, err := newReq(cfg, http.MethodGet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("%w: %d", ErrResponseNotOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
dbFile := dbPath(cfg.Database)
|
||||
tmpGZPath := dbFile + "-tmp.tar.gz"
|
||||
tmpDBPath := dbFile + "-tmp"
|
||||
|
||||
tmpGZFile, err := os.OpenFile(tmpGZPath, os.O_CREATE|os.O_RDWR, 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// cleanup the tar.gz file
|
||||
defer func() {
|
||||
_ = tmpGZFile.Close()
|
||||
_ = os.Remove(tmpGZPath)
|
||||
}()
|
||||
|
||||
cfg.logger.Info().Msg("MaxMind DB downloading...")
|
||||
|
||||
_, err = io.Copy(tmpGZFile, resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := tmpGZFile.Seek(0, io.SeekStart); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// extract .tar.gz and to database
|
||||
err = extractFileFromTarGz(tmpGZFile, dbFilename(cfg.Database), tmpDBPath)
|
||||
|
||||
if err != nil {
|
||||
return gperr.New("failed to extract database from archive").With(err)
|
||||
}
|
||||
|
||||
// test if the downloaded database is valid
|
||||
db, err := maxmindDBOpen(tmpDBPath)
|
||||
if err != nil {
|
||||
_ = os.Remove(tmpDBPath)
|
||||
return err
|
||||
}
|
||||
|
||||
db.Close()
|
||||
err = os.Rename(tmpDBPath, dbFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.db.Lock()
|
||||
defer cfg.db.Unlock()
|
||||
if cfg.db.Reader != nil {
|
||||
cfg.db.Reader.Close()
|
||||
}
|
||||
cfg.db.Reader, err = maxmindDBOpen(dbFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lastModifiedStr := resp.Header.Get("Last-Modified")
|
||||
lastModifiedTime, err := time.Parse(http.TimeFormat, lastModifiedStr)
|
||||
if err == nil {
|
||||
cfg.setLastUpdate(lastModifiedTime)
|
||||
}
|
||||
|
||||
cfg.logger.Info().Msg("MaxMind DB downloaded")
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractFileFromTarGz(tarGzFile *os.File, targetFilename, destPath string) error {
|
||||
defer tarGzFile.Close()
|
||||
|
||||
gzr, err := gzip.NewReader(tarGzFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer gzr.Close()
|
||||
|
||||
tr := tar.NewReader(gzr)
|
||||
for {
|
||||
hdr, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break // End of archive
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Only extract the file that matches targetFilename (basename match)
|
||||
if filepath.Base(hdr.Name) == targetFilename {
|
||||
outFile, err := os.OpenFile(destPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, hdr.FileInfo().Mode())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer outFile.Close()
|
||||
_, err = io.Copy(outFile, tr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil // Done
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("file %s not found in archive", targetFilename)
|
||||
}
|
||||
|
||||
var (
|
||||
dataDir = common.DataDir
|
||||
dbURL = dbURLimpl
|
||||
dbPath = dbPathImpl
|
||||
maxmindDBOpen = maxminddb.Open
|
||||
newReq = (*MaxMindConfig).newReq
|
||||
)
|
||||
@@ -1,223 +0,0 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
)
|
||||
|
||||
func Test_dbPath(t *testing.T) {
|
||||
tmpDataDir := "/tmp/testdata"
|
||||
oldDataDir := dataDir
|
||||
dataDir = tmpDataDir
|
||||
defer func() { dataDir = oldDataDir }()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType MaxMindDatabaseType
|
||||
want string
|
||||
}{
|
||||
{"GeoLite", MaxMindGeoLite, filepath.Join(tmpDataDir, "GeoLite2-City.mmdb")},
|
||||
{"GeoIP2", MaxMindGeoIP2, filepath.Join(tmpDataDir, "GeoIP2-City.mmdb")},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := dbPath(tt.dbType); got != tt.want {
|
||||
t.Errorf("dbPath() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_dbURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType MaxMindDatabaseType
|
||||
want string
|
||||
}{
|
||||
{"GeoLite", MaxMindGeoLite, "https://download.maxmind.com/geoip/databases/GeoLite2-City/download?suffix=tar.gz"},
|
||||
{"GeoIP2", MaxMindGeoIP2, "https://download.maxmind.com/geoip/databases/GeoIP2-City/download?suffix=tar.gz"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := dbURL(tt.dbType); got != tt.want {
|
||||
t.Errorf("dbURL() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper for MaxMindConfig ---
|
||||
type testLogger struct{ zerolog.Logger }
|
||||
|
||||
func (testLogger) Info() *zerolog.Event { return &zerolog.Event{} }
|
||||
func (testLogger) Warn() *zerolog.Event { return &zerolog.Event{} }
|
||||
func (testLogger) Err(_ error) *zerolog.Event { return &zerolog.Event{} }
|
||||
|
||||
func Test_MaxMindConfig_newReq(t *testing.T) {
|
||||
cfg := &MaxMindConfig{
|
||||
AccountID: "testid",
|
||||
LicenseKey: "testkey",
|
||||
Database: MaxMindGeoLite,
|
||||
logger: zerolog.Nop(),
|
||||
}
|
||||
|
||||
// Patch httpClient to use httptest
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if u, p, ok := r.BasicAuth(); !ok || u != "testid" || p != "testkey" {
|
||||
t.Errorf("basic auth not set correctly")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
oldURL := dbURL
|
||||
dbURL = func(MaxMindDatabaseType) string { return server.URL }
|
||||
defer func() { dbURL = oldURL }()
|
||||
|
||||
resp, err := cfg.newReq(http.MethodGet)
|
||||
if err != nil {
|
||||
t.Fatalf("newReq() error = %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("unexpected status: %v", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_MaxMindConfig_checkUpdate(t *testing.T) {
|
||||
cfg := &MaxMindConfig{
|
||||
AccountID: "id",
|
||||
LicenseKey: "key",
|
||||
Database: MaxMindGeoLite,
|
||||
logger: zerolog.Nop(),
|
||||
}
|
||||
lastMod := time.Now().UTC().Format(http.TimeFormat)
|
||||
buildTime := time.Now().Add(-time.Hour)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Last-Modified", lastMod)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
oldURL := dbURL
|
||||
dbURL = func(MaxMindDatabaseType) string { return server.URL }
|
||||
defer func() { dbURL = oldURL }()
|
||||
|
||||
latest, err := cfg.checkLastest()
|
||||
if err != nil {
|
||||
t.Fatalf("checkUpdate() error = %v", err)
|
||||
}
|
||||
if latest.Equal(buildTime) {
|
||||
t.Errorf("expected update needed")
|
||||
}
|
||||
}
|
||||
|
||||
type fakeReadCloser struct {
|
||||
firstRead bool
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (c *fakeReadCloser) Read(p []byte) (int, error) {
|
||||
if !c.firstRead {
|
||||
c.firstRead = true
|
||||
return strings.NewReader("FAKEMMDB").Read(p)
|
||||
}
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (c *fakeReadCloser) Close() error {
|
||||
c.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_MaxMindConfig_download(t *testing.T) {
|
||||
cfg := &MaxMindConfig{
|
||||
AccountID: "id",
|
||||
LicenseKey: "key",
|
||||
Database: MaxMindGeoLite,
|
||||
logger: zerolog.Nop(),
|
||||
}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gz := gzip.NewWriter(w)
|
||||
t := tar.NewWriter(gz)
|
||||
t.WriteHeader(&tar.Header{
|
||||
Name: dbFilename(MaxMindGeoLite),
|
||||
})
|
||||
t.Write([]byte("1234"))
|
||||
t.Close()
|
||||
gz.Close()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
oldURL := dbURL
|
||||
dbURL = func(MaxMindDatabaseType) string { return server.URL }
|
||||
defer func() { dbURL = oldURL }()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
oldDataDir := dataDir
|
||||
dataDir = tmpDir
|
||||
defer func() { dataDir = oldDataDir }()
|
||||
|
||||
// Patch maxminddb.Open to always succeed
|
||||
origOpen := maxmindDBOpen
|
||||
maxmindDBOpen = func(path string) (*maxminddb.Reader, error) {
|
||||
return &maxminddb.Reader{}, nil
|
||||
}
|
||||
defer func() { maxmindDBOpen = origOpen }()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, server.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("newReq() error = %v", err)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
oldNewReq := newReq
|
||||
newReq = func(cfg *MaxMindConfig, method string) (*http.Response, error) {
|
||||
server.Config.Handler.ServeHTTP(rw, req)
|
||||
return rw.Result(), nil
|
||||
}
|
||||
defer func() { newReq = oldNewReq }()
|
||||
|
||||
err = cfg.download()
|
||||
if err != nil {
|
||||
t.Fatalf("download() error = %v", err)
|
||||
}
|
||||
if cfg.db.Reader == nil {
|
||||
t.Error("expected db instance")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_MaxMindConfig_loadMaxMindDB(t *testing.T) {
|
||||
// This test should cover both the path where DB exists and where it does not
|
||||
// For brevity, only the non-existing path is tested here
|
||||
cfg := &MaxMindConfig{
|
||||
AccountID: "id",
|
||||
LicenseKey: "key",
|
||||
Database: MaxMindGeoLite,
|
||||
logger: zerolog.Nop(),
|
||||
}
|
||||
oldOpen := maxmindDBOpen
|
||||
maxmindDBOpen = func(path string) (*maxminddb.Reader, error) {
|
||||
return &maxminddb.Reader{}, nil
|
||||
}
|
||||
defer func() { maxmindDBOpen = oldOpen }()
|
||||
|
||||
oldDBPath := dbPath
|
||||
dbPath = func(MaxMindDatabaseType) string { return filepath.Join(t.TempDir(), "maxmind.mmdb") }
|
||||
defer func() { dbPath = oldDBPath }()
|
||||
|
||||
task := task.RootTask("test")
|
||||
defer task.Finish(nil)
|
||||
err := cfg.LoadMaxMindDB(task)
|
||||
if err != nil {
|
||||
t.Errorf("loadMaxMindDB() error = %v", err)
|
||||
}
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
package acl
|
||||
|
||||
type City struct {
|
||||
Location struct {
|
||||
TimeZone string `maxminddb:"time_zone"`
|
||||
} `maxminddb:"location"`
|
||||
Country struct {
|
||||
IsoCode string `maxminddb:"iso_code"`
|
||||
} `maxminddb:"country"`
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
package acl
|
||||
|
||||
import "net"
|
||||
|
||||
type IPInfo struct {
|
||||
IP net.IP
|
||||
Str string
|
||||
City *City
|
||||
}
|
||||
Reference in New Issue
Block a user