123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354 |
- /*
- *
- * Copyright 2022 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- */
- // Package v2 provides the S2Av2 transport credentials used by a gRPC
- // application.
- package v2
- import (
- "context"
- "crypto/tls"
- "errors"
- "net"
- "os"
- "time"
- "github.com/golang/protobuf/proto"
- "github.com/google/s2a-go/fallback"
- "github.com/google/s2a-go/internal/handshaker/service"
- "github.com/google/s2a-go/internal/tokenmanager"
- "github.com/google/s2a-go/internal/v2/tlsconfigstore"
- "github.com/google/s2a-go/stream"
- "google.golang.org/grpc"
- "google.golang.org/grpc/credentials"
- "google.golang.org/grpc/grpclog"
- commonpbv1 "github.com/google/s2a-go/internal/proto/common_go_proto"
- s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
- )
- const (
- s2aSecurityProtocol = "tls"
- defaultS2ATimeout = 3 * time.Second
- )
- // An environment variable, which sets the timeout enforced on the connection to the S2A service for handshake.
- const s2aTimeoutEnv = "S2A_TIMEOUT"
- type s2av2TransportCreds struct {
- info *credentials.ProtocolInfo
- isClient bool
- serverName string
- s2av2Address string
- tokenManager *tokenmanager.AccessTokenManager
- // localIdentity should only be used by the client.
- localIdentity *commonpbv1.Identity
- // localIdentities should only be used by the server.
- localIdentities []*commonpbv1.Identity
- verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
- fallbackClientHandshake fallback.ClientHandshake
- getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)
- serverAuthorizationPolicy []byte
- }
- // NewClientCreds returns a client-side transport credentials object that uses
- // the S2Av2 to establish a secure connection with a server.
- func NewClientCreds(s2av2Address string, localIdentity *commonpbv1.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, fallbackClientHandshakeFunc fallback.ClientHandshake, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error), serverAuthorizationPolicy []byte) (credentials.TransportCredentials, error) {
- // Create an AccessTokenManager instance to use to authenticate to S2Av2.
- accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
- creds := &s2av2TransportCreds{
- info: &credentials.ProtocolInfo{
- SecurityProtocol: s2aSecurityProtocol,
- },
- isClient: true,
- serverName: "",
- s2av2Address: s2av2Address,
- localIdentity: localIdentity,
- verificationMode: verificationMode,
- fallbackClientHandshake: fallbackClientHandshakeFunc,
- getS2AStream: getS2AStream,
- serverAuthorizationPolicy: serverAuthorizationPolicy,
- }
- if err != nil {
- creds.tokenManager = nil
- } else {
- creds.tokenManager = &accessTokenManager
- }
- if grpclog.V(1) {
- grpclog.Info("Created client S2Av2 transport credentials.")
- }
- return creds, nil
- }
- // NewServerCreds returns a server-side transport credentials object that uses
- // the S2Av2 to establish a secure connection with a client.
- func NewServerCreds(s2av2Address string, localIdentities []*commonpbv1.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)) (credentials.TransportCredentials, error) {
- // Create an AccessTokenManager instance to use to authenticate to S2Av2.
- accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
- creds := &s2av2TransportCreds{
- info: &credentials.ProtocolInfo{
- SecurityProtocol: s2aSecurityProtocol,
- },
- isClient: false,
- s2av2Address: s2av2Address,
- localIdentities: localIdentities,
- verificationMode: verificationMode,
- getS2AStream: getS2AStream,
- }
- if err != nil {
- creds.tokenManager = nil
- } else {
- creds.tokenManager = &accessTokenManager
- }
- if grpclog.V(1) {
- grpclog.Info("Created server S2Av2 transport credentials.")
- }
- return creds, nil
- }
- // ClientHandshake performs a client-side mTLS handshake using the S2Av2.
- func (c *s2av2TransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
- if !c.isClient {
- return nil, nil, errors.New("client handshake called using server transport credentials")
- }
- // Remove the port from serverAuthority.
- serverName := removeServerNamePort(serverAuthority)
- timeoutCtx, cancel := context.WithTimeout(ctx, GetS2ATimeout())
- defer cancel()
- s2AStream, err := createStream(timeoutCtx, c.s2av2Address, c.getS2AStream)
- if err != nil {
- grpclog.Infof("Failed to connect to S2Av2: %v", err)
- if c.fallbackClientHandshake != nil {
- return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
- }
- return nil, nil, err
- }
- defer s2AStream.CloseSend()
- if grpclog.V(1) {
- grpclog.Infof("Connected to S2Av2.")
- }
- var config *tls.Config
- var tokenManager tokenmanager.AccessTokenManager
- if c.tokenManager == nil {
- tokenManager = nil
- } else {
- tokenManager = *c.tokenManager
- }
- if c.serverName == "" {
- config, err = tlsconfigstore.GetTLSConfigurationForClient(serverName, s2AStream, tokenManager, c.localIdentity, c.verificationMode, c.serverAuthorizationPolicy)
- if err != nil {
- grpclog.Info("Failed to get client TLS config from S2Av2: %v", err)
- if c.fallbackClientHandshake != nil {
- return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
- }
- return nil, nil, err
- }
- } else {
- config, err = tlsconfigstore.GetTLSConfigurationForClient(c.serverName, s2AStream, tokenManager, c.localIdentity, c.verificationMode, c.serverAuthorizationPolicy)
- if err != nil {
- grpclog.Info("Failed to get client TLS config from S2Av2: %v", err)
- if c.fallbackClientHandshake != nil {
- return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
- }
- return nil, nil, err
- }
- }
- if grpclog.V(1) {
- grpclog.Infof("Got client TLS config from S2Av2.")
- }
- creds := credentials.NewTLS(config)
- conn, authInfo, err := creds.ClientHandshake(ctx, serverName, rawConn)
- if err != nil {
- grpclog.Infof("Failed to do client handshake using S2Av2: %v", err)
- if c.fallbackClientHandshake != nil {
- return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
- }
- return nil, nil, err
- }
- grpclog.Infof("Successfully done client handshake using S2Av2 to: %s", serverName)
- return conn, authInfo, err
- }
- // ServerHandshake performs a server-side mTLS handshake using the S2Av2.
- func (c *s2av2TransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
- if c.isClient {
- return nil, nil, errors.New("server handshake called using client transport credentials")
- }
- ctx, cancel := context.WithTimeout(context.Background(), GetS2ATimeout())
- defer cancel()
- s2AStream, err := createStream(ctx, c.s2av2Address, c.getS2AStream)
- if err != nil {
- grpclog.Infof("Failed to connect to S2Av2: %v", err)
- return nil, nil, err
- }
- defer s2AStream.CloseSend()
- if grpclog.V(1) {
- grpclog.Infof("Connected to S2Av2.")
- }
- var tokenManager tokenmanager.AccessTokenManager
- if c.tokenManager == nil {
- tokenManager = nil
- } else {
- tokenManager = *c.tokenManager
- }
- config, err := tlsconfigstore.GetTLSConfigurationForServer(s2AStream, tokenManager, c.localIdentities, c.verificationMode)
- if err != nil {
- grpclog.Infof("Failed to get server TLS config from S2Av2: %v", err)
- return nil, nil, err
- }
- if grpclog.V(1) {
- grpclog.Infof("Got server TLS config from S2Av2.")
- }
- creds := credentials.NewTLS(config)
- return creds.ServerHandshake(rawConn)
- }
- // Info returns protocol info of s2av2TransportCreds.
- func (c *s2av2TransportCreds) Info() credentials.ProtocolInfo {
- return *c.info
- }
- // Clone makes a deep copy of s2av2TransportCreds.
- func (c *s2av2TransportCreds) Clone() credentials.TransportCredentials {
- info := *c.info
- serverName := c.serverName
- fallbackClientHandshake := c.fallbackClientHandshake
- s2av2Address := c.s2av2Address
- var tokenManager tokenmanager.AccessTokenManager
- if c.tokenManager == nil {
- tokenManager = nil
- } else {
- tokenManager = *c.tokenManager
- }
- verificationMode := c.verificationMode
- var localIdentity *commonpbv1.Identity
- if c.localIdentity != nil {
- localIdentity = proto.Clone(c.localIdentity).(*commonpbv1.Identity)
- }
- var localIdentities []*commonpbv1.Identity
- if c.localIdentities != nil {
- localIdentities = make([]*commonpbv1.Identity, len(c.localIdentities))
- for i, localIdentity := range c.localIdentities {
- localIdentities[i] = proto.Clone(localIdentity).(*commonpbv1.Identity)
- }
- }
- creds := &s2av2TransportCreds{
- info: &info,
- isClient: c.isClient,
- serverName: serverName,
- fallbackClientHandshake: fallbackClientHandshake,
- s2av2Address: s2av2Address,
- localIdentity: localIdentity,
- localIdentities: localIdentities,
- verificationMode: verificationMode,
- }
- if c.tokenManager == nil {
- creds.tokenManager = nil
- } else {
- creds.tokenManager = &tokenManager
- }
- return creds
- }
- // NewClientTLSConfig returns a tls.Config instance that uses S2Av2 to establish a TLS connection as
- // a client. The tls.Config MUST only be used to establish a single TLS connection.
- func NewClientTLSConfig(
- ctx context.Context,
- s2av2Address string,
- tokenManager tokenmanager.AccessTokenManager,
- verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode,
- serverName string,
- serverAuthorizationPolicy []byte) (*tls.Config, error) {
- s2AStream, err := createStream(ctx, s2av2Address, nil)
- if err != nil {
- grpclog.Infof("Failed to connect to S2Av2: %v", err)
- return nil, err
- }
- return tlsconfigstore.GetTLSConfigurationForClient(removeServerNamePort(serverName), s2AStream, tokenManager, nil, verificationMode, serverAuthorizationPolicy)
- }
- // OverrideServerName sets the ServerName in the s2av2TransportCreds protocol
- // info. The ServerName MUST be a hostname.
- func (c *s2av2TransportCreds) OverrideServerName(serverNameOverride string) error {
- serverName := removeServerNamePort(serverNameOverride)
- c.info.ServerName = serverName
- c.serverName = serverName
- return nil
- }
- // Remove the trailing port from server name.
- func removeServerNamePort(serverName string) string {
- name, _, err := net.SplitHostPort(serverName)
- if err != nil {
- name = serverName
- }
- return name
- }
- type s2AGrpcStream struct {
- stream s2av2pb.S2AService_SetUpSessionClient
- }
- func (x s2AGrpcStream) Send(m *s2av2pb.SessionReq) error {
- return x.stream.Send(m)
- }
- func (x s2AGrpcStream) Recv() (*s2av2pb.SessionResp, error) {
- return x.stream.Recv()
- }
- func (x s2AGrpcStream) CloseSend() error {
- return x.stream.CloseSend()
- }
- func createStream(ctx context.Context, s2av2Address string, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)) (stream.S2AStream, error) {
- if getS2AStream != nil {
- return getS2AStream(ctx, s2av2Address)
- }
- // TODO(rmehta19): Consider whether to close the connection to S2Av2.
- conn, err := service.Dial(s2av2Address)
- if err != nil {
- return nil, err
- }
- client := s2av2pb.NewS2AServiceClient(conn)
- gRPCStream, err := client.SetUpSession(ctx, []grpc.CallOption{}...)
- if err != nil {
- return nil, err
- }
- return &s2AGrpcStream{
- stream: gRPCStream,
- }, nil
- }
- // GetS2ATimeout returns the timeout enforced on the connection to the S2A service for handshake.
- func GetS2ATimeout() time.Duration {
- timeout, err := time.ParseDuration(os.Getenv(s2aTimeoutEnv))
- if err != nil {
- return defaultS2ATimeout
- }
- return timeout
- }
|