vendor/github.com/stevvooe/ttrpc/server.go
5bd902b5
 package ttrpc
 
 import (
 	"context"
c2cb302d
 	"math/rand"
5bd902b5
 	"net"
c2cb302d
 	"sync"
 	"sync/atomic"
 	"time"
5bd902b5
 
 	"github.com/containerd/containerd/log"
c2cb302d
 	"github.com/pkg/errors"
5bd902b5
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
 )
 
c2cb302d
 var (
 	ErrServerClosed = errors.New("ttrpc: server close")
 )
 
5bd902b5
 type Server struct {
 	services *serviceSet
 	codec    codec
c2cb302d
 
 	mu          sync.Mutex
 	listeners   map[net.Listener]struct{}
 	connections map[*serverConn]struct{} // all connections to current state
 	done        chan struct{}            // marks point at which we stop serving requests
5bd902b5
 }
 
 func NewServer() *Server {
 	return &Server{
c2cb302d
 		services:    newServiceSet(),
 		done:        make(chan struct{}),
 		listeners:   make(map[net.Listener]struct{}),
 		connections: make(map[*serverConn]struct{}),
5bd902b5
 	}
 }
 
 func (s *Server) Register(name string, methods map[string]Method) {
 	s.services.register(name, methods)
 }
 
 func (s *Server) Serve(l net.Listener) error {
c2cb302d
 	s.addListener(l)
 	defer s.closeListener(l)
 
 	var (
 		ctx     = context.Background()
 		backoff time.Duration
 	)
 
5bd902b5
 	for {
 		conn, err := l.Accept()
 		if err != nil {
c2cb302d
 			select {
 			case <-s.done:
 				return ErrServerClosed
 			default:
 			}
 
 			if terr, ok := err.(interface {
 				Temporary() bool
 			}); ok && terr.Temporary() {
 				if backoff == 0 {
 					backoff = time.Millisecond
 				} else {
 					backoff *= 2
 				}
 
 				if max := time.Second; backoff > max {
 					backoff = max
 				}
 
 				sleep := time.Duration(rand.Int63n(int64(backoff)))
 				log.L.WithError(err).Errorf("ttrpc: failed accept; backoff %v", sleep)
 				time.Sleep(sleep)
 				continue
 			}
 
 			return err
 		}
 
 		backoff = 0
 		sc := s.newConn(conn)
 		go sc.run(ctx)
 	}
 }
 
 func (s *Server) Shutdown(ctx context.Context) error {
 	s.mu.Lock()
 	lnerr := s.closeListeners()
 	select {
 	case <-s.done:
 	default:
 		// protected by mutex
 		close(s.done)
 	}
 	s.mu.Unlock()
 
 	ticker := time.NewTicker(200 * time.Millisecond)
 	defer ticker.Stop()
 	for {
 		if s.closeIdleConns() {
 			return lnerr
 		}
 		select {
 		case <-ctx.Done():
 			return ctx.Err()
 		case <-ticker.C:
 		}
 	}
 }
 
 // Close the server without waiting for active connections.
 func (s *Server) Close() error {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
 	select {
 	case <-s.done:
 	default:
 		// protected by mutex
 		close(s.done)
 	}
 
 	err := s.closeListeners()
 	for c := range s.connections {
 		c.close()
 		delete(s.connections, c)
 	}
 
 	return err
 }
 
 func (s *Server) addListener(l net.Listener) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 	s.listeners[l] = struct{}{}
 }
 
 func (s *Server) closeListener(l net.Listener) error {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
 	return s.closeListenerLocked(l)
 }
 
 func (s *Server) closeListenerLocked(l net.Listener) error {
 	defer delete(s.listeners, l)
 	return l.Close()
 }
 
 func (s *Server) closeListeners() error {
 	var err error
 	for l := range s.listeners {
 		if cerr := s.closeListenerLocked(l); cerr != nil && err == nil {
 			err = cerr
 		}
 	}
 	return err
 }
 
 func (s *Server) addConnection(c *serverConn) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
 	s.connections[c] = struct{}{}
 }
 
 func (s *Server) closeIdleConns() bool {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 	quiescent := true
 	for c := range s.connections {
 		st, ok := c.getState()
 		if !ok || st != connStateIdle {
 			quiescent = false
5bd902b5
 			continue
 		}
c2cb302d
 		c.close()
 		delete(s.connections, c)
 	}
 	return quiescent
 }
 
 type connState int
 
 const (
 	connStateActive = iota + 1 // outstanding requests
 	connStateIdle              // no requests
 	connStateClosed            // closed connection
 )
5bd902b5
 
c2cb302d
 func (cs connState) String() string {
 	switch cs {
 	case connStateActive:
 		return "active"
 	case connStateIdle:
 		return "idle"
 	case connStateClosed:
 		return "closed"
 	default:
 		return "unknown"
5bd902b5
 	}
c2cb302d
 }
5bd902b5
 
c2cb302d
 func (s *Server) newConn(conn net.Conn) *serverConn {
 	c := &serverConn{
 		server:   s,
 		conn:     conn,
 		shutdown: make(chan struct{}),
 	}
 	c.setState(connStateIdle)
 	s.addConnection(c)
 	return c
 }
 
 type serverConn struct {
 	server *Server
 	conn   net.Conn
 	state  atomic.Value
 
 	shutdownOnce sync.Once
 	shutdown     chan struct{} // forced shutdown, used by close
 }
 
 func (c *serverConn) getState() (connState, bool) {
 	cs, ok := c.state.Load().(connState)
 	return cs, ok
 }
 
 func (c *serverConn) setState(newstate connState) {
 	c.state.Store(newstate)
5bd902b5
 }
 
c2cb302d
 func (c *serverConn) close() error {
 	c.shutdownOnce.Do(func() {
 		close(c.shutdown)
 	})
5bd902b5
 
c2cb302d
 	return nil
 }
 
 func (c *serverConn) run(sctx context.Context) {
5bd902b5
 	type (
 		request struct {
 			id  uint32
 			req *Request
 		}
 
 		response struct {
 			id   uint32
 			resp *Response
 		}
 	)
 
 	var (
c2cb302d
 		ch          = newChannel(c.conn, c.conn)
 		ctx, cancel = context.WithCancel(sctx)
 		active      int
 		state       connState = connStateIdle
 		responses             = make(chan response)
 		requests              = make(chan request)
 		recvErr               = make(chan error, 1)
 		shutdown              = c.shutdown
 		done                  = make(chan struct{})
5bd902b5
 	)
 
c2cb302d
 	defer c.conn.Close()
5bd902b5
 	defer cancel()
 	defer close(done)
 
c2cb302d
 	go func(recvErr chan error) {
5bd902b5
 		defer close(recvErr)
c2cb302d
 		sendImmediate := func(id uint32, st *status.Status) bool {
 			select {
 			case responses <- response{
 				// even though we've had an invalid stream id, we send it
 				// back on the same stream id so the client knows which
 				// stream id was bad.
 				id: id,
 				resp: &Response{
 					Status: st.Proto(),
 				},
 			}:
 				return true
 			case <-c.shutdown:
 				return false
 			case <-done:
 				return false
 			}
 		}
 
5bd902b5
 		for {
c2cb302d
 			select {
 			case <-c.shutdown:
5bd902b5
 				return
c2cb302d
 			case <-done:
 				return
 			default: // proceed
 			}
 
 			mh, p, err := ch.recv(ctx)
 			if err != nil {
 				status, ok := status.FromError(err)
 				if !ok {
 					recvErr <- err
 					return
 				}
 
 				// in this case, we send an error for that particular message
 				// when the status is defined.
 				if !sendImmediate(mh.StreamID, status) {
 					return
 				}
 
 				continue
5bd902b5
 			}
 
 			if mh.Type != messageTypeRequest {
 				// we must ignore this for future compat.
 				continue
 			}
 
 			var req Request
c2cb302d
 			if err := c.server.codec.Unmarshal(p, &req); err != nil {
 				ch.putmbuf(p)
 				if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) {
 					return
 				}
 				continue
5bd902b5
 			}
c2cb302d
 			ch.putmbuf(p)
5bd902b5
 
 			if mh.StreamID%2 != 1 {
 				// enforce odd client initiated identifiers.
c2cb302d
 				if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) {
 					return
5bd902b5
 				}
 				continue
 			}
 
c2cb302d
 			// Forward the request to the main loop. We don't wait on s.done
 			// because we have already accepted the client request.
5bd902b5
 			select {
 			case requests <- request{
 				id:  mh.StreamID,
 				req: &req,
 			}:
 			case <-done:
c2cb302d
 				return
5bd902b5
 			}
 		}
c2cb302d
 	}(recvErr)
5bd902b5
 
 	for {
c2cb302d
 		newstate := state
 		switch {
 		case active > 0:
 			newstate = connStateActive
 			shutdown = nil
 		case active == 0:
 			newstate = connStateIdle
 			shutdown = c.shutdown // only enable this branch in idle mode
 		}
 
 		if newstate != state {
 			c.setState(newstate)
 			state = newstate
 		}
 
5bd902b5
 		select {
 		case request := <-requests:
c2cb302d
 			active++
5bd902b5
 			go func(id uint32) {
c2cb302d
 				p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
5bd902b5
 				resp := &Response{
 					Status:  status.Proto(),
 					Payload: p,
 				}
 
 				select {
 				case responses <- response{
 					id:   id,
 					resp: resp,
 				}:
 				case <-done:
 				}
 			}(request.id)
 		case response := <-responses:
c2cb302d
 			p, err := c.server.codec.Marshal(response.resp)
5bd902b5
 			if err != nil {
 				log.L.WithError(err).Error("failed marshaling response")
 				return
 			}
c2cb302d
 
5bd902b5
 			if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
 				log.L.WithError(err).Error("failed sending message on channel")
 				return
 			}
c2cb302d
 
 			active--
5bd902b5
 		case err := <-recvErr:
c2cb302d
 			// TODO(stevvooe): Not wildly clear what we should do in this
 			// branch. Basically, it means that we are no longer receiving
 			// requests due to a terminal error.
 			recvErr = nil // connection is now "closing"
 			if err != nil {
 				log.L.WithError(err).Error("error receiving message")
 			}
 		case <-shutdown:
5bd902b5
 			return
 		}
 	}
 }