s2a_fallback.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. /*
  2. *
  3. * Copyright 2023 Google LLC
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * https://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. // Package fallback provides default implementations of fallback options when S2A fails.
  19. package fallback
  20. import (
  21. "context"
  22. "crypto/tls"
  23. "fmt"
  24. "net"
  25. "google.golang.org/grpc/credentials"
  26. "google.golang.org/grpc/grpclog"
  27. )
  28. const (
  29. alpnProtoStrH2 = "h2"
  30. alpnProtoStrHTTP = "http/1.1"
  31. defaultHTTPSPort = "443"
  32. )
  33. // FallbackTLSConfigGRPC is a tls.Config used by the DefaultFallbackClientHandshakeFunc function.
  34. // It supports GRPC use case, thus the alpn is set to 'h2'.
  35. var FallbackTLSConfigGRPC = tls.Config{
  36. MinVersion: tls.VersionTLS13,
  37. ClientSessionCache: nil,
  38. NextProtos: []string{alpnProtoStrH2},
  39. }
  40. // FallbackTLSConfigHTTP is a tls.Config used by the DefaultFallbackDialerAndAddress func.
  41. // It supports the HTTP use case and the alpn is set to both 'http/1.1' and 'h2'.
  42. var FallbackTLSConfigHTTP = tls.Config{
  43. MinVersion: tls.VersionTLS13,
  44. ClientSessionCache: nil,
  45. NextProtos: []string{alpnProtoStrH2, alpnProtoStrHTTP},
  46. }
  47. // ClientHandshake establishes a TLS connection and returns it, plus its auth info.
  48. // Inputs:
  49. //
  50. // targetServer: the server attempted with S2A.
  51. // conn: the tcp connection to the server at address targetServer that was passed into S2A's ClientHandshake func.
  52. // If fallback is successful, the `conn` should be closed.
  53. // err: the error encountered when performing the client-side TLS handshake with S2A.
  54. type ClientHandshake func(ctx context.Context, targetServer string, conn net.Conn, err error) (net.Conn, credentials.AuthInfo, error)
  55. // DefaultFallbackClientHandshakeFunc returns a ClientHandshake function,
  56. // which establishes a TLS connection to the provided fallbackAddr, returns the new connection and its auth info.
  57. // Example use:
  58. //
  59. // transportCreds, _ = s2a.NewClientCreds(&s2a.ClientOptions{
  60. // S2AAddress: s2aAddress,
  61. // FallbackOpts: &s2a.FallbackOptions{ // optional
  62. // FallbackClientHandshakeFunc: fallback.DefaultFallbackClientHandshakeFunc(fallbackAddr),
  63. // },
  64. // })
  65. //
  66. // The fallback server's certificate must be verifiable using OS root store.
  67. // The fallbackAddr is expected to be a network address, e.g. example.com:port. If port is not specified,
  68. // it uses default port 443.
  69. // In the returned function's TLS config, ClientSessionCache is explicitly set to nil to disable TLS resumption,
  70. // and min TLS version is set to 1.3.
  71. func DefaultFallbackClientHandshakeFunc(fallbackAddr string) (ClientHandshake, error) {
  72. var fallbackDialer = tls.Dialer{Config: &FallbackTLSConfigGRPC}
  73. return defaultFallbackClientHandshakeFuncInternal(fallbackAddr, fallbackDialer.DialContext)
  74. }
  75. func defaultFallbackClientHandshakeFuncInternal(fallbackAddr string, dialContextFunc func(context.Context, string, string) (net.Conn, error)) (ClientHandshake, error) {
  76. fallbackServerAddr, err := processFallbackAddr(fallbackAddr)
  77. if err != nil {
  78. if grpclog.V(1) {
  79. grpclog.Infof("error processing fallback address [%s]: %v", fallbackAddr, err)
  80. }
  81. return nil, err
  82. }
  83. return func(ctx context.Context, targetServer string, conn net.Conn, s2aErr error) (net.Conn, credentials.AuthInfo, error) {
  84. fbConn, fbErr := dialContextFunc(ctx, "tcp", fallbackServerAddr)
  85. if fbErr != nil {
  86. grpclog.Infof("dialing to fallback server %s failed: %v", fallbackServerAddr, fbErr)
  87. return nil, nil, fmt.Errorf("dialing to fallback server %s failed: %v; S2A client handshake with %s error: %w", fallbackServerAddr, fbErr, targetServer, s2aErr)
  88. }
  89. tc, success := fbConn.(*tls.Conn)
  90. if !success {
  91. grpclog.Infof("the connection with fallback server is expected to be tls but isn't")
  92. return nil, nil, fmt.Errorf("the connection with fallback server is expected to be tls but isn't; S2A client handshake with %s error: %w", targetServer, s2aErr)
  93. }
  94. tlsInfo := credentials.TLSInfo{
  95. State: tc.ConnectionState(),
  96. CommonAuthInfo: credentials.CommonAuthInfo{
  97. SecurityLevel: credentials.PrivacyAndIntegrity,
  98. },
  99. }
  100. if grpclog.V(1) {
  101. grpclog.Infof("ConnectionState.NegotiatedProtocol: %v", tc.ConnectionState().NegotiatedProtocol)
  102. grpclog.Infof("ConnectionState.HandshakeComplete: %v", tc.ConnectionState().HandshakeComplete)
  103. grpclog.Infof("ConnectionState.ServerName: %v", tc.ConnectionState().ServerName)
  104. }
  105. conn.Close()
  106. return fbConn, tlsInfo, nil
  107. }, nil
  108. }
  109. // DefaultFallbackDialerAndAddress returns a TLS dialer and the network address to dial.
  110. // Example use:
  111. //
  112. // fallbackDialer, fallbackServerAddr := fallback.DefaultFallbackDialerAndAddress(fallbackAddr)
  113. // dialTLSContext := s2a.NewS2aDialTLSContextFunc(&s2a.ClientOptions{
  114. // S2AAddress: s2aAddress, // required
  115. // FallbackOpts: &s2a.FallbackOptions{
  116. // FallbackDialer: &s2a.FallbackDialer{
  117. // Dialer: fallbackDialer,
  118. // ServerAddr: fallbackServerAddr,
  119. // },
  120. // },
  121. // })
  122. //
  123. // The fallback server's certificate should be verifiable using OS root store.
  124. // The fallbackAddr is expected to be a network address, e.g. example.com:port. If port is not specified,
  125. // it uses default port 443.
  126. // In the returned function's TLS config, ClientSessionCache is explicitly set to nil to disable TLS resumption,
  127. // and min TLS version is set to 1.3.
  128. func DefaultFallbackDialerAndAddress(fallbackAddr string) (*tls.Dialer, string, error) {
  129. fallbackServerAddr, err := processFallbackAddr(fallbackAddr)
  130. if err != nil {
  131. if grpclog.V(1) {
  132. grpclog.Infof("error processing fallback address [%s]: %v", fallbackAddr, err)
  133. }
  134. return nil, "", err
  135. }
  136. return &tls.Dialer{Config: &FallbackTLSConfigHTTP}, fallbackServerAddr, nil
  137. }
  138. func processFallbackAddr(fallbackAddr string) (string, error) {
  139. var fallbackServerAddr string
  140. var err error
  141. if fallbackAddr == "" {
  142. return "", fmt.Errorf("empty fallback address")
  143. }
  144. _, _, err = net.SplitHostPort(fallbackAddr)
  145. if err != nil {
  146. // fallbackAddr does not have port suffix
  147. fallbackServerAddr = net.JoinHostPort(fallbackAddr, defaultHTTPSPort)
  148. } else {
  149. // FallbackServerAddr already has port suffix
  150. fallbackServerAddr = fallbackAddr
  151. }
  152. return fallbackServerAddr, nil
  153. }