client.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. /*
  2. Copyright The containerd Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package ttrpc
  14. import (
  15. "context"
  16. "io"
  17. "net"
  18. "os"
  19. "strings"
  20. "sync"
  21. "syscall"
  22. "github.com/gogo/protobuf/proto"
  23. "github.com/pkg/errors"
  24. "github.com/sirupsen/logrus"
  25. "google.golang.org/grpc/status"
  26. )
  27. // ErrClosed is returned by client methods when the underlying connection is
  28. // closed.
  29. var ErrClosed = errors.New("ttrpc: closed")
  30. type Client struct {
  31. codec codec
  32. conn net.Conn
  33. channel *channel
  34. calls chan *callRequest
  35. closed chan struct{}
  36. closeOnce sync.Once
  37. closeFunc func()
  38. done chan struct{}
  39. err error
  40. }
  41. func NewClient(conn net.Conn) *Client {
  42. c := &Client{
  43. codec: codec{},
  44. conn: conn,
  45. channel: newChannel(conn),
  46. calls: make(chan *callRequest),
  47. closed: make(chan struct{}),
  48. done: make(chan struct{}),
  49. closeFunc: func() {},
  50. }
  51. go c.run()
  52. return c
  53. }
  54. type callRequest struct {
  55. ctx context.Context
  56. req *Request
  57. resp *Response // response will be written back here
  58. errs chan error // error written here on completion
  59. }
  60. func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
  61. payload, err := c.codec.Marshal(req)
  62. if err != nil {
  63. return err
  64. }
  65. var (
  66. creq = &Request{
  67. Service: service,
  68. Method: method,
  69. Payload: payload,
  70. }
  71. cresp = &Response{}
  72. )
  73. if err := c.dispatch(ctx, creq, cresp); err != nil {
  74. return err
  75. }
  76. if err := c.codec.Unmarshal(cresp.Payload, resp); err != nil {
  77. return err
  78. }
  79. if cresp.Status == nil {
  80. return errors.New("no status provided on response")
  81. }
  82. return status.ErrorProto(cresp.Status)
  83. }
  84. func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
  85. errs := make(chan error, 1)
  86. call := &callRequest{
  87. req: req,
  88. resp: resp,
  89. errs: errs,
  90. }
  91. select {
  92. case <-ctx.Done():
  93. return ctx.Err()
  94. case c.calls <- call:
  95. case <-c.done:
  96. return c.err
  97. }
  98. select {
  99. case <-ctx.Done():
  100. return ctx.Err()
  101. case err := <-errs:
  102. return filterCloseErr(err)
  103. case <-c.done:
  104. return c.err
  105. }
  106. }
  107. func (c *Client) Close() error {
  108. c.closeOnce.Do(func() {
  109. close(c.closed)
  110. })
  111. return nil
  112. }
  113. // OnClose allows a close func to be called when the server is closed
  114. func (c *Client) OnClose(closer func()) {
  115. c.closeFunc = closer
  116. }
  117. type message struct {
  118. messageHeader
  119. p []byte
  120. err error
  121. }
  122. func (c *Client) run() {
  123. var (
  124. streamID uint32 = 1
  125. waiters = make(map[uint32]*callRequest)
  126. calls = c.calls
  127. incoming = make(chan *message)
  128. shutdown = make(chan struct{})
  129. shutdownErr error
  130. )
  131. go func() {
  132. defer close(shutdown)
  133. // start one more goroutine to recv messages without blocking.
  134. for {
  135. mh, p, err := c.channel.recv(context.TODO())
  136. if err != nil {
  137. _, ok := status.FromError(err)
  138. if !ok {
  139. // treat all errors that are not an rpc status as terminal.
  140. // all others poison the connection.
  141. shutdownErr = err
  142. return
  143. }
  144. }
  145. select {
  146. case incoming <- &message{
  147. messageHeader: mh,
  148. p: p[:mh.Length],
  149. err: err,
  150. }:
  151. case <-c.done:
  152. return
  153. }
  154. }
  155. }()
  156. defer c.conn.Close()
  157. defer close(c.done)
  158. defer c.closeFunc()
  159. for {
  160. select {
  161. case call := <-calls:
  162. if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil {
  163. call.errs <- err
  164. continue
  165. }
  166. waiters[streamID] = call
  167. streamID += 2 // enforce odd client initiated request ids
  168. case msg := <-incoming:
  169. call, ok := waiters[msg.StreamID]
  170. if !ok {
  171. logrus.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID)
  172. continue
  173. }
  174. call.errs <- c.recv(call.resp, msg)
  175. delete(waiters, msg.StreamID)
  176. case <-shutdown:
  177. if shutdownErr != nil {
  178. shutdownErr = filterCloseErr(shutdownErr)
  179. } else {
  180. shutdownErr = ErrClosed
  181. }
  182. shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down")
  183. c.err = shutdownErr
  184. for _, waiter := range waiters {
  185. waiter.errs <- shutdownErr
  186. }
  187. c.Close()
  188. return
  189. case <-c.closed:
  190. if c.err == nil {
  191. c.err = ErrClosed
  192. }
  193. // broadcast the shutdown error to the remaining waiters.
  194. for _, waiter := range waiters {
  195. waiter.errs <- c.err
  196. }
  197. return
  198. }
  199. }
  200. }
  201. func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error {
  202. p, err := c.codec.Marshal(msg)
  203. if err != nil {
  204. return err
  205. }
  206. return c.channel.send(ctx, streamID, mtype, p)
  207. }
  208. func (c *Client) recv(resp *Response, msg *message) error {
  209. if msg.err != nil {
  210. return msg.err
  211. }
  212. if msg.Type != messageTypeResponse {
  213. return errors.New("unkown message type received")
  214. }
  215. defer c.channel.putmbuf(msg.p)
  216. return proto.Unmarshal(msg.p, resp)
  217. }
  218. // filterCloseErr rewrites EOF and EPIPE errors to ErrClosed. Use when
  219. // returning from call or handling errors from main read loop.
  220. //
  221. // This purposely ignores errors with a wrapped cause.
  222. func filterCloseErr(err error) error {
  223. if err == nil {
  224. return nil
  225. }
  226. if err == io.EOF {
  227. return ErrClosed
  228. }
  229. if strings.Contains(err.Error(), "use of closed network connection") {
  230. return ErrClosed
  231. }
  232. // if we have an epipe on a write, we cast to errclosed
  233. if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
  234. if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
  235. return ErrClosed
  236. }
  237. }
  238. return err
  239. }