diff --git a/goutils b/goutils index 900faa77..813b4fae 160000 --- a/goutils +++ b/goutils @@ -1 +1 @@ -Subproject commit 900faa77c8e276acefaa73a327e09ee811a5af4c +Subproject commit 813b4fae7feeb0591544eb68704d1204c1c22192 diff --git a/internal/net/gphttp/middleware/crowdsec.go b/internal/net/gphttp/middleware/crowdsec.go new file mode 100644 index 00000000..ec3af22b --- /dev/null +++ b/internal/net/gphttp/middleware/crowdsec.go @@ -0,0 +1,203 @@ +package middleware + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "net/http" + "strconv" + "strings" + "time" + + "github.com/yusing/godoxy/internal/route/routes" + httputils "github.com/yusing/goutils/http" + ioutils "github.com/yusing/goutils/io" +) + +type ( + crowdsecMiddleware struct { + CrowdsecMiddlewareOpts + } + + CrowdsecMiddlewareOpts struct { + Route string `json:"route" validate:"required"` // route name (alias) or IP address + Port int `json:"port"` // port number (optional if using route name) + APIKey string `json:"api_key" validate:"required"` // API key for CrowdSec AppSec (mandatory) + Endpoint string `json:"endpoint"` // default: "/" + LogBlocked bool `json:"log_blocked"` // default: false + Timeout time.Duration `json:"timeout"` // default: 5 seconds + + httpClient *http.Client + } +) + +var Crowdsec = NewMiddleware[crowdsecMiddleware]() + +func (m *crowdsecMiddleware) setup() { + m.CrowdsecMiddlewareOpts = CrowdsecMiddlewareOpts{ + Route: "", + Port: 7422, // default port for CrowdSec AppSec + APIKey: "", + Endpoint: "/", + LogBlocked: false, + Timeout: 5 * time.Second, + } +} + +func (m *crowdsecMiddleware) finalize() error { + if !strings.HasPrefix(m.Endpoint, "/") { + return fmt.Errorf("endpoint must start with /") + } + if m.Timeout == 0 { + m.Timeout = 5 * time.Second + } + m.httpClient = &http.Client{ + Timeout: m.Timeout, + // do not follow redirects + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + return nil +} + +// before implements RequestModifier. +func (m *crowdsecMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) { + // Build CrowdSec URL + crowdsecURL, err := m.buildCrowdSecURL() + if err != nil { + Crowdsec.LogError(r).Err(err).Msg("failed to build CrowdSec URL") + w.WriteHeader(http.StatusInternalServerError) + return false + } + + // Determine HTTP method: GET for requests without body, POST for requests with body + method := http.MethodGet + var body io.Reader + if r.Body != nil && r.Body != http.NoBody { + method = http.MethodPost + // Read the body + bodyBytes, release, err := httputils.ReadAllRequestBody(r) + if err != nil { + Crowdsec.LogError(r).Err(err).Msg("failed to read request body") + w.WriteHeader(http.StatusInternalServerError) + return false + } + r.Body = ioutils.NewHookReadCloser(io.NopCloser(bytes.NewReader(bodyBytes)), func() { + release(bodyBytes) + }) + body = bytes.NewReader(bodyBytes) + } + + ctx, cancel := context.WithTimeout(r.Context(), m.Timeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, method, crowdsecURL, body) + if err != nil { + Crowdsec.LogError(r).Err(err).Msg("failed to create CrowdSec request") + w.WriteHeader(http.StatusInternalServerError) + return false + } + + // Get remote IP + remoteIP, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + remoteIP = r.RemoteAddr + } + + // Get HTTP version in integer form (10, 11, 20, etc.) + httpVersion := m.getHTTPVersion(r) + + // Copy original headers + req.Header = r.Header.Clone() + + // Overwrite CrowdSec required headers to prevent spoofing + req.Header.Set("X-Crowdsec-Appsec-Ip", remoteIP) + req.Header.Set("X-Crowdsec-Appsec-Uri", r.URL.RequestURI()) + req.Header.Set("X-Crowdsec-Appsec-Host", r.Host) + req.Header.Set("X-Crowdsec-Appsec-Verb", r.Method) + req.Header.Set("X-Crowdsec-Appsec-Api-Key", m.APIKey) + req.Header.Set("X-Crowdsec-Appsec-User-Agent", r.UserAgent()) + req.Header.Set("X-Crowdsec-Appsec-Http-Version", httpVersion) + + // Make request to CrowdSec + resp, err := m.httpClient.Do(req) + if err != nil { + Crowdsec.LogError(r).Err(err).Msg("failed to connect to CrowdSec server") + w.WriteHeader(http.StatusInternalServerError) + return false + } + defer resp.Body.Close() + + // Handle response codes + switch resp.StatusCode { + case http.StatusOK: + // Request is allowed + return true + case http.StatusForbidden: + // Request is blocked by CrowdSec + if m.LogBlocked { + Crowdsec.LogWarn(r). + Str("ip", remoteIP). + Msg("request blocked by CrowdSec") + } + w.WriteHeader(http.StatusForbidden) + return false + case http.StatusInternalServerError: + // CrowdSec server error + bodyBytes, release, err := httputils.ReadAllBody(resp) + if err == nil { + defer release(bodyBytes) + Crowdsec.LogError(r). + Str("crowdsec_response", string(bodyBytes)). + Msg("CrowdSec server error") + } + w.WriteHeader(http.StatusInternalServerError) + return false + default: + // Unexpected response code + Crowdsec.LogWarn(r). + Int("status_code", resp.StatusCode). + Msg("unexpected response from CrowdSec server") + w.WriteHeader(http.StatusInternalServerError) + return false + } +} + +// buildCrowdSecURL constructs the CrowdSec server URL based on route or IP configuration +func (m *crowdsecMiddleware) buildCrowdSecURL() (string, error) { + // Try to get route first + if m.Route != "" { + if route, ok := routes.HTTP.Get(m.Route); ok { + // Using route name + targetURL := *route.TargetURL() + targetURL.Path = m.Endpoint + return targetURL.String(), nil + } + + // If not found in routes, assume it's an IP address + if m.Port == 0 { + return "", fmt.Errorf("port must be specified when using IP address") + } + return fmt.Sprintf("http://%s%s", net.JoinHostPort(m.Route, strconv.Itoa(m.Port)), m.Endpoint), nil + } + + return "", fmt.Errorf("route or IP address must be specified") +} + +func (m *crowdsecMiddleware) getHTTPVersion(r *http.Request) string { + switch { + case r.ProtoMajor == 1 && r.ProtoMinor == 0: + return "10" + case r.ProtoMajor == 1 && r.ProtoMinor == 1: + return "11" + case r.ProtoMajor == 2: + return "20" + case r.ProtoMajor == 3: + return "30" + default: + return strconv.Itoa(r.ProtoMajor*10 + r.ProtoMinor) + } +} diff --git a/internal/net/gphttp/middleware/middlewares.go b/internal/net/gphttp/middleware/middlewares.go index bf0699e4..76b28682 100644 --- a/internal/net/gphttp/middleware/middlewares.go +++ b/internal/net/gphttp/middleware/middlewares.go @@ -19,6 +19,7 @@ var allMiddlewares = map[string]*Middleware{ "oidc": OIDC, "forwardauth": ForwardAuth, + "crowdsec": Crowdsec, "request": ModifyRequest, "modifyrequest": ModifyRequest,