server.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579
  1. /*
  2. Copyright The containerd Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package ttrpc
  14. import (
  15. "context"
  16. "errors"
  17. "io"
  18. "math/rand"
  19. "net"
  20. "sync"
  21. "sync/atomic"
  22. "syscall"
  23. "time"
  24. "github.com/sirupsen/logrus"
  25. "google.golang.org/grpc/codes"
  26. "google.golang.org/grpc/status"
  27. )
  28. type Server struct {
  29. config *serverConfig
  30. services *serviceSet
  31. codec codec
  32. mu sync.Mutex
  33. listeners map[net.Listener]struct{}
  34. connections map[*serverConn]struct{} // all connections to current state
  35. done chan struct{} // marks point at which we stop serving requests
  36. }
  37. func NewServer(opts ...ServerOpt) (*Server, error) {
  38. config := &serverConfig{}
  39. for _, opt := range opts {
  40. if err := opt(config); err != nil {
  41. return nil, err
  42. }
  43. }
  44. if config.interceptor == nil {
  45. config.interceptor = defaultServerInterceptor
  46. }
  47. return &Server{
  48. config: config,
  49. services: newServiceSet(config.interceptor),
  50. done: make(chan struct{}),
  51. listeners: make(map[net.Listener]struct{}),
  52. connections: make(map[*serverConn]struct{}),
  53. }, nil
  54. }
  55. // Register registers a map of methods to method handlers
  56. // TODO: Remove in 2.0, does not support streams
  57. func (s *Server) Register(name string, methods map[string]Method) {
  58. s.services.register(name, &ServiceDesc{Methods: methods})
  59. }
  60. func (s *Server) RegisterService(name string, desc *ServiceDesc) {
  61. s.services.register(name, desc)
  62. }
  63. func (s *Server) Serve(ctx context.Context, l net.Listener) error {
  64. s.addListener(l)
  65. defer s.closeListener(l)
  66. var (
  67. backoff time.Duration
  68. handshaker = s.config.handshaker
  69. )
  70. if handshaker == nil {
  71. handshaker = handshakerFunc(noopHandshake)
  72. }
  73. for {
  74. conn, err := l.Accept()
  75. if err != nil {
  76. select {
  77. case <-s.done:
  78. return ErrServerClosed
  79. default:
  80. }
  81. if terr, ok := err.(interface {
  82. Temporary() bool
  83. }); ok && terr.Temporary() {
  84. if backoff == 0 {
  85. backoff = time.Millisecond
  86. } else {
  87. backoff *= 2
  88. }
  89. if max := time.Second; backoff > max {
  90. backoff = max
  91. }
  92. sleep := time.Duration(rand.Int63n(int64(backoff)))
  93. logrus.WithError(err).Errorf("ttrpc: failed accept; backoff %v", sleep)
  94. time.Sleep(sleep)
  95. continue
  96. }
  97. return err
  98. }
  99. backoff = 0
  100. approved, handshake, err := handshaker.Handshake(ctx, conn)
  101. if err != nil {
  102. logrus.WithError(err).Error("ttrpc: refusing connection after handshake")
  103. conn.Close()
  104. continue
  105. }
  106. sc, err := s.newConn(approved, handshake)
  107. if err != nil {
  108. logrus.WithError(err).Error("ttrpc: create connection failed")
  109. conn.Close()
  110. continue
  111. }
  112. go sc.run(ctx)
  113. }
  114. }
  115. func (s *Server) Shutdown(ctx context.Context) error {
  116. s.mu.Lock()
  117. select {
  118. case <-s.done:
  119. default:
  120. // protected by mutex
  121. close(s.done)
  122. }
  123. lnerr := s.closeListeners()
  124. s.mu.Unlock()
  125. ticker := time.NewTicker(200 * time.Millisecond)
  126. defer ticker.Stop()
  127. for {
  128. s.closeIdleConns()
  129. if s.countConnection() == 0 {
  130. break
  131. }
  132. select {
  133. case <-ctx.Done():
  134. return ctx.Err()
  135. case <-ticker.C:
  136. }
  137. }
  138. return lnerr
  139. }
  140. // Close the server without waiting for active connections.
  141. func (s *Server) Close() error {
  142. s.mu.Lock()
  143. defer s.mu.Unlock()
  144. select {
  145. case <-s.done:
  146. default:
  147. // protected by mutex
  148. close(s.done)
  149. }
  150. err := s.closeListeners()
  151. for c := range s.connections {
  152. c.close()
  153. delete(s.connections, c)
  154. }
  155. return err
  156. }
  157. func (s *Server) addListener(l net.Listener) {
  158. s.mu.Lock()
  159. defer s.mu.Unlock()
  160. s.listeners[l] = struct{}{}
  161. }
  162. func (s *Server) closeListener(l net.Listener) error {
  163. s.mu.Lock()
  164. defer s.mu.Unlock()
  165. return s.closeListenerLocked(l)
  166. }
  167. func (s *Server) closeListenerLocked(l net.Listener) error {
  168. defer delete(s.listeners, l)
  169. return l.Close()
  170. }
  171. func (s *Server) closeListeners() error {
  172. var err error
  173. for l := range s.listeners {
  174. if cerr := s.closeListenerLocked(l); cerr != nil && err == nil {
  175. err = cerr
  176. }
  177. }
  178. return err
  179. }
  180. func (s *Server) addConnection(c *serverConn) error {
  181. s.mu.Lock()
  182. defer s.mu.Unlock()
  183. select {
  184. case <-s.done:
  185. return ErrServerClosed
  186. default:
  187. }
  188. s.connections[c] = struct{}{}
  189. return nil
  190. }
  191. func (s *Server) delConnection(c *serverConn) {
  192. s.mu.Lock()
  193. defer s.mu.Unlock()
  194. delete(s.connections, c)
  195. }
  196. func (s *Server) countConnection() int {
  197. s.mu.Lock()
  198. defer s.mu.Unlock()
  199. return len(s.connections)
  200. }
  201. func (s *Server) closeIdleConns() {
  202. s.mu.Lock()
  203. defer s.mu.Unlock()
  204. for c := range s.connections {
  205. if st, ok := c.getState(); !ok || st == connStateActive {
  206. continue
  207. }
  208. c.close()
  209. delete(s.connections, c)
  210. }
  211. }
  212. type connState int
  213. const (
  214. connStateActive = iota + 1 // outstanding requests
  215. connStateIdle // no requests
  216. connStateClosed // closed connection
  217. )
  218. func (cs connState) String() string {
  219. switch cs {
  220. case connStateActive:
  221. return "active"
  222. case connStateIdle:
  223. return "idle"
  224. case connStateClosed:
  225. return "closed"
  226. default:
  227. return "unknown"
  228. }
  229. }
  230. func (s *Server) newConn(conn net.Conn, handshake interface{}) (*serverConn, error) {
  231. c := &serverConn{
  232. server: s,
  233. conn: conn,
  234. handshake: handshake,
  235. shutdown: make(chan struct{}),
  236. }
  237. c.setState(connStateIdle)
  238. if err := s.addConnection(c); err != nil {
  239. c.close()
  240. return nil, err
  241. }
  242. return c, nil
  243. }
  244. type serverConn struct {
  245. server *Server
  246. conn net.Conn
  247. handshake interface{} // data from handshake, not used for now
  248. state atomic.Value
  249. shutdownOnce sync.Once
  250. shutdown chan struct{} // forced shutdown, used by close
  251. }
  252. func (c *serverConn) getState() (connState, bool) {
  253. cs, ok := c.state.Load().(connState)
  254. return cs, ok
  255. }
  256. func (c *serverConn) setState(newstate connState) {
  257. c.state.Store(newstate)
  258. }
  259. func (c *serverConn) close() error {
  260. c.shutdownOnce.Do(func() {
  261. close(c.shutdown)
  262. })
  263. return nil
  264. }
  265. func (c *serverConn) run(sctx context.Context) {
  266. type (
  267. response struct {
  268. id uint32
  269. status *status.Status
  270. data []byte
  271. closeStream bool
  272. streaming bool
  273. }
  274. )
  275. var (
  276. ch = newChannel(c.conn)
  277. ctx, cancel = context.WithCancel(sctx)
  278. state connState = connStateIdle
  279. responses = make(chan response)
  280. recvErr = make(chan error, 1)
  281. done = make(chan struct{})
  282. streams = sync.Map{}
  283. active int32
  284. lastStreamID uint32
  285. )
  286. defer c.conn.Close()
  287. defer cancel()
  288. defer close(done)
  289. defer c.server.delConnection(c)
  290. sendStatus := func(id uint32, st *status.Status) bool {
  291. select {
  292. case responses <- response{
  293. // even though we've had an invalid stream id, we send it
  294. // back on the same stream id so the client knows which
  295. // stream id was bad.
  296. id: id,
  297. status: st,
  298. closeStream: true,
  299. }:
  300. return true
  301. case <-c.shutdown:
  302. return false
  303. case <-done:
  304. return false
  305. }
  306. }
  307. go func(recvErr chan error) {
  308. defer close(recvErr)
  309. for {
  310. select {
  311. case <-c.shutdown:
  312. return
  313. case <-done:
  314. return
  315. default: // proceed
  316. }
  317. mh, p, err := ch.recv()
  318. if err != nil {
  319. status, ok := status.FromError(err)
  320. if !ok {
  321. recvErr <- err
  322. return
  323. }
  324. // in this case, we send an error for that particular message
  325. // when the status is defined.
  326. if !sendStatus(mh.StreamID, status) {
  327. return
  328. }
  329. continue
  330. }
  331. if mh.StreamID%2 != 1 {
  332. // enforce odd client initiated identifiers.
  333. if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) {
  334. return
  335. }
  336. continue
  337. }
  338. if mh.Type == messageTypeData {
  339. i, ok := streams.Load(mh.StreamID)
  340. if !ok {
  341. if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID is no longer active")) {
  342. return
  343. }
  344. }
  345. sh := i.(*streamHandler)
  346. if mh.Flags&flagNoData != flagNoData {
  347. unmarshal := func(obj interface{}) error {
  348. err := protoUnmarshal(p, obj)
  349. ch.putmbuf(p)
  350. return err
  351. }
  352. if err := sh.data(unmarshal); err != nil {
  353. if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "data handling error: %v", err)) {
  354. return
  355. }
  356. }
  357. }
  358. if mh.Flags&flagRemoteClosed == flagRemoteClosed {
  359. sh.closeSend()
  360. if len(p) > 0 {
  361. if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "data close message cannot include data")) {
  362. return
  363. }
  364. }
  365. }
  366. } else if mh.Type == messageTypeRequest {
  367. if mh.StreamID <= lastStreamID {
  368. // enforce odd client initiated identifiers.
  369. if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID cannot be re-used and must increment")) {
  370. return
  371. }
  372. continue
  373. }
  374. lastStreamID = mh.StreamID
  375. // TODO: Make request type configurable
  376. // Unmarshaller which takes in a byte array and returns an interface?
  377. var req Request
  378. if err := c.server.codec.Unmarshal(p, &req); err != nil {
  379. ch.putmbuf(p)
  380. if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) {
  381. return
  382. }
  383. continue
  384. }
  385. ch.putmbuf(p)
  386. id := mh.StreamID
  387. respond := func(status *status.Status, data []byte, streaming, closeStream bool) error {
  388. select {
  389. case responses <- response{
  390. id: id,
  391. status: status,
  392. data: data,
  393. closeStream: closeStream,
  394. streaming: streaming,
  395. }:
  396. case <-done:
  397. return ErrClosed
  398. }
  399. return nil
  400. }
  401. sh, err := c.server.services.handle(ctx, &req, respond)
  402. if err != nil {
  403. status, _ := status.FromError(err)
  404. if !sendStatus(mh.StreamID, status) {
  405. return
  406. }
  407. continue
  408. }
  409. streams.Store(id, sh)
  410. atomic.AddInt32(&active, 1)
  411. }
  412. // TODO: else we must ignore this for future compat. log this?
  413. }
  414. }(recvErr)
  415. for {
  416. var (
  417. newstate connState
  418. shutdown chan struct{}
  419. )
  420. activeN := atomic.LoadInt32(&active)
  421. if activeN > 0 {
  422. newstate = connStateActive
  423. shutdown = nil
  424. } else {
  425. newstate = connStateIdle
  426. shutdown = c.shutdown // only enable this branch in idle mode
  427. }
  428. if newstate != state {
  429. c.setState(newstate)
  430. state = newstate
  431. }
  432. select {
  433. case response := <-responses:
  434. if !response.streaming || response.status.Code() != codes.OK {
  435. p, err := c.server.codec.Marshal(&Response{
  436. Status: response.status.Proto(),
  437. Payload: response.data,
  438. })
  439. if err != nil {
  440. logrus.WithError(err).Error("failed marshaling response")
  441. return
  442. }
  443. if err := ch.send(response.id, messageTypeResponse, 0, p); err != nil {
  444. logrus.WithError(err).Error("failed sending message on channel")
  445. return
  446. }
  447. } else {
  448. var flags uint8
  449. if response.closeStream {
  450. flags = flagRemoteClosed
  451. }
  452. if response.data == nil {
  453. flags = flags | flagNoData
  454. }
  455. if err := ch.send(response.id, messageTypeData, flags, response.data); err != nil {
  456. logrus.WithError(err).Error("failed sending message on channel")
  457. return
  458. }
  459. }
  460. if response.closeStream {
  461. // The ttrpc protocol currently does not support the case where
  462. // the server is localClosed but not remoteClosed. Once the server
  463. // is closing, the whole stream may be considered finished
  464. streams.Delete(response.id)
  465. atomic.AddInt32(&active, -1)
  466. }
  467. case err := <-recvErr:
  468. // TODO(stevvooe): Not wildly clear what we should do in this
  469. // branch. Basically, it means that we are no longer receiving
  470. // requests due to a terminal error.
  471. recvErr = nil // connection is now "closing"
  472. if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, syscall.ECONNRESET) {
  473. // The client went away and we should stop processing
  474. // requests, so that the client connection is closed
  475. return
  476. }
  477. logrus.WithError(err).Error("error receiving message")
  478. // else, initiate shutdown
  479. case <-shutdown:
  480. return
  481. }
  482. }
  483. }
  484. var noopFunc = func() {}
  485. func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) {
  486. if len(req.Metadata) > 0 {
  487. md := MD{}
  488. md.fromRequest(req)
  489. ctx = WithMetadata(ctx, md)
  490. }
  491. cancel = noopFunc
  492. if req.TimeoutNano == 0 {
  493. return ctx, cancel
  494. }
  495. ctx, cancel = context.WithTimeout(ctx, time.Duration(req.TimeoutNano))
  496. return ctx, cancel
  497. }