|
@@ -29,6 +29,7 @@ import (
|
|
|
"github.com/gogo/protobuf/proto"
|
|
|
"github.com/pkg/errors"
|
|
|
"github.com/sirupsen/logrus"
|
|
|
+ "google.golang.org/grpc/codes"
|
|
|
"google.golang.org/grpc/status"
|
|
|
)
|
|
|
|
|
@@ -36,36 +37,52 @@ import (
|
|
|
// closed.
|
|
|
var ErrClosed = errors.New("ttrpc: closed")
|
|
|
|
|
|
+// Client for a ttrpc server
|
|
|
type Client struct {
|
|
|
codec codec
|
|
|
conn net.Conn
|
|
|
channel *channel
|
|
|
calls chan *callRequest
|
|
|
|
|
|
- closed chan struct{}
|
|
|
- closeOnce sync.Once
|
|
|
- closeFunc func()
|
|
|
- done chan struct{}
|
|
|
- err error
|
|
|
+ ctx context.Context
|
|
|
+ closed func()
|
|
|
+
|
|
|
+ closeOnce sync.Once
|
|
|
+ userCloseFunc func()
|
|
|
+
|
|
|
+ errOnce sync.Once
|
|
|
+ err error
|
|
|
+ interceptor UnaryClientInterceptor
|
|
|
}
|
|
|
|
|
|
+// ClientOpts configures a client
|
|
|
type ClientOpts func(c *Client)
|
|
|
|
|
|
+// WithOnClose sets the close func whenever the client's Close() method is called
|
|
|
func WithOnClose(onClose func()) ClientOpts {
|
|
|
return func(c *Client) {
|
|
|
- c.closeFunc = onClose
|
|
|
+ c.userCloseFunc = onClose
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// WithUnaryClientInterceptor sets the provided client interceptor
|
|
|
+func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
|
|
|
+ return func(c *Client) {
|
|
|
+ c.interceptor = i
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
|
|
|
+ ctx, cancel := context.WithCancel(context.Background())
|
|
|
c := &Client{
|
|
|
- codec: codec{},
|
|
|
- conn: conn,
|
|
|
- channel: newChannel(conn),
|
|
|
- calls: make(chan *callRequest),
|
|
|
- closed: make(chan struct{}),
|
|
|
- done: make(chan struct{}),
|
|
|
- closeFunc: func() {},
|
|
|
+ codec: codec{},
|
|
|
+ conn: conn,
|
|
|
+ channel: newChannel(conn),
|
|
|
+ calls: make(chan *callRequest),
|
|
|
+ closed: cancel,
|
|
|
+ ctx: ctx,
|
|
|
+ userCloseFunc: func() {},
|
|
|
+ interceptor: defaultClientInterceptor,
|
|
|
}
|
|
|
|
|
|
for _, o := range opts {
|
|
@@ -99,11 +116,18 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int
|
|
|
cresp = &Response{}
|
|
|
)
|
|
|
|
|
|
+ if metadata, ok := GetMetadata(ctx); ok {
|
|
|
+ metadata.setRequest(creq)
|
|
|
+ }
|
|
|
+
|
|
|
if dl, ok := ctx.Deadline(); ok {
|
|
|
creq.TimeoutNano = dl.Sub(time.Now()).Nanoseconds()
|
|
|
}
|
|
|
|
|
|
- if err := c.dispatch(ctx, creq, cresp); err != nil {
|
|
|
+ info := &UnaryClientInfo{
|
|
|
+ FullMethod: fullPath(service, method),
|
|
|
+ }
|
|
|
+ if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
@@ -111,11 +135,10 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
- if cresp.Status == nil {
|
|
|
- return errors.New("no status provided on response")
|
|
|
+ if cresp.Status != nil && cresp.Status.Code != int32(codes.OK) {
|
|
|
+ return status.ErrorProto(cresp.Status)
|
|
|
}
|
|
|
-
|
|
|
- return status.ErrorProto(cresp.Status)
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
|
|
@@ -131,8 +154,8 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
|
|
|
case <-ctx.Done():
|
|
|
return ctx.Err()
|
|
|
case c.calls <- call:
|
|
|
- case <-c.done:
|
|
|
- return c.err
|
|
|
+ case <-c.ctx.Done():
|
|
|
+ return c.error()
|
|
|
}
|
|
|
|
|
|
select {
|
|
@@ -140,16 +163,15 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
|
|
|
return ctx.Err()
|
|
|
case err := <-errs:
|
|
|
return filterCloseErr(err)
|
|
|
- case <-c.done:
|
|
|
- return c.err
|
|
|
+ case <-c.ctx.Done():
|
|
|
+ return c.error()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func (c *Client) Close() error {
|
|
|
c.closeOnce.Do(func() {
|
|
|
- close(c.closed)
|
|
|
+ c.closed()
|
|
|
})
|
|
|
-
|
|
|
return nil
|
|
|
}
|
|
|
|
|
@@ -159,51 +181,82 @@ type message struct {
|
|
|
err error
|
|
|
}
|
|
|
|
|
|
-func (c *Client) run() {
|
|
|
- var (
|
|
|
- streamID uint32 = 1
|
|
|
- waiters = make(map[uint32]*callRequest)
|
|
|
- calls = c.calls
|
|
|
- incoming = make(chan *message)
|
|
|
- shutdown = make(chan struct{})
|
|
|
- shutdownErr error
|
|
|
- )
|
|
|
+type receiver struct {
|
|
|
+ wg *sync.WaitGroup
|
|
|
+ messages chan *message
|
|
|
+ err error
|
|
|
+}
|
|
|
|
|
|
- go func() {
|
|
|
- defer close(shutdown)
|
|
|
+func (r *receiver) run(ctx context.Context, c *channel) {
|
|
|
+ defer r.wg.Done()
|
|
|
|
|
|
- // start one more goroutine to recv messages without blocking.
|
|
|
- for {
|
|
|
- mh, p, err := c.channel.recv(context.TODO())
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ r.err = ctx.Err()
|
|
|
+ return
|
|
|
+ default:
|
|
|
+ mh, p, err := c.recv()
|
|
|
if err != nil {
|
|
|
_, ok := status.FromError(err)
|
|
|
if !ok {
|
|
|
// treat all errors that are not an rpc status as terminal.
|
|
|
// all others poison the connection.
|
|
|
- shutdownErr = err
|
|
|
+ r.err = filterCloseErr(err)
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
select {
|
|
|
- case incoming <- &message{
|
|
|
+ case r.messages <- &message{
|
|
|
messageHeader: mh,
|
|
|
p: p[:mh.Length],
|
|
|
err: err,
|
|
|
}:
|
|
|
- case <-c.done:
|
|
|
+ case <-ctx.Done():
|
|
|
+ r.err = ctx.Err()
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (c *Client) run() {
|
|
|
+ var (
|
|
|
+ streamID uint32 = 1
|
|
|
+ waiters = make(map[uint32]*callRequest)
|
|
|
+ calls = c.calls
|
|
|
+ incoming = make(chan *message)
|
|
|
+ receiversDone = make(chan struct{})
|
|
|
+ wg sync.WaitGroup
|
|
|
+ )
|
|
|
+
|
|
|
+ // broadcast the shutdown error to the remaining waiters.
|
|
|
+ abortWaiters := func(wErr error) {
|
|
|
+ for _, waiter := range waiters {
|
|
|
+ waiter.errs <- wErr
|
|
|
+ }
|
|
|
+ }
|
|
|
+ recv := &receiver{
|
|
|
+ wg: &wg,
|
|
|
+ messages: incoming,
|
|
|
+ }
|
|
|
+ wg.Add(1)
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ wg.Wait()
|
|
|
+ close(receiversDone)
|
|
|
}()
|
|
|
+ go recv.run(c.ctx, c.channel)
|
|
|
|
|
|
- defer c.conn.Close()
|
|
|
- defer close(c.done)
|
|
|
- defer c.closeFunc()
|
|
|
+ defer func() {
|
|
|
+ c.conn.Close()
|
|
|
+ c.userCloseFunc()
|
|
|
+ }()
|
|
|
|
|
|
for {
|
|
|
select {
|
|
|
case call := <-calls:
|
|
|
- if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil {
|
|
|
+ if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
|
|
|
call.errs <- err
|
|
|
continue
|
|
|
}
|
|
@@ -219,41 +272,42 @@ func (c *Client) run() {
|
|
|
|
|
|
call.errs <- c.recv(call.resp, msg)
|
|
|
delete(waiters, msg.StreamID)
|
|
|
- case <-shutdown:
|
|
|
- if shutdownErr != nil {
|
|
|
- shutdownErr = filterCloseErr(shutdownErr)
|
|
|
- } else {
|
|
|
- shutdownErr = ErrClosed
|
|
|
- }
|
|
|
-
|
|
|
- shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down")
|
|
|
-
|
|
|
- c.err = shutdownErr
|
|
|
- for _, waiter := range waiters {
|
|
|
- waiter.errs <- shutdownErr
|
|
|
+ case <-receiversDone:
|
|
|
+ // all the receivers have exited
|
|
|
+ if recv.err != nil {
|
|
|
+ c.setError(recv.err)
|
|
|
}
|
|
|
+ // don't return out, let the close of the context trigger the abort of waiters
|
|
|
c.Close()
|
|
|
- return
|
|
|
- case <-c.closed:
|
|
|
- if c.err == nil {
|
|
|
- c.err = ErrClosed
|
|
|
- }
|
|
|
- // broadcast the shutdown error to the remaining waiters.
|
|
|
- for _, waiter := range waiters {
|
|
|
- waiter.errs <- c.err
|
|
|
- }
|
|
|
+ case <-c.ctx.Done():
|
|
|
+ abortWaiters(c.error())
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error {
|
|
|
+func (c *Client) error() error {
|
|
|
+ c.errOnce.Do(func() {
|
|
|
+ if c.err == nil {
|
|
|
+ c.err = ErrClosed
|
|
|
+ }
|
|
|
+ })
|
|
|
+ return c.err
|
|
|
+}
|
|
|
+
|
|
|
+func (c *Client) setError(err error) {
|
|
|
+ c.errOnce.Do(func() {
|
|
|
+ c.err = err
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func (c *Client) send(streamID uint32, mtype messageType, msg interface{}) error {
|
|
|
p, err := c.codec.Marshal(msg)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
- return c.channel.send(ctx, streamID, mtype, p)
|
|
|
+ return c.channel.send(streamID, mtype, p)
|
|
|
}
|
|
|
|
|
|
func (c *Client) recv(resp *Response, msg *message) error {
|
|
@@ -274,22 +328,21 @@ func (c *Client) recv(resp *Response, msg *message) error {
|
|
|
//
|
|
|
// This purposely ignores errors with a wrapped cause.
|
|
|
func filterCloseErr(err error) error {
|
|
|
- if err == nil {
|
|
|
+ switch {
|
|
|
+ case err == nil:
|
|
|
return nil
|
|
|
- }
|
|
|
-
|
|
|
- if err == io.EOF {
|
|
|
+ case err == io.EOF:
|
|
|
return ErrClosed
|
|
|
- }
|
|
|
-
|
|
|
- if strings.Contains(err.Error(), "use of closed network connection") {
|
|
|
+ case errors.Cause(err) == io.EOF:
|
|
|
return ErrClosed
|
|
|
- }
|
|
|
-
|
|
|
- // if we have an epipe on a write, we cast to errclosed
|
|
|
- if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
|
|
|
- if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
|
|
|
- return ErrClosed
|
|
|
+ case strings.Contains(err.Error(), "use of closed network connection"):
|
|
|
+ return ErrClosed
|
|
|
+ default:
|
|
|
+ // if we have an epipe on a write, we cast to errclosed
|
|
|
+ if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
|
|
|
+ if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
|
|
|
+ return ErrClosed
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|