|
@@ -20,6 +20,8 @@ import (
|
|
|
"crypto/x509/pkix"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
+ "io/fs"
|
|
|
+ "math/rand"
|
|
|
"os"
|
|
|
"path/filepath"
|
|
|
"sync"
|
|
@@ -34,6 +36,17 @@ const (
|
|
|
DefaultTLSKeyPaidID = "default"
|
|
|
)
|
|
|
|
|
|
+var (
|
|
|
+ certAutoReload bool
|
|
|
+)
|
|
|
+
|
|
|
+// SetCertAutoReloadMode sets if the certificate must be monitored for changes and
|
|
|
+// automatically reloaded
|
|
|
+func SetCertAutoReloadMode(val bool) {
|
|
|
+ certAutoReload = val
|
|
|
+ logger.Debug(logSender, "", "is certificate monitoring enabled? %t", certAutoReload)
|
|
|
+}
|
|
|
+
|
|
|
// TLSKeyPair defines the paths and the unique identifier for a TLS key pair
|
|
|
type TLSKeyPair struct {
|
|
|
Cert string
|
|
@@ -49,7 +62,9 @@ type CertManager struct {
|
|
|
sync.RWMutex
|
|
|
caCertificates []string
|
|
|
caRevocationLists []string
|
|
|
+ monitorList []string
|
|
|
certs map[string]*tls.Certificate
|
|
|
+ certsInfo map[string]fs.FileInfo
|
|
|
rootCAs *x509.CertPool
|
|
|
crls []*pkix.CertificateList
|
|
|
}
|
|
@@ -77,15 +92,18 @@ func (m *CertManager) loadCertificates() error {
|
|
|
}
|
|
|
newCert, err := tls.LoadX509KeyPair(keyPair.Cert, keyPair.Key)
|
|
|
if err != nil {
|
|
|
- logger.Warn(m.logSender, "", "unable to load X509 key pair, cert file %#v key file %#v error: %v",
|
|
|
+ logger.Warn(m.logSender, "", "unable to load X509 key pair, cert file %q key file %q error: %v",
|
|
|
keyPair.Cert, keyPair.Key, err)
|
|
|
return err
|
|
|
}
|
|
|
if _, ok := certs[keyPair.ID]; ok {
|
|
|
- return fmt.Errorf("TLS certificate with id %#v is duplicated", keyPair.ID)
|
|
|
+ return fmt.Errorf("TLS certificate with id %q is duplicated", keyPair.ID)
|
|
|
}
|
|
|
- logger.Debug(m.logSender, "", "TLS certificate %#v successfully loaded, id %v", keyPair.Cert, keyPair.ID)
|
|
|
+ logger.Debug(m.logSender, "", "TLS certificate %q successfully loaded, id %v", keyPair.Cert, keyPair.ID)
|
|
|
certs[keyPair.ID] = &newCert
|
|
|
+ if !util.Contains(m.monitorList, keyPair.Cert) {
|
|
|
+ m.monitorList = append(m.monitorList, keyPair.Cert)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
m.Lock()
|
|
@@ -116,7 +134,7 @@ func (m *CertManager) IsRevoked(crt *x509.Certificate, caCrt *x509.Certificate)
|
|
|
defer m.RUnlock()
|
|
|
|
|
|
if crt == nil || caCrt == nil {
|
|
|
- logger.Warn(m.logSender, "", "unable to verify crt %v ca crt %v", crt, caCrt)
|
|
|
+ logger.Warn(m.logSender, "", "unable to verify crt %v, ca crt %v", crt, caCrt)
|
|
|
return len(m.crls) > 0
|
|
|
}
|
|
|
|
|
@@ -143,24 +161,27 @@ func (m *CertManager) LoadCRLs() error {
|
|
|
|
|
|
for _, revocationList := range m.caRevocationLists {
|
|
|
if !util.IsFileInputValid(revocationList) {
|
|
|
- return fmt.Errorf("invalid root CA revocation list %#v", revocationList)
|
|
|
+ return fmt.Errorf("invalid root CA revocation list %q", revocationList)
|
|
|
}
|
|
|
if revocationList != "" && !filepath.IsAbs(revocationList) {
|
|
|
revocationList = filepath.Join(m.configDir, revocationList)
|
|
|
}
|
|
|
crlBytes, err := os.ReadFile(revocationList)
|
|
|
if err != nil {
|
|
|
- logger.Warn(m.logSender, "unable to read revocation list %#v", revocationList)
|
|
|
+ logger.Warn(m.logSender, "", "unable to read revocation list %q", revocationList)
|
|
|
return err
|
|
|
}
|
|
|
crl, err := x509.ParseCRL(crlBytes)
|
|
|
if err != nil {
|
|
|
- logger.Warn(m.logSender, "unable to parse revocation list %#v", revocationList)
|
|
|
+ logger.Warn(m.logSender, "", "unable to parse revocation list %q", revocationList)
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
- logger.Debug(m.logSender, "", "CRL %#v successfully loaded", revocationList)
|
|
|
+ logger.Debug(m.logSender, "", "CRL %q successfully loaded", revocationList)
|
|
|
crls = append(crls, crl)
|
|
|
+ if !util.Contains(m.monitorList, revocationList) {
|
|
|
+ m.monitorList = append(m.monitorList, revocationList)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
m.Lock()
|
|
@@ -190,7 +211,7 @@ func (m *CertManager) LoadRootCAs() error {
|
|
|
|
|
|
for _, rootCA := range m.caCertificates {
|
|
|
if !util.IsFileInputValid(rootCA) {
|
|
|
- return fmt.Errorf("invalid root CA certificate %#v", rootCA)
|
|
|
+ return fmt.Errorf("invalid root CA certificate %q", rootCA)
|
|
|
}
|
|
|
if rootCA != "" && !filepath.IsAbs(rootCA) {
|
|
|
rootCA = filepath.Join(m.configDir, rootCA)
|
|
@@ -200,9 +221,9 @@ func (m *CertManager) LoadRootCAs() error {
|
|
|
return err
|
|
|
}
|
|
|
if rootCAs.AppendCertsFromPEM(crt) {
|
|
|
- logger.Debug(m.logSender, "", "TLS certificate authority %#v successfully loaded", rootCA)
|
|
|
+ logger.Debug(m.logSender, "", "TLS certificate authority %q successfully loaded", rootCA)
|
|
|
} else {
|
|
|
- err := fmt.Errorf("unable to load TLS certificate authority %#v", rootCA)
|
|
|
+ err := fmt.Errorf("unable to load TLS certificate authority %q", rootCA)
|
|
|
logger.Warn(m.logSender, "", "%v", err)
|
|
|
return err
|
|
|
}
|
|
@@ -227,11 +248,45 @@ func (m *CertManager) SetCARevocationLists(caRevocationLists []string) {
|
|
|
m.caRevocationLists = util.RemoveDuplicates(caRevocationLists, true)
|
|
|
}
|
|
|
|
|
|
+func (m *CertManager) monitor() {
|
|
|
+ certsInfo := make(map[string]fs.FileInfo)
|
|
|
+
|
|
|
+ for _, crt := range m.monitorList {
|
|
|
+ info, err := os.Stat(crt)
|
|
|
+ if err != nil {
|
|
|
+ logger.Warn(m.logSender, "", "unable to stat certificate to monitor %q: %v", crt, err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ certsInfo[crt] = info
|
|
|
+ }
|
|
|
+
|
|
|
+ m.Lock()
|
|
|
+
|
|
|
+ isChanged := false
|
|
|
+ for k, oldInfo := range m.certsInfo {
|
|
|
+ newInfo, ok := certsInfo[k]
|
|
|
+ if ok {
|
|
|
+ if newInfo.Size() != oldInfo.Size() || newInfo.ModTime() != oldInfo.ModTime() {
|
|
|
+ logger.Debug(m.logSender, "", "change detected for certificate %q, reload required", k)
|
|
|
+ isChanged = true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ m.certsInfo = certsInfo
|
|
|
+
|
|
|
+ m.Unlock()
|
|
|
+
|
|
|
+ if isChanged {
|
|
|
+ m.Reload() //nolint:errcheck
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
// NewCertManager creates a new certificate manager
|
|
|
func NewCertManager(keyPairs []TLSKeyPair, configDir, logSender string) (*CertManager, error) {
|
|
|
manager := &CertManager{
|
|
|
keyPairs: keyPairs,
|
|
|
certs: make(map[string]*tls.Certificate),
|
|
|
+ certsInfo: make(map[string]fs.FileInfo),
|
|
|
configDir: configDir,
|
|
|
logSender: logSender,
|
|
|
}
|
|
@@ -239,5 +294,11 @@ func NewCertManager(keyPairs []TLSKeyPair, configDir, logSender string) (*CertMa
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
+ if certAutoReload {
|
|
|
+ randSecs := rand.Intn(59)
|
|
|
+ manager.monitor()
|
|
|
+ _, err := eventScheduler.AddFunc(fmt.Sprintf("@every 8h0m%ds", randSecs), manager.monitor)
|
|
|
+ util.PanicOnError(err)
|
|
|
+ }
|
|
|
return manager, nil
|
|
|
}
|