123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412 |
- /*
- *
- * Copyright 2021 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- */
- // Package s2a provides the S2A transport credentials used by a gRPC
- // application.
- package s2a
- import (
- "context"
- "crypto/tls"
- "errors"
- "fmt"
- "net"
- "sync"
- "time"
- "github.com/golang/protobuf/proto"
- "github.com/google/s2a-go/fallback"
- "github.com/google/s2a-go/internal/handshaker"
- "github.com/google/s2a-go/internal/handshaker/service"
- "github.com/google/s2a-go/internal/tokenmanager"
- "github.com/google/s2a-go/internal/v2"
- "google.golang.org/grpc/credentials"
- "google.golang.org/grpc/grpclog"
- commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
- s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
- )
- const (
- s2aSecurityProtocol = "tls"
- // defaultTimeout specifies the default server handshake timeout.
- defaultTimeout = 30.0 * time.Second
- )
- // s2aTransportCreds are the transport credentials required for establishing
- // a secure connection using the S2A. They implement the
- // credentials.TransportCredentials interface.
- type s2aTransportCreds struct {
- info *credentials.ProtocolInfo
- minTLSVersion commonpb.TLSVersion
- maxTLSVersion commonpb.TLSVersion
- // tlsCiphersuites contains the ciphersuites used in the S2A connection.
- // Note that these are currently unconfigurable.
- tlsCiphersuites []commonpb.Ciphersuite
- // localIdentity should only be used by the client.
- localIdentity *commonpb.Identity
- // localIdentities should only be used by the server.
- localIdentities []*commonpb.Identity
- // targetIdentities should only be used by the client.
- targetIdentities []*commonpb.Identity
- isClient bool
- s2aAddr string
- ensureProcessSessionTickets *sync.WaitGroup
- }
- // NewClientCreds returns a client-side transport credentials object that uses
- // the S2A to establish a secure connection with a server.
- func NewClientCreds(opts *ClientOptions) (credentials.TransportCredentials, error) {
- if opts == nil {
- return nil, errors.New("nil client options")
- }
- var targetIdentities []*commonpb.Identity
- for _, targetIdentity := range opts.TargetIdentities {
- protoTargetIdentity, err := toProtoIdentity(targetIdentity)
- if err != nil {
- return nil, err
- }
- targetIdentities = append(targetIdentities, protoTargetIdentity)
- }
- localIdentity, err := toProtoIdentity(opts.LocalIdentity)
- if err != nil {
- return nil, err
- }
- if opts.EnableLegacyMode {
- return &s2aTransportCreds{
- info: &credentials.ProtocolInfo{
- SecurityProtocol: s2aSecurityProtocol,
- },
- minTLSVersion: commonpb.TLSVersion_TLS1_3,
- maxTLSVersion: commonpb.TLSVersion_TLS1_3,
- tlsCiphersuites: []commonpb.Ciphersuite{
- commonpb.Ciphersuite_AES_128_GCM_SHA256,
- commonpb.Ciphersuite_AES_256_GCM_SHA384,
- commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
- },
- localIdentity: localIdentity,
- targetIdentities: targetIdentities,
- isClient: true,
- s2aAddr: opts.S2AAddress,
- ensureProcessSessionTickets: opts.EnsureProcessSessionTickets,
- }, nil
- }
- verificationMode := getVerificationMode(opts.VerificationMode)
- var fallbackFunc fallback.ClientHandshake
- if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackClientHandshakeFunc != nil {
- fallbackFunc = opts.FallbackOpts.FallbackClientHandshakeFunc
- }
- return v2.NewClientCreds(opts.S2AAddress, localIdentity, verificationMode, fallbackFunc, opts.getS2AStream, opts.serverAuthorizationPolicy)
- }
- // NewServerCreds returns a server-side transport credentials object that uses
- // the S2A to establish a secure connection with a client.
- func NewServerCreds(opts *ServerOptions) (credentials.TransportCredentials, error) {
- if opts == nil {
- return nil, errors.New("nil server options")
- }
- var localIdentities []*commonpb.Identity
- for _, localIdentity := range opts.LocalIdentities {
- protoLocalIdentity, err := toProtoIdentity(localIdentity)
- if err != nil {
- return nil, err
- }
- localIdentities = append(localIdentities, protoLocalIdentity)
- }
- if opts.EnableLegacyMode {
- return &s2aTransportCreds{
- info: &credentials.ProtocolInfo{
- SecurityProtocol: s2aSecurityProtocol,
- },
- minTLSVersion: commonpb.TLSVersion_TLS1_3,
- maxTLSVersion: commonpb.TLSVersion_TLS1_3,
- tlsCiphersuites: []commonpb.Ciphersuite{
- commonpb.Ciphersuite_AES_128_GCM_SHA256,
- commonpb.Ciphersuite_AES_256_GCM_SHA384,
- commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
- },
- localIdentities: localIdentities,
- isClient: false,
- s2aAddr: opts.S2AAddress,
- }, nil
- }
- verificationMode := getVerificationMode(opts.VerificationMode)
- return v2.NewServerCreds(opts.S2AAddress, localIdentities, verificationMode, opts.getS2AStream)
- }
- // ClientHandshake initiates a client-side TLS handshake using the S2A.
- func (c *s2aTransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
- if !c.isClient {
- return nil, nil, errors.New("client handshake called using server transport credentials")
- }
- // Connect to the S2A.
- hsConn, err := service.Dial(c.s2aAddr)
- if err != nil {
- grpclog.Infof("Failed to connect to S2A: %v", err)
- return nil, nil, err
- }
- var cancel context.CancelFunc
- ctx, cancel = context.WithCancel(ctx)
- defer cancel()
- opts := &handshaker.ClientHandshakerOptions{
- MinTLSVersion: c.minTLSVersion,
- MaxTLSVersion: c.maxTLSVersion,
- TLSCiphersuites: c.tlsCiphersuites,
- TargetIdentities: c.targetIdentities,
- LocalIdentity: c.localIdentity,
- TargetName: serverAuthority,
- EnsureProcessSessionTickets: c.ensureProcessSessionTickets,
- }
- chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
- if err != nil {
- grpclog.Infof("Call to handshaker.NewClientHandshaker failed: %v", err)
- return nil, nil, err
- }
- defer func() {
- if err != nil {
- if closeErr := chs.Close(); closeErr != nil {
- grpclog.Infof("Close failed unexpectedly: %v", err)
- err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
- }
- }
- }()
- secConn, authInfo, err := chs.ClientHandshake(context.Background())
- if err != nil {
- grpclog.Infof("Handshake failed: %v", err)
- return nil, nil, err
- }
- return secConn, authInfo, nil
- }
- // ServerHandshake initiates a server-side TLS handshake using the S2A.
- func (c *s2aTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
- if c.isClient {
- return nil, nil, errors.New("server handshake called using client transport credentials")
- }
- // Connect to the S2A.
- hsConn, err := service.Dial(c.s2aAddr)
- if err != nil {
- grpclog.Infof("Failed to connect to S2A: %v", err)
- return nil, nil, err
- }
- ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
- defer cancel()
- opts := &handshaker.ServerHandshakerOptions{
- MinTLSVersion: c.minTLSVersion,
- MaxTLSVersion: c.maxTLSVersion,
- TLSCiphersuites: c.tlsCiphersuites,
- LocalIdentities: c.localIdentities,
- }
- shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
- if err != nil {
- grpclog.Infof("Call to handshaker.NewServerHandshaker failed: %v", err)
- return nil, nil, err
- }
- defer func() {
- if err != nil {
- if closeErr := shs.Close(); closeErr != nil {
- grpclog.Infof("Close failed unexpectedly: %v", err)
- err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
- }
- }
- }()
- secConn, authInfo, err := shs.ServerHandshake(context.Background())
- if err != nil {
- grpclog.Infof("Handshake failed: %v", err)
- return nil, nil, err
- }
- return secConn, authInfo, nil
- }
- func (c *s2aTransportCreds) Info() credentials.ProtocolInfo {
- return *c.info
- }
- func (c *s2aTransportCreds) Clone() credentials.TransportCredentials {
- info := *c.info
- var localIdentity *commonpb.Identity
- if c.localIdentity != nil {
- localIdentity = proto.Clone(c.localIdentity).(*commonpb.Identity)
- }
- var localIdentities []*commonpb.Identity
- if c.localIdentities != nil {
- localIdentities = make([]*commonpb.Identity, len(c.localIdentities))
- for i, localIdentity := range c.localIdentities {
- localIdentities[i] = proto.Clone(localIdentity).(*commonpb.Identity)
- }
- }
- var targetIdentities []*commonpb.Identity
- if c.targetIdentities != nil {
- targetIdentities = make([]*commonpb.Identity, len(c.targetIdentities))
- for i, targetIdentity := range c.targetIdentities {
- targetIdentities[i] = proto.Clone(targetIdentity).(*commonpb.Identity)
- }
- }
- return &s2aTransportCreds{
- info: &info,
- minTLSVersion: c.minTLSVersion,
- maxTLSVersion: c.maxTLSVersion,
- tlsCiphersuites: c.tlsCiphersuites,
- localIdentity: localIdentity,
- localIdentities: localIdentities,
- targetIdentities: targetIdentities,
- isClient: c.isClient,
- s2aAddr: c.s2aAddr,
- }
- }
- func (c *s2aTransportCreds) OverrideServerName(serverNameOverride string) error {
- c.info.ServerName = serverNameOverride
- return nil
- }
- // TLSClientConfigOptions specifies parameters for creating client TLS config.
- type TLSClientConfigOptions struct {
- // ServerName is required by s2a as the expected name when verifying the hostname found in server's certificate.
- // tlsConfig, _ := factory.Build(ctx, &s2a.TLSClientConfigOptions{
- // ServerName: "example.com",
- // })
- ServerName string
- }
- // TLSClientConfigFactory defines the interface for a client TLS config factory.
- type TLSClientConfigFactory interface {
- Build(ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error)
- }
- // NewTLSClientConfigFactory returns an instance of s2aTLSClientConfigFactory.
- func NewTLSClientConfigFactory(opts *ClientOptions) (TLSClientConfigFactory, error) {
- if opts == nil {
- return nil, fmt.Errorf("opts must be non-nil")
- }
- if opts.EnableLegacyMode {
- return nil, fmt.Errorf("NewTLSClientConfigFactory only supports S2Av2")
- }
- tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
- if err != nil {
- // The only possible error is: access token not set in the environment,
- // which is okay in environments other than serverless.
- grpclog.Infof("Access token manager not initialized: %v", err)
- return &s2aTLSClientConfigFactory{
- s2av2Address: opts.S2AAddress,
- tokenManager: nil,
- verificationMode: getVerificationMode(opts.VerificationMode),
- serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
- }, nil
- }
- return &s2aTLSClientConfigFactory{
- s2av2Address: opts.S2AAddress,
- tokenManager: tokenManager,
- verificationMode: getVerificationMode(opts.VerificationMode),
- serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
- }, nil
- }
- type s2aTLSClientConfigFactory struct {
- s2av2Address string
- tokenManager tokenmanager.AccessTokenManager
- verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
- serverAuthorizationPolicy []byte
- }
- func (f *s2aTLSClientConfigFactory) Build(
- ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error) {
- serverName := ""
- if opts != nil && opts.ServerName != "" {
- serverName = opts.ServerName
- }
- return v2.NewClientTLSConfig(ctx, f.s2av2Address, f.tokenManager, f.verificationMode, serverName, f.serverAuthorizationPolicy)
- }
- func getVerificationMode(verificationMode VerificationModeType) s2av2pb.ValidatePeerCertificateChainReq_VerificationMode {
- switch verificationMode {
- case ConnectToGoogle:
- return s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE
- case Spiffe:
- return s2av2pb.ValidatePeerCertificateChainReq_SPIFFE
- default:
- return s2av2pb.ValidatePeerCertificateChainReq_UNSPECIFIED
- }
- }
- // NewS2ADialTLSContextFunc returns a dialer which establishes an MTLS connection using S2A.
- // Example use with http.RoundTripper:
- //
- // dialTLSContext := s2a.NewS2aDialTLSContextFunc(&s2a.ClientOptions{
- // S2AAddress: s2aAddress, // required
- // })
- // transport := http.DefaultTransport
- // transport.DialTLSContext = dialTLSContext
- func NewS2ADialTLSContextFunc(opts *ClientOptions) func(ctx context.Context, network, addr string) (net.Conn, error) {
- return func(ctx context.Context, network, addr string) (net.Conn, error) {
- fallback := func(err error) (net.Conn, error) {
- if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackDialer != nil &&
- opts.FallbackOpts.FallbackDialer.Dialer != nil && opts.FallbackOpts.FallbackDialer.ServerAddr != "" {
- fbDialer := opts.FallbackOpts.FallbackDialer
- grpclog.Infof("fall back to dial: %s", fbDialer.ServerAddr)
- fbConn, fbErr := fbDialer.Dialer.DialContext(ctx, network, fbDialer.ServerAddr)
- if fbErr != nil {
- return nil, fmt.Errorf("error fallback to %s: %v; S2A error: %w", fbDialer.ServerAddr, fbErr, err)
- }
- return fbConn, nil
- }
- return nil, err
- }
- factory, err := NewTLSClientConfigFactory(opts)
- if err != nil {
- grpclog.Infof("error creating S2A client config factory: %v", err)
- return fallback(err)
- }
- serverName, _, err := net.SplitHostPort(addr)
- if err != nil {
- serverName = addr
- }
- timeoutCtx, cancel := context.WithTimeout(ctx, v2.GetS2ATimeout())
- defer cancel()
- s2aTLSConfig, err := factory.Build(timeoutCtx, &TLSClientConfigOptions{
- ServerName: serverName,
- })
- if err != nil {
- grpclog.Infof("error building S2A TLS config: %v", err)
- return fallback(err)
- }
- s2aDialer := &tls.Dialer{
- Config: s2aTLSConfig,
- }
- c, err := s2aDialer.DialContext(ctx, network, addr)
- if err != nil {
- grpclog.Infof("error dialing with S2A to %s: %v", addr, err)
- return fallback(err)
- }
- grpclog.Infof("success dialing MTLS to %s with S2A", addr)
- return c, nil
- }
- }
|