|
- /*
- Copyright The containerd Authors.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
- package ttrpc
- import (
- "context"
- "errors"
- "io"
- "math/rand"
- "net"
- "sync"
- "sync/atomic"
- "syscall"
- "time"
- "github.com/sirupsen/logrus"
- "google.golang.org/grpc/codes"
- "google.golang.org/grpc/status"
- )
- type Server struct {
- config *serverConfig
- services *serviceSet
- codec codec
- mu sync.Mutex
- listeners map[net.Listener]struct{}
- connections map[*serverConn]struct{} // all connections to current state
- done chan struct{} // marks point at which we stop serving requests
- }
- func NewServer(opts ...ServerOpt) (*Server, error) {
- config := &serverConfig{}
- for _, opt := range opts {
- if err := opt(config); err != nil {
- return nil, err
- }
- }
- if config.interceptor == nil {
- config.interceptor = defaultServerInterceptor
- }
- return &Server{
- config: config,
- services: newServiceSet(config.interceptor),
- done: make(chan struct{}),
- listeners: make(map[net.Listener]struct{}),
- connections: make(map[*serverConn]struct{}),
- }, nil
- }
- // Register registers a map of methods to method handlers
- // TODO: Remove in 2.0, does not support streams
- func (s *Server) Register(name string, methods map[string]Method) {
- s.services.register(name, &ServiceDesc{Methods: methods})
- }
- func (s *Server) RegisterService(name string, desc *ServiceDesc) {
- s.services.register(name, desc)
- }
- func (s *Server) Serve(ctx context.Context, l net.Listener) error {
- s.addListener(l)
- defer s.closeListener(l)
- var (
- backoff time.Duration
- handshaker = s.config.handshaker
- )
- if handshaker == nil {
- handshaker = handshakerFunc(noopHandshake)
- }
- for {
- conn, err := l.Accept()
- if err != nil {
- select {
- case <-s.done:
- return ErrServerClosed
- default:
- }
- if terr, ok := err.(interface {
- Temporary() bool
- }); ok && terr.Temporary() {
- if backoff == 0 {
- backoff = time.Millisecond
- } else {
- backoff *= 2
- }
- if max := time.Second; backoff > max {
- backoff = max
- }
- sleep := time.Duration(rand.Int63n(int64(backoff)))
- logrus.WithError(err).Errorf("ttrpc: failed accept; backoff %v", sleep)
- time.Sleep(sleep)
- continue
- }
- return err
- }
- backoff = 0
- approved, handshake, err := handshaker.Handshake(ctx, conn)
- if err != nil {
- logrus.WithError(err).Error("ttrpc: refusing connection after handshake")
- conn.Close()
- continue
- }
- sc, err := s.newConn(approved, handshake)
- if err != nil {
- logrus.WithError(err).Error("ttrpc: create connection failed")
- conn.Close()
- continue
- }
- go sc.run(ctx)
- }
- }
- func (s *Server) Shutdown(ctx context.Context) error {
- s.mu.Lock()
- select {
- case <-s.done:
- default:
- // protected by mutex
- close(s.done)
- }
- lnerr := s.closeListeners()
- s.mu.Unlock()
- ticker := time.NewTicker(200 * time.Millisecond)
- defer ticker.Stop()
- for {
- s.closeIdleConns()
- if s.countConnection() == 0 {
- break
- }
- select {
- case <-ctx.Done():
- return ctx.Err()
- case <-ticker.C:
- }
- }
- return lnerr
- }
- // Close the server without waiting for active connections.
- func (s *Server) Close() error {
- s.mu.Lock()
- defer s.mu.Unlock()
- select {
- case <-s.done:
- default:
- // protected by mutex
- close(s.done)
- }
- err := s.closeListeners()
- for c := range s.connections {
- c.close()
- delete(s.connections, c)
- }
- return err
- }
- func (s *Server) addListener(l net.Listener) {
- s.mu.Lock()
- defer s.mu.Unlock()
- s.listeners[l] = struct{}{}
- }
- func (s *Server) closeListener(l net.Listener) error {
- s.mu.Lock()
- defer s.mu.Unlock()
- return s.closeListenerLocked(l)
- }
- func (s *Server) closeListenerLocked(l net.Listener) error {
- defer delete(s.listeners, l)
- return l.Close()
- }
- func (s *Server) closeListeners() error {
- var err error
- for l := range s.listeners {
- if cerr := s.closeListenerLocked(l); cerr != nil && err == nil {
- err = cerr
- }
- }
- return err
- }
- func (s *Server) addConnection(c *serverConn) error {
- s.mu.Lock()
- defer s.mu.Unlock()
- select {
- case <-s.done:
- return ErrServerClosed
- default:
- }
- s.connections[c] = struct{}{}
- return nil
- }
- func (s *Server) delConnection(c *serverConn) {
- s.mu.Lock()
- defer s.mu.Unlock()
- delete(s.connections, c)
- }
- func (s *Server) countConnection() int {
- s.mu.Lock()
- defer s.mu.Unlock()
- return len(s.connections)
- }
- func (s *Server) closeIdleConns() {
- s.mu.Lock()
- defer s.mu.Unlock()
- for c := range s.connections {
- if st, ok := c.getState(); !ok || st == connStateActive {
- continue
- }
- c.close()
- delete(s.connections, c)
- }
- }
- type connState int
- const (
- connStateActive = iota + 1 // outstanding requests
- connStateIdle // no requests
- connStateClosed // closed connection
- )
- func (cs connState) String() string {
- switch cs {
- case connStateActive:
- return "active"
- case connStateIdle:
- return "idle"
- case connStateClosed:
- return "closed"
- default:
- return "unknown"
- }
- }
- func (s *Server) newConn(conn net.Conn, handshake interface{}) (*serverConn, error) {
- c := &serverConn{
- server: s,
- conn: conn,
- handshake: handshake,
- shutdown: make(chan struct{}),
- }
- c.setState(connStateIdle)
- if err := s.addConnection(c); err != nil {
- c.close()
- return nil, err
- }
- return c, nil
- }
- type serverConn struct {
- server *Server
- conn net.Conn
- handshake interface{} // data from handshake, not used for now
- state atomic.Value
- shutdownOnce sync.Once
- shutdown chan struct{} // forced shutdown, used by close
- }
- func (c *serverConn) getState() (connState, bool) {
- cs, ok := c.state.Load().(connState)
- return cs, ok
- }
- func (c *serverConn) setState(newstate connState) {
- c.state.Store(newstate)
- }
- func (c *serverConn) close() error {
- c.shutdownOnce.Do(func() {
- close(c.shutdown)
- })
- return nil
- }
- func (c *serverConn) run(sctx context.Context) {
- type (
- response struct {
- id uint32
- status *status.Status
- data []byte
- closeStream bool
- streaming bool
- }
- )
- var (
- ch = newChannel(c.conn)
- ctx, cancel = context.WithCancel(sctx)
- state connState = connStateIdle
- responses = make(chan response)
- recvErr = make(chan error, 1)
- done = make(chan struct{})
- streams = sync.Map{}
- active int32
- lastStreamID uint32
- )
- defer c.conn.Close()
- defer cancel()
- defer close(done)
- defer c.server.delConnection(c)
- sendStatus := func(id uint32, st *status.Status) bool {
- select {
- case responses <- response{
- // even though we've had an invalid stream id, we send it
- // back on the same stream id so the client knows which
- // stream id was bad.
- id: id,
- status: st,
- closeStream: true,
- }:
- return true
- case <-c.shutdown:
- return false
- case <-done:
- return false
- }
- }
- go func(recvErr chan error) {
- defer close(recvErr)
- for {
- select {
- case <-c.shutdown:
- return
- case <-done:
- return
- default: // proceed
- }
- mh, p, err := ch.recv()
- if err != nil {
- status, ok := status.FromError(err)
- if !ok {
- recvErr <- err
- return
- }
- // in this case, we send an error for that particular message
- // when the status is defined.
- if !sendStatus(mh.StreamID, status) {
- return
- }
- continue
- }
- if mh.StreamID%2 != 1 {
- // enforce odd client initiated identifiers.
- if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) {
- return
- }
- continue
- }
- if mh.Type == messageTypeData {
- i, ok := streams.Load(mh.StreamID)
- if !ok {
- if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID is no longer active")) {
- return
- }
- }
- sh := i.(*streamHandler)
- if mh.Flags&flagNoData != flagNoData {
- unmarshal := func(obj interface{}) error {
- err := protoUnmarshal(p, obj)
- ch.putmbuf(p)
- return err
- }
- if err := sh.data(unmarshal); err != nil {
- if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "data handling error: %v", err)) {
- return
- }
- }
- }
- if mh.Flags&flagRemoteClosed == flagRemoteClosed {
- sh.closeSend()
- if len(p) > 0 {
- if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "data close message cannot include data")) {
- return
- }
- }
- }
- } else if mh.Type == messageTypeRequest {
- if mh.StreamID <= lastStreamID {
- // enforce odd client initiated identifiers.
- if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID cannot be re-used and must increment")) {
- return
- }
- continue
- }
- lastStreamID = mh.StreamID
- // TODO: Make request type configurable
- // Unmarshaller which takes in a byte array and returns an interface?
- var req Request
- if err := c.server.codec.Unmarshal(p, &req); err != nil {
- ch.putmbuf(p)
- if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) {
- return
- }
- continue
- }
- ch.putmbuf(p)
- id := mh.StreamID
- respond := func(status *status.Status, data []byte, streaming, closeStream bool) error {
- select {
- case responses <- response{
- id: id,
- status: status,
- data: data,
- closeStream: closeStream,
- streaming: streaming,
- }:
- case <-done:
- return ErrClosed
- }
- return nil
- }
- sh, err := c.server.services.handle(ctx, &req, respond)
- if err != nil {
- status, _ := status.FromError(err)
- if !sendStatus(mh.StreamID, status) {
- return
- }
- continue
- }
- streams.Store(id, sh)
- atomic.AddInt32(&active, 1)
- }
- // TODO: else we must ignore this for future compat. log this?
- }
- }(recvErr)
- for {
- var (
- newstate connState
- shutdown chan struct{}
- )
- activeN := atomic.LoadInt32(&active)
- if activeN > 0 {
- newstate = connStateActive
- shutdown = nil
- } else {
- newstate = connStateIdle
- shutdown = c.shutdown // only enable this branch in idle mode
- }
- if newstate != state {
- c.setState(newstate)
- state = newstate
- }
- select {
- case response := <-responses:
- if !response.streaming || response.status.Code() != codes.OK {
- p, err := c.server.codec.Marshal(&Response{
- Status: response.status.Proto(),
- Payload: response.data,
- })
- if err != nil {
- logrus.WithError(err).Error("failed marshaling response")
- return
- }
- if err := ch.send(response.id, messageTypeResponse, 0, p); err != nil {
- logrus.WithError(err).Error("failed sending message on channel")
- return
- }
- } else {
- var flags uint8
- if response.closeStream {
- flags = flagRemoteClosed
- }
- if response.data == nil {
- flags = flags | flagNoData
- }
- if err := ch.send(response.id, messageTypeData, flags, response.data); err != nil {
- logrus.WithError(err).Error("failed sending message on channel")
- return
- }
- }
- if response.closeStream {
- // The ttrpc protocol currently does not support the case where
- // the server is localClosed but not remoteClosed. Once the server
- // is closing, the whole stream may be considered finished
- streams.Delete(response.id)
- atomic.AddInt32(&active, -1)
- }
- case err := <-recvErr:
- // TODO(stevvooe): Not wildly clear what we should do in this
- // branch. Basically, it means that we are no longer receiving
- // requests due to a terminal error.
- recvErr = nil // connection is now "closing"
- if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, syscall.ECONNRESET) {
- // The client went away and we should stop processing
- // requests, so that the client connection is closed
- return
- }
- logrus.WithError(err).Error("error receiving message")
- // else, initiate shutdown
- case <-shutdown:
- return
- }
- }
- }
- var noopFunc = func() {}
- func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) {
- if len(req.Metadata) > 0 {
- md := MD{}
- md.fromRequest(req)
- ctx = WithMetadata(ctx, md)
- }
- cancel = noopFunc
- if req.TimeoutNano == 0 {
- return ctx, cancel
- }
- ctx, cancel = context.WithTimeout(ctx, time.Duration(req.TimeoutNano))
- return ctx, cancel
- }
|