tls_auth.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. package v1
  2. import (
  3. "bytes"
  4. "crypto"
  5. "crypto/x509"
  6. "fmt"
  7. "io/ioutil"
  8. "net/http"
  9. "net/url"
  10. "time"
  11. "github.com/gin-gonic/gin"
  12. "github.com/pkg/errors"
  13. log "github.com/sirupsen/logrus"
  14. "golang.org/x/crypto/ocsp"
  15. )
  16. type TLSAuth struct {
  17. AllowedOUs []string
  18. CrlPath string
  19. revokationCache map[string]cacheEntry
  20. cacheExpiration time.Duration
  21. logger *log.Entry
  22. }
  23. type cacheEntry struct {
  24. revoked bool
  25. err error
  26. timestamp time.Time
  27. }
  28. func (ta *TLSAuth) ocspQuery(server string, cert *x509.Certificate, issuer *x509.Certificate) (*ocsp.Response, error) {
  29. req, err := ocsp.CreateRequest(cert, issuer, &ocsp.RequestOptions{Hash: crypto.SHA256})
  30. if err != nil {
  31. ta.logger.Errorf("TLSAuth: error creating OCSP request: %s", err)
  32. return nil, err
  33. }
  34. httpRequest, err := http.NewRequest(http.MethodPost, server, bytes.NewBuffer(req))
  35. if err != nil {
  36. ta.logger.Error("TLSAuth: cannot create HTTP request for OCSP")
  37. return nil, err
  38. }
  39. ocspURL, err := url.Parse(server)
  40. if err != nil {
  41. ta.logger.Error("TLSAuth: cannot parse OCSP URL")
  42. return nil, err
  43. }
  44. httpRequest.Header.Add("Content-Type", "application/ocsp-request")
  45. httpRequest.Header.Add("Accept", "application/ocsp-response")
  46. httpRequest.Header.Add("host", ocspURL.Host)
  47. httpClient := &http.Client{}
  48. httpResponse, err := httpClient.Do(httpRequest)
  49. if err != nil {
  50. ta.logger.Error("TLSAuth: cannot send HTTP request to OCSP")
  51. return nil, err
  52. }
  53. defer httpResponse.Body.Close()
  54. output, err := ioutil.ReadAll(httpResponse.Body)
  55. if err != nil {
  56. ta.logger.Error("TLSAuth: cannot read HTTP response from OCSP")
  57. return nil, err
  58. }
  59. ocspResponse, err := ocsp.ParseResponseForCert(output, cert, issuer)
  60. return ocspResponse, err
  61. }
  62. func (ta *TLSAuth) isExpired(cert *x509.Certificate) bool {
  63. now := time.Now().UTC()
  64. if cert.NotAfter.UTC().Before(now) {
  65. ta.logger.Errorf("TLSAuth: client certificate is expired (NotAfter: %s)", cert.NotAfter.UTC())
  66. return true
  67. }
  68. if cert.NotBefore.UTC().After(now) {
  69. ta.logger.Errorf("TLSAuth: client certificate is not yet valid (NotBefore: %s)", cert.NotBefore.UTC())
  70. return true
  71. }
  72. return false
  73. }
  74. func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) {
  75. if cert.OCSPServer == nil || (cert.OCSPServer != nil && len(cert.OCSPServer) == 0) {
  76. ta.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification")
  77. return false, nil
  78. }
  79. for _, server := range cert.OCSPServer {
  80. ocspResponse, err := ta.ocspQuery(server, cert, issuer)
  81. if err != nil {
  82. ta.logger.Errorf("TLSAuth: error querying OCSP server %s: %s", server, err)
  83. continue
  84. }
  85. switch ocspResponse.Status {
  86. case ocsp.Good:
  87. return false, nil
  88. case ocsp.Revoked:
  89. return true, fmt.Errorf("client certificate is revoked by server %s", server)
  90. case ocsp.Unknown:
  91. log.Debugf("unknow OCSP status for server %s", server)
  92. continue
  93. }
  94. }
  95. log.Infof("Could not get any valid OCSP response, assuming the cert is revoked")
  96. return true, nil
  97. }
  98. func (ta *TLSAuth) isCRLRevoked(cert *x509.Certificate) (bool, error) {
  99. if ta.CrlPath == "" {
  100. ta.logger.Warn("no crl_path, skipping CRL check")
  101. return false, nil
  102. }
  103. crlContent, err := ioutil.ReadFile(ta.CrlPath)
  104. if err != nil {
  105. ta.logger.Warnf("could not read CRL file, skipping check: %s", err)
  106. return false, nil
  107. }
  108. crl, err := x509.ParseCRL(crlContent)
  109. if err != nil {
  110. ta.logger.Warnf("could not parse CRL file, skipping check: %s", err)
  111. return false, nil
  112. }
  113. if crl.HasExpired(time.Now().UTC()) {
  114. ta.logger.Warn("CRL has expired, will still validate the cert against it.")
  115. }
  116. for _, revoked := range crl.TBSCertList.RevokedCertificates {
  117. if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 {
  118. return true, fmt.Errorf("client certificate is revoked by CRL")
  119. }
  120. }
  121. return false, nil
  122. }
  123. func (ta *TLSAuth) isRevoked(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) {
  124. sn := cert.SerialNumber.String()
  125. if cacheValue, ok := ta.revokationCache[sn]; ok {
  126. if time.Now().UTC().Sub(cacheValue.timestamp) < ta.cacheExpiration {
  127. ta.logger.Debugf("TLSAuth: using cached value for cert %s: %t | %s", sn, cacheValue.revoked, cacheValue.err)
  128. return cacheValue.revoked, cacheValue.err
  129. } else {
  130. ta.logger.Debugf("TLSAuth: cached value expired, removing from cache")
  131. delete(ta.revokationCache, sn)
  132. }
  133. } else {
  134. ta.logger.Tracef("TLSAuth: no cached value for cert %s", sn)
  135. }
  136. revoked, err := ta.isOCSPRevoked(cert, issuer)
  137. if err != nil {
  138. ta.revokationCache[sn] = cacheEntry{
  139. revoked: revoked,
  140. err: err,
  141. timestamp: time.Now().UTC(),
  142. }
  143. return true, err
  144. }
  145. if revoked {
  146. ta.revokationCache[sn] = cacheEntry{
  147. revoked: revoked,
  148. err: err,
  149. timestamp: time.Now().UTC(),
  150. }
  151. return true, nil
  152. }
  153. revoked, err = ta.isCRLRevoked(cert)
  154. ta.revokationCache[sn] = cacheEntry{
  155. revoked: revoked,
  156. err: err,
  157. timestamp: time.Now().UTC(),
  158. }
  159. return revoked, err
  160. }
  161. func (ta *TLSAuth) isInvalid(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) {
  162. if ta.isExpired(cert) {
  163. return true, nil
  164. }
  165. revoked, err := ta.isRevoked(cert, issuer)
  166. if err != nil {
  167. //Fail securely, if we can't check the revokation status, let's consider the cert invalid
  168. //We may change this in the future based on users feedback, but this seems the most sensible thing to do
  169. return true, errors.Wrap(err, "could not check for client certification revokation status")
  170. }
  171. return revoked, nil
  172. }
  173. func (ta *TLSAuth) SetAllowedOu(allowedOus []string) error {
  174. for _, ou := range allowedOus {
  175. //disallow empty ou
  176. if ou == "" {
  177. return fmt.Errorf("empty ou isn't allowed")
  178. }
  179. //drop & warn on duplicate ou
  180. ok := true
  181. for _, validOu := range ta.AllowedOUs {
  182. if validOu == ou {
  183. ta.logger.Warningf("dropping duplicate ou %s", ou)
  184. ok = false
  185. }
  186. }
  187. if ok {
  188. ta.AllowedOUs = append(ta.AllowedOUs, ou)
  189. }
  190. }
  191. return nil
  192. }
  193. func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) {
  194. //Checks cert validity, Returns true + CN if client cert matches requested OU
  195. var clientCert *x509.Certificate
  196. if c.Request.TLS == nil || len(c.Request.TLS.PeerCertificates) == 0 {
  197. //do not error if it's not TLS or there are no peer certs
  198. return false, "", nil
  199. }
  200. if len(c.Request.TLS.VerifiedChains) > 0 {
  201. validOU := false
  202. clientCert = c.Request.TLS.VerifiedChains[0][0]
  203. for _, ou := range clientCert.Subject.OrganizationalUnit {
  204. for _, allowedOu := range ta.AllowedOUs {
  205. if allowedOu == ou {
  206. validOU = true
  207. break
  208. }
  209. }
  210. }
  211. if !validOU {
  212. return false, "", fmt.Errorf("client certificate OU (%v) doesn't match expected OU (%v)",
  213. clientCert.Subject.OrganizationalUnit, ta.AllowedOUs)
  214. }
  215. revoked, err := ta.isInvalid(clientCert, c.Request.TLS.VerifiedChains[0][1])
  216. if err != nil {
  217. ta.logger.Errorf("TLSAuth: error checking if client certificate is revoked: %s", err)
  218. return false, "", errors.Wrap(err, "could not check for client certification revokation status")
  219. }
  220. if revoked {
  221. return false, "", fmt.Errorf("client certificate is revoked")
  222. }
  223. ta.logger.Debugf("client OU %v is allowed vs required OU %v", clientCert.Subject.OrganizationalUnit, ta.AllowedOUs)
  224. return true, clientCert.Subject.CommonName, nil
  225. }
  226. return false, "", fmt.Errorf("no verified cert in request")
  227. }
  228. func NewTLSAuth(allowedOus []string, crlPath string, cacheExpiration time.Duration, logger *log.Entry) (*TLSAuth, error) {
  229. ta := &TLSAuth{
  230. revokationCache: map[string]cacheEntry{},
  231. cacheExpiration: cacheExpiration,
  232. CrlPath: crlPath,
  233. logger: logger,
  234. }
  235. err := ta.SetAllowedOu(allowedOus)
  236. if err != nil {
  237. return nil, err
  238. }
  239. return ta, nil
  240. }