123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284 |
- /*
- Copyright The containerd Authors.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
- package ttrpc
- import (
- "context"
- "io"
- "net"
- "os"
- "strings"
- "sync"
- "syscall"
- "github.com/gogo/protobuf/proto"
- "github.com/pkg/errors"
- "github.com/sirupsen/logrus"
- "google.golang.org/grpc/status"
- )
- // ErrClosed is returned by client methods when the underlying connection is
- // closed.
- var ErrClosed = errors.New("ttrpc: closed")
- type Client struct {
- codec codec
- conn net.Conn
- channel *channel
- calls chan *callRequest
- closed chan struct{}
- closeOnce sync.Once
- closeFunc func()
- done chan struct{}
- err error
- }
- func NewClient(conn net.Conn) *Client {
- c := &Client{
- codec: codec{},
- conn: conn,
- channel: newChannel(conn),
- calls: make(chan *callRequest),
- closed: make(chan struct{}),
- done: make(chan struct{}),
- closeFunc: func() {},
- }
- go c.run()
- return c
- }
- type callRequest struct {
- ctx context.Context
- req *Request
- resp *Response // response will be written back here
- errs chan error // error written here on completion
- }
- func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
- payload, err := c.codec.Marshal(req)
- if err != nil {
- return err
- }
- var (
- creq = &Request{
- Service: service,
- Method: method,
- Payload: payload,
- }
- cresp = &Response{}
- )
- if err := c.dispatch(ctx, creq, cresp); err != nil {
- return err
- }
- if err := c.codec.Unmarshal(cresp.Payload, resp); err != nil {
- return err
- }
- if cresp.Status == nil {
- return errors.New("no status provided on response")
- }
- return status.ErrorProto(cresp.Status)
- }
- func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
- errs := make(chan error, 1)
- call := &callRequest{
- req: req,
- resp: resp,
- errs: errs,
- }
- select {
- case <-ctx.Done():
- return ctx.Err()
- case c.calls <- call:
- case <-c.done:
- return c.err
- }
- select {
- case <-ctx.Done():
- return ctx.Err()
- case err := <-errs:
- return filterCloseErr(err)
- case <-c.done:
- return c.err
- }
- }
- func (c *Client) Close() error {
- c.closeOnce.Do(func() {
- close(c.closed)
- })
- return nil
- }
- // OnClose allows a close func to be called when the server is closed
- func (c *Client) OnClose(closer func()) {
- c.closeFunc = closer
- }
- type message struct {
- messageHeader
- p []byte
- err error
- }
- func (c *Client) run() {
- var (
- streamID uint32 = 1
- waiters = make(map[uint32]*callRequest)
- calls = c.calls
- incoming = make(chan *message)
- shutdown = make(chan struct{})
- shutdownErr error
- )
- go func() {
- defer close(shutdown)
- // start one more goroutine to recv messages without blocking.
- for {
- mh, p, err := c.channel.recv(context.TODO())
- if err != nil {
- _, ok := status.FromError(err)
- if !ok {
- // treat all errors that are not an rpc status as terminal.
- // all others poison the connection.
- shutdownErr = err
- return
- }
- }
- select {
- case incoming <- &message{
- messageHeader: mh,
- p: p[:mh.Length],
- err: err,
- }:
- case <-c.done:
- return
- }
- }
- }()
- defer c.conn.Close()
- defer close(c.done)
- defer c.closeFunc()
- for {
- select {
- case call := <-calls:
- if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil {
- call.errs <- err
- continue
- }
- waiters[streamID] = call
- streamID += 2 // enforce odd client initiated request ids
- case msg := <-incoming:
- call, ok := waiters[msg.StreamID]
- if !ok {
- logrus.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID)
- continue
- }
- call.errs <- c.recv(call.resp, msg)
- delete(waiters, msg.StreamID)
- case <-shutdown:
- if shutdownErr != nil {
- shutdownErr = filterCloseErr(shutdownErr)
- } else {
- shutdownErr = ErrClosed
- }
- shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down")
- c.err = shutdownErr
- for _, waiter := range waiters {
- waiter.errs <- shutdownErr
- }
- c.Close()
- return
- case <-c.closed:
- if c.err == nil {
- c.err = ErrClosed
- }
- // broadcast the shutdown error to the remaining waiters.
- for _, waiter := range waiters {
- waiter.errs <- c.err
- }
- return
- }
- }
- }
- func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error {
- p, err := c.codec.Marshal(msg)
- if err != nil {
- return err
- }
- return c.channel.send(ctx, streamID, mtype, p)
- }
- func (c *Client) recv(resp *Response, msg *message) error {
- if msg.err != nil {
- return msg.err
- }
- if msg.Type != messageTypeResponse {
- return errors.New("unkown message type received")
- }
- defer c.channel.putmbuf(msg.p)
- return proto.Unmarshal(msg.p, resp)
- }
- // filterCloseErr rewrites EOF and EPIPE errors to ErrClosed. Use when
- // returning from call or handling errors from main read loop.
- //
- // This purposely ignores errors with a wrapped cause.
- func filterCloseErr(err error) error {
- if err == nil {
- return nil
- }
- if err == io.EOF {
- return ErrClosed
- }
- if strings.Contains(err.Error(), "use of closed network connection") {
- return ErrClosed
- }
- // if we have an epipe on a write, we cast to errclosed
- if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
- if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
- return ErrClosed
- }
- }
- return err
- }
|