handshaker.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. /*
  2. *
  3. * Copyright 2021 Google LLC
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * https://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. // Package handshaker communicates with the S2A handshaker service.
  19. package handshaker
  20. import (
  21. "context"
  22. "errors"
  23. "fmt"
  24. "io"
  25. "net"
  26. "sync"
  27. "github.com/google/s2a-go/internal/authinfo"
  28. commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
  29. s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
  30. "github.com/google/s2a-go/internal/record"
  31. "github.com/google/s2a-go/internal/tokenmanager"
  32. grpc "google.golang.org/grpc"
  33. "google.golang.org/grpc/codes"
  34. "google.golang.org/grpc/credentials"
  35. "google.golang.org/grpc/grpclog"
  36. )
  37. var (
  38. // appProtocol contains the application protocol accepted by the handshaker.
  39. appProtocol = "grpc"
  40. // frameLimit is the maximum size of a frame in bytes.
  41. frameLimit = 1024 * 64
  42. // peerNotRespondingError is the error thrown when the peer doesn't respond.
  43. errPeerNotResponding = errors.New("peer is not responding and re-connection should be attempted")
  44. )
  45. // Handshaker defines a handshaker interface.
  46. type Handshaker interface {
  47. // ClientHandshake starts and completes a TLS handshake from the client side,
  48. // and returns a secure connection along with additional auth information.
  49. ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error)
  50. // ServerHandshake starts and completes a TLS handshake from the server side,
  51. // and returns a secure connection along with additional auth information.
  52. ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error)
  53. // Close terminates the Handshaker. It should be called when the handshake
  54. // is complete.
  55. Close() error
  56. }
  57. // ClientHandshakerOptions contains the options needed to configure the S2A
  58. // handshaker service on the client-side.
  59. type ClientHandshakerOptions struct {
  60. // MinTLSVersion specifies the min TLS version supported by the client.
  61. MinTLSVersion commonpb.TLSVersion
  62. // MaxTLSVersion specifies the max TLS version supported by the client.
  63. MaxTLSVersion commonpb.TLSVersion
  64. // TLSCiphersuites is the ordered list of ciphersuites supported by the
  65. // client.
  66. TLSCiphersuites []commonpb.Ciphersuite
  67. // TargetIdentities contains a list of allowed server identities. One of the
  68. // target identities should match the peer identity in the handshake
  69. // result; otherwise, the handshake fails.
  70. TargetIdentities []*commonpb.Identity
  71. // LocalIdentity is the local identity of the client application. If none is
  72. // provided, then the S2A will choose the default identity.
  73. LocalIdentity *commonpb.Identity
  74. // TargetName is the allowed server name, which may be used for server
  75. // authorization check by the S2A if it is provided.
  76. TargetName string
  77. // EnsureProcessSessionTickets allows users to wait and ensure that all
  78. // available session tickets are sent to S2A before a process completes.
  79. EnsureProcessSessionTickets *sync.WaitGroup
  80. }
  81. // ServerHandshakerOptions contains the options needed to configure the S2A
  82. // handshaker service on the server-side.
  83. type ServerHandshakerOptions struct {
  84. // MinTLSVersion specifies the min TLS version supported by the server.
  85. MinTLSVersion commonpb.TLSVersion
  86. // MaxTLSVersion specifies the max TLS version supported by the server.
  87. MaxTLSVersion commonpb.TLSVersion
  88. // TLSCiphersuites is the ordered list of ciphersuites supported by the
  89. // server.
  90. TLSCiphersuites []commonpb.Ciphersuite
  91. // LocalIdentities is the list of local identities that may be assumed by
  92. // the server. If no local identity is specified, then the S2A chooses a
  93. // default local identity.
  94. LocalIdentities []*commonpb.Identity
  95. }
  96. // s2aHandshaker performs a TLS handshake using the S2A handshaker service.
  97. type s2aHandshaker struct {
  98. // stream is used to communicate with the S2A handshaker service.
  99. stream s2apb.S2AService_SetUpSessionClient
  100. // conn is the connection to the peer.
  101. conn net.Conn
  102. // clientOpts should be non-nil iff the handshaker is client-side.
  103. clientOpts *ClientHandshakerOptions
  104. // serverOpts should be non-nil iff the handshaker is server-side.
  105. serverOpts *ServerHandshakerOptions
  106. // isClient determines if the handshaker is client or server side.
  107. isClient bool
  108. // hsAddr stores the address of the S2A handshaker service.
  109. hsAddr string
  110. // tokenManager manages access tokens for authenticating to S2A.
  111. tokenManager tokenmanager.AccessTokenManager
  112. // localIdentities is the set of local identities for whom the
  113. // tokenManager should fetch a token when preparing a request to be
  114. // sent to S2A.
  115. localIdentities []*commonpb.Identity
  116. }
  117. // NewClientHandshaker creates an s2aHandshaker instance that performs a
  118. // client-side TLS handshake using the S2A handshaker service.
  119. func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, hsAddr string, opts *ClientHandshakerOptions) (Handshaker, error) {
  120. stream, err := s2apb.NewS2AServiceClient(conn).SetUpSession(ctx, grpc.WaitForReady(true))
  121. if err != nil {
  122. return nil, err
  123. }
  124. tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
  125. if err != nil {
  126. grpclog.Infof("failed to create single token access token manager: %v", err)
  127. }
  128. return newClientHandshaker(stream, c, hsAddr, opts, tokenManager), nil
  129. }
  130. func newClientHandshaker(stream s2apb.S2AService_SetUpSessionClient, c net.Conn, hsAddr string, opts *ClientHandshakerOptions, tokenManager tokenmanager.AccessTokenManager) *s2aHandshaker {
  131. var localIdentities []*commonpb.Identity
  132. if opts != nil {
  133. localIdentities = []*commonpb.Identity{opts.LocalIdentity}
  134. }
  135. return &s2aHandshaker{
  136. stream: stream,
  137. conn: c,
  138. clientOpts: opts,
  139. isClient: true,
  140. hsAddr: hsAddr,
  141. tokenManager: tokenManager,
  142. localIdentities: localIdentities,
  143. }
  144. }
  145. // NewServerHandshaker creates an s2aHandshaker instance that performs a
  146. // server-side TLS handshake using the S2A handshaker service.
  147. func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, hsAddr string, opts *ServerHandshakerOptions) (Handshaker, error) {
  148. stream, err := s2apb.NewS2AServiceClient(conn).SetUpSession(ctx, grpc.WaitForReady(true))
  149. if err != nil {
  150. return nil, err
  151. }
  152. tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
  153. if err != nil {
  154. grpclog.Infof("failed to create single token access token manager: %v", err)
  155. }
  156. return newServerHandshaker(stream, c, hsAddr, opts, tokenManager), nil
  157. }
  158. func newServerHandshaker(stream s2apb.S2AService_SetUpSessionClient, c net.Conn, hsAddr string, opts *ServerHandshakerOptions, tokenManager tokenmanager.AccessTokenManager) *s2aHandshaker {
  159. var localIdentities []*commonpb.Identity
  160. if opts != nil {
  161. localIdentities = opts.LocalIdentities
  162. }
  163. return &s2aHandshaker{
  164. stream: stream,
  165. conn: c,
  166. serverOpts: opts,
  167. isClient: false,
  168. hsAddr: hsAddr,
  169. tokenManager: tokenManager,
  170. localIdentities: localIdentities,
  171. }
  172. }
  173. // ClientHandshake performs a client-side TLS handshake using the S2A handshaker
  174. // service. When complete, returns a TLS connection.
  175. func (h *s2aHandshaker) ClientHandshake(_ context.Context) (net.Conn, credentials.AuthInfo, error) {
  176. if !h.isClient {
  177. return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client-side handshake")
  178. }
  179. // Extract the hostname from the target name. The target name is assumed to be an authority.
  180. hostname, _, err := net.SplitHostPort(h.clientOpts.TargetName)
  181. if err != nil {
  182. // If the target name had no host port or could not be parsed, use it as is.
  183. hostname = h.clientOpts.TargetName
  184. }
  185. // Prepare a client start message to send to the S2A handshaker service.
  186. req := &s2apb.SessionReq{
  187. ReqOneof: &s2apb.SessionReq_ClientStart{
  188. ClientStart: &s2apb.ClientSessionStartReq{
  189. ApplicationProtocols: []string{appProtocol},
  190. MinTlsVersion: h.clientOpts.MinTLSVersion,
  191. MaxTlsVersion: h.clientOpts.MaxTLSVersion,
  192. TlsCiphersuites: h.clientOpts.TLSCiphersuites,
  193. TargetIdentities: h.clientOpts.TargetIdentities,
  194. LocalIdentity: h.clientOpts.LocalIdentity,
  195. TargetName: hostname,
  196. },
  197. },
  198. AuthMechanisms: h.getAuthMechanisms(),
  199. }
  200. conn, result, err := h.setUpSession(req)
  201. if err != nil {
  202. return nil, nil, err
  203. }
  204. authInfo, err := authinfo.NewS2AAuthInfo(result)
  205. if err != nil {
  206. return nil, nil, err
  207. }
  208. return conn, authInfo, nil
  209. }
  210. // ServerHandshake performs a server-side TLS handshake using the S2A handshaker
  211. // service. When complete, returns a TLS connection.
  212. func (h *s2aHandshaker) ServerHandshake(_ context.Context) (net.Conn, credentials.AuthInfo, error) {
  213. if h.isClient {
  214. return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server-side handshake")
  215. }
  216. p := make([]byte, frameLimit)
  217. n, err := h.conn.Read(p)
  218. if err != nil {
  219. return nil, nil, err
  220. }
  221. // Prepare a server start message to send to the S2A handshaker service.
  222. req := &s2apb.SessionReq{
  223. ReqOneof: &s2apb.SessionReq_ServerStart{
  224. ServerStart: &s2apb.ServerSessionStartReq{
  225. ApplicationProtocols: []string{appProtocol},
  226. MinTlsVersion: h.serverOpts.MinTLSVersion,
  227. MaxTlsVersion: h.serverOpts.MaxTLSVersion,
  228. TlsCiphersuites: h.serverOpts.TLSCiphersuites,
  229. LocalIdentities: h.serverOpts.LocalIdentities,
  230. InBytes: p[:n],
  231. },
  232. },
  233. AuthMechanisms: h.getAuthMechanisms(),
  234. }
  235. conn, result, err := h.setUpSession(req)
  236. if err != nil {
  237. return nil, nil, err
  238. }
  239. authInfo, err := authinfo.NewS2AAuthInfo(result)
  240. if err != nil {
  241. return nil, nil, err
  242. }
  243. return conn, authInfo, nil
  244. }
  245. // setUpSession proxies messages between the peer and the S2A handshaker
  246. // service.
  247. func (h *s2aHandshaker) setUpSession(req *s2apb.SessionReq) (net.Conn, *s2apb.SessionResult, error) {
  248. resp, err := h.accessHandshakerService(req)
  249. if err != nil {
  250. return nil, nil, err
  251. }
  252. // Check if the returned status is an error.
  253. if resp.GetStatus() != nil {
  254. if got, want := resp.GetStatus().Code, uint32(codes.OK); got != want {
  255. return nil, nil, fmt.Errorf("%v", resp.GetStatus().Details)
  256. }
  257. }
  258. // Calculate the extra unread bytes from the Session. Attempting to consume
  259. // more than the bytes sent will throw an error.
  260. var extra []byte
  261. if req.GetServerStart() != nil {
  262. if resp.GetBytesConsumed() > uint32(len(req.GetServerStart().GetInBytes())) {
  263. return nil, nil, errors.New("handshaker service consumed bytes value is out-of-bounds")
  264. }
  265. extra = req.GetServerStart().GetInBytes()[resp.GetBytesConsumed():]
  266. }
  267. result, extra, err := h.processUntilDone(resp, extra)
  268. if err != nil {
  269. return nil, nil, err
  270. }
  271. if result.GetLocalIdentity() == nil {
  272. return nil, nil, errors.New("local identity must be populated in session result")
  273. }
  274. // Create a new TLS record protocol using the Session Result.
  275. newConn, err := record.NewConn(&record.ConnParameters{
  276. NetConn: h.conn,
  277. Ciphersuite: result.GetState().GetTlsCiphersuite(),
  278. TLSVersion: result.GetState().GetTlsVersion(),
  279. InTrafficSecret: result.GetState().GetInKey(),
  280. OutTrafficSecret: result.GetState().GetOutKey(),
  281. UnusedBuf: extra,
  282. InSequence: result.GetState().GetInSequence(),
  283. OutSequence: result.GetState().GetOutSequence(),
  284. HSAddr: h.hsAddr,
  285. ConnectionID: result.GetState().GetConnectionId(),
  286. LocalIdentity: result.GetLocalIdentity(),
  287. EnsureProcessSessionTickets: h.ensureProcessSessionTickets(),
  288. })
  289. if err != nil {
  290. return nil, nil, err
  291. }
  292. return newConn, result, nil
  293. }
  294. func (h *s2aHandshaker) ensureProcessSessionTickets() *sync.WaitGroup {
  295. if h.clientOpts == nil {
  296. return nil
  297. }
  298. return h.clientOpts.EnsureProcessSessionTickets
  299. }
  300. // accessHandshakerService sends the session request to the S2A handshaker
  301. // service and returns the session response.
  302. func (h *s2aHandshaker) accessHandshakerService(req *s2apb.SessionReq) (*s2apb.SessionResp, error) {
  303. if err := h.stream.Send(req); err != nil {
  304. return nil, err
  305. }
  306. resp, err := h.stream.Recv()
  307. if err != nil {
  308. return nil, err
  309. }
  310. return resp, nil
  311. }
  312. // processUntilDone continues proxying messages between the peer and the S2A
  313. // handshaker service until the handshaker service returns the SessionResult at
  314. // the end of the handshake or an error occurs.
  315. func (h *s2aHandshaker) processUntilDone(resp *s2apb.SessionResp, unusedBytes []byte) (*s2apb.SessionResult, []byte, error) {
  316. for {
  317. if len(resp.OutFrames) > 0 {
  318. if _, err := h.conn.Write(resp.OutFrames); err != nil {
  319. return nil, nil, err
  320. }
  321. }
  322. if resp.Result != nil {
  323. return resp.Result, unusedBytes, nil
  324. }
  325. buf := make([]byte, frameLimit)
  326. n, err := h.conn.Read(buf)
  327. if err != nil && err != io.EOF {
  328. return nil, nil, err
  329. }
  330. // If there is nothing to send to the handshaker service and nothing is
  331. // received from the peer, then we are stuck. This covers the case when
  332. // the peer is not responding. Note that handshaker service connection
  333. // issues are caught in accessHandshakerService before we even get
  334. // here.
  335. if len(resp.OutFrames) == 0 && n == 0 {
  336. return nil, nil, errPeerNotResponding
  337. }
  338. // Append extra bytes from the previous interaction with the handshaker
  339. // service with the current buffer read from conn.
  340. p := append(unusedBytes, buf[:n]...)
  341. // From here on, p and unusedBytes point to the same slice.
  342. resp, err = h.accessHandshakerService(&s2apb.SessionReq{
  343. ReqOneof: &s2apb.SessionReq_Next{
  344. Next: &s2apb.SessionNextReq{
  345. InBytes: p,
  346. },
  347. },
  348. AuthMechanisms: h.getAuthMechanisms(),
  349. })
  350. if err != nil {
  351. return nil, nil, err
  352. }
  353. // Cache the local identity returned by S2A, if it is populated. This
  354. // overwrites any existing local identities. This is done because, once the
  355. // S2A has selected a local identity, then only that local identity should
  356. // be asserted in future requests until the end of the current handshake.
  357. if resp.GetLocalIdentity() != nil {
  358. h.localIdentities = []*commonpb.Identity{resp.GetLocalIdentity()}
  359. }
  360. // Set unusedBytes based on the handshaker service response.
  361. if resp.GetBytesConsumed() > uint32(len(p)) {
  362. return nil, nil, errors.New("handshaker service consumed bytes value is out-of-bounds")
  363. }
  364. unusedBytes = p[resp.GetBytesConsumed():]
  365. }
  366. }
  367. // Close shuts down the handshaker and the stream to the S2A handshaker service
  368. // when the handshake is complete. It should be called when the caller obtains
  369. // the secure connection at the end of the handshake.
  370. func (h *s2aHandshaker) Close() error {
  371. return h.stream.CloseSend()
  372. }
  373. func (h *s2aHandshaker) getAuthMechanisms() []*s2apb.AuthenticationMechanism {
  374. if h.tokenManager == nil {
  375. return nil
  376. }
  377. // First handle the special case when no local identities have been provided
  378. // by the application. In this case, an AuthenticationMechanism with no local
  379. // identity will be sent.
  380. if len(h.localIdentities) == 0 {
  381. token, err := h.tokenManager.DefaultToken()
  382. if err != nil {
  383. grpclog.Infof("unable to get token for empty local identity: %v", err)
  384. return nil
  385. }
  386. return []*s2apb.AuthenticationMechanism{
  387. {
  388. MechanismOneof: &s2apb.AuthenticationMechanism_Token{
  389. Token: token,
  390. },
  391. },
  392. }
  393. }
  394. // Next, handle the case where the application (or the S2A) has provided
  395. // one or more local identities.
  396. var authMechanisms []*s2apb.AuthenticationMechanism
  397. for _, localIdentity := range h.localIdentities {
  398. token, err := h.tokenManager.Token(localIdentity)
  399. if err != nil {
  400. grpclog.Infof("unable to get token for local identity %v: %v", localIdentity, err)
  401. continue
  402. }
  403. authMechanism := &s2apb.AuthenticationMechanism{
  404. Identity: localIdentity,
  405. MechanismOneof: &s2apb.AuthenticationMechanism_Token{
  406. Token: token,
  407. },
  408. }
  409. authMechanisms = append(authMechanisms, authMechanism)
  410. }
  411. return authMechanisms
  412. }