fix modifyResponse middleware incorrect variable substitution

This commit is contained in:
yusing
2024-12-05 10:31:48 +08:00
parent a9f6c4eb20
commit aff8a3b401
10 changed files with 255 additions and 104 deletions

View File

@@ -22,32 +22,61 @@ var (
reStatic = regexp.MustCompile(`\$[\w_]+`)
)
const (
VarRequestMethod = "$req_method"
VarRequestScheme = "$req_scheme"
VarRequestHost = "$req_host"
VarRequestPort = "$req_port"
VarRequestPath = "$req_path"
VarRequestAddr = "$req_addr"
VarRequestQuery = "$req_query"
VarRequestURL = "$req_url"
VarRequestURI = "$req_uri"
VarRequestContentType = "$req_content_type"
VarRequestContentLen = "$req_content_length"
VarRemoteAddr = "$remote_addr"
VarUpstreamScheme = "$upstream_scheme"
VarUpstreamHost = "$upstream_host"
VarUpstreamPort = "$upstream_port"
VarUpstreamAddr = "$upstream_addr"
VarUpstreamURL = "$upstream_url"
VarRespContentType = "$resp_content_type"
VarRespContentLen = "$resp_content_length"
VarRespStatusCode = "$status_code"
)
var staticReqVarSubsMap = map[string]reqVarGetter{
"$req_method": func(req *Request) string { return req.Method },
"$req_scheme": func(req *Request) string { return req.URL.Scheme },
"$req_host": func(req *Request) string {
VarRequestMethod: func(req *Request) string { return req.Method },
VarRequestScheme: func(req *Request) string {
if req.TLS != nil {
return "https"
}
return "http"
},
VarRequestHost: func(req *Request) string {
reqHost, _, err := net.SplitHostPort(req.Host)
if err != nil {
return req.Host
}
return reqHost
},
"$req_port": func(req *Request) string {
VarRequestPort: func(req *Request) string {
_, reqPort, _ := net.SplitHostPort(req.Host)
return reqPort
},
"$req_addr": func(req *Request) string { return req.Host },
"$req_path": func(req *Request) string { return req.URL.Path },
"$req_query": func(req *Request) string { return req.URL.RawQuery },
"$req_url": func(req *Request) string { return req.URL.String() },
"$req_uri": func(req *Request) string { return req.URL.RequestURI() },
"$req_content_type": func(req *Request) string { return req.Header.Get("Content-Type") },
"$req_content_length": func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) },
"$remote_addr": func(req *Request) string { return req.RemoteAddr },
"$upstream_scheme": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
"$upstream_host": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
"$upstream_port": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
"$upstream_addr": func(req *Request) string {
VarRequestAddr: func(req *Request) string { return req.Host },
VarRequestPath: func(req *Request) string { return req.URL.Path },
VarRequestQuery: func(req *Request) string { return req.URL.RawQuery },
VarRequestURL: func(req *Request) string { return req.URL.String() },
VarRequestURI: func(req *Request) string { return req.URL.RequestURI() },
VarRequestContentType: func(req *Request) string { return req.Header.Get("Content-Type") },
VarRequestContentLen: func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) },
VarRemoteAddr: func(req *Request) string { return req.RemoteAddr },
VarUpstreamScheme: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
VarUpstreamHost: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
VarUpstreamPort: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
VarUpstreamAddr: func(req *Request) string {
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
if upPort != "" {
@@ -55,7 +84,7 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
}
return upHost
},
"$upstream_url": func(req *Request) string {
VarUpstreamURL: func(req *Request) string {
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
if upScheme == "" {
return ""
@@ -71,9 +100,9 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
}
var staticRespVarSubsMap = map[string]respVarGetter{
"$resp_content_type": func(resp *Response) string { return resp.Header.Get("Content-Type") },
"$resp_content_length": func(resp *Response) string { return resp.Header.Get("Content-Length") },
"$status_code": func(resp *Response) string { return strconv.Itoa(resp.StatusCode) },
VarRespContentType: func(resp *Response) string { return resp.Header.Get("Content-Type") },
VarRespContentLen: func(resp *Response) string { return strconv.FormatInt(resp.ContentLength, 10) },
VarRespStatusCode: func(resp *Response) string { return strconv.Itoa(resp.StatusCode) },
}
func varReplace(req *Request, resp *Response, s string) string {
@@ -99,7 +128,7 @@ func varReplace(req *Request, resp *Response, s string) string {
if resp != nil {
// Replace response headers
s = reRespHeader.ReplaceAllStringFunc(s, func(match string) string {
header := http.CanonicalHeaderKey(match[14 : len(match)-1])
header := http.CanonicalHeaderKey(match[13 : len(match)-1])
return resp.Header.Get(header)
})
}