diff --git a/go.mod b/go.mod index b1be642..bc3f8c9 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,9 @@ module libvault go 1.18 require ( + github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/sirupsen/logrus v1.9.0 // indirect + github.com/spf13/cobra v1.5.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect ) diff --git a/go.sum b/go.sum index 64cb7ad..f621d46 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,20 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/spf13/cobra v1.5.0 h1:X+jTBEBqF0bHN+9cSMgmfuvv2VHJ9ezmFNf9Y/XstYU= +github.com/spf13/cobra v1.5.0/go.mod h1:dWXEIy2H428czQCjInthrTRUg7yKbok+2Qi/yBIJoUM= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/mini-vault/mini-vault.go b/mini-vault/mini-vault.go new file mode 100644 index 0000000..d3cb844 --- /dev/null +++ b/mini-vault/mini-vault.go @@ -0,0 +1,75 @@ +////////////////////////////////////////////////////////////////////////// + +package main + +////////////////////////////////////////////////////////////////////////// + +import ( + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "os" + + vault "libvault" +) + +////////////////////////////////////////////////////////////////////////// +// everything starts here + +func main() { + log.SetLevel(log.ErrorLevel) + + cmdRoot := &cobra.Command{ + Use: "mini-vault", + Short: "Hashicorp Vault helper proglet", + } + cmdRoot.PersistentFlags().StringVarP(&TokenFile, "token", "t", "", "Token file") + + // configure subcommands + cmdToken := &cobra.Command{ + Use: "token", + Short: "token manipulation", + } + + cmdTokenRenew := &cobra.Command{ + Use: "renew", + Short: "Renew Token", + Run: CmdTokenRenew, + } + cmdTokenRenew.Flags().StringVarP(&TokenTTL, "ttl", "l", "", "Renewal TTL") + + cmdTLS := &cobra.Command{ + Use: "tls", + Short: "TLS cert management", + } + + cmdTLSRenew := &cobra.Command{ + Use: "renew", + Short: "Renew TLS certificate", + Run: CmdTLSRenew, + } + cmdTLSRenew.Flags().StringVarP(&TLSCertPEM, "cert", "c", "", "Path to Certificate PEM") + cmdTLSRenew.MarkFlagRequired("cert") + cmdTLSRenew.Flags().StringVarP(&TLSKeyPEM, "key", "k", "", "Path to Key PEM") + cmdTLSRenew.MarkFlagRequired("key") + cmdTLSRenew.Flags().StringVarP(&TLSCAPEM, "ca", "a", "", "Path to CA PEM") + cmdTLSRenew.MarkFlagRequired("ca") + cmdTLSRenew.Flags().StringVarP(&TLSRequest, "request", "r", "", "Request Parameters") + cmdTLSRenew.MarkFlagRequired("request") + + cmdRoot.AddCommand(cmdToken, cmdTLS) + cmdToken.AddCommand(cmdTokenRenew) + cmdTLS.AddCommand(cmdTLSRenew) + + // set vault address from environment + va := os.Getenv("VAULT_ADDR") + if va != "" { + vault.VAULT_ADDR = va + } + + // do it + cmdRoot.Execute() +} + +////////////////////////////////////////////////////////////////////////// +// end of code diff --git a/mini-vault/tls.go b/mini-vault/tls.go new file mode 100644 index 0000000..b73f046 --- /dev/null +++ b/mini-vault/tls.go @@ -0,0 +1,151 @@ +////////////////////////////////////////////////////////////////////////// + +package main + +////////////////////////////////////////////////////////////////////////// + +import ( + // log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "encoding/json" + "encoding/pem" + "fmt" + "os" + + vault "libvault" +) + +////////////////////////////////////////////////////////////////////////// + +var ( + TLSCertPEM string + TLSKeyPEM string + TLSCAPEM string + TLSRequest string +) + +////////////////////////////////////////////////////////////////////////// +// helper funcs + +func loadRequest(filename string) *vault.TLSRequest { + + content, err := os.ReadFile(filename) + if err != nil { + fmt.Printf("ERROR: failed to read request file (%s): %s\n", + filename, err) + os.Exit(1) + } + + request := &vault.TLSRequest{} + if err := json.Unmarshal(content, request); err != nil { + fmt.Printf("ERROR: failed to parse request file (%s): %s\n", + filename, err) + os.Exit(1) + } + + return request +} + +func writePEM(filename string, blocks []*pem.Block) { + + file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, 0600) + if err != nil { + fmt.Printf("ERROR: failed to create PEM file (%s): %s\n", + filename, err) + os.Exit(1) + } + defer file.Close() + + for _, block := range blocks { + if err := pem.Encode(file, block); err != nil { + fmt.Printf("ERROR: failed to write PEM block (%s): %s\n", + filename, err) + os.Exit(1) + } + } + +} + +////////////////////////////////////////////////////////////////////////// + +func CmdTLSRenew(cmd *cobra.Command, args []string) { + + // load token and TLS request parameters + token := loadToken() + request := loadRequest(TLSRequest) + + // load existing cert if it existed + if _, err := os.Stat(TLSCertPEM); err == nil { + fmt.Printf("Loading existing certificate: %s\n", TLSCertPEM) + + data, err := os.ReadFile(TLSCertPEM) + if err != nil { + fmt.Printf("ERROR: failed to read existing certificate: %s\n", err) + os.Exit(1) + } + block, _ := pem.Decode(data) + if block == nil || block.Type != "CERTIFICATE" { + fmt.Println("ERROR: failed to parse PEM block") + os.Exit(1) + } + + // check if certificate needed renewing + renew, err := request.CheckRenew(block.Bytes) + if err != nil { + fmt.Printf("ERROR: failed to check existing certificate: %s\n", err) + os.Exit(1) + } + + if !renew { + // nothing to do + fmt.Println("Renewal not required, no action") + os.Exit(0) + } + } + + // issue the cert + kc, err := request.Issue(token) + if err != nil { + fmt.Printf("ERROR: failed to issue TLS cert: %s\n", err) + os.Exit(1) + } + + // write out the certs + fmt.Println("Success ! updating certs") + + fmt.Printf(" - Certificate: %s\n", TLSCertPEM) + if err := os.WriteFile( + TLSCertPEM, + []byte(kc.Certificate+"\n"+kc.IssuingCA), + 0600, + ); err != nil { + fmt.Printf("ERROR: failed to write certificate: %s\n", err) + os.Exit(1) + } + + fmt.Printf(" - Private Key: %s\n", TLSKeyPEM) + if err := os.WriteFile( + TLSKeyPEM, + []byte(kc.PrivateKey+"\n"), + 0600, + ); err != nil { + fmt.Printf("ERROR: failed to write key: %s\n", err) + os.Exit(1) + } + + fmt.Printf(" - CA: %s\n", TLSCAPEM) + if err := os.WriteFile( + TLSCAPEM, + []byte(kc.IssuingCA+"\n"), + 0600, + ); err != nil { + fmt.Printf("ERROR: failed to write CA: %s\n", err) + os.Exit(1) + } + + os.Exit(0) +} + +////////////////////////////////////////////////////////////////////////// +// end of code diff --git a/mini-vault/token.go b/mini-vault/token.go new file mode 100644 index 0000000..d6f1f06 --- /dev/null +++ b/mini-vault/token.go @@ -0,0 +1,95 @@ +////////////////////////////////////////////////////////////////////////// + +package main + +////////////////////////////////////////////////////////////////////////// + +import ( + // log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "fmt" + "os" + "time" + + vault "libvault" +) + +////////////////////////////////////////////////////////////////////////// + +var ( + TokenFile string + TokenTTL string +) + +////////////////////////////////////////////////////////////////////////// +// helper funcs + +func loadToken() *vault.Token { + + var token *vault.Token + + if os.Getenv("VAULT_TOKEN") == "" { + // if no env set, read token from a file + + filename := TokenFile + if filename == "" { + filename = vault.VAULT_TOKEN_FILE + } + fmt.Printf("Reading token from file: %s\n", filename) + + var err error + token, err = vault.NewTokenFromFile(filename) + if err != nil { + fmt.Printf("ERROR: Failed to read token from file: %s\n", err) + os.Exit(1) + } + + } else { + // obtain token from environment + + token = &vault.Token{} + token.Token = os.Getenv("VAULT_TOKEN") + + } + + return token +} + +////////////////////////////////////////////////////////////////////////// + +func CmdTokenRenew(cmd *cobra.Command, args []string) { + + token := loadToken() + + // set the renewal duration + var ttl time.Duration + if TokenTTL == "" { + ttl = vault.VAULT_TTL + } else { + var err error + ttl, err = time.ParseDuration(TokenTTL) + if err != nil { + fmt.Printf("ERROR: failed to parse TTL: %s\n", err) + os.Exit(1) + } + } + + fmt.Printf("Renewing token for %s\n", ttl.String()) + if err := token.Renew(ttl); err != nil { + fmt.Printf("ERROR: Failed to renew token: %s\n", err) + os.Exit(1) + } + + expiry, err := token.Expires() + if err != nil { + fmt.Printf("ERROR: renewed token, but couldn't get new expiry date: %s\n", err) + os.Exit(1) + } + + fmt.Printf("New token expiry date: %s\n", expiry.String()) + os.Exit(0) +} + +////////////////////////////////////////////////////////////////////////// +// end of code diff --git a/test/test.go b/test/test.go index 339f439..cf91337 100644 --- a/test/test.go +++ b/test/test.go @@ -1,6 +1,4 @@ ////////////////////////////////////////////////////////////////////////// -// burble.dn42 services -////////////////////////////////////////////////////////////////////////// package main diff --git a/tls.go b/tls.go index 4a0f5e6..86bde53 100644 --- a/tls.go +++ b/tls.go @@ -8,68 +8,111 @@ import ( "crypto/tls" "crypto/x509" log "github.com/sirupsen/logrus" + "math/rand" + "sync" "time" ) ////////////////////////////////////////////////////////////////////////// type TLSRequest struct { - CommonName string `json:"common_name"` - AltNames string `json:"alt_names"` - IPSANs string `json:"ip_sans"` - URISANs string `json:"uri_sans"` - OtherSANs string `json:"other_sans"` - TTL time.Duration `json:"ttl"` + CommonName string `json:"common_name"` + AltNames string `json:"alt_names"` + IPSANs string `json:"ip_sans"` + URISANs string `json:"uri_sans"` + OtherSANs string `json:"other_sans"` + TTL time.Duration `json:"ttl"` + RenewPeriod time.Duration `json:"-"` + stop chan bool `json:"-"` + done sync.WaitGroup `json:"-"` +} + +type TLSKeyCert struct { + Certificate string `json:"certificate"` + IssuingCA string `json:"issuing_ca"` + CAChain []string `json:"ca_chain"` + PrivateKey string `json:"private_key"` +} + +////////////////////////////////////////////////////////////////////////// +// issue a new certificate based on the request + +func (req *TLSRequest) Issue(t *Token) (*TLSKeyCert, error) { + + log.WithFields(log.Fields{ + "CommonName": req.CommonName, + }).Debug("libvault: Issuing certificate") + + // default the TTL if required + if req.TTL == 0 { + req.TTL = VAULT_TTL + } + + response := &struct { + Data *TLSKeyCert `json:"data"` + }{Data: &TLSKeyCert{}} + + if err := vault.POST(t, "/burble.dn42/pki/sites/issue/tls", + req, response); err != nil { + log.WithFields(log.Fields{ + "request": req, + "error": err, + }).Error("libvault: vault failed to issue certificate") + return nil, err + } + + return response.Data, nil +} + +////////////////////////////////////////////////////////////////////////// +// check if a certificate needs renewing + +func (req *TLSRequest) CheckRenew(cdata []byte) (bool, error) { + + // default the renew period + if req.RenewPeriod == 0 { + req.RenewPeriod = VAULT_RENEW_PERIOD + } + + // parse the certificate + cert, err := x509.ParseCertificate(cdata) + if err != nil { + log.WithFields(log.Fields{ + "error": err, + }).Error("libvault: failed to parse tls certificate") + return false, err + } + + // and check the ttl + ttl := cert.NotAfter.Sub(time.Now()) + return (ttl.Seconds() < req.RenewPeriod.Seconds()), nil } ////////////////////////////////////////////////////////////////////////// func (req *TLSRequest) Renew(t *Token, config *tls.Config) (bool, error) { + // if there is an existing certificate, check if it needs renewing if len(config.Certificates) > 0 { - cert, err := x509.ParseCertificate(config.Certificates[0].Certificate[0]) + renew, err := req.CheckRenew(config.Certificates[0].Certificate[0]) if err != nil { - log.WithFields(log.Fields{ - "error": err, - }).Error("libvault: failed to parse existing tls certificate") + return false, err } - ttl := cert.NotAfter.Sub(time.Now()) - if ttl.Seconds() > VAULT_RENEW_PERIOD.Seconds() { + if !renew { // nothing to see here, move along log.WithFields(log.Fields{ "CommonName": req.CommonName, - "ttl": ttl.String(), + "ttl": req.TTL.String(), }).Info("libvault: TLS certificate renewal not required") return false, nil } } - // default the TTL if it wasn't previously set - if req.TTL == 0 { - req.TTL = VAULT_TTL - } - - // issue a new key pair - log.WithFields(log.Fields{ - "CommonName": req.CommonName, - }).Debug("libvault: renewing TLS certificate") - - response := &struct { - Data struct { - Certificate string `json:"certificate"` - IssuingCA string `json:"issuing_ca"` - CAChain []string `json:"ca_chain"` - PrivateKey string `json:"private_key"` - } `json:"data"` - }{} - - if err := vault.POST(t, "/burble.dn42/pki/sites/issue/tls", req, response); err != nil { - log.WithFields(log.Fields{ - "token": t, - "request": req, - "error": err, - }).Error("libvault: vault failed to renew certificate") + // issue a new cert + kc, err := req.Issue(t) + if err != nil { + log.Error("libvault: certificate renewal failed") return false, err } @@ -77,19 +120,19 @@ func (req *TLSRequest) Renew(t *Token, config *tls.Config) (bool, error) { config.ServerName = req.CommonName config.RootCAs = x509.NewCertPool() - config.RootCAs.AppendCertsFromPEM([]byte(response.Data.IssuingCA)) - for _, ca := range response.Data.CAChain { + config.RootCAs.AppendCertsFromPEM([]byte(kc.IssuingCA)) + for _, ca := range kc.CAChain { config.RootCAs.AppendCertsFromPEM([]byte(ca)) } cert, err := tls.X509KeyPair( - []byte(response.Data.Certificate), - []byte(response.Data.PrivateKey), + []byte(kc.Certificate), + []byte(kc.PrivateKey), ) if err != nil { log.WithFields(log.Fields{ - "cert": response.Data.Certificate, - "key": response.Data.PrivateKey, + "cert": kc.Certificate, + "key": kc.PrivateKey, "error": err, }).Error("libvault: unable to load x509 cert pair") return false, err @@ -103,5 +146,82 @@ func (req *TLSRequest) Renew(t *Token, config *tls.Config) (bool, error) { return true, nil } +////////////////////////////////////////////////////////////////////////// +// auto renew + +func (req *TLSRequest) AutoRenew( + t *Token, + config *tls.Config, + callback func(config *tls.Config), +) { + + log.Info("Starting TLS auto renew") + req.stop = make(chan bool) + req.done.Add(1) + rgen := rand.New(rand.NewSource(time.Now().UnixNano())) + + // every day + ticker := time.NewTicker(24 * time.Hour) + + go func() { + defer req.done.Done() + + for { + for i := 0; i < 3; i++ { + // attempt to renew + updated, err := req.Renew(t, config) + if err != nil { + + // if renew fails then sleep for a while and try again + sleep := time.Duration(rgen.Intn(300) + 600) + + log.WithFields(log.Fields{ + "attempt": i, + "sleep": sleep, + "error": err, + }).Error("libvault: auto renew failed") + + time.Sleep(sleep * time.Second) + + } else { + // no error + + if updated { + if callback != nil { + callback(config) + } + } + break + } + } + + // wait for the timer expiry + select { + case <-req.stop: + return + case <-ticker.C: + // sleep for a random period before going round loop + // to prevent stampeeding herds + + sleep := time.Duration(rgen.Intn(300)) + + log.WithFields(log.Fields{ + "sleep": sleep, + }).Debug("libvault: TLS auto renew") + + time.Sleep(sleep * time.Second) + } + } + }() + +} + +func (req *TLSRequest) Shutdown() { + log.Info("Stopping TLS AutoRenew") + req.stop <- true + req.done.Wait() + log.Info("TLS AutoRenew Stopped") +} + ////////////////////////////////////////////////////////////////////////// // end of file