simplify setup process

This commit is contained in:
yusing
2025-02-11 05:05:56 +08:00
parent 2c57e439d5
commit 3332ce34c5
21 changed files with 386 additions and 206 deletions

View File

@@ -13,7 +13,6 @@ import (
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/agent/pkg/certs"
"github.com/yusing/go-proxy/agent/pkg/env"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/http"
@@ -94,6 +93,14 @@ func (cfg *AgentConfig) errIfNameExists() E.Error {
return nil
}
func withoutBuildTime(version string) string {
return strings.Split(version, "-")[0]
}
func checkVersion(a, b string) bool {
return withoutBuildTime(a) == withoutBuildTime(b)
}
func (cfg *AgentConfig) load() E.Error {
certData, err := os.ReadFile(certs.AgentCertsFilename(cfg.Addr))
if err != nil {
@@ -132,15 +139,13 @@ func (cfg *AgentConfig) load() E.Error {
defer cancel()
// check agent version
if !env.AgentSkipVersionCheck {
version, _, err := cfg.Fetch(ctx, EndpointVersion)
if err != nil {
return E.Wrap(err)
}
version, _, err := cfg.Fetch(ctx, EndpointVersion)
if err != nil {
return E.Wrap(err)
}
if string(version) != pkg.GetVersion() {
return E.Errorf("agent version mismatch: server: %s, agent: %s", pkg.GetVersion(), string(version))
}
if !checkVersion(string(version), pkg.GetVersion()) {
return E.Errorf("agent version mismatch: server: %s, agent: %s", pkg.GetVersion(), string(version))
}
// get agent name

30
agent/pkg/agent/utils.go Normal file
View File

@@ -0,0 +1,30 @@
package agent
import (
"net"
"strings"
)
func MachineIP() (string, bool) {
interfaces, err := net.Interfaces()
if err != nil {
interfaces = []net.Interface{}
}
for _, in := range interfaces {
addrs, err := in.Addrs()
if err != nil {
continue
}
if !strings.HasPrefix(in.Name, "eth") && !strings.HasPrefix(in.Name, "en") {
continue
}
for _, addr := range addrs {
if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
if ipnet.IP.To4() != nil {
return ipnet.IP.String(), true
}
}
}
}
return "", false
}

View File

@@ -32,6 +32,7 @@ func readFile(f *zip.File) ([]byte, error) {
func ZipCert(ca, crt, key []byte) ([]byte, error) {
data := bytes.NewBuffer(nil)
data.Grow(6144)
zipWriter := zip.NewWriter(data)
defer zipWriter.Close()

56
agent/pkg/env/env.go vendored
View File

@@ -1,7 +1,10 @@
package env
import (
"log"
"net"
"os"
"strings"
"github.com/yusing/go-proxy/internal/common"
)
@@ -15,7 +18,54 @@ func DefaultAgentName() string {
}
var (
AgentName = common.GetEnvString("AGENT_NAME", DefaultAgentName())
AgentPort = common.GetEnvInt("AGENT_PORT", 8890)
AgentSkipVersionCheck = common.GetEnvBool("AGENT_SKIP_VERSION_CHECK", false)
AgentName = common.GetEnvString("AGENT_NAME", DefaultAgentName())
AgentPort = common.GetEnvInt("AGENT_PORT", 8890)
AgentRegistrationPort = common.GetEnvInt("AGENT_REGISTRATION_PORT", 8891)
AgentSkipClientCertCheck = common.GetEnvBool("AGENT_SKIP_CLIENT_CERT_CHECK", false)
RegistrationAllowedHosts = common.GetCommaSepEnv("REGISTRATION_ALLOWED_HOSTS", "")
RegistrationAllowedCIDRs []*net.IPNet
)
func init() {
cidrs, err := toCIDRs(RegistrationAllowedHosts)
if err != nil {
log.Fatalf("failed to parse allowed hosts: %v", err)
}
if len(cidrs) == 0 {
log.Fatal("REGISTRATION_ALLOWED_HOSTS is empty")
}
RegistrationAllowedCIDRs = cidrs
}
func toCIDRs(hosts []string) ([]*net.IPNet, error) {
var cidrs []*net.IPNet
for _, host := range hosts {
if !strings.Contains(host, "/") {
host += "/32"
}
_, cidr, err := net.ParseCIDR(host)
if err != nil {
return nil, err
}
cidrs = append(cidrs, cidr)
}
return cidrs, nil
}
func IsAllowedHost(remoteAddr string) bool {
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
ip = remoteAddr
}
netIP := net.ParseIP(ip)
if netIP == nil {
return false
}
for _, cidr := range RegistrationAllowedCIDRs {
if cidr.Contains(netIP) {
return true
}
}
return false
}

View File

@@ -1,14 +1,21 @@
package handler
import (
"crypto/tls"
"encoding/pem"
"fmt"
"io"
"net/http"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/agent/pkg/certs"
"github.com/yusing/go-proxy/agent/pkg/env"
v1 "github.com/yusing/go-proxy/internal/api/v1"
"github.com/yusing/go-proxy/internal/api/v1/utils"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/logging/memlogger"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@@ -32,7 +39,7 @@ func (NopWriteCloser) Close() error {
return nil
}
func NewHandler() http.Handler {
func NewAgentHandler() http.Handler {
mux := ServeMux{http.NewServeMux()}
mux.HandleFunc(agent.EndpointProxyHTTP+"/{path...}", ProxyHTTP)
@@ -46,3 +53,46 @@ func NewHandler() http.Handler {
mux.ServeMux.HandleFunc("/", DockerSocketHandler())
return mux
}
// NewRegistrationHandler creates a new registration handler
// It checks if the request is coming from an allowed host
// Generates a new client certificate and zips it
// Sends the zipped certificate to the client
// its run only once on agent first start.
func NewRegistrationHandler(task *task.Task, ca *tls.Certificate) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !env.IsAllowedHost(r.RemoteAddr) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
if r.URL.Path == "/done" {
logging.Info().Msg("registration done")
task.Finish(nil)
w.WriteHeader(http.StatusOK)
return
}
logging.Info().Msgf("received registration request from %s", r.RemoteAddr)
caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ca.Certificate[0]})
crt, key, err := certs.NewClientCert(ca)
if err != nil {
utils.HandleErr(w, r, E.Wrap(err, "failed to generate client certificate"))
return
}
zipped, err := certs.ZipCert(caPEM, crt, key)
if err != nil {
utils.HandleErr(w, r, E.Wrap(err, "failed to zip certificate"))
return
}
w.Header().Set("Content-Type", "application/zip")
if _, err := w.Write(zipped); err != nil {
logging.Error().Err(err).Msg("failed to respond to registration request")
return
}
}
}

View File

@@ -11,9 +11,10 @@ import (
"net/http"
"time"
"github.com/yusing/go-proxy/agent/pkg/env"
"github.com/yusing/go-proxy/agent/pkg/handler"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/http/server"
"github.com/yusing/go-proxy/internal/task"
)
@@ -23,7 +24,7 @@ type Options struct {
}
func StartAgentServer(parent task.Parent, opt Options) {
t := parent.Subtask("agent server")
t := parent.Subtask("agent_server")
caCertPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: opt.CACert.Certificate[0]})
caCertPool := x509.NewCertPool()
@@ -36,23 +37,24 @@ func StartAgentServer(parent task.Parent, opt Options) {
ClientAuth: tls.RequireAndVerifyClientCert,
}
if common.IsDebug {
if env.AgentSkipClientCertCheck {
tlsConfig.ClientAuth = tls.NoClientCert
}
l, err := net.Listen("tcp", fmt.Sprintf(":%d", opt.Port))
if err != nil {
logging.Fatal().Err(err).Int("port", opt.Port).Msg("failed to listen on port")
return
}
server := &http.Server{
Handler: handler.NewHandler(),
agentServer := &http.Server{
Handler: handler.NewAgentHandler(),
TLSConfig: tlsConfig,
ErrorLog: log.New(logging.GetLogger(), "", 0),
}
go func() {
l, err := net.Listen("tcp", fmt.Sprintf(":%d", opt.Port))
if err != nil {
logging.Fatal().Err(err).Int("port", opt.Port).Msg("failed to listen on port")
return
}
defer l.Close()
if err := server.Serve(tls.NewListener(l, tlsConfig)); err != nil {
if err := agentServer.Serve(tls.NewListener(l, tlsConfig)); err != nil {
logging.Fatal().Err(err).Int("port", opt.Port).Msg("failed to serve")
}
}()
@@ -66,10 +68,38 @@ func StartAgentServer(parent task.Parent, opt Options) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err := server.Shutdown(ctx)
err := agentServer.Shutdown(ctx)
if err != nil {
logging.Error().Err(err).Int("port", opt.Port).Msg("failed to shutdown agent server")
}
logging.Info().Int("port", opt.Port).Msg("agent server stopped")
}()
}
func StartRegistrationServer(parent task.Parent, opt Options) {
t := parent.Subtask("registration_server")
registrationServer := &http.Server{
Addr: fmt.Sprintf(":%d", opt.Port),
Handler: handler.NewRegistrationHandler(t, opt.CACert),
ErrorLog: log.New(logging.GetLogger(), "", 0),
}
go func() {
err := registrationServer.ListenAndServe()
server.HandleError(logging.GetLogger(), err)
}()
logging.Info().Int("port", opt.Port).Msg("registration server started")
defer t.Finish(nil)
<-t.Context().Done()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err := registrationServer.Shutdown(ctx)
server.HandleError(logging.GetLogger(), err)
logging.Info().Int("port", opt.Port).Msg("registration server stopped")
}