From 602b0c8591cbd23854b4b1d3c28e5fd4f8c0a295 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Vav=C5=99=C3=ADk?= Date: Tue, 11 May 2021 23:05:21 +0200 Subject: [PATCH] Added mTLS support --- cmd/register.go | 42 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/cmd/register.go b/cmd/register.go index ac7c3fd..a037e55 100644 --- a/cmd/register.go +++ b/cmd/register.go @@ -17,7 +17,11 @@ package cmd import ( "context" + "crypto/tls" + "crypto/x509" "fmt" + "google.golang.org/grpc/credentials" + "io/ioutil" "log" "strconv" "strings" @@ -25,7 +29,6 @@ import ( "github.com/dexidp/dex/api/v2" "github.com/spf13/cobra" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" ) // registerCmd represents the register command @@ -41,6 +44,8 @@ var registerCmd = &cobra.Command{ var host string var port int var caPath string +var clientCert string +var clientKey string var clientId string var clientSecret string var redirectUris []string @@ -55,8 +60,10 @@ func init() { _ = registerCmd.MarkFlagRequired("address") registerCmd.Flags().IntVarP(&port, "port", "p", 5557, "Host port to connect to") _ = registerCmd.MarkFlagRequired("port") - registerCmd.Flags().StringVarP(&caPath, "cacertpath", "t", "/", "Path to client CA cert to connect to") + registerCmd.Flags().StringVarP(&caPath, "cacertpath", "t", "/etc/dex/ca.crt", "Path to client CA cert to connect to") _ = registerCmd.MarkFlagRequired("cacertpath") + registerCmd.Flags().StringVarP(&clientCert, "clientCert", "e", "", "Path to client cert for mTLS") + registerCmd.Flags().StringVarP(&clientKey, "clientKey", "k", "", "Path to client key for mTLS") registerCmd.Flags().StringVarP(&clientId, "clientid", "c", "", "ClientID to register") _ = registerCmd.MarkFlagRequired("clientid") @@ -67,12 +74,37 @@ func init() { } func newDexClient(hostAndPort string) (api.DexClient, error) { - creds, err := credentials.NewClientTLSFromFile(caPath, "") + //---------- TLS Setting -----------// + clientCertificate, err := tls.LoadX509KeyPair( + clientCert, + clientKey, + ) if err != nil { - return nil, fmt.Errorf("load dex cert: %v", err) + log.Fatalf("failed to read client cert or key: %s", err) + } + serverCertPool := x509.NewCertPool() + bs, err := ioutil.ReadFile(caPath) + if err != nil { + log.Fatalf("failed to read ca cert: %s", err) } - conn, err := grpc.Dial(hostAndPort, grpc.WithTransportCredentials(creds)) + ok := serverCertPool.AppendCertsFromPEM(bs) + if !ok { + log.Fatal("failed to append certs") + } + var transportCreds credentials.TransportCredentials + if clientCert != "" { + transportCreds = credentials.NewTLS(&tls.Config{ + Certificates: []tls.Certificate{clientCertificate}, + RootCAs: serverCertPool, + }) + } else { + transportCreds = credentials.NewTLS(&tls.Config{ + RootCAs: serverCertPool, + }) + } + + conn, err := grpc.Dial(hostAndPort, grpc.WithTransportCredentials(transportCreds)) if err != nil { return nil, fmt.Errorf("dial: %v", err) }