Added mTLS support

This commit is contained in:
Jakub Vavřík
2021-05-11 23:05:21 +02:00
parent 75518fc898
commit 602b0c8591

View File

@@ -17,7 +17,11 @@ package cmd
import ( import (
"context" "context"
"crypto/tls"
"crypto/x509"
"fmt" "fmt"
"google.golang.org/grpc/credentials"
"io/ioutil"
"log" "log"
"strconv" "strconv"
"strings" "strings"
@@ -25,7 +29,6 @@ import (
"github.com/dexidp/dex/api/v2" "github.com/dexidp/dex/api/v2"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
) )
// registerCmd represents the register command // registerCmd represents the register command
@@ -41,6 +44,8 @@ var registerCmd = &cobra.Command{
var host string var host string
var port int var port int
var caPath string var caPath string
var clientCert string
var clientKey string
var clientId string var clientId string
var clientSecret string var clientSecret string
var redirectUris []string var redirectUris []string
@@ -55,8 +60,10 @@ func init() {
_ = registerCmd.MarkFlagRequired("address") _ = registerCmd.MarkFlagRequired("address")
registerCmd.Flags().IntVarP(&port, "port", "p", 5557, "Host port to connect to") registerCmd.Flags().IntVarP(&port, "port", "p", 5557, "Host port to connect to")
_ = registerCmd.MarkFlagRequired("port") _ = registerCmd.MarkFlagRequired("port")
registerCmd.Flags().StringVarP(&caPath, "cacertpath", "t", "/", "Path to client CA cert to connect to") registerCmd.Flags().StringVarP(&caPath, "cacertpath", "t", "/etc/dex/ca.crt", "Path to client CA cert to connect to")
_ = registerCmd.MarkFlagRequired("cacertpath") _ = registerCmd.MarkFlagRequired("cacertpath")
registerCmd.Flags().StringVarP(&clientCert, "clientCert", "e", "", "Path to client cert for mTLS")
registerCmd.Flags().StringVarP(&clientKey, "clientKey", "k", "", "Path to client key for mTLS")
registerCmd.Flags().StringVarP(&clientId, "clientid", "c", "", "ClientID to register") registerCmd.Flags().StringVarP(&clientId, "clientid", "c", "", "ClientID to register")
_ = registerCmd.MarkFlagRequired("clientid") _ = registerCmd.MarkFlagRequired("clientid")
@@ -67,12 +74,37 @@ func init() {
} }
func newDexClient(hostAndPort string) (api.DexClient, error) { func newDexClient(hostAndPort string) (api.DexClient, error) {
creds, err := credentials.NewClientTLSFromFile(caPath, "") //---------- TLS Setting -----------//
clientCertificate, err := tls.LoadX509KeyPair(
clientCert,
clientKey,
)
if err != nil { if err != nil {
return nil, fmt.Errorf("load dex cert: %v", err) log.Fatalf("failed to read client cert or key: %s", err)
}
serverCertPool := x509.NewCertPool()
bs, err := ioutil.ReadFile(caPath)
if err != nil {
log.Fatalf("failed to read ca cert: %s", err)
} }
conn, err := grpc.Dial(hostAndPort, grpc.WithTransportCredentials(creds)) ok := serverCertPool.AppendCertsFromPEM(bs)
if !ok {
log.Fatal("failed to append certs")
}
var transportCreds credentials.TransportCredentials
if clientCert != "" {
transportCreds = credentials.NewTLS(&tls.Config{
Certificates: []tls.Certificate{clientCertificate},
RootCAs: serverCertPool,
})
} else {
transportCreds = credentials.NewTLS(&tls.Config{
RootCAs: serverCertPool,
})
}
conn, err := grpc.Dial(hostAndPort, grpc.WithTransportCredentials(transportCreds))
if err != nil { if err != nil {
return nil, fmt.Errorf("dial: %v", err) return nil, fmt.Errorf("dial: %v", err)
} }