Files
godoxy/internal/entrypoint/http_server.go
yusing 2a3823091d feat(entrypoint): add inbound mTLS profiles for HTTPS
Add root-level inbound_mtls_profiles combining optional system CAs with PEM
CA files, and entrypoint.inbound_mtls_profile to require client certificates
on every HTTPS connection. Route-level inbound_mtls_profile is allowed only
without a global profile; per-handshake TLS picks ClientCAs from SNI, and
requests fail with 421 when Host and SNI would select different mTLS routes.

Compile pools at init (SetInboundMTLSProfiles from state.initEntrypoint) and
reject unknown profile refs or mixed global-plus-route configuration.

Extend config.example.yml and package READMEs; add entrypoint and config
tests for TLS mutation, handshakes, and validation.
2026-04-13 15:14:57 +08:00

224 lines
5.9 KiB
Go

package entrypoint
import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/rs/zerolog/log"
acl "github.com/yusing/godoxy/internal/acl/types"
autocert "github.com/yusing/godoxy/internal/autocert/types"
"github.com/yusing/godoxy/internal/common"
"github.com/yusing/godoxy/internal/logging/accesslog"
"github.com/yusing/godoxy/internal/net/gphttp/middleware"
"github.com/yusing/godoxy/internal/net/gphttp/middleware/errorpage"
"github.com/yusing/godoxy/internal/route/routes"
"github.com/yusing/godoxy/internal/types"
"github.com/yusing/goutils/pool"
"github.com/yusing/goutils/server"
)
// HTTPServer is a server that listens on a given address and serves HTTP routes.
type HTTPServer interface {
Listen(addr string, proto HTTPProto) error
AddRoute(route types.HTTPRoute)
DelRoute(route types.HTTPRoute)
FindRoute(s string) types.HTTPRoute
ServeHTTP(w http.ResponseWriter, r *http.Request)
}
type httpServer struct {
ep *Entrypoint
stopFunc func(reason any)
addr string
routes *pool.Pool[types.HTTPRoute]
}
type HTTPProto string
const (
HTTPProtoHTTP HTTPProto = "http"
HTTPProtoHTTPS HTTPProto = "https"
)
func NewHTTPServer(ep *Entrypoint) HTTPServer {
return newHTTPServer(ep)
}
func newHTTPServer(ep *Entrypoint) *httpServer {
return &httpServer{ep: ep}
}
// Listen starts the server and stop when entrypoint is stopped.
func (srv *httpServer) Listen(addr string, proto HTTPProto) error {
if srv.addr != "" {
return errors.New("server already started")
}
opts := server.Options{
Name: addr,
Handler: srv,
ACL: acl.FromCtx(srv.ep.task.Context()),
SupportProxyProtocol: srv.ep.cfg.SupportProxyProtocol,
}
switch proto {
case HTTPProtoHTTP:
opts.HTTPAddr = addr
case HTTPProtoHTTPS:
opts.HTTPSAddr = addr
opts.CertProvider = autocert.FromCtx(srv.ep.task.Context())
opts.TLSConfigMutator = srv.mutateServerTLSConfig
}
task := srv.ep.task.Subtask("http_server", false)
_, err := server.StartServer(task, opts)
if err != nil {
return err
}
srv.stopFunc = task.FinishAndWait
srv.addr = addr
srv.routes = pool.New[types.HTTPRoute](fmt.Sprintf("[%s] %s", proto, addr), "http_routes")
srv.routes.DisableLog(srv.ep.httpPoolDisableLog.Load())
return nil
}
func (srv *httpServer) Close() {
if srv.stopFunc == nil {
return
}
srv.stopFunc(nil)
}
func (srv *httpServer) AddRoute(route types.HTTPRoute) {
srv.routes.Add(route)
}
func (srv *httpServer) DelRoute(route types.HTTPRoute) {
srv.routes.Del(route)
}
func (srv *httpServer) FindRoute(s string) types.HTTPRoute {
return srv.ep.findRouteFunc(srv.routes, s)
}
func (srv *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if srv.ep.accessLogger != nil {
rec := accesslog.GetResponseRecorder(w)
w = rec
defer func() {
// there is no body to close
//nolint:bodyclose
srv.ep.accessLogger.LogRequest(r, rec.Response())
accesslog.PutResponseRecorder(rec)
}()
}
route, err := srv.resolveRequestRoute(r)
switch {
case err != nil:
http.Error(w, err.Error(), http.StatusMisdirectedRequest)
return
case route != nil:
r = routes.WithRouteContext(r, route)
if srv.ep.middleware != nil {
srv.ep.middleware.ServeHTTP(route.ServeHTTP, w, r)
} else {
route.ServeHTTP(w, r)
}
case srv.tryHandleShortLink(w, r):
return
case srv.ep.notFoundHandler != nil:
srv.ep.notFoundHandler.ServeHTTP(w, r)
default:
serveNotFound(w, r)
}
}
var (
errSecureRouteRequiresSNI = errors.New("secure route requires matching TLS SNI")
errSecureRouteMisdirected = errors.New("secure route host must match TLS SNI")
)
func (srv *httpServer) resolveRequestRoute(req *http.Request) (types.HTTPRoute, error) {
hostRoute := srv.FindRoute(req.Host)
if req.TLS == nil || srv.ep.cfg.InboundMTLSProfile != "" || len(srv.ep.inboundMTLSProfiles) == 0 {
return hostRoute, nil
}
serverName := req.TLS.ServerName
if serverName == "" {
if pool := srv.resolveInboundMTLSProfileForRoute(hostRoute); pool != nil {
return nil, errSecureRouteRequiresSNI
}
return hostRoute, nil
}
sniRoute := srv.FindRoute(serverName)
if pool := srv.resolveInboundMTLSProfileForRoute(sniRoute); pool != nil {
if !sameHTTPRoute(hostRoute, sniRoute) {
return nil, errSecureRouteMisdirected
}
return sniRoute, nil
}
if pool := srv.resolveInboundMTLSProfileForRoute(hostRoute); pool != nil {
return nil, errSecureRouteMisdirected
}
return hostRoute, nil
}
func sameHTTPRoute(left, right types.HTTPRoute) bool {
switch {
case left == nil || right == nil:
return left == right
case left == right:
return true
default:
return left.Key() == right.Key()
}
}
func (srv *httpServer) tryHandleShortLink(w http.ResponseWriter, r *http.Request) (handled bool) {
host := r.Host
if before, _, ok := strings.Cut(host, ":"); ok {
host = before
}
if strings.EqualFold(host, common.ShortLinkPrefix) {
if srv.ep.middleware != nil {
srv.ep.middleware.ServeHTTP(srv.ep.shortLinkMatcher.ServeHTTP, w, r)
} else {
srv.ep.shortLinkMatcher.ServeHTTP(w, r)
}
return true
}
return false
}
func serveNotFound(w http.ResponseWriter, r *http.Request) {
// Why use StatusNotFound instead of StatusBadRequest or StatusBadGateway?
// On nginx, when route for domain does not exist, it returns StatusBadGateway.
// Then scraper / scanners will know the subdomain is invalid.
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
if served := middleware.ServeStaticErrorPageFile(w, r); !served {
log.Warn().
Str("method", r.Method).
Str("url", r.URL.String()).
Str("remote", r.RemoteAddr).
Msgf("not found: %s", r.Host)
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
if ok {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusNotFound)
if _, err := w.Write(errorPage); err != nil {
log.Err(err).Msg("failed to write error page")
}
} else {
http.NotFound(w, r)
}
}
}