diff --git a/internal/config/config.go b/internal/config/config.go index 5a6e1d73..3545594d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -4,7 +4,6 @@ import ( "context" "errors" "os" - "regexp" "strconv" "strings" "sync" @@ -216,23 +215,10 @@ func (cfg *Config) StartServers(opts ...*StartServersOptions) { } } -var envRegex = regexp.MustCompile(`\$\{([^}]+)\}`) // e.g. ${CLOUDFLARE_API_KEY} -var readFile = os.ReadFile - -func (cfg *Config) readConfigFile() ([]byte, error) { - data, err := readFile(common.ConfigPath) - if err != nil { - return nil, err - } - return envRegex.ReplaceAllFunc(data, func(match []byte) []byte { - return strconv.AppendQuote(nil, os.Getenv(string(match[2:len(match)-1]))) - }), nil -} - func (cfg *Config) load() gperr.Error { const errMsg = "config load error" - data, err := cfg.readConfigFile() + data, err := os.ReadFile(common.ConfigPath) if err != nil { if os.IsNotExist(err) { log.Warn().Msg("config file not found, using default config") diff --git a/internal/config/config_test.go b/internal/config/config_test.go deleted file mode 100644 index 53bb0dc2..00000000 --- a/internal/config/config_test.go +++ /dev/null @@ -1,38 +0,0 @@ -package config - -import ( - "os" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestConfigEnvSubstitution(t *testing.T) { - os.Setenv("CLOUDFLARE_AUTH_TOKEN", "test") - readFile = func(_ string) ([]byte, error) { - return []byte(` ---- -autocert: - email: "test@test.com" - domains: - - "*.test.com" - provider: cloudflare - options: - auth_token: ${CLOUDFLARE_AUTH_TOKEN} -`), nil - } - - var cfg Config - out, err := cfg.readConfigFile() - require.NoError(t, err) - require.Equal(t, ` ---- -autocert: - email: "test@test.com" - domains: - - "*.test.com" - provider: cloudflare - options: - auth_token: "test" -`, string(out)) -} diff --git a/internal/serialization/serialization.go b/internal/serialization/serialization.go index 25971984..deb73169 100644 --- a/internal/serialization/serialization.go +++ b/internal/serialization/serialization.go @@ -5,6 +5,7 @@ import ( "errors" "os" "reflect" + "regexp" "strconv" "strings" "time" @@ -517,7 +518,22 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gpe return true, Convert(reflect.ValueOf(tmp), dst, true) } +var envRegex = regexp.MustCompile(`\$\{([^}]+)\}`) // e.g. ${CLOUDFLARE_API_KEY} + func UnmarshalValidateYAML[T any](data []byte, target *T) gperr.Error { + envError := gperr.NewBuilder("env substitution error") + data = envRegex.ReplaceAllFunc(data, func(match []byte) []byte { + varName := string(match[2 : len(match)-1]) + env, ok := os.LookupEnv(varName) + if !ok { + envError.Addf("%s is not set", varName) + } + return strconv.AppendQuote(nil, env) + }) + if envError.HasError() { + return envError.Error() + } + m := make(map[string]any) if err := yaml.Unmarshal(data, &m); err != nil { return gperr.Wrap(err) diff --git a/internal/serialization/serialization_test.go b/internal/serialization/serialization_test.go index d067f81d..ff656964 100644 --- a/internal/serialization/serialization_test.go +++ b/internal/serialization/serialization_test.go @@ -1,11 +1,13 @@ package serialization import ( + "os" "reflect" "strconv" "testing" "github.com/goccy/go-yaml" + "github.com/stretchr/testify/require" . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -314,6 +316,26 @@ func TestStringToStruct(t *testing.T) { }) } +func TestConfigEnvSubstitution(t *testing.T) { + os.Setenv("CLOUDFLARE_AUTH_TOKEN", "test") + data := []byte(` +--- +autocert: + options: + auth_token: ${CLOUDFLARE_AUTH_TOKEN} +`) + + var cfg struct { + Autocert struct { + Options struct { + AuthToken string `yaml:"auth_token"` + } `yaml:"options"` + } `yaml:"autocert"` + } + require.NoError(t, UnmarshalValidateYAML(data, &cfg)) + require.Equal(t, "test", cfg.Autocert.Options.AuthToken) +} + func BenchmarkStringToStruct(b *testing.B) { for range b.N { dst := struct {