Compare commits

...

8 Commits

Author SHA1 Message Date
yusing
8b3e058885 fix: error formatting 2025-05-14 20:34:41 +08:00
yusing
023cbc81bc ci: update Docker CI workflows to exclude tags for socket-proxy and improve caching 2025-05-14 13:50:12 +08:00
yusing
b490e8c475 fix(acl): maxmind error even if configured, refactor 2025-05-14 13:44:43 +08:00
yusing
8e27886235 fix: incorrect unmarshal behavior for pointer primitives 2025-05-14 12:20:52 +08:00
yusing
7435b8e485 tests: add test for acl matchers 2025-05-13 20:11:16 +08:00
yusing
21724c037f fix: error formatting 2025-05-13 20:11:03 +08:00
yusing
44b4cff35e fix: acl matcher parsing, refactor 2025-05-13 19:40:43 +08:00
yusing
1e24765b17 fix: nil when printing error in edge cases 2025-05-13 19:40:04 +08:00
17 changed files with 256 additions and 71 deletions

View File

@@ -2,8 +2,16 @@ name: Docker Image CI (socket-proxy)
on:
push:
branches:
- main
paths:
- "socket-proxy/**"
tags-ignore:
- '**'
workflow_dispatch:
permissions:
contents: read
jobs:
build:

View File

@@ -84,10 +84,10 @@ jobs:
outputs: type=image,name=${{ env.REGISTRY }}/${{ inputs.image_name }},push-by-digest=true,name-canonical=true,push=true
cache-from: |
type=registry,ref=${{ env.REGISTRY }}/${{ inputs.image_name }}:buildcache-${{ env.PLATFORM_PAIR }}
type=gha,scope=${{ github.workflow }}
type=gha,scope=${{ github.workflow }}-${{ env.PLATFORM_PAIR }}
cache-to: |
type=registry,ref=${{ env.REGISTRY }}/${{ inputs.image_name }}:buildcache-${{ env.PLATFORM_PAIR }},mode=max
type=gha,scope=${{ github.workflow }},mode=max
type=gha,scope=${{ github.workflow }}-${{ env.PLATFORM_PAIR }},mode=max
build-args: |
VERSION=${{ github.ref_name }}
MAKE_ARGS=${{ env.MAKE_ARGS }}

View File

@@ -7,6 +7,7 @@ import (
"github.com/puzpuzpuz/xsync/v3"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/logging/accesslog"
"github.com/yusing/go-proxy/internal/maxmind"
"github.com/yusing/go-proxy/internal/task"
@@ -21,6 +22,7 @@ type Config struct {
Log *accesslog.ACLLoggerConfig `json:"log"`
config
valErr gperr.Error
}
type config struct {
@@ -57,7 +59,8 @@ func (c *Config) Validate() gperr.Error {
case ACLDeny:
c.defaultAllow = false
default:
return gperr.New("invalid default value").Subject(c.Default)
c.valErr = gperr.New("invalid default value").Subject(c.Default)
return c.valErr
}
if c.AllowLocal != nil {
@@ -70,12 +73,17 @@ func (c *Config) Validate() gperr.Error {
c.logAllowed = c.Log.LogAllowed
}
if !c.allowLocal && !c.defaultAllow && len(c.Allow) == 0 {
c.valErr = gperr.New("allow_local is false and default is deny, but no allow rules are configured")
return c.valErr
}
c.ipCache = xsync.NewMapOf[string, *checkCache]()
return nil
}
func (c *Config) Valid() bool {
return c != nil && (len(c.Allow) > 0 || len(c.Deny) > 0 || c.allowLocal)
return c != nil && c.valErr == nil
}
func (c *Config) Start(parent *task.Task) gperr.Error {
@@ -86,6 +94,15 @@ func (c *Config) Start(parent *task.Task) gperr.Error {
}
c.logger = logger
}
if c.valErr != nil {
return c.valErr
}
logging.Info().
Str("default", c.Default).
Bool("allow_local", c.allowLocal).
Int("allow_rules", len(c.Allow)).
Int("deny_rules", len(c.Deny)).
Msg("ACL started")
return nil
}
@@ -114,8 +131,7 @@ func (c *Config) IPAllowed(ip net.IP) bool {
return false
}
// always allow loopback
// loopback is not logged
// always allow loopback, not logged
if ip.IsLoopback() {
return true
}
@@ -133,19 +149,15 @@ func (c *Config) IPAllowed(ip net.IP) bool {
}
ipAndStr := &maxmind.IPInfo{IP: ip, Str: ipStr}
for _, m := range c.Allow {
if m(ipAndStr) {
c.log(ipAndStr, true)
c.cacheRecord(ipAndStr, true)
return true
}
if c.Allow.Match(ipAndStr) {
c.log(ipAndStr, true)
c.cacheRecord(ipAndStr, true)
return true
}
for _, m := range c.Deny {
if m(ipAndStr) {
c.log(ipAndStr, false)
c.cacheRecord(ipAndStr, false)
return false
}
if c.Deny.Match(ipAndStr) {
c.log(ipAndStr, false)
c.cacheRecord(ipAndStr, false)
return false
}
c.log(ipAndStr, c.defaultAllow)

View File

@@ -8,7 +8,11 @@ import (
"github.com/yusing/go-proxy/internal/maxmind"
)
type Matcher func(*maxmind.IPInfo) bool
type MatcherFunc func(*maxmind.IPInfo) bool
type Matcher struct {
match MatcherFunc
}
type Matchers []Matcher
@@ -19,6 +23,9 @@ const (
MatcherTypeCountry = "country"
)
// TODO: use this error in the future
//
//nolint:unused
var errMatcherFormat = gperr.Multiline().AddLines(
"invalid matcher format, expect {type}:{value}",
"Available types: ip|cidr|tz|country",
@@ -29,68 +36,62 @@ var errMatcherFormat = gperr.Multiline().AddLines(
)
var (
errSyntax = gperr.New("syntax error")
errInvalidIP = gperr.New("invalid IP")
errInvalidCIDR = gperr.New("invalid CIDR")
errMaxMindNotConfigured = gperr.New("MaxMind not configured")
errSyntax = gperr.New("syntax error")
errInvalidIP = gperr.New("invalid IP")
errInvalidCIDR = gperr.New("invalid CIDR")
)
func ParseMatcher(s string) (Matcher, gperr.Error) {
func (matcher *Matcher) Parse(s string) error {
parts := strings.Split(s, ":")
if len(parts) != 2 {
return nil, errSyntax
return errSyntax
}
switch parts[0] {
case MatcherTypeIP:
ip := net.ParseIP(parts[1])
if ip == nil {
return nil, errInvalidIP
return errInvalidIP
}
return matchIP(ip), nil
matcher.match = matchIP(ip)
case MatcherTypeCIDR:
_, net, err := net.ParseCIDR(parts[1])
if err != nil {
return nil, errInvalidCIDR
return errInvalidCIDR
}
return matchCIDR(net), nil
matcher.match = matchCIDR(net)
case MatcherTypeTimeZone:
if !maxmind.HasInstance() {
return nil, errMaxMindNotConfigured
}
return matchTimeZone(parts[1]), nil
matcher.match = matchTimeZone(parts[1])
case MatcherTypeCountry:
if !maxmind.HasInstance() {
return nil, errMaxMindNotConfigured
}
return matchISOCode(parts[1]), nil
matcher.match = matchISOCode(parts[1])
default:
return nil, errSyntax
return errSyntax
}
return nil
}
func (matchers Matchers) Match(ip *maxmind.IPInfo) bool {
for _, m := range matchers {
if m(ip) {
if m.match(ip) {
return true
}
}
return false
}
func matchIP(ip net.IP) Matcher {
func matchIP(ip net.IP) MatcherFunc {
return func(ip2 *maxmind.IPInfo) bool {
return ip.Equal(ip2.IP)
}
}
func matchCIDR(n *net.IPNet) Matcher {
func matchCIDR(n *net.IPNet) MatcherFunc {
return func(ip *maxmind.IPInfo) bool {
return n.Contains(ip.IP)
}
}
func matchTimeZone(tz string) Matcher {
func matchTimeZone(tz string) MatcherFunc {
return func(ip *maxmind.IPInfo) bool {
city, ok := maxmind.LookupCity(ip)
if !ok {
@@ -100,7 +101,7 @@ func matchTimeZone(tz string) Matcher {
}
}
func matchISOCode(iso string) Matcher {
func matchISOCode(iso string) MatcherFunc {
return func(ip *maxmind.IPInfo) bool {
city, ok := maxmind.LookupCity(ip)
if !ok {

View File

@@ -0,0 +1,49 @@
package acl
import (
"net"
"reflect"
"testing"
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
"github.com/yusing/go-proxy/internal/utils"
)
func TestMatchers(t *testing.T) {
strMatchers := []string{
"ip:127.0.0.1",
"cidr:10.0.0.0/8",
}
var mathers Matchers
err := utils.Convert(reflect.ValueOf(strMatchers), reflect.ValueOf(&mathers), false)
if err != nil {
t.Fatal(err)
}
tests := []struct {
ip string
want bool
}{
{"127.0.0.1", true},
{"10.0.0.1", true},
{"127.0.0.2", false},
{"192.168.0.1", false},
{"11.0.0.1", false},
}
for _, test := range tests {
ip := net.ParseIP(test.ip)
if ip == nil {
t.Fatalf("invalid ip: %s", test.ip)
}
got := mathers.Match(&maxmind.IPInfo{
IP: ip,
Str: test.ip,
})
if got != test.want {
t.Errorf("mathers.Match(%s) = %v, want %v", test.ip, got, test.want)
}
}
}

View File

@@ -10,12 +10,12 @@ type UDPListener struct {
lis net.PacketConn
}
func (cfg *Config) WrapUDP(lis net.PacketConn) net.PacketConn {
if cfg == nil {
func (c *Config) WrapUDP(lis net.PacketConn) net.PacketConn {
if c == nil {
return lis
}
return &UDPListener{
acl: cfg,
acl: c,
lis: lis,
}
}

View File

@@ -248,8 +248,6 @@ func (cfg *Config) load() gperr.Error {
err := model.ACL.Start(cfg.task)
if err != nil {
errs.Add(err)
} else {
logging.Info().Msg("ACL started")
}
}

View File

@@ -36,8 +36,11 @@ func (err *baseError) Subjectf(format string, args ...any) Error {
return err.Subject(format)
}
func (err baseError) With(extra error) Error {
return &nestedError{&err, []error{extra}}
func (err *baseError) With(extra error) Error {
if extra == nil {
return err
}
return &nestedError{&baseError{err.Err}, []error{extra}}
}
func (err baseError) Withf(format string, args ...any) Error {

View File

@@ -59,6 +59,9 @@ func (b *Builder) Error() Error {
if len(b.errs) == 0 {
return nil
}
if len(b.errs) == 1 && b.about == "" {
return wrap(b.errs[0])
}
return &nestedError{Err: New(b.about), Extras: b.errs}
}

View File

@@ -31,7 +31,7 @@ func (h *Hint) String() string {
return h.Error()
}
func DoYouMean(s string) *Hint {
func DoYouMean(s string) error {
if s == "" {
return nil
}

View File

@@ -11,9 +11,11 @@ type nestedError struct {
Extras []error `json:"extras"`
}
var emptyError = errStr("")
func (err nestedError) Subject(subject string) Error {
if err.Err == nil {
err.Err = PrependSubject(subject, errStr(""))
err.Err = PrependSubject(subject, emptyError)
} else {
err.Err = PrependSubject(subject, err.Err)
}
@@ -78,8 +80,10 @@ func (err *nestedError) fmtError(appendLine appendLineFunc) []byte {
}
if err.Err != nil {
buf := appendLine(nil, err.Err, 0)
buf = append(buf, '\n')
buf = appendLines(buf, err.Extras, 1, appendLine)
if len(err.Extras) > 0 {
buf = append(buf, '\n')
buf = appendLines(buf, err.Extras, 1, appendLine)
}
return buf
}
return appendLines(nil, err.Extras, 0, appendLine)

View File

@@ -45,6 +45,11 @@ func PrependSubject(subject string, err error) error {
switch err := err.(type) {
case *withSubject:
return err.Prepend(subject)
case *wrappedError:
return &wrappedError{
Err: PrependSubject(subject, err.Err),
Message: err.Message,
}
case Error:
return err.Subject(subject)
}
@@ -95,20 +100,24 @@ func (err *withSubject) Markdown() []byte {
func (err *withSubject) fmtError(highlight highlightFunc) []byte {
// subject is in reversed order
n := len(err.Subjects)
size := 0
errStr := err.Err.Error()
subjects := err.Subjects
if err.pendingSubject != "" {
subjects = append(subjects, err.pendingSubject)
}
var buf bytes.Buffer
for _, s := range err.Subjects {
for _, s := range subjects {
size += len(s)
}
n := len(subjects)
buf.Grow(size + 2 + n*len(subjectSep) + len(errStr) + len(highlight("")))
for i := n - 1; i > 0; i-- {
buf.WriteString(err.Subjects[i])
buf.WriteString(subjects[i])
buf.WriteString(subjectSep)
}
buf.WriteString(highlight(err.Subjects[0]))
buf.WriteString(highlight(subjects[0]))
if errStr != "" {
buf.WriteString(": ")
buf.WriteString(errStr)
@@ -127,6 +136,9 @@ func (err *withSubject) MarshalJSON() ([]byte, error) {
Subjects: subjects,
Err: err.Err,
}
if err.pendingSubject != "" {
reversed.Subjects = append(reversed.Subjects, err.pendingSubject)
}
return json.Marshal(reversed)
}

View File

@@ -87,6 +87,9 @@ func Join(errors ...error) Error {
func JoinLines(main error, errors ...string) Error {
errs := make([]error, len(errors))
for i, err := range errors {
if err == "" {
continue
}
errs[i] = newError(err)
}
return &nestedError{Err: main, Extras: errs}

View File

@@ -1,12 +1,29 @@
package maxmind
import (
"sync"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/task"
)
var instance *MaxMind
var warnOnce sync.Once
func warnNotConfigured() {
log.Warn().Msg("MaxMind not configured, geo lookup will fail")
notif.Notify(&notif.LogMessage{
Level: zerolog.WarnLevel,
Title: "MaxMind not configured",
Body: notif.MessageBody("MaxMind is not configured, geo lookup will fail"),
Color: notif.ColorError,
})
}
func SetInstance(parent task.Parent, cfg *Config) gperr.Error {
newInstance := &MaxMind{Config: cfg}
if err := newInstance.LoadMaxMindDB(parent); err != nil {
@@ -22,6 +39,7 @@ func HasInstance() bool {
func LookupCity(ip *IPInfo) (*City, bool) {
if instance == nil {
warnOnce.Do(warnNotConfigured)
return nil, false
}
return instance.lookupCity(ip)

View File

@@ -256,7 +256,7 @@ func mapUnmarshalValidate(src SerializedObject, dst any, checkValidateTag bool)
if field, ok := mapping[strutils.ToLowerNoSnake(k)]; ok {
err := Convert(reflect.ValueOf(v), field, !hasValidateTag)
if err != nil {
errs.Add(err)
errs.Add(err.Subject(k))
}
} else {
errs.Add(ErrUnknownField.Subject(k).With(gperr.DoYouMean(NearestField(k, mapping))))
@@ -314,12 +314,26 @@ func Convert(src reflect.Value, dst reflect.Value, checkValidateTag bool) gperr.
return gperr.Errorf("convert: dst is %w", ErrNilValue)
}
if !src.IsValid() || src.IsZero() {
if dst.CanSet() {
dst.Set(reflect.Zero(dst.Type()))
return nil
if (src.Kind() == reflect.Pointer && src.IsNil()) || !src.IsValid() {
if !dst.CanSet() {
return gperr.Errorf("convert: src is %w", ErrNilValue)
}
return gperr.Errorf("convert: src is %w", ErrNilValue)
// manually set nil
dst.Set(reflect.Zero(dst.Type()))
return nil
}
if src.IsZero() {
if !dst.CanSet() {
return gperr.Errorf("convert: src is %w", ErrNilValue)
}
switch dst.Kind() {
case reflect.Pointer, reflect.Interface:
dst.Set(reflect.New(dst.Type().Elem()))
default:
dst.Set(reflect.Zero(dst.Type()))
}
return nil
}
srcT := src.Type()
@@ -330,10 +344,6 @@ func Convert(src reflect.Value, dst reflect.Value, checkValidateTag bool) gperr.
srcT = src.Type()
}
if !dst.CanSet() {
return ErrUnsettable.Subject(dstT.String())
}
if dst.Kind() == reflect.Pointer {
if dst.IsNil() {
dst.Set(New(dstT.Elem()))
@@ -346,16 +356,25 @@ func Convert(src reflect.Value, dst reflect.Value, checkValidateTag bool) gperr.
switch {
case srcT.AssignableTo(dstT):
if !dst.CanSet() {
return ErrUnsettable.Subject(dstT.String())
}
dst.Set(src)
return nil
// case srcT.ConvertibleTo(dstT):
// dst.Set(src.Convert(dstT))
// return nil
case srcKind == reflect.String:
if !dst.CanSet() {
return ErrUnsettable.Subject(dstT.String())
}
if convertible, err := ConvertString(src.String(), dst); convertible {
return err
}
case isIntFloat(srcKind):
if !dst.CanSet() {
return ErrUnsettable.Subject(dstT.String())
}
var strV string
switch {
case src.CanInt():
@@ -386,7 +405,7 @@ func Convert(src reflect.Value, dst reflect.Value, checkValidateTag bool) gperr.
if dstT.Kind() != reflect.Slice {
return ErrUnsupportedConversion.Subject(dstT.String() + " to " + srcT.String())
}
sliceErrs := gperr.NewBuilder("slice conversion errors")
sliceErrs := gperr.NewBuilder()
newSlice := reflect.MakeSlice(dstT, src.Len(), src.Len())
i := 0
for j, v := range src.Seq2() {
@@ -469,7 +488,7 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gpe
if !isMultiline && src[0] != '-' {
values := strutils.CommaSeperatedList(src)
dst.Set(reflect.MakeSlice(dst.Type(), len(values), len(values)))
errs := gperr.NewBuilder("invalid slice values")
errs := gperr.NewBuilder()
for i, v := range values {
err := Convert(reflect.ValueOf(v), dst.Index(i), true)
if err != nil {

View File

@@ -73,6 +73,61 @@ func TestDeserializeAnonymousField(t *testing.T) {
ExpectEqual(t, s2.C, 3)
}
func TestPointerPrimitives(t *testing.T) {
type testType struct {
B *bool `json:"b"`
I8 *int8 `json:"i8"`
I16 *int16 `json:"i16"`
I32 *int32 `json:"i32"`
I64 *int64 `json:"i64"`
U8 *uint8 `json:"u8"`
U16 *uint16 `json:"u16"`
U32 *uint32 `json:"u32"`
U64 *uint64 `json:"u64"`
}
var test testType
err := MapUnmarshalValidate(map[string]any{"b": true, "i8": int8(127), "i16": int16(127), "i32": int32(127), "i64": int64(127), "u8": uint8(127), "u16": uint16(127), "u32": uint32(127), "u64": uint64(127)}, &test)
ExpectNoError(t, err)
ExpectEqual(t, *test.B, true)
ExpectEqual(t, *test.I8, int8(127))
ExpectEqual(t, *test.I16, int16(127))
ExpectEqual(t, *test.I32, int32(127))
ExpectEqual(t, *test.I64, int64(127))
ExpectEqual(t, *test.U8, uint8(127))
ExpectEqual(t, *test.U16, uint16(127))
ExpectEqual(t, *test.U32, uint32(127))
ExpectEqual(t, *test.U64, uint64(127))
// zero values
err = MapUnmarshalValidate(map[string]any{"b": false, "i8": int8(0), "i16": int16(0), "i32": int32(0), "i64": int64(0), "u8": uint8(0), "u16": uint16(0), "u32": uint32(0), "u64": uint64(0)}, &test)
ExpectNoError(t, err)
ExpectEqual(t, *test.B, false)
ExpectEqual(t, *test.I8, int8(0))
ExpectEqual(t, *test.I16, int16(0))
ExpectEqual(t, *test.I32, int32(0))
ExpectEqual(t, *test.I64, int64(0))
ExpectEqual(t, *test.U8, uint8(0))
ExpectEqual(t, *test.U16, uint16(0))
ExpectEqual(t, *test.U32, uint32(0))
ExpectEqual(t, *test.U64, uint64(0))
// nil values
err = MapUnmarshalValidate(map[string]any{"b": true, "i8": int8(127), "i16": int16(127), "i32": int32(127), "i64": int64(127), "u8": uint8(127), "u16": uint16(127), "u32": uint32(127), "u64": uint64(127)}, &test)
ExpectNoError(t, err)
err = MapUnmarshalValidate(map[string]any{"b": nil, "i8": nil, "i16": nil, "i32": nil, "i64": nil, "u8": nil, "u16": nil, "u32": nil, "u64": nil}, &test)
ExpectNoError(t, err)
ExpectEqual(t, test.B, nil)
ExpectEqual(t, test.I8, nil)
ExpectEqual(t, test.I16, nil)
ExpectEqual(t, test.I32, nil)
ExpectEqual(t, test.I64, nil)
ExpectEqual(t, test.U8, nil)
ExpectEqual(t, test.U16, nil)
ExpectEqual(t, test.U32, nil)
ExpectEqual(t, test.U64, nil)
}
func TestStringIntConvert(t *testing.T) {
s := "127"

View File

@@ -35,7 +35,7 @@ type (
}
)
var ErrNegativeInterval = errors.New("negative interval")
var ErrNegativeInterval = gperr.New("negative interval")
func NewMonitor(r routes.Route) health.HealthMonCheck {
var mon health.HealthMonCheck