s2a.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. /*
  2. *
  3. * Copyright 2021 Google LLC
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * https://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. // Package s2a provides the S2A transport credentials used by a gRPC
  19. // application.
  20. package s2a
  21. import (
  22. "context"
  23. "crypto/tls"
  24. "errors"
  25. "fmt"
  26. "net"
  27. "sync"
  28. "time"
  29. "github.com/golang/protobuf/proto"
  30. "github.com/google/s2a-go/fallback"
  31. "github.com/google/s2a-go/internal/handshaker"
  32. "github.com/google/s2a-go/internal/handshaker/service"
  33. "github.com/google/s2a-go/internal/tokenmanager"
  34. "github.com/google/s2a-go/internal/v2"
  35. "google.golang.org/grpc/credentials"
  36. "google.golang.org/grpc/grpclog"
  37. commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
  38. s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
  39. )
  40. const (
  41. s2aSecurityProtocol = "tls"
  42. // defaultTimeout specifies the default server handshake timeout.
  43. defaultTimeout = 30.0 * time.Second
  44. )
  45. // s2aTransportCreds are the transport credentials required for establishing
  46. // a secure connection using the S2A. They implement the
  47. // credentials.TransportCredentials interface.
  48. type s2aTransportCreds struct {
  49. info *credentials.ProtocolInfo
  50. minTLSVersion commonpb.TLSVersion
  51. maxTLSVersion commonpb.TLSVersion
  52. // tlsCiphersuites contains the ciphersuites used in the S2A connection.
  53. // Note that these are currently unconfigurable.
  54. tlsCiphersuites []commonpb.Ciphersuite
  55. // localIdentity should only be used by the client.
  56. localIdentity *commonpb.Identity
  57. // localIdentities should only be used by the server.
  58. localIdentities []*commonpb.Identity
  59. // targetIdentities should only be used by the client.
  60. targetIdentities []*commonpb.Identity
  61. isClient bool
  62. s2aAddr string
  63. ensureProcessSessionTickets *sync.WaitGroup
  64. }
  65. // NewClientCreds returns a client-side transport credentials object that uses
  66. // the S2A to establish a secure connection with a server.
  67. func NewClientCreds(opts *ClientOptions) (credentials.TransportCredentials, error) {
  68. if opts == nil {
  69. return nil, errors.New("nil client options")
  70. }
  71. var targetIdentities []*commonpb.Identity
  72. for _, targetIdentity := range opts.TargetIdentities {
  73. protoTargetIdentity, err := toProtoIdentity(targetIdentity)
  74. if err != nil {
  75. return nil, err
  76. }
  77. targetIdentities = append(targetIdentities, protoTargetIdentity)
  78. }
  79. localIdentity, err := toProtoIdentity(opts.LocalIdentity)
  80. if err != nil {
  81. return nil, err
  82. }
  83. if opts.EnableLegacyMode {
  84. return &s2aTransportCreds{
  85. info: &credentials.ProtocolInfo{
  86. SecurityProtocol: s2aSecurityProtocol,
  87. },
  88. minTLSVersion: commonpb.TLSVersion_TLS1_3,
  89. maxTLSVersion: commonpb.TLSVersion_TLS1_3,
  90. tlsCiphersuites: []commonpb.Ciphersuite{
  91. commonpb.Ciphersuite_AES_128_GCM_SHA256,
  92. commonpb.Ciphersuite_AES_256_GCM_SHA384,
  93. commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
  94. },
  95. localIdentity: localIdentity,
  96. targetIdentities: targetIdentities,
  97. isClient: true,
  98. s2aAddr: opts.S2AAddress,
  99. ensureProcessSessionTickets: opts.EnsureProcessSessionTickets,
  100. }, nil
  101. }
  102. verificationMode := getVerificationMode(opts.VerificationMode)
  103. var fallbackFunc fallback.ClientHandshake
  104. if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackClientHandshakeFunc != nil {
  105. fallbackFunc = opts.FallbackOpts.FallbackClientHandshakeFunc
  106. }
  107. return v2.NewClientCreds(opts.S2AAddress, localIdentity, verificationMode, fallbackFunc, opts.getS2AStream, opts.serverAuthorizationPolicy)
  108. }
  109. // NewServerCreds returns a server-side transport credentials object that uses
  110. // the S2A to establish a secure connection with a client.
  111. func NewServerCreds(opts *ServerOptions) (credentials.TransportCredentials, error) {
  112. if opts == nil {
  113. return nil, errors.New("nil server options")
  114. }
  115. var localIdentities []*commonpb.Identity
  116. for _, localIdentity := range opts.LocalIdentities {
  117. protoLocalIdentity, err := toProtoIdentity(localIdentity)
  118. if err != nil {
  119. return nil, err
  120. }
  121. localIdentities = append(localIdentities, protoLocalIdentity)
  122. }
  123. if opts.EnableLegacyMode {
  124. return &s2aTransportCreds{
  125. info: &credentials.ProtocolInfo{
  126. SecurityProtocol: s2aSecurityProtocol,
  127. },
  128. minTLSVersion: commonpb.TLSVersion_TLS1_3,
  129. maxTLSVersion: commonpb.TLSVersion_TLS1_3,
  130. tlsCiphersuites: []commonpb.Ciphersuite{
  131. commonpb.Ciphersuite_AES_128_GCM_SHA256,
  132. commonpb.Ciphersuite_AES_256_GCM_SHA384,
  133. commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
  134. },
  135. localIdentities: localIdentities,
  136. isClient: false,
  137. s2aAddr: opts.S2AAddress,
  138. }, nil
  139. }
  140. verificationMode := getVerificationMode(opts.VerificationMode)
  141. return v2.NewServerCreds(opts.S2AAddress, localIdentities, verificationMode, opts.getS2AStream)
  142. }
  143. // ClientHandshake initiates a client-side TLS handshake using the S2A.
  144. func (c *s2aTransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
  145. if !c.isClient {
  146. return nil, nil, errors.New("client handshake called using server transport credentials")
  147. }
  148. // Connect to the S2A.
  149. hsConn, err := service.Dial(c.s2aAddr)
  150. if err != nil {
  151. grpclog.Infof("Failed to connect to S2A: %v", err)
  152. return nil, nil, err
  153. }
  154. var cancel context.CancelFunc
  155. ctx, cancel = context.WithCancel(ctx)
  156. defer cancel()
  157. opts := &handshaker.ClientHandshakerOptions{
  158. MinTLSVersion: c.minTLSVersion,
  159. MaxTLSVersion: c.maxTLSVersion,
  160. TLSCiphersuites: c.tlsCiphersuites,
  161. TargetIdentities: c.targetIdentities,
  162. LocalIdentity: c.localIdentity,
  163. TargetName: serverAuthority,
  164. EnsureProcessSessionTickets: c.ensureProcessSessionTickets,
  165. }
  166. chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
  167. if err != nil {
  168. grpclog.Infof("Call to handshaker.NewClientHandshaker failed: %v", err)
  169. return nil, nil, err
  170. }
  171. defer func() {
  172. if err != nil {
  173. if closeErr := chs.Close(); closeErr != nil {
  174. grpclog.Infof("Close failed unexpectedly: %v", err)
  175. err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
  176. }
  177. }
  178. }()
  179. secConn, authInfo, err := chs.ClientHandshake(context.Background())
  180. if err != nil {
  181. grpclog.Infof("Handshake failed: %v", err)
  182. return nil, nil, err
  183. }
  184. return secConn, authInfo, nil
  185. }
  186. // ServerHandshake initiates a server-side TLS handshake using the S2A.
  187. func (c *s2aTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
  188. if c.isClient {
  189. return nil, nil, errors.New("server handshake called using client transport credentials")
  190. }
  191. // Connect to the S2A.
  192. hsConn, err := service.Dial(c.s2aAddr)
  193. if err != nil {
  194. grpclog.Infof("Failed to connect to S2A: %v", err)
  195. return nil, nil, err
  196. }
  197. ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
  198. defer cancel()
  199. opts := &handshaker.ServerHandshakerOptions{
  200. MinTLSVersion: c.minTLSVersion,
  201. MaxTLSVersion: c.maxTLSVersion,
  202. TLSCiphersuites: c.tlsCiphersuites,
  203. LocalIdentities: c.localIdentities,
  204. }
  205. shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
  206. if err != nil {
  207. grpclog.Infof("Call to handshaker.NewServerHandshaker failed: %v", err)
  208. return nil, nil, err
  209. }
  210. defer func() {
  211. if err != nil {
  212. if closeErr := shs.Close(); closeErr != nil {
  213. grpclog.Infof("Close failed unexpectedly: %v", err)
  214. err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
  215. }
  216. }
  217. }()
  218. secConn, authInfo, err := shs.ServerHandshake(context.Background())
  219. if err != nil {
  220. grpclog.Infof("Handshake failed: %v", err)
  221. return nil, nil, err
  222. }
  223. return secConn, authInfo, nil
  224. }
  225. func (c *s2aTransportCreds) Info() credentials.ProtocolInfo {
  226. return *c.info
  227. }
  228. func (c *s2aTransportCreds) Clone() credentials.TransportCredentials {
  229. info := *c.info
  230. var localIdentity *commonpb.Identity
  231. if c.localIdentity != nil {
  232. localIdentity = proto.Clone(c.localIdentity).(*commonpb.Identity)
  233. }
  234. var localIdentities []*commonpb.Identity
  235. if c.localIdentities != nil {
  236. localIdentities = make([]*commonpb.Identity, len(c.localIdentities))
  237. for i, localIdentity := range c.localIdentities {
  238. localIdentities[i] = proto.Clone(localIdentity).(*commonpb.Identity)
  239. }
  240. }
  241. var targetIdentities []*commonpb.Identity
  242. if c.targetIdentities != nil {
  243. targetIdentities = make([]*commonpb.Identity, len(c.targetIdentities))
  244. for i, targetIdentity := range c.targetIdentities {
  245. targetIdentities[i] = proto.Clone(targetIdentity).(*commonpb.Identity)
  246. }
  247. }
  248. return &s2aTransportCreds{
  249. info: &info,
  250. minTLSVersion: c.minTLSVersion,
  251. maxTLSVersion: c.maxTLSVersion,
  252. tlsCiphersuites: c.tlsCiphersuites,
  253. localIdentity: localIdentity,
  254. localIdentities: localIdentities,
  255. targetIdentities: targetIdentities,
  256. isClient: c.isClient,
  257. s2aAddr: c.s2aAddr,
  258. }
  259. }
  260. func (c *s2aTransportCreds) OverrideServerName(serverNameOverride string) error {
  261. c.info.ServerName = serverNameOverride
  262. return nil
  263. }
  264. // TLSClientConfigOptions specifies parameters for creating client TLS config.
  265. type TLSClientConfigOptions struct {
  266. // ServerName is required by s2a as the expected name when verifying the hostname found in server's certificate.
  267. // tlsConfig, _ := factory.Build(ctx, &s2a.TLSClientConfigOptions{
  268. // ServerName: "example.com",
  269. // })
  270. ServerName string
  271. }
  272. // TLSClientConfigFactory defines the interface for a client TLS config factory.
  273. type TLSClientConfigFactory interface {
  274. Build(ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error)
  275. }
  276. // NewTLSClientConfigFactory returns an instance of s2aTLSClientConfigFactory.
  277. func NewTLSClientConfigFactory(opts *ClientOptions) (TLSClientConfigFactory, error) {
  278. if opts == nil {
  279. return nil, fmt.Errorf("opts must be non-nil")
  280. }
  281. if opts.EnableLegacyMode {
  282. return nil, fmt.Errorf("NewTLSClientConfigFactory only supports S2Av2")
  283. }
  284. tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
  285. if err != nil {
  286. // The only possible error is: access token not set in the environment,
  287. // which is okay in environments other than serverless.
  288. grpclog.Infof("Access token manager not initialized: %v", err)
  289. return &s2aTLSClientConfigFactory{
  290. s2av2Address: opts.S2AAddress,
  291. tokenManager: nil,
  292. verificationMode: getVerificationMode(opts.VerificationMode),
  293. serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
  294. }, nil
  295. }
  296. return &s2aTLSClientConfigFactory{
  297. s2av2Address: opts.S2AAddress,
  298. tokenManager: tokenManager,
  299. verificationMode: getVerificationMode(opts.VerificationMode),
  300. serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
  301. }, nil
  302. }
  303. type s2aTLSClientConfigFactory struct {
  304. s2av2Address string
  305. tokenManager tokenmanager.AccessTokenManager
  306. verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
  307. serverAuthorizationPolicy []byte
  308. }
  309. func (f *s2aTLSClientConfigFactory) Build(
  310. ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error) {
  311. serverName := ""
  312. if opts != nil && opts.ServerName != "" {
  313. serverName = opts.ServerName
  314. }
  315. return v2.NewClientTLSConfig(ctx, f.s2av2Address, f.tokenManager, f.verificationMode, serverName, f.serverAuthorizationPolicy)
  316. }
  317. func getVerificationMode(verificationMode VerificationModeType) s2av2pb.ValidatePeerCertificateChainReq_VerificationMode {
  318. switch verificationMode {
  319. case ConnectToGoogle:
  320. return s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE
  321. case Spiffe:
  322. return s2av2pb.ValidatePeerCertificateChainReq_SPIFFE
  323. default:
  324. return s2av2pb.ValidatePeerCertificateChainReq_UNSPECIFIED
  325. }
  326. }
  327. // NewS2ADialTLSContextFunc returns a dialer which establishes an MTLS connection using S2A.
  328. // Example use with http.RoundTripper:
  329. //
  330. // dialTLSContext := s2a.NewS2aDialTLSContextFunc(&s2a.ClientOptions{
  331. // S2AAddress: s2aAddress, // required
  332. // })
  333. // transport := http.DefaultTransport
  334. // transport.DialTLSContext = dialTLSContext
  335. func NewS2ADialTLSContextFunc(opts *ClientOptions) func(ctx context.Context, network, addr string) (net.Conn, error) {
  336. return func(ctx context.Context, network, addr string) (net.Conn, error) {
  337. fallback := func(err error) (net.Conn, error) {
  338. if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackDialer != nil &&
  339. opts.FallbackOpts.FallbackDialer.Dialer != nil && opts.FallbackOpts.FallbackDialer.ServerAddr != "" {
  340. fbDialer := opts.FallbackOpts.FallbackDialer
  341. grpclog.Infof("fall back to dial: %s", fbDialer.ServerAddr)
  342. fbConn, fbErr := fbDialer.Dialer.DialContext(ctx, network, fbDialer.ServerAddr)
  343. if fbErr != nil {
  344. return nil, fmt.Errorf("error fallback to %s: %v; S2A error: %w", fbDialer.ServerAddr, fbErr, err)
  345. }
  346. return fbConn, nil
  347. }
  348. return nil, err
  349. }
  350. factory, err := NewTLSClientConfigFactory(opts)
  351. if err != nil {
  352. grpclog.Infof("error creating S2A client config factory: %v", err)
  353. return fallback(err)
  354. }
  355. serverName, _, err := net.SplitHostPort(addr)
  356. if err != nil {
  357. serverName = addr
  358. }
  359. timeoutCtx, cancel := context.WithTimeout(ctx, v2.GetS2ATimeout())
  360. defer cancel()
  361. s2aTLSConfig, err := factory.Build(timeoutCtx, &TLSClientConfigOptions{
  362. ServerName: serverName,
  363. })
  364. if err != nil {
  365. grpclog.Infof("error building S2A TLS config: %v", err)
  366. return fallback(err)
  367. }
  368. s2aDialer := &tls.Dialer{
  369. Config: s2aTLSConfig,
  370. }
  371. c, err := s2aDialer.DialContext(ctx, network, addr)
  372. if err != nil {
  373. grpclog.Infof("error dialing with S2A to %s: %v", addr, err)
  374. return fallback(err)
  375. }
  376. grpclog.Infof("success dialing MTLS to %s with S2A", addr)
  377. return c, nil
  378. }
  379. }