|
@@ -47,8 +47,9 @@ type Client struct {
|
|
|
ctx context.Context
|
|
|
closed func()
|
|
|
|
|
|
- closeOnce sync.Once
|
|
|
- userCloseFunc func()
|
|
|
+ closeOnce sync.Once
|
|
|
+ userCloseFunc func()
|
|
|
+ userCloseWaitCh chan struct{}
|
|
|
|
|
|
errOnce sync.Once
|
|
|
err error
|
|
@@ -75,14 +76,15 @@ func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
|
|
|
func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
c := &Client{
|
|
|
- codec: codec{},
|
|
|
- conn: conn,
|
|
|
- channel: newChannel(conn),
|
|
|
- calls: make(chan *callRequest),
|
|
|
- closed: cancel,
|
|
|
- ctx: ctx,
|
|
|
- userCloseFunc: func() {},
|
|
|
- interceptor: defaultClientInterceptor,
|
|
|
+ codec: codec{},
|
|
|
+ conn: conn,
|
|
|
+ channel: newChannel(conn),
|
|
|
+ calls: make(chan *callRequest),
|
|
|
+ closed: cancel,
|
|
|
+ ctx: ctx,
|
|
|
+ userCloseFunc: func() {},
|
|
|
+ userCloseWaitCh: make(chan struct{}),
|
|
|
+ interceptor: defaultClientInterceptor,
|
|
|
}
|
|
|
|
|
|
for _, o := range opts {
|
|
@@ -175,6 +177,17 @@ func (c *Client) Close() error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+// UserOnCloseWait is used to blocks untils the user's on-close callback
|
|
|
+// finishes.
|
|
|
+func (c *Client) UserOnCloseWait(ctx context.Context) error {
|
|
|
+ select {
|
|
|
+ case <-c.userCloseWaitCh:
|
|
|
+ return nil
|
|
|
+ case <-ctx.Done():
|
|
|
+ return ctx.Err()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
type message struct {
|
|
|
messageHeader
|
|
|
p []byte
|
|
@@ -251,6 +264,7 @@ func (c *Client) run() {
|
|
|
defer func() {
|
|
|
c.conn.Close()
|
|
|
c.userCloseFunc()
|
|
|
+ close(c.userCloseWaitCh)
|
|
|
}()
|
|
|
|
|
|
for {
|
|
@@ -339,7 +353,8 @@ func filterCloseErr(err error) error {
|
|
|
return ErrClosed
|
|
|
default:
|
|
|
// if we have an epipe on a write or econnreset on a read , we cast to errclosed
|
|
|
- if oerr, ok := err.(*net.OpError); ok && (oerr.Op == "write" || oerr.Op == "read") {
|
|
|
+ var oerr *net.OpError
|
|
|
+ if errors.As(err, &oerr) && (oerr.Op == "write" || oerr.Op == "read") {
|
|
|
serr, sok := oerr.Err.(*os.SyscallError)
|
|
|
if sok && ((serr.Err == syscall.EPIPE && oerr.Op == "write") ||
|
|
|
(serr.Err == syscall.ECONNRESET && oerr.Op == "read")) {
|