mirror of
https://github.com/ysoftdevs/dexregistrar.git
synced 2026-03-20 08:14:41 +01:00
Added mTLS support
This commit is contained in:
@@ -17,7 +17,11 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -25,7 +29,6 @@ import (
|
||||
"github.com/dexidp/dex/api/v2"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
// registerCmd represents the register command
|
||||
@@ -41,6 +44,8 @@ var registerCmd = &cobra.Command{
|
||||
var host string
|
||||
var port int
|
||||
var caPath string
|
||||
var clientCert string
|
||||
var clientKey string
|
||||
var clientId string
|
||||
var clientSecret string
|
||||
var redirectUris []string
|
||||
@@ -55,8 +60,10 @@ func init() {
|
||||
_ = registerCmd.MarkFlagRequired("address")
|
||||
registerCmd.Flags().IntVarP(&port, "port", "p", 5557, "Host port to connect to")
|
||||
_ = registerCmd.MarkFlagRequired("port")
|
||||
registerCmd.Flags().StringVarP(&caPath, "cacertpath", "t", "/", "Path to client CA cert to connect to")
|
||||
registerCmd.Flags().StringVarP(&caPath, "cacertpath", "t", "/etc/dex/ca.crt", "Path to client CA cert to connect to")
|
||||
_ = registerCmd.MarkFlagRequired("cacertpath")
|
||||
registerCmd.Flags().StringVarP(&clientCert, "clientCert", "e", "", "Path to client cert for mTLS")
|
||||
registerCmd.Flags().StringVarP(&clientKey, "clientKey", "k", "", "Path to client key for mTLS")
|
||||
|
||||
registerCmd.Flags().StringVarP(&clientId, "clientid", "c", "", "ClientID to register")
|
||||
_ = registerCmd.MarkFlagRequired("clientid")
|
||||
@@ -67,12 +74,37 @@ func init() {
|
||||
}
|
||||
|
||||
func newDexClient(hostAndPort string) (api.DexClient, error) {
|
||||
creds, err := credentials.NewClientTLSFromFile(caPath, "")
|
||||
//---------- TLS Setting -----------//
|
||||
clientCertificate, err := tls.LoadX509KeyPair(
|
||||
clientCert,
|
||||
clientKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load dex cert: %v", err)
|
||||
log.Fatalf("failed to read client cert or key: %s", err)
|
||||
}
|
||||
serverCertPool := x509.NewCertPool()
|
||||
bs, err := ioutil.ReadFile(caPath)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to read ca cert: %s", err)
|
||||
}
|
||||
|
||||
conn, err := grpc.Dial(hostAndPort, grpc.WithTransportCredentials(creds))
|
||||
ok := serverCertPool.AppendCertsFromPEM(bs)
|
||||
if !ok {
|
||||
log.Fatal("failed to append certs")
|
||||
}
|
||||
var transportCreds credentials.TransportCredentials
|
||||
if clientCert != "" {
|
||||
transportCreds = credentials.NewTLS(&tls.Config{
|
||||
Certificates: []tls.Certificate{clientCertificate},
|
||||
RootCAs: serverCertPool,
|
||||
})
|
||||
} else {
|
||||
transportCreds = credentials.NewTLS(&tls.Config{
|
||||
RootCAs: serverCertPool,
|
||||
})
|
||||
}
|
||||
|
||||
conn, err := grpc.Dial(hostAndPort, grpc.WithTransportCredentials(transportCreds))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial: %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user