tls_auth.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. package v1
  2. import (
  3. "bytes"
  4. "crypto"
  5. "crypto/x509"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "net/url"
  10. "os"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/pkg/errors"
  14. log "github.com/sirupsen/logrus"
  15. "golang.org/x/crypto/ocsp"
  16. )
  17. type TLSAuth struct {
  18. AllowedOUs []string
  19. CrlPath string
  20. revokationCache map[string]cacheEntry
  21. cacheExpiration time.Duration
  22. logger *log.Entry
  23. }
  24. type cacheEntry struct {
  25. revoked bool
  26. err error
  27. timestamp time.Time
  28. }
  29. func (ta *TLSAuth) ocspQuery(server string, cert *x509.Certificate, issuer *x509.Certificate) (*ocsp.Response, error) {
  30. req, err := ocsp.CreateRequest(cert, issuer, &ocsp.RequestOptions{Hash: crypto.SHA256})
  31. if err != nil {
  32. ta.logger.Errorf("TLSAuth: error creating OCSP request: %s", err)
  33. return nil, err
  34. }
  35. httpRequest, err := http.NewRequest(http.MethodPost, server, bytes.NewBuffer(req))
  36. if err != nil {
  37. ta.logger.Error("TLSAuth: cannot create HTTP request for OCSP")
  38. return nil, err
  39. }
  40. ocspURL, err := url.Parse(server)
  41. if err != nil {
  42. ta.logger.Error("TLSAuth: cannot parse OCSP URL")
  43. return nil, err
  44. }
  45. httpRequest.Header.Add("Content-Type", "application/ocsp-request")
  46. httpRequest.Header.Add("Accept", "application/ocsp-response")
  47. httpRequest.Header.Add("host", ocspURL.Host)
  48. httpClient := &http.Client{}
  49. httpResponse, err := httpClient.Do(httpRequest)
  50. if err != nil {
  51. ta.logger.Error("TLSAuth: cannot send HTTP request to OCSP")
  52. return nil, err
  53. }
  54. defer httpResponse.Body.Close()
  55. output, err := io.ReadAll(httpResponse.Body)
  56. if err != nil {
  57. ta.logger.Error("TLSAuth: cannot read HTTP response from OCSP")
  58. return nil, err
  59. }
  60. ocspResponse, err := ocsp.ParseResponseForCert(output, cert, issuer)
  61. return ocspResponse, err
  62. }
  63. func (ta *TLSAuth) isExpired(cert *x509.Certificate) bool {
  64. now := time.Now().UTC()
  65. if cert.NotAfter.UTC().Before(now) {
  66. ta.logger.Errorf("TLSAuth: client certificate is expired (NotAfter: %s)", cert.NotAfter.UTC())
  67. return true
  68. }
  69. if cert.NotBefore.UTC().After(now) {
  70. ta.logger.Errorf("TLSAuth: client certificate is not yet valid (NotBefore: %s)", cert.NotBefore.UTC())
  71. return true
  72. }
  73. return false
  74. }
  75. func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) {
  76. if cert.OCSPServer == nil || (cert.OCSPServer != nil && len(cert.OCSPServer) == 0) {
  77. ta.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification")
  78. return false, nil
  79. }
  80. for _, server := range cert.OCSPServer {
  81. ocspResponse, err := ta.ocspQuery(server, cert, issuer)
  82. if err != nil {
  83. ta.logger.Errorf("TLSAuth: error querying OCSP server %s: %s", server, err)
  84. continue
  85. }
  86. switch ocspResponse.Status {
  87. case ocsp.Good:
  88. return false, nil
  89. case ocsp.Revoked:
  90. return true, fmt.Errorf("client certificate is revoked by server %s", server)
  91. case ocsp.Unknown:
  92. log.Debugf("unknow OCSP status for server %s", server)
  93. continue
  94. }
  95. }
  96. log.Infof("Could not get any valid OCSP response, assuming the cert is revoked")
  97. return true, nil
  98. }
  99. func (ta *TLSAuth) isCRLRevoked(cert *x509.Certificate) (bool, error) {
  100. if ta.CrlPath == "" {
  101. ta.logger.Warn("no crl_path, skipping CRL check")
  102. return false, nil
  103. }
  104. crlContent, err := os.ReadFile(ta.CrlPath)
  105. if err != nil {
  106. ta.logger.Warnf("could not read CRL file, skipping check: %s", err)
  107. return false, nil
  108. }
  109. crl, err := x509.ParseCRL(crlContent)
  110. if err != nil {
  111. ta.logger.Warnf("could not parse CRL file, skipping check: %s", err)
  112. return false, nil
  113. }
  114. if crl.HasExpired(time.Now().UTC()) {
  115. ta.logger.Warn("CRL has expired, will still validate the cert against it.")
  116. }
  117. for _, revoked := range crl.TBSCertList.RevokedCertificates {
  118. if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 {
  119. return true, fmt.Errorf("client certificate is revoked by CRL")
  120. }
  121. }
  122. return false, nil
  123. }
  124. func (ta *TLSAuth) isRevoked(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) {
  125. sn := cert.SerialNumber.String()
  126. if cacheValue, ok := ta.revokationCache[sn]; ok {
  127. if time.Now().UTC().Sub(cacheValue.timestamp) < ta.cacheExpiration {
  128. ta.logger.Debugf("TLSAuth: using cached value for cert %s: %t | %s", sn, cacheValue.revoked, cacheValue.err)
  129. return cacheValue.revoked, cacheValue.err
  130. } else {
  131. ta.logger.Debugf("TLSAuth: cached value expired, removing from cache")
  132. delete(ta.revokationCache, sn)
  133. }
  134. } else {
  135. ta.logger.Tracef("TLSAuth: no cached value for cert %s", sn)
  136. }
  137. revoked, err := ta.isOCSPRevoked(cert, issuer)
  138. if err != nil {
  139. ta.revokationCache[sn] = cacheEntry{
  140. revoked: revoked,
  141. err: err,
  142. timestamp: time.Now().UTC(),
  143. }
  144. return true, err
  145. }
  146. if revoked {
  147. ta.revokationCache[sn] = cacheEntry{
  148. revoked: revoked,
  149. err: err,
  150. timestamp: time.Now().UTC(),
  151. }
  152. return true, nil
  153. }
  154. revoked, err = ta.isCRLRevoked(cert)
  155. ta.revokationCache[sn] = cacheEntry{
  156. revoked: revoked,
  157. err: err,
  158. timestamp: time.Now().UTC(),
  159. }
  160. return revoked, err
  161. }
  162. func (ta *TLSAuth) isInvalid(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) {
  163. if ta.isExpired(cert) {
  164. return true, nil
  165. }
  166. revoked, err := ta.isRevoked(cert, issuer)
  167. if err != nil {
  168. //Fail securely, if we can't check the revokation status, let's consider the cert invalid
  169. //We may change this in the future based on users feedback, but this seems the most sensible thing to do
  170. return true, errors.Wrap(err, "could not check for client certification revokation status")
  171. }
  172. return revoked, nil
  173. }
  174. func (ta *TLSAuth) SetAllowedOu(allowedOus []string) error {
  175. for _, ou := range allowedOus {
  176. //disallow empty ou
  177. if ou == "" {
  178. return fmt.Errorf("empty ou isn't allowed")
  179. }
  180. //drop & warn on duplicate ou
  181. ok := true
  182. for _, validOu := range ta.AllowedOUs {
  183. if validOu == ou {
  184. ta.logger.Warningf("dropping duplicate ou %s", ou)
  185. ok = false
  186. }
  187. }
  188. if ok {
  189. ta.AllowedOUs = append(ta.AllowedOUs, ou)
  190. }
  191. }
  192. return nil
  193. }
  194. func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) {
  195. //Checks cert validity, Returns true + CN if client cert matches requested OU
  196. var clientCert *x509.Certificate
  197. if c.Request.TLS == nil || len(c.Request.TLS.PeerCertificates) == 0 {
  198. //do not error if it's not TLS or there are no peer certs
  199. return false, "", nil
  200. }
  201. if len(c.Request.TLS.VerifiedChains) > 0 {
  202. validOU := false
  203. clientCert = c.Request.TLS.VerifiedChains[0][0]
  204. for _, ou := range clientCert.Subject.OrganizationalUnit {
  205. for _, allowedOu := range ta.AllowedOUs {
  206. if allowedOu == ou {
  207. validOU = true
  208. break
  209. }
  210. }
  211. }
  212. if !validOU {
  213. return false, "", fmt.Errorf("client certificate OU (%v) doesn't match expected OU (%v)",
  214. clientCert.Subject.OrganizationalUnit, ta.AllowedOUs)
  215. }
  216. revoked, err := ta.isInvalid(clientCert, c.Request.TLS.VerifiedChains[0][1])
  217. if err != nil {
  218. ta.logger.Errorf("TLSAuth: error checking if client certificate is revoked: %s", err)
  219. return false, "", errors.Wrap(err, "could not check for client certification revokation status")
  220. }
  221. if revoked {
  222. return false, "", fmt.Errorf("client certificate is revoked")
  223. }
  224. ta.logger.Debugf("client OU %v is allowed vs required OU %v", clientCert.Subject.OrganizationalUnit, ta.AllowedOUs)
  225. return true, clientCert.Subject.CommonName, nil
  226. }
  227. return false, "", fmt.Errorf("no verified cert in request")
  228. }
  229. func NewTLSAuth(allowedOus []string, crlPath string, cacheExpiration time.Duration, logger *log.Entry) (*TLSAuth, error) {
  230. ta := &TLSAuth{
  231. revokationCache: map[string]cacheEntry{},
  232. cacheExpiration: cacheExpiration,
  233. CrlPath: crlPath,
  234. logger: logger,
  235. }
  236. err := ta.SetAllowedOu(allowedOus)
  237. if err != nil {
  238. return nil, err
  239. }
  240. return ta, nil
  241. }