mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-20 15:31:24 +02:00
refactor header utils to httpheader package, cleanup api endpoints
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
@@ -12,39 +13,73 @@ import (
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/logging/memlogger"
|
||||
"github.com/yusing/go-proxy/internal/metrics/uptime"
|
||||
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type ServeMux struct{ *http.ServeMux }
|
||||
type (
|
||||
ServeMux struct {
|
||||
*http.ServeMux
|
||||
cfg config.ConfigInstance
|
||||
}
|
||||
WithCfgHandler = func(config.ConfigInstance, http.ResponseWriter, *http.Request)
|
||||
)
|
||||
|
||||
func (mux ServeMux) HandleFunc(methods, endpoint string, h any, requireAuth ...bool) {
|
||||
var handler http.HandlerFunc
|
||||
switch h := h.(type) {
|
||||
case func(http.ResponseWriter, *http.Request):
|
||||
handler = h
|
||||
case http.Handler:
|
||||
handler = h.ServeHTTP
|
||||
case WithCfgHandler:
|
||||
handler = func(w http.ResponseWriter, r *http.Request) {
|
||||
h(mux.cfg, w, r)
|
||||
}
|
||||
default:
|
||||
panic(fmt.Errorf("unsupported handler type: %T", h))
|
||||
}
|
||||
|
||||
matchDomains := mux.cfg.Value().MatchDomains
|
||||
if len(matchDomains) > 0 {
|
||||
origHandler := handler
|
||||
handler = func(w http.ResponseWriter, r *http.Request) {
|
||||
if httpheaders.IsWebsocket(r.Header) {
|
||||
httpheaders.SetWebsocketAllowedDomains(r.Header, matchDomains)
|
||||
}
|
||||
origHandler(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
if len(requireAuth) > 0 && requireAuth[0] {
|
||||
handler = auth.RequireAuth(handler)
|
||||
}
|
||||
|
||||
func (mux ServeMux) HandleFunc(methods, endpoint string, handler http.HandlerFunc) {
|
||||
for _, m := range strutils.CommaSeperatedList(methods) {
|
||||
mux.ServeMux.HandleFunc(m+" "+endpoint, handler)
|
||||
}
|
||||
}
|
||||
|
||||
func NewHandler(cfg config.ConfigInstance) http.Handler {
|
||||
mux := ServeMux{http.NewServeMux()}
|
||||
mux := ServeMux{http.NewServeMux(), cfg}
|
||||
mux.HandleFunc("GET", "/v1", v1.Index)
|
||||
mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
|
||||
mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload))
|
||||
mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(useCfg(cfg, v1.List)))
|
||||
mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(useCfg(cfg, v1.List)))
|
||||
mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(useCfg(cfg, v1.List)))
|
||||
mux.HandleFunc("GET", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.GetFileContent))
|
||||
mux.HandleFunc("POST,PUT", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent))
|
||||
mux.HandleFunc("POST", "/v1/file/validate/{type}", auth.RequireAuth(v1.ValidateFile))
|
||||
mux.HandleFunc("GET", "/v1/stats", useCfg(cfg, v1.Stats))
|
||||
mux.HandleFunc("GET", "/v1/stats/ws", useCfg(cfg, v1.StatsWS))
|
||||
mux.HandleFunc("GET", "/v1/health/ws", auth.RequireAuth(useCfg(cfg, v1.HealthWS)))
|
||||
mux.HandleFunc("GET", "/v1/logs/ws", auth.RequireAuth(memlogger.LogsWS(cfg.Value().MatchDomains)))
|
||||
mux.HandleFunc("GET", "/v1/favicon", auth.RequireAuth(favicon.GetFavIcon))
|
||||
mux.HandleFunc("POST", "/v1/homepage/set", auth.RequireAuth(v1.SetHomePageOverrides))
|
||||
mux.HandleFunc("GET", "/v1/agents/ws", auth.RequireAuth(useCfg(cfg, v1.AgentsWS)))
|
||||
mux.HandleFunc("GET", "/v1/metrics/system_info", auth.RequireAuth(useCfg(cfg, v1.SystemInfo)))
|
||||
mux.HandleFunc("GET", "/v1/metrics/system_info/ws", auth.RequireAuth(useCfg(cfg, v1.SystemInfo)))
|
||||
mux.HandleFunc("GET", "/v1/metrics/uptime", auth.RequireAuth(uptime.Poller.ServeHTTP))
|
||||
mux.HandleFunc("GET", "/v1/metrics/uptime/ws", auth.RequireAuth(useWS(cfg, uptime.Poller.ServeWS)))
|
||||
|
||||
mux.HandleFunc("GET", "/v1/stats", v1.Stats, true)
|
||||
mux.HandleFunc("POST", "/v1/reload", v1.Reload, true)
|
||||
mux.HandleFunc("GET", "/v1/list", v1.List, true)
|
||||
mux.HandleFunc("GET", "/v1/list/{what}", v1.List, true)
|
||||
mux.HandleFunc("GET", "/v1/list/{what}/{which}", v1.List, true)
|
||||
mux.HandleFunc("GET", "/v1/file/{type}/{filename}", v1.GetFileContent, true)
|
||||
mux.HandleFunc("POST,PUT", "/v1/file/{type}/{filename}", v1.SetFileContent, true)
|
||||
mux.HandleFunc("POST", "/v1/file/validate/{type}", v1.ValidateFile, true)
|
||||
mux.HandleFunc("GET", "/v1/health", v1.Health, true)
|
||||
mux.HandleFunc("GET", "/v1/logs", memlogger.Handler(), true)
|
||||
mux.HandleFunc("GET", "/v1/favicon", favicon.GetFavIcon, true)
|
||||
mux.HandleFunc("POST", "/v1/homepage/set", v1.SetHomePageOverrides, true)
|
||||
mux.HandleFunc("GET", "/v1/agents", v1.AgentsWS, true)
|
||||
mux.HandleFunc("GET", "/v1/metrics/system_info", v1.SystemInfo, true)
|
||||
mux.HandleFunc("GET", "/v1/metrics/uptime", uptime.Poller.ServeHTTP, true)
|
||||
|
||||
if common.PrometheusEnabled {
|
||||
mux.Handle("GET /v1/metrics", promhttp.Handler())
|
||||
@@ -69,15 +104,3 @@ func NewHandler(cfg config.ConfigInstance) http.Handler {
|
||||
}
|
||||
return mux
|
||||
}
|
||||
|
||||
func useCfg(cfg config.ConfigInstance, handler func(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
handler(cfg, w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func useWS(cfg config.ConfigInstance, handler func(allowedDomains []string, w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
handler(cfg.Value().MatchDomains, w, r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,11 +8,16 @@ import (
|
||||
"github.com/coder/websocket/wsjson"
|
||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
config "github.com/yusing/go-proxy/internal/config/types"
|
||||
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
|
||||
)
|
||||
|
||||
func AgentsWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
|
||||
U.PeriodicWS(cfg.Value().MatchDomains, w, r, 10*time.Second, func(conn *websocket.Conn) error {
|
||||
wsjson.Write(r.Context(), conn, cfg.ListAgents())
|
||||
return nil
|
||||
})
|
||||
if httpheaders.IsWebsocket(r.Header) {
|
||||
U.PeriodicWS(w, r, 10*time.Second, func(conn *websocket.Conn) error {
|
||||
wsjson.Write(r.Context(), conn, cfg.ListAgents())
|
||||
return nil
|
||||
})
|
||||
} else {
|
||||
U.RespondJSON(w, r, cfg.ListAgents())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,12 +7,16 @@ import (
|
||||
"github.com/coder/websocket"
|
||||
"github.com/coder/websocket/wsjson"
|
||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
config "github.com/yusing/go-proxy/internal/config/types"
|
||||
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/route/routes/routequery"
|
||||
)
|
||||
|
||||
func HealthWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
|
||||
U.PeriodicWS(cfg.Value().MatchDomains, w, r, 1*time.Second, func(conn *websocket.Conn) error {
|
||||
return wsjson.Write(r.Context(), conn, routequery.HealthMap())
|
||||
})
|
||||
func Health(w http.ResponseWriter, r *http.Request) {
|
||||
if httpheaders.IsWebsocket(r.Header) {
|
||||
U.PeriodicWS(w, r, 1*time.Second, func(conn *websocket.Conn) error {
|
||||
return wsjson.Write(r.Context(), conn, routequery.HealthMap())
|
||||
})
|
||||
} else {
|
||||
U.RespondJSON(w, r, routequery.HealthMap())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,17 +8,18 @@ import (
|
||||
"github.com/coder/websocket/wsjson"
|
||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
config "github.com/yusing/go-proxy/internal/config/types"
|
||||
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
func Stats(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
|
||||
U.RespondJSON(w, r, getStats(cfg))
|
||||
}
|
||||
|
||||
func StatsWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
|
||||
U.PeriodicWS(cfg.Value().MatchDomains, w, r, 1*time.Second, func(conn *websocket.Conn) error {
|
||||
return wsjson.Write(r.Context(), conn, getStats(cfg))
|
||||
})
|
||||
if httpheaders.IsWebsocket(r.Header) {
|
||||
U.PeriodicWS(w, r, 1*time.Second, func(conn *websocket.Conn) error {
|
||||
return wsjson.Write(r.Context(), conn, getStats(cfg))
|
||||
})
|
||||
} else {
|
||||
U.RespondJSON(w, r, getStats(cfg))
|
||||
}
|
||||
}
|
||||
|
||||
var startTime = time.Now()
|
||||
|
||||
@@ -2,24 +2,22 @@ package v1
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/coder/websocket/wsjson"
|
||||
agentPkg "github.com/yusing/go-proxy/agent/pkg/agent"
|
||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
config "github.com/yusing/go-proxy/internal/config/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
|
||||
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
|
||||
)
|
||||
|
||||
func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
|
||||
isWS := strings.HasSuffix(r.URL.Path, "/ws")
|
||||
agentName := r.URL.Query().Get("agent_name")
|
||||
query := r.URL.Query()
|
||||
agentName := query.Get("agent_name")
|
||||
query.Del("agent_name")
|
||||
if agentName == "" {
|
||||
if isWS {
|
||||
systeminfo.Poller.ServeWS(cfg.Value().MatchDomains, w, r)
|
||||
} else {
|
||||
systeminfo.Poller.ServeHTTP(w, r)
|
||||
}
|
||||
systeminfo.Poller.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -28,10 +26,12 @@ func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Reques
|
||||
U.HandleErr(w, r, U.ErrInvalidKey("agent_name"), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
isWS := httpheaders.IsWebsocket(r.Header)
|
||||
if !isWS {
|
||||
respData, status, err := agent.Forward(r, agentPkg.EndpointSystemInfo)
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, err)
|
||||
U.HandleErr(w, r, E.Wrap(err, "failed to forward request to agent"))
|
||||
return
|
||||
}
|
||||
if status != http.StatusOK {
|
||||
@@ -40,14 +40,16 @@ func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
U.WriteBody(w, respData)
|
||||
} else {
|
||||
clientConn, err := U.InitiateWS(cfg.Value().MatchDomains, w, r)
|
||||
r = r.WithContext(r.Context())
|
||||
clientConn, err := U.InitiateWS(w, r)
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, err)
|
||||
U.HandleErr(w, r, E.Wrap(err, "failed to initiate websocket"))
|
||||
return
|
||||
}
|
||||
agentConn, _, err := agent.Websocket(r.Context(), agentPkg.EndpointSystemInfo)
|
||||
defer clientConn.CloseNow()
|
||||
agentConn, _, err := agent.Websocket(r.Context(), agentPkg.EndpointSystemInfo+"?"+query.Encode())
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, err)
|
||||
U.HandleErr(w, r, E.Wrap(err, "failed to connect to agent with websocket"))
|
||||
return
|
||||
}
|
||||
//nolint:errcheck
|
||||
@@ -63,7 +65,7 @@ func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Reques
|
||||
err = wsjson.Write(r.Context(), clientConn, data)
|
||||
}
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, err)
|
||||
U.HandleErr(w, r, E.Wrap(err, "failed to write data to client"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/coder/websocket"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
|
||||
)
|
||||
|
||||
func warnNoMatchDomains() {
|
||||
@@ -16,11 +17,12 @@ func warnNoMatchDomains() {
|
||||
|
||||
var warnNoMatchDomainOnce sync.Once
|
||||
|
||||
func InitiateWS(allowedDomains []string, w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
|
||||
func InitiateWS(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
|
||||
var originPats []string
|
||||
|
||||
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
|
||||
|
||||
allowedDomains := httpheaders.WebsocketAllowedDomains(r.Header)
|
||||
if len(allowedDomains) == 0 || common.IsDebug {
|
||||
warnNoMatchDomainOnce.Do(warnNoMatchDomains)
|
||||
originPats = []string{"*"}
|
||||
@@ -40,8 +42,8 @@ func InitiateWS(allowedDomains []string, w http.ResponseWriter, r *http.Request)
|
||||
})
|
||||
}
|
||||
|
||||
func PeriodicWS(allowedDomains []string, w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) {
|
||||
conn, err := InitiateWS(allowedDomains, w, r)
|
||||
func PeriodicWS(w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) {
|
||||
conn, err := InitiateWS(w, r)
|
||||
if err != nil {
|
||||
HandleErr(w, r, err)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user