Przeglądaj źródła

bump containerd/ttrpc 92c8520ef9f86600c650dd540266a007bf03670f

full diff: https://github.com/containerd/ttrpc/compare/699c4e40d1e7416e08bf7019c7ce2e9beced4636...92c8520ef9f86600c650dd540266a007bf03670f

changes:

- containerd/ttrpc#37 Handle EOF to prevent file descriptor leak
- containerd/ttrpc#38 Improve connection error handling
- containerd/ttrpc#40 Support headers
- containerd/ttrpc#41 Add client and server unary interceptors
- containerd/ttrpc#43 metadata as KeyValue type
- containerd/ttrpc#42 Refactor close handling for ttrpc clients
- containerd/ttrpc#44 Fix method full name generation
- containerd/ttrpc#46 Client.Call(): do not return error if no Status is set (gRPC v1.23 and up)
- containerd/ttrpc#49 Handle ok status

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
(cherry picked from commit 8769255d1bb9c469d4f2966e7e9869a9f126f9e9)
Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
Sebastiaan van Stijn 5 lat temu
rodzic
commit
525e8ed3fe

+ 1 - 1
vendor.conf

@@ -126,7 +126,7 @@ github.com/containerd/cgroups                       4994991857f9b0ae8dc439551e8b
 github.com/containerd/console                       0650fd9eeb50bab4fc99dceb9f2e14cf58f36e7f
 github.com/containerd/go-runc                       7d11b49dc0769f6dbb0d1b19f3d48524d1bad9ad
 github.com/containerd/typeurl                       2a93cfde8c20b23de8eb84a5adbc234ddf7a9e8d
-github.com/containerd/ttrpc                         699c4e40d1e7416e08bf7019c7ce2e9beced4636
+github.com/containerd/ttrpc                         92c8520ef9f86600c650dd540266a007bf03670f
 github.com/gogo/googleapis                          d31c731455cb061f42baff3bda55bad0118b126b # v1.2.0
 
 # cluster

+ 2 - 3
vendor/github.com/containerd/ttrpc/channel.go

@@ -18,7 +18,6 @@ package ttrpc
 
 import (
 	"bufio"
-	"context"
 	"encoding/binary"
 	"io"
 	"net"
@@ -98,7 +97,7 @@ func newChannel(conn net.Conn) *channel {
 // returned will be valid and caller should send that along to
 // the correct consumer. The bytes on the underlying channel
 // will be discarded.
-func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
+func (ch *channel) recv() (messageHeader, []byte, error) {
 	mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
 	if err != nil {
 		return messageHeader{}, nil, err
@@ -120,7 +119,7 @@ func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
 	return mh, p, nil
 }
 
-func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error {
+func (ch *channel) send(streamID uint32, t messageType, p []byte) error {
 	if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
 		return err
 	}

+ 134 - 81
vendor/github.com/containerd/ttrpc/client.go

@@ -29,6 +29,7 @@ import (
 	"github.com/gogo/protobuf/proto"
 	"github.com/pkg/errors"
 	"github.com/sirupsen/logrus"
+	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
 )
 
@@ -36,36 +37,52 @@ import (
 // closed.
 var ErrClosed = errors.New("ttrpc: closed")
 
+// Client for a ttrpc server
 type Client struct {
 	codec   codec
 	conn    net.Conn
 	channel *channel
 	calls   chan *callRequest
 
-	closed    chan struct{}
-	closeOnce sync.Once
-	closeFunc func()
-	done      chan struct{}
-	err       error
+	ctx    context.Context
+	closed func()
+
+	closeOnce     sync.Once
+	userCloseFunc func()
+
+	errOnce     sync.Once
+	err         error
+	interceptor UnaryClientInterceptor
 }
 
+// ClientOpts configures a client
 type ClientOpts func(c *Client)
 
+// WithOnClose sets the close func whenever the client's Close() method is called
 func WithOnClose(onClose func()) ClientOpts {
 	return func(c *Client) {
-		c.closeFunc = onClose
+		c.userCloseFunc = onClose
+	}
+}
+
+// WithUnaryClientInterceptor sets the provided client interceptor
+func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
+	return func(c *Client) {
+		c.interceptor = i
 	}
 }
 
 func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
+	ctx, cancel := context.WithCancel(context.Background())
 	c := &Client{
-		codec:     codec{},
-		conn:      conn,
-		channel:   newChannel(conn),
-		calls:     make(chan *callRequest),
-		closed:    make(chan struct{}),
-		done:      make(chan struct{}),
-		closeFunc: func() {},
+		codec:         codec{},
+		conn:          conn,
+		channel:       newChannel(conn),
+		calls:         make(chan *callRequest),
+		closed:        cancel,
+		ctx:           ctx,
+		userCloseFunc: func() {},
+		interceptor:   defaultClientInterceptor,
 	}
 
 	for _, o := range opts {
@@ -99,11 +116,18 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int
 		cresp = &Response{}
 	)
 
+	if metadata, ok := GetMetadata(ctx); ok {
+		metadata.setRequest(creq)
+	}
+
 	if dl, ok := ctx.Deadline(); ok {
 		creq.TimeoutNano = dl.Sub(time.Now()).Nanoseconds()
 	}
 
-	if err := c.dispatch(ctx, creq, cresp); err != nil {
+	info := &UnaryClientInfo{
+		FullMethod: fullPath(service, method),
+	}
+	if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil {
 		return err
 	}
 
@@ -111,11 +135,10 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int
 		return err
 	}
 
-	if cresp.Status == nil {
-		return errors.New("no status provided on response")
+	if cresp.Status != nil && cresp.Status.Code != int32(codes.OK) {
+		return status.ErrorProto(cresp.Status)
 	}
-
-	return status.ErrorProto(cresp.Status)
+	return nil
 }
 
 func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
@@ -131,8 +154,8 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
 	case <-ctx.Done():
 		return ctx.Err()
 	case c.calls <- call:
-	case <-c.done:
-		return c.err
+	case <-c.ctx.Done():
+		return c.error()
 	}
 
 	select {
@@ -140,16 +163,15 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
 		return ctx.Err()
 	case err := <-errs:
 		return filterCloseErr(err)
-	case <-c.done:
-		return c.err
+	case <-c.ctx.Done():
+		return c.error()
 	}
 }
 
 func (c *Client) Close() error {
 	c.closeOnce.Do(func() {
-		close(c.closed)
+		c.closed()
 	})
-
 	return nil
 }
 
@@ -159,51 +181,82 @@ type message struct {
 	err error
 }
 
-func (c *Client) run() {
-	var (
-		streamID    uint32 = 1
-		waiters            = make(map[uint32]*callRequest)
-		calls              = c.calls
-		incoming           = make(chan *message)
-		shutdown           = make(chan struct{})
-		shutdownErr error
-	)
+type receiver struct {
+	wg       *sync.WaitGroup
+	messages chan *message
+	err      error
+}
 
-	go func() {
-		defer close(shutdown)
+func (r *receiver) run(ctx context.Context, c *channel) {
+	defer r.wg.Done()
 
-		// start one more goroutine to recv messages without blocking.
-		for {
-			mh, p, err := c.channel.recv(context.TODO())
+	for {
+		select {
+		case <-ctx.Done():
+			r.err = ctx.Err()
+			return
+		default:
+			mh, p, err := c.recv()
 			if err != nil {
 				_, ok := status.FromError(err)
 				if !ok {
 					// treat all errors that are not an rpc status as terminal.
 					// all others poison the connection.
-					shutdownErr = err
+					r.err = filterCloseErr(err)
 					return
 				}
 			}
 			select {
-			case incoming <- &message{
+			case r.messages <- &message{
 				messageHeader: mh,
 				p:             p[:mh.Length],
 				err:           err,
 			}:
-			case <-c.done:
+			case <-ctx.Done():
+				r.err = ctx.Err()
 				return
 			}
 		}
+	}
+}
+
+func (c *Client) run() {
+	var (
+		streamID      uint32 = 1
+		waiters              = make(map[uint32]*callRequest)
+		calls                = c.calls
+		incoming             = make(chan *message)
+		receiversDone        = make(chan struct{})
+		wg            sync.WaitGroup
+	)
+
+	// broadcast the shutdown error to the remaining waiters.
+	abortWaiters := func(wErr error) {
+		for _, waiter := range waiters {
+			waiter.errs <- wErr
+		}
+	}
+	recv := &receiver{
+		wg:       &wg,
+		messages: incoming,
+	}
+	wg.Add(1)
+
+	go func() {
+		wg.Wait()
+		close(receiversDone)
 	}()
+	go recv.run(c.ctx, c.channel)
 
-	defer c.conn.Close()
-	defer close(c.done)
-	defer c.closeFunc()
+	defer func() {
+		c.conn.Close()
+		c.userCloseFunc()
+	}()
 
 	for {
 		select {
 		case call := <-calls:
-			if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil {
+			if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
 				call.errs <- err
 				continue
 			}
@@ -219,41 +272,42 @@ func (c *Client) run() {
 
 			call.errs <- c.recv(call.resp, msg)
 			delete(waiters, msg.StreamID)
-		case <-shutdown:
-			if shutdownErr != nil {
-				shutdownErr = filterCloseErr(shutdownErr)
-			} else {
-				shutdownErr = ErrClosed
-			}
-
-			shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down")
-
-			c.err = shutdownErr
-			for _, waiter := range waiters {
-				waiter.errs <- shutdownErr
+		case <-receiversDone:
+			// all the receivers have exited
+			if recv.err != nil {
+				c.setError(recv.err)
 			}
+			// don't return out, let the close of the context trigger the abort of waiters
 			c.Close()
-			return
-		case <-c.closed:
-			if c.err == nil {
-				c.err = ErrClosed
-			}
-			// broadcast the shutdown error to the remaining waiters.
-			for _, waiter := range waiters {
-				waiter.errs <- c.err
-			}
+		case <-c.ctx.Done():
+			abortWaiters(c.error())
 			return
 		}
 	}
 }
 
-func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error {
+func (c *Client) error() error {
+	c.errOnce.Do(func() {
+		if c.err == nil {
+			c.err = ErrClosed
+		}
+	})
+	return c.err
+}
+
+func (c *Client) setError(err error) {
+	c.errOnce.Do(func() {
+		c.err = err
+	})
+}
+
+func (c *Client) send(streamID uint32, mtype messageType, msg interface{}) error {
 	p, err := c.codec.Marshal(msg)
 	if err != nil {
 		return err
 	}
 
-	return c.channel.send(ctx, streamID, mtype, p)
+	return c.channel.send(streamID, mtype, p)
 }
 
 func (c *Client) recv(resp *Response, msg *message) error {
@@ -274,22 +328,21 @@ func (c *Client) recv(resp *Response, msg *message) error {
 //
 // This purposely ignores errors with a wrapped cause.
 func filterCloseErr(err error) error {
-	if err == nil {
+	switch {
+	case err == nil:
 		return nil
-	}
-
-	if err == io.EOF {
+	case err == io.EOF:
 		return ErrClosed
-	}
-
-	if strings.Contains(err.Error(), "use of closed network connection") {
+	case errors.Cause(err) == io.EOF:
 		return ErrClosed
-	}
-
-	// if we have an epipe on a write, we cast to errclosed
-	if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
-		if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
-			return ErrClosed
+	case strings.Contains(err.Error(), "use of closed network connection"):
+		return ErrClosed
+	default:
+		// if we have an epipe on a write, we cast to errclosed
+		if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
+			if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
+				return ErrClosed
+			}
 		}
 	}
 

+ 14 - 1
vendor/github.com/containerd/ttrpc/config.go

@@ -19,9 +19,11 @@ package ttrpc
 import "github.com/pkg/errors"
 
 type serverConfig struct {
-	handshaker Handshaker
+	handshaker  Handshaker
+	interceptor UnaryServerInterceptor
 }
 
+// ServerOpt for configuring a ttrpc server
 type ServerOpt func(*serverConfig) error
 
 // WithServerHandshaker can be passed to NewServer to ensure that the
@@ -37,3 +39,14 @@ func WithServerHandshaker(handshaker Handshaker) ServerOpt {
 		return nil
 	}
 }
+
+// WithUnaryServerInterceptor sets the provided interceptor on the server
+func WithUnaryServerInterceptor(i UnaryServerInterceptor) ServerOpt {
+	return func(c *serverConfig) error {
+		if c.interceptor != nil {
+			return errors.New("only one interceptor allowed per server")
+		}
+		c.interceptor = i
+		return nil
+	}
+}

+ 50 - 0
vendor/github.com/containerd/ttrpc/interceptor.go

@@ -0,0 +1,50 @@
+/*
+   Copyright The containerd Authors.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+*/
+
+package ttrpc
+
+import "context"
+
+// UnaryServerInfo provides information about the server request
+type UnaryServerInfo struct {
+	FullMethod string
+}
+
+// UnaryClientInfo provides information about the client request
+type UnaryClientInfo struct {
+	FullMethod string
+}
+
+// Unmarshaler contains the server request data and allows it to be unmarshaled
+// into a concrete type
+type Unmarshaler func(interface{}) error
+
+// Invoker invokes the client's request and response from the ttrpc server
+type Invoker func(context.Context, *Request, *Response) error
+
+// UnaryServerInterceptor specifies the interceptor function for server request/response
+type UnaryServerInterceptor func(context.Context, Unmarshaler, *UnaryServerInfo, Method) (interface{}, error)
+
+// UnaryClientInterceptor specifies the interceptor function for client request/response
+type UnaryClientInterceptor func(context.Context, *Request, *Response, *UnaryClientInfo, Invoker) error
+
+func defaultServerInterceptor(ctx context.Context, unmarshal Unmarshaler, info *UnaryServerInfo, method Method) (interface{}, error) {
+	return method(ctx, unmarshal)
+}
+
+func defaultClientInterceptor(ctx context.Context, req *Request, resp *Response, _ *UnaryClientInfo, invoker Invoker) error {
+	return invoker(ctx, req, resp)
+}

+ 107 - 0
vendor/github.com/containerd/ttrpc/metadata.go

@@ -0,0 +1,107 @@
+/*
+   Copyright The containerd Authors.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+*/
+
+package ttrpc
+
+import (
+	"context"
+	"strings"
+)
+
+// MD is the user type for ttrpc metadata
+type MD map[string][]string
+
+// Get returns the metadata for a given key when they exist.
+// If there is no metadata, a nil slice and false are returned.
+func (m MD) Get(key string) ([]string, bool) {
+	key = strings.ToLower(key)
+	list, ok := m[key]
+	if !ok || len(list) == 0 {
+		return nil, false
+	}
+
+	return list, true
+}
+
+// Set sets the provided values for a given key.
+// The values will overwrite any existing values.
+// If no values provided, a key will be deleted.
+func (m MD) Set(key string, values ...string) {
+	key = strings.ToLower(key)
+	if len(values) == 0 {
+		delete(m, key)
+		return
+	}
+	m[key] = values
+}
+
+// Append appends additional values to the given key.
+func (m MD) Append(key string, values ...string) {
+	key = strings.ToLower(key)
+	if len(values) == 0 {
+		return
+	}
+	current, ok := m[key]
+	if ok {
+		m.Set(key, append(current, values...)...)
+	} else {
+		m.Set(key, values...)
+	}
+}
+
+func (m MD) setRequest(r *Request) {
+	for k, values := range m {
+		for _, v := range values {
+			r.Metadata = append(r.Metadata, &KeyValue{
+				Key:   k,
+				Value: v,
+			})
+		}
+	}
+}
+
+func (m MD) fromRequest(r *Request) {
+	for _, kv := range r.Metadata {
+		m[kv.Key] = append(m[kv.Key], kv.Value)
+	}
+}
+
+type metadataKey struct{}
+
+// GetMetadata retrieves metadata from context.Context (previously attached with WithMetadata)
+func GetMetadata(ctx context.Context) (MD, bool) {
+	metadata, ok := ctx.Value(metadataKey{}).(MD)
+	return metadata, ok
+}
+
+// GetMetadataValue gets a specific metadata value by name from context.Context
+func GetMetadataValue(ctx context.Context, name string) (string, bool) {
+	metadata, ok := GetMetadata(ctx)
+	if !ok {
+		return "", false
+	}
+
+	if list, ok := metadata.Get(name); ok {
+		return list[0], true
+	}
+
+	return "", false
+}
+
+// WithMetadata attaches metadata map to a context.Context
+func WithMetadata(ctx context.Context, md MD) context.Context {
+	return context.WithValue(ctx, metadataKey{}, md)
+}

+ 18 - 4
vendor/github.com/containerd/ttrpc/server.go

@@ -53,10 +53,13 @@ func NewServer(opts ...ServerOpt) (*Server, error) {
 			return nil, err
 		}
 	}
+	if config.interceptor == nil {
+		config.interceptor = defaultServerInterceptor
+	}
 
 	return &Server{
 		config:      config,
-		services:    newServiceSet(),
+		services:    newServiceSet(config.interceptor),
 		done:        make(chan struct{}),
 		listeners:   make(map[net.Listener]struct{}),
 		connections: make(map[*serverConn]struct{}),
@@ -341,7 +344,7 @@ func (c *serverConn) run(sctx context.Context) {
 			default: // proceed
 			}
 
-			mh, p, err := ch.recv(ctx)
+			mh, p, err := ch.recv()
 			if err != nil {
 				status, ok := status.FromError(err)
 				if !ok {
@@ -438,7 +441,7 @@ func (c *serverConn) run(sctx context.Context) {
 				return
 			}
 
-			if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
+			if err := ch.send(response.id, messageTypeResponse, p); err != nil {
 				logrus.WithError(err).Error("failed sending message on channel")
 				return
 			}
@@ -449,7 +452,12 @@ func (c *serverConn) run(sctx context.Context) {
 			// branch. Basically, it means that we are no longer receiving
 			// requests due to a terminal error.
 			recvErr = nil // connection is now "closing"
-			if err != nil && err != io.EOF {
+			if err == io.EOF || err == io.ErrUnexpectedEOF {
+				// The client went away and we should stop processing
+				// requests, so that the client connection is closed
+				return
+			}
+			if err != nil {
 				logrus.WithError(err).Error("error receiving message")
 			}
 		case <-shutdown:
@@ -461,6 +469,12 @@ func (c *serverConn) run(sctx context.Context) {
 var noopFunc = func() {}
 
 func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) {
+	if len(req.Metadata) > 0 {
+		md := MD{}
+		md.fromRequest(req)
+		ctx = WithMetadata(ctx, md)
+	}
+
 	cancel = noopFunc
 	if req.TimeoutNano == 0 {
 		return ctx, cancel

+ 11 - 5
vendor/github.com/containerd/ttrpc/services.go

@@ -37,12 +37,14 @@ type ServiceDesc struct {
 }
 
 type serviceSet struct {
-	services map[string]ServiceDesc
+	services    map[string]ServiceDesc
+	interceptor UnaryServerInterceptor
 }
 
-func newServiceSet() *serviceSet {
+func newServiceSet(interceptor UnaryServerInterceptor) *serviceSet {
 	return &serviceSet{
-		services: make(map[string]ServiceDesc),
+		services:    make(map[string]ServiceDesc),
+		interceptor: interceptor,
 	}
 }
 
@@ -84,7 +86,11 @@ func (s *serviceSet) dispatch(ctx context.Context, serviceName, methodName strin
 		return nil
 	}
 
-	resp, err := method(ctx, unmarshal)
+	info := &UnaryServerInfo{
+		FullMethod: fullPath(serviceName, methodName),
+	}
+
+	resp, err := s.interceptor(ctx, unmarshal, info, method)
 	if err != nil {
 		return nil, err
 	}
@@ -146,5 +152,5 @@ func convertCode(err error) codes.Code {
 }
 
 func fullPath(service, method string) string {
-	return "/" + path.Join("/", service, method)
+	return "/" + path.Join(service, method)
 }

+ 24 - 4
vendor/github.com/containerd/ttrpc/types.go

@@ -23,10 +23,11 @@ import (
 )
 
 type Request struct {
-	Service     string `protobuf:"bytes,1,opt,name=service,proto3"`
-	Method      string `protobuf:"bytes,2,opt,name=method,proto3"`
-	Payload     []byte `protobuf:"bytes,3,opt,name=payload,proto3"`
-	TimeoutNano int64  `protobuf:"varint,4,opt,name=timeout_nano,proto3"`
+	Service     string      `protobuf:"bytes,1,opt,name=service,proto3"`
+	Method      string      `protobuf:"bytes,2,opt,name=method,proto3"`
+	Payload     []byte      `protobuf:"bytes,3,opt,name=payload,proto3"`
+	TimeoutNano int64       `protobuf:"varint,4,opt,name=timeout_nano,proto3"`
+	Metadata    []*KeyValue `protobuf:"bytes,5,rep,name=metadata,proto3"`
 }
 
 func (r *Request) Reset()         { *r = Request{} }
@@ -41,3 +42,22 @@ type Response struct {
 func (r *Response) Reset()         { *r = Response{} }
 func (r *Response) String() string { return fmt.Sprintf("%+#v", r) }
 func (r *Response) ProtoMessage()  {}
+
+type StringList struct {
+	List []string `protobuf:"bytes,1,rep,name=list,proto3"`
+}
+
+func (r *StringList) Reset()         { *r = StringList{} }
+func (r *StringList) String() string { return fmt.Sprintf("%+#v", r) }
+func (r *StringList) ProtoMessage()  {}
+
+func makeStringList(item ...string) StringList { return StringList{List: item} }
+
+type KeyValue struct {
+	Key   string `protobuf:"bytes,1,opt,name=key,proto3"`
+	Value string `protobuf:"bytes,2,opt,name=value,proto3"`
+}
+
+func (m *KeyValue) Reset()         { *m = KeyValue{} }
+func (*KeyValue) ProtoMessage()    {}
+func (m *KeyValue) String() string { return fmt.Sprintf("%+#v", m) }