first version mini-vault

This commit is contained in:
Simon Marsh 2022-07-25 15:53:20 +01:00
parent fbfae93b5e
commit bba6a7d04f
Signed by: burble
GPG Key ID: 0FCCD13AE1CF7ED8
7 changed files with 497 additions and 46 deletions

3
go.mod
View File

@ -3,6 +3,9 @@ module libvault
go 1.18
require (
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/sirupsen/logrus v1.9.0 // indirect
github.com/spf13/cobra v1.5.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect
)

9
go.sum
View File

@ -1,11 +1,20 @@
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/cobra v1.5.0 h1:X+jTBEBqF0bHN+9cSMgmfuvv2VHJ9ezmFNf9Y/XstYU=
github.com/spf13/cobra v1.5.0/go.mod h1:dWXEIy2H428czQCjInthrTRUg7yKbok+2Qi/yBIJoUM=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

75
mini-vault/mini-vault.go Normal file
View File

@ -0,0 +1,75 @@
//////////////////////////////////////////////////////////////////////////
package main
//////////////////////////////////////////////////////////////////////////
import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"os"
vault "libvault"
)
//////////////////////////////////////////////////////////////////////////
// everything starts here
func main() {
log.SetLevel(log.ErrorLevel)
cmdRoot := &cobra.Command{
Use: "mini-vault",
Short: "Hashicorp Vault helper proglet",
}
cmdRoot.PersistentFlags().StringVarP(&TokenFile, "token", "t", "", "Token file")
// configure subcommands
cmdToken := &cobra.Command{
Use: "token",
Short: "token manipulation",
}
cmdTokenRenew := &cobra.Command{
Use: "renew",
Short: "Renew Token",
Run: CmdTokenRenew,
}
cmdTokenRenew.Flags().StringVarP(&TokenTTL, "ttl", "l", "", "Renewal TTL")
cmdTLS := &cobra.Command{
Use: "tls",
Short: "TLS cert management",
}
cmdTLSRenew := &cobra.Command{
Use: "renew",
Short: "Renew TLS certificate",
Run: CmdTLSRenew,
}
cmdTLSRenew.Flags().StringVarP(&TLSCertPEM, "cert", "c", "", "Path to Certificate PEM")
cmdTLSRenew.MarkFlagRequired("cert")
cmdTLSRenew.Flags().StringVarP(&TLSKeyPEM, "key", "k", "", "Path to Key PEM")
cmdTLSRenew.MarkFlagRequired("key")
cmdTLSRenew.Flags().StringVarP(&TLSCAPEM, "ca", "a", "", "Path to CA PEM")
cmdTLSRenew.MarkFlagRequired("ca")
cmdTLSRenew.Flags().StringVarP(&TLSRequest, "request", "r", "", "Request Parameters")
cmdTLSRenew.MarkFlagRequired("request")
cmdRoot.AddCommand(cmdToken, cmdTLS)
cmdToken.AddCommand(cmdTokenRenew)
cmdTLS.AddCommand(cmdTLSRenew)
// set vault address from environment
va := os.Getenv("VAULT_ADDR")
if va != "" {
vault.VAULT_ADDR = va
}
// do it
cmdRoot.Execute()
}
//////////////////////////////////////////////////////////////////////////
// end of code

151
mini-vault/tls.go Normal file
View File

@ -0,0 +1,151 @@
//////////////////////////////////////////////////////////////////////////
package main
//////////////////////////////////////////////////////////////////////////
import (
// log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"encoding/json"
"encoding/pem"
"fmt"
"os"
vault "libvault"
)
//////////////////////////////////////////////////////////////////////////
var (
TLSCertPEM string
TLSKeyPEM string
TLSCAPEM string
TLSRequest string
)
//////////////////////////////////////////////////////////////////////////
// helper funcs
func loadRequest(filename string) *vault.TLSRequest {
content, err := os.ReadFile(filename)
if err != nil {
fmt.Printf("ERROR: failed to read request file (%s): %s\n",
filename, err)
os.Exit(1)
}
request := &vault.TLSRequest{}
if err := json.Unmarshal(content, request); err != nil {
fmt.Printf("ERROR: failed to parse request file (%s): %s\n",
filename, err)
os.Exit(1)
}
return request
}
func writePEM(filename string, blocks []*pem.Block) {
file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
fmt.Printf("ERROR: failed to create PEM file (%s): %s\n",
filename, err)
os.Exit(1)
}
defer file.Close()
for _, block := range blocks {
if err := pem.Encode(file, block); err != nil {
fmt.Printf("ERROR: failed to write PEM block (%s): %s\n",
filename, err)
os.Exit(1)
}
}
}
//////////////////////////////////////////////////////////////////////////
func CmdTLSRenew(cmd *cobra.Command, args []string) {
// load token and TLS request parameters
token := loadToken()
request := loadRequest(TLSRequest)
// load existing cert if it existed
if _, err := os.Stat(TLSCertPEM); err == nil {
fmt.Printf("Loading existing certificate: %s\n", TLSCertPEM)
data, err := os.ReadFile(TLSCertPEM)
if err != nil {
fmt.Printf("ERROR: failed to read existing certificate: %s\n", err)
os.Exit(1)
}
block, _ := pem.Decode(data)
if block == nil || block.Type != "CERTIFICATE" {
fmt.Println("ERROR: failed to parse PEM block")
os.Exit(1)
}
// check if certificate needed renewing
renew, err := request.CheckRenew(block.Bytes)
if err != nil {
fmt.Printf("ERROR: failed to check existing certificate: %s\n", err)
os.Exit(1)
}
if !renew {
// nothing to do
fmt.Println("Renewal not required, no action")
os.Exit(0)
}
}
// issue the cert
kc, err := request.Issue(token)
if err != nil {
fmt.Printf("ERROR: failed to issue TLS cert: %s\n", err)
os.Exit(1)
}
// write out the certs
fmt.Println("Success ! updating certs")
fmt.Printf(" - Certificate: %s\n", TLSCertPEM)
if err := os.WriteFile(
TLSCertPEM,
[]byte(kc.Certificate+"\n"+kc.IssuingCA),
0600,
); err != nil {
fmt.Printf("ERROR: failed to write certificate: %s\n", err)
os.Exit(1)
}
fmt.Printf(" - Private Key: %s\n", TLSKeyPEM)
if err := os.WriteFile(
TLSKeyPEM,
[]byte(kc.PrivateKey+"\n"),
0600,
); err != nil {
fmt.Printf("ERROR: failed to write key: %s\n", err)
os.Exit(1)
}
fmt.Printf(" - CA: %s\n", TLSCAPEM)
if err := os.WriteFile(
TLSCAPEM,
[]byte(kc.IssuingCA+"\n"),
0600,
); err != nil {
fmt.Printf("ERROR: failed to write CA: %s\n", err)
os.Exit(1)
}
os.Exit(0)
}
//////////////////////////////////////////////////////////////////////////
// end of code

95
mini-vault/token.go Normal file
View File

@ -0,0 +1,95 @@
//////////////////////////////////////////////////////////////////////////
package main
//////////////////////////////////////////////////////////////////////////
import (
// log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"fmt"
"os"
"time"
vault "libvault"
)
//////////////////////////////////////////////////////////////////////////
var (
TokenFile string
TokenTTL string
)
//////////////////////////////////////////////////////////////////////////
// helper funcs
func loadToken() *vault.Token {
var token *vault.Token
if os.Getenv("VAULT_TOKEN") == "" {
// if no env set, read token from a file
filename := TokenFile
if filename == "" {
filename = vault.VAULT_TOKEN_FILE
}
fmt.Printf("Reading token from file: %s\n", filename)
var err error
token, err = vault.NewTokenFromFile(filename)
if err != nil {
fmt.Printf("ERROR: Failed to read token from file: %s\n", err)
os.Exit(1)
}
} else {
// obtain token from environment
token = &vault.Token{}
token.Token = os.Getenv("VAULT_TOKEN")
}
return token
}
//////////////////////////////////////////////////////////////////////////
func CmdTokenRenew(cmd *cobra.Command, args []string) {
token := loadToken()
// set the renewal duration
var ttl time.Duration
if TokenTTL == "" {
ttl = vault.VAULT_TTL
} else {
var err error
ttl, err = time.ParseDuration(TokenTTL)
if err != nil {
fmt.Printf("ERROR: failed to parse TTL: %s\n", err)
os.Exit(1)
}
}
fmt.Printf("Renewing token for %s\n", ttl.String())
if err := token.Renew(ttl); err != nil {
fmt.Printf("ERROR: Failed to renew token: %s\n", err)
os.Exit(1)
}
expiry, err := token.Expires()
if err != nil {
fmt.Printf("ERROR: renewed token, but couldn't get new expiry date: %s\n", err)
os.Exit(1)
}
fmt.Printf("New token expiry date: %s\n", expiry.String())
os.Exit(0)
}
//////////////////////////////////////////////////////////////////////////
// end of code

View File

@ -1,6 +1,4 @@
//////////////////////////////////////////////////////////////////////////
// burble.dn42 services
//////////////////////////////////////////////////////////////////////////
package main

208
tls.go
View File

@ -8,68 +8,111 @@ import (
"crypto/tls"
"crypto/x509"
log "github.com/sirupsen/logrus"
"math/rand"
"sync"
"time"
)
//////////////////////////////////////////////////////////////////////////
type TLSRequest struct {
CommonName string `json:"common_name"`
AltNames string `json:"alt_names"`
IPSANs string `json:"ip_sans"`
URISANs string `json:"uri_sans"`
OtherSANs string `json:"other_sans"`
TTL time.Duration `json:"ttl"`
CommonName string `json:"common_name"`
AltNames string `json:"alt_names"`
IPSANs string `json:"ip_sans"`
URISANs string `json:"uri_sans"`
OtherSANs string `json:"other_sans"`
TTL time.Duration `json:"ttl"`
RenewPeriod time.Duration `json:"-"`
stop chan bool `json:"-"`
done sync.WaitGroup `json:"-"`
}
type TLSKeyCert struct {
Certificate string `json:"certificate"`
IssuingCA string `json:"issuing_ca"`
CAChain []string `json:"ca_chain"`
PrivateKey string `json:"private_key"`
}
//////////////////////////////////////////////////////////////////////////
// issue a new certificate based on the request
func (req *TLSRequest) Issue(t *Token) (*TLSKeyCert, error) {
log.WithFields(log.Fields{
"CommonName": req.CommonName,
}).Debug("libvault: Issuing certificate")
// default the TTL if required
if req.TTL == 0 {
req.TTL = VAULT_TTL
}
response := &struct {
Data *TLSKeyCert `json:"data"`
}{Data: &TLSKeyCert{}}
if err := vault.POST(t, "/burble.dn42/pki/sites/issue/tls",
req, response); err != nil {
log.WithFields(log.Fields{
"request": req,
"error": err,
}).Error("libvault: vault failed to issue certificate")
return nil, err
}
return response.Data, nil
}
//////////////////////////////////////////////////////////////////////////
// check if a certificate needs renewing
func (req *TLSRequest) CheckRenew(cdata []byte) (bool, error) {
// default the renew period
if req.RenewPeriod == 0 {
req.RenewPeriod = VAULT_RENEW_PERIOD
}
// parse the certificate
cert, err := x509.ParseCertificate(cdata)
if err != nil {
log.WithFields(log.Fields{
"error": err,
}).Error("libvault: failed to parse tls certificate")
return false, err
}
// and check the ttl
ttl := cert.NotAfter.Sub(time.Now())
return (ttl.Seconds() < req.RenewPeriod.Seconds()), nil
}
//////////////////////////////////////////////////////////////////////////
func (req *TLSRequest) Renew(t *Token, config *tls.Config) (bool, error) {
// if there is an existing certificate, check if it needs renewing
if len(config.Certificates) > 0 {
cert, err := x509.ParseCertificate(config.Certificates[0].Certificate[0])
renew, err := req.CheckRenew(config.Certificates[0].Certificate[0])
if err != nil {
log.WithFields(log.Fields{
"error": err,
}).Error("libvault: failed to parse existing tls certificate")
return false, err
}
ttl := cert.NotAfter.Sub(time.Now())
if ttl.Seconds() > VAULT_RENEW_PERIOD.Seconds() {
if !renew {
// nothing to see here, move along
log.WithFields(log.Fields{
"CommonName": req.CommonName,
"ttl": ttl.String(),
"ttl": req.TTL.String(),
}).Info("libvault: TLS certificate renewal not required")
return false, nil
}
}
// default the TTL if it wasn't previously set
if req.TTL == 0 {
req.TTL = VAULT_TTL
}
// issue a new key pair
log.WithFields(log.Fields{
"CommonName": req.CommonName,
}).Debug("libvault: renewing TLS certificate")
response := &struct {
Data struct {
Certificate string `json:"certificate"`
IssuingCA string `json:"issuing_ca"`
CAChain []string `json:"ca_chain"`
PrivateKey string `json:"private_key"`
} `json:"data"`
}{}
if err := vault.POST(t, "/burble.dn42/pki/sites/issue/tls", req, response); err != nil {
log.WithFields(log.Fields{
"token": t,
"request": req,
"error": err,
}).Error("libvault: vault failed to renew certificate")
// issue a new cert
kc, err := req.Issue(t)
if err != nil {
log.Error("libvault: certificate renewal failed")
return false, err
}
@ -77,19 +120,19 @@ func (req *TLSRequest) Renew(t *Token, config *tls.Config) (bool, error) {
config.ServerName = req.CommonName
config.RootCAs = x509.NewCertPool()
config.RootCAs.AppendCertsFromPEM([]byte(response.Data.IssuingCA))
for _, ca := range response.Data.CAChain {
config.RootCAs.AppendCertsFromPEM([]byte(kc.IssuingCA))
for _, ca := range kc.CAChain {
config.RootCAs.AppendCertsFromPEM([]byte(ca))
}
cert, err := tls.X509KeyPair(
[]byte(response.Data.Certificate),
[]byte(response.Data.PrivateKey),
[]byte(kc.Certificate),
[]byte(kc.PrivateKey),
)
if err != nil {
log.WithFields(log.Fields{
"cert": response.Data.Certificate,
"key": response.Data.PrivateKey,
"cert": kc.Certificate,
"key": kc.PrivateKey,
"error": err,
}).Error("libvault: unable to load x509 cert pair")
return false, err
@ -103,5 +146,82 @@ func (req *TLSRequest) Renew(t *Token, config *tls.Config) (bool, error) {
return true, nil
}
//////////////////////////////////////////////////////////////////////////
// auto renew
func (req *TLSRequest) AutoRenew(
t *Token,
config *tls.Config,
callback func(config *tls.Config),
) {
log.Info("Starting TLS auto renew")
req.stop = make(chan bool)
req.done.Add(1)
rgen := rand.New(rand.NewSource(time.Now().UnixNano()))
// every day
ticker := time.NewTicker(24 * time.Hour)
go func() {
defer req.done.Done()
for {
for i := 0; i < 3; i++ {
// attempt to renew
updated, err := req.Renew(t, config)
if err != nil {
// if renew fails then sleep for a while and try again
sleep := time.Duration(rgen.Intn(300) + 600)
log.WithFields(log.Fields{
"attempt": i,
"sleep": sleep,
"error": err,
}).Error("libvault: auto renew failed")
time.Sleep(sleep * time.Second)
} else {
// no error
if updated {
if callback != nil {
callback(config)
}
}
break
}
}
// wait for the timer expiry
select {
case <-req.stop:
return
case <-ticker.C:
// sleep for a random period before going round loop
// to prevent stampeeding herds
sleep := time.Duration(rgen.Intn(300))
log.WithFields(log.Fields{
"sleep": sleep,
}).Debug("libvault: TLS auto renew")
time.Sleep(sleep * time.Second)
}
}
}()
}
func (req *TLSRequest) Shutdown() {
log.Info("Stopping TLS AutoRenew")
req.stop <- true
req.done.Wait()
log.Info("TLS AutoRenew Stopped")
}
//////////////////////////////////////////////////////////////////////////
// end of file