mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-18 14:39:49 +02:00
simplify setup process
This commit is contained in:
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/agent/pkg/certs"
|
||||
"github.com/yusing/go-proxy/agent/pkg/env"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
@@ -94,6 +93,14 @@ func (cfg *AgentConfig) errIfNameExists() E.Error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func withoutBuildTime(version string) string {
|
||||
return strings.Split(version, "-")[0]
|
||||
}
|
||||
|
||||
func checkVersion(a, b string) bool {
|
||||
return withoutBuildTime(a) == withoutBuildTime(b)
|
||||
}
|
||||
|
||||
func (cfg *AgentConfig) load() E.Error {
|
||||
certData, err := os.ReadFile(certs.AgentCertsFilename(cfg.Addr))
|
||||
if err != nil {
|
||||
@@ -132,15 +139,13 @@ func (cfg *AgentConfig) load() E.Error {
|
||||
defer cancel()
|
||||
|
||||
// check agent version
|
||||
if !env.AgentSkipVersionCheck {
|
||||
version, _, err := cfg.Fetch(ctx, EndpointVersion)
|
||||
if err != nil {
|
||||
return E.Wrap(err)
|
||||
}
|
||||
version, _, err := cfg.Fetch(ctx, EndpointVersion)
|
||||
if err != nil {
|
||||
return E.Wrap(err)
|
||||
}
|
||||
|
||||
if string(version) != pkg.GetVersion() {
|
||||
return E.Errorf("agent version mismatch: server: %s, agent: %s", pkg.GetVersion(), string(version))
|
||||
}
|
||||
if !checkVersion(string(version), pkg.GetVersion()) {
|
||||
return E.Errorf("agent version mismatch: server: %s, agent: %s", pkg.GetVersion(), string(version))
|
||||
}
|
||||
|
||||
// get agent name
|
||||
|
||||
30
agent/pkg/agent/utils.go
Normal file
30
agent/pkg/agent/utils.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func MachineIP() (string, bool) {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
interfaces = []net.Interface{}
|
||||
}
|
||||
for _, in := range interfaces {
|
||||
addrs, err := in.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(in.Name, "eth") && !strings.HasPrefix(in.Name, "en") {
|
||||
continue
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
|
||||
if ipnet.IP.To4() != nil {
|
||||
return ipnet.IP.String(), true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
@@ -32,6 +32,7 @@ func readFile(f *zip.File) ([]byte, error) {
|
||||
|
||||
func ZipCert(ca, crt, key []byte) ([]byte, error) {
|
||||
data := bytes.NewBuffer(nil)
|
||||
data.Grow(6144)
|
||||
zipWriter := zip.NewWriter(data)
|
||||
defer zipWriter.Close()
|
||||
|
||||
|
||||
56
agent/pkg/env/env.go
vendored
56
agent/pkg/env/env.go
vendored
@@ -1,7 +1,10 @@
|
||||
package env
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
)
|
||||
@@ -15,7 +18,54 @@ func DefaultAgentName() string {
|
||||
}
|
||||
|
||||
var (
|
||||
AgentName = common.GetEnvString("AGENT_NAME", DefaultAgentName())
|
||||
AgentPort = common.GetEnvInt("AGENT_PORT", 8890)
|
||||
AgentSkipVersionCheck = common.GetEnvBool("AGENT_SKIP_VERSION_CHECK", false)
|
||||
AgentName = common.GetEnvString("AGENT_NAME", DefaultAgentName())
|
||||
AgentPort = common.GetEnvInt("AGENT_PORT", 8890)
|
||||
AgentRegistrationPort = common.GetEnvInt("AGENT_REGISTRATION_PORT", 8891)
|
||||
AgentSkipClientCertCheck = common.GetEnvBool("AGENT_SKIP_CLIENT_CERT_CHECK", false)
|
||||
|
||||
RegistrationAllowedHosts = common.GetCommaSepEnv("REGISTRATION_ALLOWED_HOSTS", "")
|
||||
RegistrationAllowedCIDRs []*net.IPNet
|
||||
)
|
||||
|
||||
func init() {
|
||||
cidrs, err := toCIDRs(RegistrationAllowedHosts)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to parse allowed hosts: %v", err)
|
||||
}
|
||||
if len(cidrs) == 0 {
|
||||
log.Fatal("REGISTRATION_ALLOWED_HOSTS is empty")
|
||||
}
|
||||
RegistrationAllowedCIDRs = cidrs
|
||||
}
|
||||
|
||||
func toCIDRs(hosts []string) ([]*net.IPNet, error) {
|
||||
var cidrs []*net.IPNet
|
||||
for _, host := range hosts {
|
||||
if !strings.Contains(host, "/") {
|
||||
host += "/32"
|
||||
}
|
||||
_, cidr, err := net.ParseCIDR(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cidrs = append(cidrs, cidr)
|
||||
}
|
||||
return cidrs, nil
|
||||
}
|
||||
|
||||
func IsAllowedHost(remoteAddr string) bool {
|
||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
ip = remoteAddr
|
||||
}
|
||||
netIP := net.ParseIP(ip)
|
||||
if netIP == nil {
|
||||
return false
|
||||
}
|
||||
for _, cidr := range RegistrationAllowedCIDRs {
|
||||
if cidr.Contains(netIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/agent/pkg/agent"
|
||||
"github.com/yusing/go-proxy/agent/pkg/certs"
|
||||
"github.com/yusing/go-proxy/agent/pkg/env"
|
||||
v1 "github.com/yusing/go-proxy/internal/api/v1"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/logging/memlogger"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
@@ -32,7 +39,7 @@ func (NopWriteCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewHandler() http.Handler {
|
||||
func NewAgentHandler() http.Handler {
|
||||
mux := ServeMux{http.NewServeMux()}
|
||||
|
||||
mux.HandleFunc(agent.EndpointProxyHTTP+"/{path...}", ProxyHTTP)
|
||||
@@ -46,3 +53,46 @@ func NewHandler() http.Handler {
|
||||
mux.ServeMux.HandleFunc("/", DockerSocketHandler())
|
||||
return mux
|
||||
}
|
||||
|
||||
// NewRegistrationHandler creates a new registration handler
|
||||
// It checks if the request is coming from an allowed host
|
||||
// Generates a new client certificate and zips it
|
||||
// Sends the zipped certificate to the client
|
||||
// its run only once on agent first start.
|
||||
func NewRegistrationHandler(task *task.Task, ca *tls.Certificate) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !env.IsAllowedHost(r.RemoteAddr) {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/done" {
|
||||
logging.Info().Msg("registration done")
|
||||
task.Finish(nil)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
logging.Info().Msgf("received registration request from %s", r.RemoteAddr)
|
||||
|
||||
caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ca.Certificate[0]})
|
||||
|
||||
crt, key, err := certs.NewClientCert(ca)
|
||||
if err != nil {
|
||||
utils.HandleErr(w, r, E.Wrap(err, "failed to generate client certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
zipped, err := certs.ZipCert(caPEM, crt, key)
|
||||
if err != nil {
|
||||
utils.HandleErr(w, r, E.Wrap(err, "failed to zip certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/zip")
|
||||
if _, err := w.Write(zipped); err != nil {
|
||||
logging.Error().Err(err).Msg("failed to respond to registration request")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,9 +11,10 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/agent/pkg/env"
|
||||
"github.com/yusing/go-proxy/agent/pkg/handler"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/http/server"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
)
|
||||
|
||||
@@ -23,7 +24,7 @@ type Options struct {
|
||||
}
|
||||
|
||||
func StartAgentServer(parent task.Parent, opt Options) {
|
||||
t := parent.Subtask("agent server")
|
||||
t := parent.Subtask("agent_server")
|
||||
|
||||
caCertPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: opt.CACert.Certificate[0]})
|
||||
caCertPool := x509.NewCertPool()
|
||||
@@ -36,23 +37,24 @@ func StartAgentServer(parent task.Parent, opt Options) {
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
}
|
||||
|
||||
if common.IsDebug {
|
||||
if env.AgentSkipClientCertCheck {
|
||||
tlsConfig.ClientAuth = tls.NoClientCert
|
||||
}
|
||||
l, err := net.Listen("tcp", fmt.Sprintf(":%d", opt.Port))
|
||||
if err != nil {
|
||||
logging.Fatal().Err(err).Int("port", opt.Port).Msg("failed to listen on port")
|
||||
return
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Handler: handler.NewHandler(),
|
||||
agentServer := &http.Server{
|
||||
Handler: handler.NewAgentHandler(),
|
||||
TLSConfig: tlsConfig,
|
||||
ErrorLog: log.New(logging.GetLogger(), "", 0),
|
||||
}
|
||||
|
||||
go func() {
|
||||
l, err := net.Listen("tcp", fmt.Sprintf(":%d", opt.Port))
|
||||
if err != nil {
|
||||
logging.Fatal().Err(err).Int("port", opt.Port).Msg("failed to listen on port")
|
||||
return
|
||||
}
|
||||
defer l.Close()
|
||||
if err := server.Serve(tls.NewListener(l, tlsConfig)); err != nil {
|
||||
if err := agentServer.Serve(tls.NewListener(l, tlsConfig)); err != nil {
|
||||
logging.Fatal().Err(err).Int("port", opt.Port).Msg("failed to serve")
|
||||
}
|
||||
}()
|
||||
@@ -66,10 +68,38 @@ func StartAgentServer(parent task.Parent, opt Options) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := server.Shutdown(ctx)
|
||||
err := agentServer.Shutdown(ctx)
|
||||
if err != nil {
|
||||
logging.Error().Err(err).Int("port", opt.Port).Msg("failed to shutdown agent server")
|
||||
}
|
||||
logging.Info().Int("port", opt.Port).Msg("agent server stopped")
|
||||
}()
|
||||
}
|
||||
|
||||
func StartRegistrationServer(parent task.Parent, opt Options) {
|
||||
t := parent.Subtask("registration_server")
|
||||
|
||||
registrationServer := &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", opt.Port),
|
||||
Handler: handler.NewRegistrationHandler(t, opt.CACert),
|
||||
ErrorLog: log.New(logging.GetLogger(), "", 0),
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := registrationServer.ListenAndServe()
|
||||
server.HandleError(logging.GetLogger(), err)
|
||||
}()
|
||||
|
||||
logging.Info().Int("port", opt.Port).Msg("registration server started")
|
||||
|
||||
defer t.Finish(nil)
|
||||
<-t.Context().Done()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := registrationServer.Shutdown(ctx)
|
||||
server.HandleError(logging.GetLogger(), err)
|
||||
|
||||
logging.Info().Int("port", opt.Port).Msg("registration server stopped")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user