client.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. /*
  2. Copyright The containerd Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package ttrpc
  14. import (
  15. "context"
  16. "errors"
  17. "fmt"
  18. "io"
  19. "net"
  20. "strings"
  21. "sync"
  22. "syscall"
  23. "time"
  24. "github.com/sirupsen/logrus"
  25. "google.golang.org/grpc/codes"
  26. "google.golang.org/grpc/status"
  27. "google.golang.org/protobuf/proto"
  28. )
  29. // Client for a ttrpc server
  30. type Client struct {
  31. codec codec
  32. conn net.Conn
  33. channel *channel
  34. streamLock sync.RWMutex
  35. streams map[streamID]*stream
  36. nextStreamID streamID
  37. sendLock sync.Mutex
  38. ctx context.Context
  39. closed func()
  40. closeOnce sync.Once
  41. userCloseFunc func()
  42. userCloseWaitCh chan struct{}
  43. interceptor UnaryClientInterceptor
  44. }
  45. // ClientOpts configures a client
  46. type ClientOpts func(c *Client)
  47. // WithOnClose sets the close func whenever the client's Close() method is called
  48. func WithOnClose(onClose func()) ClientOpts {
  49. return func(c *Client) {
  50. c.userCloseFunc = onClose
  51. }
  52. }
  53. // WithUnaryClientInterceptor sets the provided client interceptor
  54. func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
  55. return func(c *Client) {
  56. c.interceptor = i
  57. }
  58. }
  59. // NewClient creates a new ttrpc client using the given connection
  60. func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
  61. ctx, cancel := context.WithCancel(context.Background())
  62. channel := newChannel(conn)
  63. c := &Client{
  64. codec: codec{},
  65. conn: conn,
  66. channel: channel,
  67. streams: make(map[streamID]*stream),
  68. nextStreamID: 1,
  69. closed: cancel,
  70. ctx: ctx,
  71. userCloseFunc: func() {},
  72. userCloseWaitCh: make(chan struct{}),
  73. interceptor: defaultClientInterceptor,
  74. }
  75. for _, o := range opts {
  76. o(c)
  77. }
  78. go c.run()
  79. return c
  80. }
  81. func (c *Client) send(sid uint32, mt messageType, flags uint8, b []byte) error {
  82. c.sendLock.Lock()
  83. defer c.sendLock.Unlock()
  84. return c.channel.send(sid, mt, flags, b)
  85. }
  86. // Call makes a unary request and returns with response
  87. func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
  88. payload, err := c.codec.Marshal(req)
  89. if err != nil {
  90. return err
  91. }
  92. var (
  93. creq = &Request{
  94. Service: service,
  95. Method: method,
  96. Payload: payload,
  97. // TODO: metadata from context
  98. }
  99. cresp = &Response{}
  100. )
  101. if metadata, ok := GetMetadata(ctx); ok {
  102. metadata.setRequest(creq)
  103. }
  104. if dl, ok := ctx.Deadline(); ok {
  105. creq.TimeoutNano = time.Until(dl).Nanoseconds()
  106. }
  107. info := &UnaryClientInfo{
  108. FullMethod: fullPath(service, method),
  109. }
  110. if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil {
  111. return err
  112. }
  113. if err := c.codec.Unmarshal(cresp.Payload, resp); err != nil {
  114. return err
  115. }
  116. if cresp.Status != nil && cresp.Status.Code != int32(codes.OK) {
  117. return status.ErrorProto(cresp.Status)
  118. }
  119. return nil
  120. }
  121. // StreamDesc describes the stream properties, whether the stream has
  122. // a streaming client, a streaming server, or both
  123. type StreamDesc struct {
  124. StreamingClient bool
  125. StreamingServer bool
  126. }
  127. // ClientStream is used to send or recv messages on the underlying stream
  128. type ClientStream interface {
  129. CloseSend() error
  130. SendMsg(m interface{}) error
  131. RecvMsg(m interface{}) error
  132. }
  133. type clientStream struct {
  134. ctx context.Context
  135. s *stream
  136. c *Client
  137. desc *StreamDesc
  138. localClosed bool
  139. remoteClosed bool
  140. }
  141. func (cs *clientStream) CloseSend() error {
  142. if !cs.desc.StreamingClient {
  143. return fmt.Errorf("%w: cannot close non-streaming client", ErrProtocol)
  144. }
  145. if cs.localClosed {
  146. return ErrStreamClosed
  147. }
  148. err := cs.s.send(messageTypeData, flagRemoteClosed|flagNoData, nil)
  149. if err != nil {
  150. return filterCloseErr(err)
  151. }
  152. cs.localClosed = true
  153. return nil
  154. }
  155. func (cs *clientStream) SendMsg(m interface{}) error {
  156. if !cs.desc.StreamingClient {
  157. return fmt.Errorf("%w: cannot send data from non-streaming client", ErrProtocol)
  158. }
  159. if cs.localClosed {
  160. return ErrStreamClosed
  161. }
  162. var (
  163. payload []byte
  164. err error
  165. )
  166. if m != nil {
  167. payload, err = cs.c.codec.Marshal(m)
  168. if err != nil {
  169. return err
  170. }
  171. }
  172. err = cs.s.send(messageTypeData, 0, payload)
  173. if err != nil {
  174. return filterCloseErr(err)
  175. }
  176. return nil
  177. }
  178. func (cs *clientStream) RecvMsg(m interface{}) error {
  179. if cs.remoteClosed {
  180. return io.EOF
  181. }
  182. var msg *streamMessage
  183. select {
  184. case <-cs.ctx.Done():
  185. return cs.ctx.Err()
  186. case <-cs.s.recvClose:
  187. // If recv has a pending message, process that first
  188. select {
  189. case msg = <-cs.s.recv:
  190. default:
  191. return cs.s.recvErr
  192. }
  193. case msg = <-cs.s.recv:
  194. }
  195. if msg.header.Type == messageTypeResponse {
  196. resp := &Response{}
  197. err := proto.Unmarshal(msg.payload[:msg.header.Length], resp)
  198. // return the payload buffer for reuse
  199. cs.c.channel.putmbuf(msg.payload)
  200. if err != nil {
  201. return err
  202. }
  203. if err := cs.c.codec.Unmarshal(resp.Payload, m); err != nil {
  204. return err
  205. }
  206. if resp.Status != nil && resp.Status.Code != int32(codes.OK) {
  207. return status.ErrorProto(resp.Status)
  208. }
  209. cs.c.deleteStream(cs.s)
  210. cs.remoteClosed = true
  211. return nil
  212. } else if msg.header.Type == messageTypeData {
  213. if !cs.desc.StreamingServer {
  214. cs.c.deleteStream(cs.s)
  215. cs.remoteClosed = true
  216. return fmt.Errorf("received data from non-streaming server: %w", ErrProtocol)
  217. }
  218. if msg.header.Flags&flagRemoteClosed == flagRemoteClosed {
  219. cs.c.deleteStream(cs.s)
  220. cs.remoteClosed = true
  221. if msg.header.Flags&flagNoData == flagNoData {
  222. return io.EOF
  223. }
  224. }
  225. err := cs.c.codec.Unmarshal(msg.payload[:msg.header.Length], m)
  226. cs.c.channel.putmbuf(msg.payload)
  227. if err != nil {
  228. return err
  229. }
  230. return nil
  231. }
  232. return fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol)
  233. }
  234. // Close closes the ttrpc connection and underlying connection
  235. func (c *Client) Close() error {
  236. c.closeOnce.Do(func() {
  237. c.closed()
  238. c.conn.Close()
  239. })
  240. return nil
  241. }
  242. // UserOnCloseWait is used to blocks untils the user's on-close callback
  243. // finishes.
  244. func (c *Client) UserOnCloseWait(ctx context.Context) error {
  245. select {
  246. case <-c.userCloseWaitCh:
  247. return nil
  248. case <-ctx.Done():
  249. return ctx.Err()
  250. }
  251. }
  252. func (c *Client) run() {
  253. err := c.receiveLoop()
  254. c.Close()
  255. c.cleanupStreams(err)
  256. c.userCloseFunc()
  257. close(c.userCloseWaitCh)
  258. }
  259. func (c *Client) receiveLoop() error {
  260. for {
  261. select {
  262. case <-c.ctx.Done():
  263. return ErrClosed
  264. default:
  265. var (
  266. msg = &streamMessage{}
  267. err error
  268. )
  269. msg.header, msg.payload, err = c.channel.recv()
  270. if err != nil {
  271. _, ok := status.FromError(err)
  272. if !ok {
  273. // treat all errors that are not an rpc status as terminal.
  274. // all others poison the connection.
  275. return filterCloseErr(err)
  276. }
  277. }
  278. sid := streamID(msg.header.StreamID)
  279. s := c.getStream(sid)
  280. if s == nil {
  281. logrus.WithField("stream", sid).Errorf("ttrpc: received message on inactive stream")
  282. continue
  283. }
  284. if err != nil {
  285. s.closeWithError(err)
  286. } else {
  287. if err := s.receive(c.ctx, msg); err != nil {
  288. logrus.WithError(err).WithField("stream", sid).Errorf("ttrpc: failed to handle message")
  289. }
  290. }
  291. }
  292. }
  293. }
  294. // createStream creates a new stream and registers it with the client
  295. // Introduce stream types for multiple or single response
  296. func (c *Client) createStream(flags uint8, b []byte) (*stream, error) {
  297. c.streamLock.Lock()
  298. // Check if closed since lock acquired to prevent adding
  299. // anything after cleanup completes
  300. select {
  301. case <-c.ctx.Done():
  302. c.streamLock.Unlock()
  303. return nil, ErrClosed
  304. default:
  305. }
  306. // Stream ID should be allocated at same time
  307. s := newStream(c.nextStreamID, c)
  308. c.streams[s.id] = s
  309. c.nextStreamID = c.nextStreamID + 2
  310. c.sendLock.Lock()
  311. defer c.sendLock.Unlock()
  312. c.streamLock.Unlock()
  313. if err := c.channel.send(uint32(s.id), messageTypeRequest, flags, b); err != nil {
  314. return s, filterCloseErr(err)
  315. }
  316. return s, nil
  317. }
  318. func (c *Client) deleteStream(s *stream) {
  319. c.streamLock.Lock()
  320. delete(c.streams, s.id)
  321. c.streamLock.Unlock()
  322. s.closeWithError(nil)
  323. }
  324. func (c *Client) getStream(sid streamID) *stream {
  325. c.streamLock.RLock()
  326. s := c.streams[sid]
  327. c.streamLock.RUnlock()
  328. return s
  329. }
  330. func (c *Client) cleanupStreams(err error) {
  331. c.streamLock.Lock()
  332. defer c.streamLock.Unlock()
  333. for sid, s := range c.streams {
  334. s.closeWithError(err)
  335. delete(c.streams, sid)
  336. }
  337. }
  338. // filterCloseErr rewrites EOF and EPIPE errors to ErrClosed. Use when
  339. // returning from call or handling errors from main read loop.
  340. //
  341. // This purposely ignores errors with a wrapped cause.
  342. func filterCloseErr(err error) error {
  343. switch {
  344. case err == nil:
  345. return nil
  346. case err == io.EOF:
  347. return ErrClosed
  348. case errors.Is(err, io.ErrClosedPipe):
  349. return ErrClosed
  350. case errors.Is(err, io.EOF):
  351. return ErrClosed
  352. case strings.Contains(err.Error(), "use of closed network connection"):
  353. return ErrClosed
  354. default:
  355. // if we have an epipe on a write or econnreset on a read , we cast to errclosed
  356. var oerr *net.OpError
  357. if errors.As(err, &oerr) {
  358. if (oerr.Op == "write" && errors.Is(err, syscall.EPIPE)) ||
  359. (oerr.Op == "read" && errors.Is(err, syscall.ECONNRESET)) {
  360. return ErrClosed
  361. }
  362. }
  363. }
  364. return err
  365. }
  366. // NewStream creates a new stream with the given stream descriptor to the
  367. // specified service and method. If not a streaming client, the request object
  368. // may be provided.
  369. func (c *Client) NewStream(ctx context.Context, desc *StreamDesc, service, method string, req interface{}) (ClientStream, error) {
  370. var payload []byte
  371. if req != nil {
  372. var err error
  373. payload, err = c.codec.Marshal(req)
  374. if err != nil {
  375. return nil, err
  376. }
  377. }
  378. request := &Request{
  379. Service: service,
  380. Method: method,
  381. Payload: payload,
  382. // TODO: metadata from context
  383. }
  384. p, err := c.codec.Marshal(request)
  385. if err != nil {
  386. return nil, err
  387. }
  388. var flags uint8
  389. if desc.StreamingClient {
  390. flags = flagRemoteOpen
  391. } else {
  392. flags = flagRemoteClosed
  393. }
  394. s, err := c.createStream(flags, p)
  395. if err != nil {
  396. return nil, err
  397. }
  398. return &clientStream{
  399. ctx: ctx,
  400. s: s,
  401. c: c,
  402. desc: desc,
  403. }, nil
  404. }
  405. func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
  406. p, err := c.codec.Marshal(req)
  407. if err != nil {
  408. return err
  409. }
  410. s, err := c.createStream(0, p)
  411. if err != nil {
  412. return err
  413. }
  414. defer c.deleteStream(s)
  415. var msg *streamMessage
  416. select {
  417. case <-ctx.Done():
  418. return ctx.Err()
  419. case <-c.ctx.Done():
  420. return ErrClosed
  421. case <-s.recvClose:
  422. // If recv has a pending message, process that first
  423. select {
  424. case msg = <-s.recv:
  425. default:
  426. return s.recvErr
  427. }
  428. case msg = <-s.recv:
  429. }
  430. if msg.header.Type == messageTypeResponse {
  431. err = proto.Unmarshal(msg.payload[:msg.header.Length], resp)
  432. } else {
  433. err = fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol)
  434. }
  435. // return the payload buffer for reuse
  436. c.channel.putmbuf(msg.payload)
  437. return err
  438. }