|
@@ -18,6 +18,7 @@ package ttrpc
|
|
|
|
|
|
import (
|
|
import (
|
|
"context"
|
|
"context"
|
|
|
|
+ "errors"
|
|
"io"
|
|
"io"
|
|
"net"
|
|
"net"
|
|
"os"
|
|
"os"
|
|
@@ -27,7 +28,6 @@ import (
|
|
"time"
|
|
"time"
|
|
|
|
|
|
"github.com/gogo/protobuf/proto"
|
|
"github.com/gogo/protobuf/proto"
|
|
- "github.com/pkg/errors"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/sirupsen/logrus"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/grpc/status"
|
|
@@ -194,72 +194,131 @@ type message struct {
|
|
err error
|
|
err error
|
|
}
|
|
}
|
|
|
|
|
|
-type receiver struct {
|
|
|
|
- wg *sync.WaitGroup
|
|
|
|
- messages chan *message
|
|
|
|
- err error
|
|
|
|
|
|
+// callMap provides access to a map of active calls, guarded by a mutex.
|
|
|
|
+type callMap struct {
|
|
|
|
+ m sync.Mutex
|
|
|
|
+ activeCalls map[uint32]*callRequest
|
|
|
|
+ closeErr error
|
|
}
|
|
}
|
|
|
|
|
|
-func (r *receiver) run(ctx context.Context, c *channel) {
|
|
|
|
- defer r.wg.Done()
|
|
|
|
|
|
+// newCallMap returns a new callMap with an empty set of active calls.
|
|
|
|
+func newCallMap() *callMap {
|
|
|
|
+ return &callMap{
|
|
|
|
+ activeCalls: make(map[uint32]*callRequest),
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
|
|
- for {
|
|
|
|
- select {
|
|
|
|
- case <-ctx.Done():
|
|
|
|
- r.err = ctx.Err()
|
|
|
|
- return
|
|
|
|
- default:
|
|
|
|
- mh, p, err := c.recv()
|
|
|
|
- if err != nil {
|
|
|
|
- _, ok := status.FromError(err)
|
|
|
|
- if !ok {
|
|
|
|
- // treat all errors that are not an rpc status as terminal.
|
|
|
|
- // all others poison the connection.
|
|
|
|
- r.err = filterCloseErr(err)
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- select {
|
|
|
|
- case r.messages <- &message{
|
|
|
|
- messageHeader: mh,
|
|
|
|
- p: p[:mh.Length],
|
|
|
|
- err: err,
|
|
|
|
- }:
|
|
|
|
- case <-ctx.Done():
|
|
|
|
- r.err = ctx.Err()
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
|
|
+// set adds a call entry to the map with the given streamID key.
|
|
|
|
+func (cm *callMap) set(streamID uint32, cr *callRequest) error {
|
|
|
|
+ cm.m.Lock()
|
|
|
|
+ defer cm.m.Unlock()
|
|
|
|
+ if cm.closeErr != nil {
|
|
|
|
+ return cm.closeErr
|
|
}
|
|
}
|
|
|
|
+ cm.activeCalls[streamID] = cr
|
|
|
|
+ return nil
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// get looks up the call entry for the given streamID key, then removes it
|
|
|
|
+// from the map and returns it.
|
|
|
|
+func (cm *callMap) get(streamID uint32) (cr *callRequest, ok bool, err error) {
|
|
|
|
+ cm.m.Lock()
|
|
|
|
+ defer cm.m.Unlock()
|
|
|
|
+ if cm.closeErr != nil {
|
|
|
|
+ return nil, false, cm.closeErr
|
|
|
|
+ }
|
|
|
|
+ cr, ok = cm.activeCalls[streamID]
|
|
|
|
+ if ok {
|
|
|
|
+ delete(cm.activeCalls, streamID)
|
|
|
|
+ }
|
|
|
|
+ return
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// abort sends the given error to each active call, and clears the map.
|
|
|
|
+// Once abort has been called, any subsequent calls to the callMap will return the error passed to abort.
|
|
|
|
+func (cm *callMap) abort(err error) error {
|
|
|
|
+ cm.m.Lock()
|
|
|
|
+ defer cm.m.Unlock()
|
|
|
|
+ if cm.closeErr != nil {
|
|
|
|
+ return cm.closeErr
|
|
|
|
+ }
|
|
|
|
+ for streamID, call := range cm.activeCalls {
|
|
|
|
+ call.errs <- err
|
|
|
|
+ delete(cm.activeCalls, streamID)
|
|
|
|
+ }
|
|
|
|
+ cm.closeErr = err
|
|
|
|
+ return nil
|
|
}
|
|
}
|
|
|
|
|
|
func (c *Client) run() {
|
|
func (c *Client) run() {
|
|
var (
|
|
var (
|
|
- streamID uint32 = 1
|
|
|
|
- waiters = make(map[uint32]*callRequest)
|
|
|
|
- calls = c.calls
|
|
|
|
- incoming = make(chan *message)
|
|
|
|
- receiversDone = make(chan struct{})
|
|
|
|
- wg sync.WaitGroup
|
|
|
|
|
|
+ waiters = newCallMap()
|
|
|
|
+ receiverDone = make(chan struct{})
|
|
)
|
|
)
|
|
|
|
|
|
- // broadcast the shutdown error to the remaining waiters.
|
|
|
|
- abortWaiters := func(wErr error) {
|
|
|
|
- for _, waiter := range waiters {
|
|
|
|
- waiter.errs <- wErr
|
|
|
|
|
|
+ // Sender goroutine
|
|
|
|
+ // Receives calls from dispatch, adds them to the set of active calls, and sends them
|
|
|
|
+ // to the server.
|
|
|
|
+ go func() {
|
|
|
|
+ var streamID uint32 = 1
|
|
|
|
+ for {
|
|
|
|
+ select {
|
|
|
|
+ case <-c.ctx.Done():
|
|
|
|
+ return
|
|
|
|
+ case call := <-c.calls:
|
|
|
|
+ id := streamID
|
|
|
|
+ streamID += 2 // enforce odd client initiated request ids
|
|
|
|
+ if err := waiters.set(id, call); err != nil {
|
|
|
|
+ call.errs <- err // errs is buffered so should not block.
|
|
|
|
+ continue
|
|
|
|
+ }
|
|
|
|
+ if err := c.send(id, messageTypeRequest, call.req); err != nil {
|
|
|
|
+ call.errs <- err // errs is buffered so should not block.
|
|
|
|
+ waiters.get(id) // remove from waiters set
|
|
|
|
+ }
|
|
|
|
+ }
|
|
}
|
|
}
|
|
- }
|
|
|
|
- recv := &receiver{
|
|
|
|
- wg: &wg,
|
|
|
|
- messages: incoming,
|
|
|
|
- }
|
|
|
|
- wg.Add(1)
|
|
|
|
|
|
+ }()
|
|
|
|
|
|
|
|
+ // Receiver goroutine
|
|
|
|
+ // Receives responses from the server, looks up the call info in the set of active calls,
|
|
|
|
+ // and notifies the caller of the response.
|
|
go func() {
|
|
go func() {
|
|
- wg.Wait()
|
|
|
|
- close(receiversDone)
|
|
|
|
|
|
+ defer close(receiverDone)
|
|
|
|
+ for {
|
|
|
|
+ select {
|
|
|
|
+ case <-c.ctx.Done():
|
|
|
|
+ c.setError(c.ctx.Err())
|
|
|
|
+ return
|
|
|
|
+ default:
|
|
|
|
+ mh, p, err := c.channel.recv()
|
|
|
|
+ if err != nil {
|
|
|
|
+ _, ok := status.FromError(err)
|
|
|
|
+ if !ok {
|
|
|
|
+ // treat all errors that are not an rpc status as terminal.
|
|
|
|
+ // all others poison the connection.
|
|
|
|
+ c.setError(filterCloseErr(err))
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ msg := &message{
|
|
|
|
+ messageHeader: mh,
|
|
|
|
+ p: p[:mh.Length],
|
|
|
|
+ err: err,
|
|
|
|
+ }
|
|
|
|
+ call, ok, err := waiters.get(mh.StreamID)
|
|
|
|
+ if err != nil {
|
|
|
|
+ logrus.Errorf("ttrpc: failed to look up active call: %s", err)
|
|
|
|
+ continue
|
|
|
|
+ }
|
|
|
|
+ if !ok {
|
|
|
|
+ logrus.Errorf("ttrpc: received message for unknown channel %v", mh.StreamID)
|
|
|
|
+ continue
|
|
|
|
+ }
|
|
|
|
+ call.errs <- c.recv(call.resp, msg)
|
|
|
|
+ }
|
|
|
|
+ }
|
|
}()
|
|
}()
|
|
- go recv.run(c.ctx, c.channel)
|
|
|
|
|
|
|
|
defer func() {
|
|
defer func() {
|
|
c.conn.Close()
|
|
c.conn.Close()
|
|
@@ -269,32 +328,14 @@ func (c *Client) run() {
|
|
|
|
|
|
for {
|
|
for {
|
|
select {
|
|
select {
|
|
- case call := <-calls:
|
|
|
|
- if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
|
|
|
|
- call.errs <- err
|
|
|
|
- continue
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- waiters[streamID] = call
|
|
|
|
- streamID += 2 // enforce odd client initiated request ids
|
|
|
|
- case msg := <-incoming:
|
|
|
|
- call, ok := waiters[msg.StreamID]
|
|
|
|
- if !ok {
|
|
|
|
- logrus.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID)
|
|
|
|
- continue
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- call.errs <- c.recv(call.resp, msg)
|
|
|
|
- delete(waiters, msg.StreamID)
|
|
|
|
- case <-receiversDone:
|
|
|
|
- // all the receivers have exited
|
|
|
|
- if recv.err != nil {
|
|
|
|
- c.setError(recv.err)
|
|
|
|
- }
|
|
|
|
|
|
+ case <-receiverDone:
|
|
|
|
+ // The receiver has exited.
|
|
// don't return out, let the close of the context trigger the abort of waiters
|
|
// don't return out, let the close of the context trigger the abort of waiters
|
|
c.Close()
|
|
c.Close()
|
|
case <-c.ctx.Done():
|
|
case <-c.ctx.Done():
|
|
- abortWaiters(c.error())
|
|
|
|
|
|
+ // Abort all active calls. This will also prevent any new calls from being added
|
|
|
|
+ // to waiters.
|
|
|
|
+ waiters.abort(c.error())
|
|
return
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -347,7 +388,7 @@ func filterCloseErr(err error) error {
|
|
return nil
|
|
return nil
|
|
case err == io.EOF:
|
|
case err == io.EOF:
|
|
return ErrClosed
|
|
return ErrClosed
|
|
- case errors.Cause(err) == io.EOF:
|
|
|
|
|
|
+ case errors.Is(err, io.EOF):
|
|
return ErrClosed
|
|
return ErrClosed
|
|
case strings.Contains(err.Error(), "use of closed network connection"):
|
|
case strings.Contains(err.Error(), "use of closed network connection"):
|
|
return ErrClosed
|
|
return ErrClosed
|