transport.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. package ca
  2. import (
  3. "crypto/tls"
  4. "crypto/x509"
  5. "crypto/x509/pkix"
  6. "net"
  7. "strings"
  8. "sync"
  9. "github.com/pkg/errors"
  10. "golang.org/x/net/context"
  11. "google.golang.org/grpc/credentials"
  12. )
  13. var (
  14. // alpnProtoStr is the specified application level protocols for gRPC.
  15. alpnProtoStr = []string{"h2"}
  16. )
  17. type timeoutError struct{}
  18. func (timeoutError) Error() string { return "mutablecredentials: Dial timed out" }
  19. func (timeoutError) Timeout() bool { return true }
  20. func (timeoutError) Temporary() bool { return true }
  21. // MutableTLSCreds is the credentials required for authenticating a connection using TLS.
  22. type MutableTLSCreds struct {
  23. // Mutex for the tls config
  24. sync.Mutex
  25. // TLS configuration
  26. config *tls.Config
  27. // TLS Credentials
  28. tlsCreds credentials.TransportCredentials
  29. // store the subject for easy access
  30. subject pkix.Name
  31. }
  32. // Info implements the credentials.TransportCredentials interface
  33. func (c *MutableTLSCreds) Info() credentials.ProtocolInfo {
  34. return credentials.ProtocolInfo{
  35. SecurityProtocol: "tls",
  36. SecurityVersion: "1.2",
  37. }
  38. }
  39. // Clone returns new MutableTLSCreds created from underlying *tls.Config.
  40. // It panics if validation of underlying config fails.
  41. func (c *MutableTLSCreds) Clone() credentials.TransportCredentials {
  42. c.Lock()
  43. newCfg, err := NewMutableTLS(c.config)
  44. if err != nil {
  45. panic("validation error on Clone")
  46. }
  47. c.Unlock()
  48. return newCfg
  49. }
  50. // OverrideServerName overrides *tls.Config.ServerName.
  51. func (c *MutableTLSCreds) OverrideServerName(name string) error {
  52. c.Lock()
  53. c.config.ServerName = name
  54. c.Unlock()
  55. return nil
  56. }
  57. // GetRequestMetadata implements the credentials.TransportCredentials interface
  58. func (c *MutableTLSCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
  59. return nil, nil
  60. }
  61. // RequireTransportSecurity implements the credentials.TransportCredentials interface
  62. func (c *MutableTLSCreds) RequireTransportSecurity() bool {
  63. return true
  64. }
  65. // ClientHandshake implements the credentials.TransportCredentials interface
  66. func (c *MutableTLSCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
  67. // borrow all the code from the original TLS credentials
  68. c.Lock()
  69. if c.config.ServerName == "" {
  70. colonPos := strings.LastIndex(addr, ":")
  71. if colonPos == -1 {
  72. colonPos = len(addr)
  73. }
  74. c.config.ServerName = addr[:colonPos]
  75. }
  76. conn := tls.Client(rawConn, c.config)
  77. // Need to allow conn.Handshake to have access to config,
  78. // would create a deadlock otherwise
  79. c.Unlock()
  80. var err error
  81. errChannel := make(chan error, 1)
  82. go func() {
  83. errChannel <- conn.Handshake()
  84. }()
  85. select {
  86. case err = <-errChannel:
  87. case <-ctx.Done():
  88. err = ctx.Err()
  89. }
  90. if err != nil {
  91. rawConn.Close()
  92. return nil, nil, err
  93. }
  94. return conn, nil, nil
  95. }
  96. // ServerHandshake implements the credentials.TransportCredentials interface
  97. func (c *MutableTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
  98. c.Lock()
  99. conn := tls.Server(rawConn, c.config)
  100. c.Unlock()
  101. if err := conn.Handshake(); err != nil {
  102. rawConn.Close()
  103. return nil, nil, err
  104. }
  105. return conn, credentials.TLSInfo{State: conn.ConnectionState()}, nil
  106. }
  107. // loadNewTLSConfig replaces the currently loaded TLS config with a new one
  108. func (c *MutableTLSCreds) loadNewTLSConfig(newConfig *tls.Config) error {
  109. newSubject, err := GetAndValidateCertificateSubject(newConfig.Certificates)
  110. if err != nil {
  111. return err
  112. }
  113. c.Lock()
  114. defer c.Unlock()
  115. c.subject = newSubject
  116. c.config = newConfig
  117. return nil
  118. }
  119. // Config returns the current underlying TLS config.
  120. func (c *MutableTLSCreds) Config() *tls.Config {
  121. c.Lock()
  122. defer c.Unlock()
  123. return c.config
  124. }
  125. // Role returns the OU for the certificate encapsulated in this TransportCredentials
  126. func (c *MutableTLSCreds) Role() string {
  127. c.Lock()
  128. defer c.Unlock()
  129. return c.subject.OrganizationalUnit[0]
  130. }
  131. // Organization returns the O for the certificate encapsulated in this TransportCredentials
  132. func (c *MutableTLSCreds) Organization() string {
  133. c.Lock()
  134. defer c.Unlock()
  135. return c.subject.Organization[0]
  136. }
  137. // NodeID returns the CN for the certificate encapsulated in this TransportCredentials
  138. func (c *MutableTLSCreds) NodeID() string {
  139. c.Lock()
  140. defer c.Unlock()
  141. return c.subject.CommonName
  142. }
  143. // NewMutableTLS uses c to construct a mutable TransportCredentials based on TLS.
  144. func NewMutableTLS(c *tls.Config) (*MutableTLSCreds, error) {
  145. originalTC := credentials.NewTLS(c)
  146. if len(c.Certificates) < 1 {
  147. return nil, errors.New("invalid configuration: needs at least one certificate")
  148. }
  149. subject, err := GetAndValidateCertificateSubject(c.Certificates)
  150. if err != nil {
  151. return nil, err
  152. }
  153. tc := &MutableTLSCreds{config: c, tlsCreds: originalTC, subject: subject}
  154. tc.config.NextProtos = alpnProtoStr
  155. return tc, nil
  156. }
  157. // GetAndValidateCertificateSubject is a helper method to retrieve and validate the subject
  158. // from the x509 certificate underlying a tls.Certificate
  159. func GetAndValidateCertificateSubject(certs []tls.Certificate) (pkix.Name, error) {
  160. for i := range certs {
  161. cert := &certs[i]
  162. x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
  163. if err != nil {
  164. continue
  165. }
  166. if len(x509Cert.Subject.OrganizationalUnit) < 1 {
  167. return pkix.Name{}, errors.New("no OU found in certificate subject")
  168. }
  169. if len(x509Cert.Subject.Organization) < 1 {
  170. return pkix.Name{}, errors.New("no organization found in certificate subject")
  171. }
  172. if x509Cert.Subject.CommonName == "" {
  173. return pkix.Name{}, errors.New("no valid subject names found for TLS configuration")
  174. }
  175. return x509Cert.Subject, nil
  176. }
  177. return pkix.Name{}, errors.New("no valid certificates found for TLS configuration")
  178. }