123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512 |
- /*
- Copyright The containerd Authors.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
- package ttrpc
- import (
- "context"
- "errors"
- "fmt"
- "io"
- "net"
- "strings"
- "sync"
- "syscall"
- "time"
- "github.com/sirupsen/logrus"
- "google.golang.org/grpc/codes"
- "google.golang.org/grpc/status"
- "google.golang.org/protobuf/proto"
- )
- // Client for a ttrpc server
- type Client struct {
- codec codec
- conn net.Conn
- channel *channel
- streamLock sync.RWMutex
- streams map[streamID]*stream
- nextStreamID streamID
- sendLock sync.Mutex
- ctx context.Context
- closed func()
- closeOnce sync.Once
- userCloseFunc func()
- userCloseWaitCh chan struct{}
- interceptor UnaryClientInterceptor
- }
- // ClientOpts configures a client
- type ClientOpts func(c *Client)
- // WithOnClose sets the close func whenever the client's Close() method is called
- func WithOnClose(onClose func()) ClientOpts {
- return func(c *Client) {
- c.userCloseFunc = onClose
- }
- }
- // WithUnaryClientInterceptor sets the provided client interceptor
- func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
- return func(c *Client) {
- c.interceptor = i
- }
- }
- // NewClient creates a new ttrpc client using the given connection
- func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
- ctx, cancel := context.WithCancel(context.Background())
- channel := newChannel(conn)
- c := &Client{
- codec: codec{},
- conn: conn,
- channel: channel,
- streams: make(map[streamID]*stream),
- nextStreamID: 1,
- closed: cancel,
- ctx: ctx,
- userCloseFunc: func() {},
- userCloseWaitCh: make(chan struct{}),
- interceptor: defaultClientInterceptor,
- }
- for _, o := range opts {
- o(c)
- }
- go c.run()
- return c
- }
- func (c *Client) send(sid uint32, mt messageType, flags uint8, b []byte) error {
- c.sendLock.Lock()
- defer c.sendLock.Unlock()
- return c.channel.send(sid, mt, flags, b)
- }
- // Call makes a unary request and returns with response
- func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
- payload, err := c.codec.Marshal(req)
- if err != nil {
- return err
- }
- var (
- creq = &Request{
- Service: service,
- Method: method,
- Payload: payload,
- // TODO: metadata from context
- }
- cresp = &Response{}
- )
- if metadata, ok := GetMetadata(ctx); ok {
- metadata.setRequest(creq)
- }
- if dl, ok := ctx.Deadline(); ok {
- creq.TimeoutNano = time.Until(dl).Nanoseconds()
- }
- info := &UnaryClientInfo{
- FullMethod: fullPath(service, method),
- }
- if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil {
- return err
- }
- if err := c.codec.Unmarshal(cresp.Payload, resp); err != nil {
- return err
- }
- if cresp.Status != nil && cresp.Status.Code != int32(codes.OK) {
- return status.ErrorProto(cresp.Status)
- }
- return nil
- }
- // StreamDesc describes the stream properties, whether the stream has
- // a streaming client, a streaming server, or both
- type StreamDesc struct {
- StreamingClient bool
- StreamingServer bool
- }
- // ClientStream is used to send or recv messages on the underlying stream
- type ClientStream interface {
- CloseSend() error
- SendMsg(m interface{}) error
- RecvMsg(m interface{}) error
- }
- type clientStream struct {
- ctx context.Context
- s *stream
- c *Client
- desc *StreamDesc
- localClosed bool
- remoteClosed bool
- }
- func (cs *clientStream) CloseSend() error {
- if !cs.desc.StreamingClient {
- return fmt.Errorf("%w: cannot close non-streaming client", ErrProtocol)
- }
- if cs.localClosed {
- return ErrStreamClosed
- }
- err := cs.s.send(messageTypeData, flagRemoteClosed|flagNoData, nil)
- if err != nil {
- return filterCloseErr(err)
- }
- cs.localClosed = true
- return nil
- }
- func (cs *clientStream) SendMsg(m interface{}) error {
- if !cs.desc.StreamingClient {
- return fmt.Errorf("%w: cannot send data from non-streaming client", ErrProtocol)
- }
- if cs.localClosed {
- return ErrStreamClosed
- }
- var (
- payload []byte
- err error
- )
- if m != nil {
- payload, err = cs.c.codec.Marshal(m)
- if err != nil {
- return err
- }
- }
- err = cs.s.send(messageTypeData, 0, payload)
- if err != nil {
- return filterCloseErr(err)
- }
- return nil
- }
- func (cs *clientStream) RecvMsg(m interface{}) error {
- if cs.remoteClosed {
- return io.EOF
- }
- var msg *streamMessage
- select {
- case <-cs.ctx.Done():
- return cs.ctx.Err()
- case <-cs.s.recvClose:
- // If recv has a pending message, process that first
- select {
- case msg = <-cs.s.recv:
- default:
- return cs.s.recvErr
- }
- case msg = <-cs.s.recv:
- }
- if msg.header.Type == messageTypeResponse {
- resp := &Response{}
- err := proto.Unmarshal(msg.payload[:msg.header.Length], resp)
- // return the payload buffer for reuse
- cs.c.channel.putmbuf(msg.payload)
- if err != nil {
- return err
- }
- if err := cs.c.codec.Unmarshal(resp.Payload, m); err != nil {
- return err
- }
- if resp.Status != nil && resp.Status.Code != int32(codes.OK) {
- return status.ErrorProto(resp.Status)
- }
- cs.c.deleteStream(cs.s)
- cs.remoteClosed = true
- return nil
- } else if msg.header.Type == messageTypeData {
- if !cs.desc.StreamingServer {
- cs.c.deleteStream(cs.s)
- cs.remoteClosed = true
- return fmt.Errorf("received data from non-streaming server: %w", ErrProtocol)
- }
- if msg.header.Flags&flagRemoteClosed == flagRemoteClosed {
- cs.c.deleteStream(cs.s)
- cs.remoteClosed = true
- if msg.header.Flags&flagNoData == flagNoData {
- return io.EOF
- }
- }
- err := cs.c.codec.Unmarshal(msg.payload[:msg.header.Length], m)
- cs.c.channel.putmbuf(msg.payload)
- if err != nil {
- return err
- }
- return nil
- }
- return fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol)
- }
- // Close closes the ttrpc connection and underlying connection
- func (c *Client) Close() error {
- c.closeOnce.Do(func() {
- c.closed()
- c.conn.Close()
- })
- return nil
- }
- // UserOnCloseWait is used to blocks untils the user's on-close callback
- // finishes.
- func (c *Client) UserOnCloseWait(ctx context.Context) error {
- select {
- case <-c.userCloseWaitCh:
- return nil
- case <-ctx.Done():
- return ctx.Err()
- }
- }
- func (c *Client) run() {
- err := c.receiveLoop()
- c.Close()
- c.cleanupStreams(err)
- c.userCloseFunc()
- close(c.userCloseWaitCh)
- }
- func (c *Client) receiveLoop() error {
- for {
- select {
- case <-c.ctx.Done():
- return ErrClosed
- default:
- var (
- msg = &streamMessage{}
- err error
- )
- msg.header, msg.payload, err = c.channel.recv()
- if err != nil {
- _, ok := status.FromError(err)
- if !ok {
- // treat all errors that are not an rpc status as terminal.
- // all others poison the connection.
- return filterCloseErr(err)
- }
- }
- sid := streamID(msg.header.StreamID)
- s := c.getStream(sid)
- if s == nil {
- logrus.WithField("stream", sid).Errorf("ttrpc: received message on inactive stream")
- continue
- }
- if err != nil {
- s.closeWithError(err)
- } else {
- if err := s.receive(c.ctx, msg); err != nil {
- logrus.WithError(err).WithField("stream", sid).Errorf("ttrpc: failed to handle message")
- }
- }
- }
- }
- }
- // createStream creates a new stream and registers it with the client
- // Introduce stream types for multiple or single response
- func (c *Client) createStream(flags uint8, b []byte) (*stream, error) {
- c.streamLock.Lock()
- // Check if closed since lock acquired to prevent adding
- // anything after cleanup completes
- select {
- case <-c.ctx.Done():
- c.streamLock.Unlock()
- return nil, ErrClosed
- default:
- }
- // Stream ID should be allocated at same time
- s := newStream(c.nextStreamID, c)
- c.streams[s.id] = s
- c.nextStreamID = c.nextStreamID + 2
- c.sendLock.Lock()
- defer c.sendLock.Unlock()
- c.streamLock.Unlock()
- if err := c.channel.send(uint32(s.id), messageTypeRequest, flags, b); err != nil {
- return s, filterCloseErr(err)
- }
- return s, nil
- }
- func (c *Client) deleteStream(s *stream) {
- c.streamLock.Lock()
- delete(c.streams, s.id)
- c.streamLock.Unlock()
- s.closeWithError(nil)
- }
- func (c *Client) getStream(sid streamID) *stream {
- c.streamLock.RLock()
- s := c.streams[sid]
- c.streamLock.RUnlock()
- return s
- }
- func (c *Client) cleanupStreams(err error) {
- c.streamLock.Lock()
- defer c.streamLock.Unlock()
- for sid, s := range c.streams {
- s.closeWithError(err)
- delete(c.streams, sid)
- }
- }
- // filterCloseErr rewrites EOF and EPIPE errors to ErrClosed. Use when
- // returning from call or handling errors from main read loop.
- //
- // This purposely ignores errors with a wrapped cause.
- func filterCloseErr(err error) error {
- switch {
- case err == nil:
- return nil
- case err == io.EOF:
- return ErrClosed
- case errors.Is(err, io.ErrClosedPipe):
- return ErrClosed
- case errors.Is(err, io.EOF):
- return ErrClosed
- case strings.Contains(err.Error(), "use of closed network connection"):
- return ErrClosed
- default:
- // if we have an epipe on a write or econnreset on a read , we cast to errclosed
- var oerr *net.OpError
- if errors.As(err, &oerr) {
- if (oerr.Op == "write" && errors.Is(err, syscall.EPIPE)) ||
- (oerr.Op == "read" && errors.Is(err, syscall.ECONNRESET)) {
- return ErrClosed
- }
- }
- }
- return err
- }
- // NewStream creates a new stream with the given stream descriptor to the
- // specified service and method. If not a streaming client, the request object
- // may be provided.
- func (c *Client) NewStream(ctx context.Context, desc *StreamDesc, service, method string, req interface{}) (ClientStream, error) {
- var payload []byte
- if req != nil {
- var err error
- payload, err = c.codec.Marshal(req)
- if err != nil {
- return nil, err
- }
- }
- request := &Request{
- Service: service,
- Method: method,
- Payload: payload,
- // TODO: metadata from context
- }
- p, err := c.codec.Marshal(request)
- if err != nil {
- return nil, err
- }
- var flags uint8
- if desc.StreamingClient {
- flags = flagRemoteOpen
- } else {
- flags = flagRemoteClosed
- }
- s, err := c.createStream(flags, p)
- if err != nil {
- return nil, err
- }
- return &clientStream{
- ctx: ctx,
- s: s,
- c: c,
- desc: desc,
- }, nil
- }
- func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
- p, err := c.codec.Marshal(req)
- if err != nil {
- return err
- }
- s, err := c.createStream(0, p)
- if err != nil {
- return err
- }
- defer c.deleteStream(s)
- var msg *streamMessage
- select {
- case <-ctx.Done():
- return ctx.Err()
- case <-c.ctx.Done():
- return ErrClosed
- case <-s.recvClose:
- // If recv has a pending message, process that first
- select {
- case msg = <-s.recv:
- default:
- return s.recvErr
- }
- case msg = <-s.recv:
- }
- if msg.header.Type == messageTypeResponse {
- err = proto.Unmarshal(msg.payload[:msg.header.Length], resp)
- } else {
- err = fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol)
- }
- // return the payload buffer for reuse
- c.channel.putmbuf(msg.payload)
- return err
- }
|