package ttrpc import ( "context" "math/rand" "net" "sync" "sync/atomic" "time" "github.com/containerd/containerd/log" "github.com/pkg/errors" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) var ( ErrServerClosed = errors.New("ttrpc: server close") ) type Server struct { services *serviceSet codec codec mu sync.Mutex listeners map[net.Listener]struct{} connections map[*serverConn]struct{} // all connections to current state done chan struct{} // marks point at which we stop serving requests } func NewServer() *Server { return &Server{ services: newServiceSet(), done: make(chan struct{}), listeners: make(map[net.Listener]struct{}), connections: make(map[*serverConn]struct{}), } } func (s *Server) Register(name string, methods map[string]Method) { s.services.register(name, methods) } func (s *Server) Serve(l net.Listener) error { s.addListener(l) defer s.closeListener(l) var ( ctx = context.Background() backoff time.Duration ) for { conn, err := l.Accept() if err != nil { select { case <-s.done: return ErrServerClosed default: } if terr, ok := err.(interface { Temporary() bool }); ok && terr.Temporary() { if backoff == 0 { backoff = time.Millisecond } else { backoff *= 2 } if max := time.Second; backoff > max { backoff = max } sleep := time.Duration(rand.Int63n(int64(backoff))) log.L.WithError(err).Errorf("ttrpc: failed accept; backoff %v", sleep) time.Sleep(sleep) continue } return err } backoff = 0 sc := s.newConn(conn) go sc.run(ctx) } } func (s *Server) Shutdown(ctx context.Context) error { s.mu.Lock() lnerr := s.closeListeners() select { case <-s.done: default: // protected by mutex close(s.done) } s.mu.Unlock() ticker := time.NewTicker(200 * time.Millisecond) defer ticker.Stop() for { if s.closeIdleConns() { return lnerr } select { case <-ctx.Done(): return ctx.Err() case <-ticker.C: } } } // Close the server without waiting for active connections. func (s *Server) Close() error { s.mu.Lock() defer s.mu.Unlock() select { case <-s.done: default: // protected by mutex close(s.done) } err := s.closeListeners() for c := range s.connections { c.close() delete(s.connections, c) } return err } func (s *Server) addListener(l net.Listener) { s.mu.Lock() defer s.mu.Unlock() s.listeners[l] = struct{}{} } func (s *Server) closeListener(l net.Listener) error { s.mu.Lock() defer s.mu.Unlock() return s.closeListenerLocked(l) } func (s *Server) closeListenerLocked(l net.Listener) error { defer delete(s.listeners, l) return l.Close() } func (s *Server) closeListeners() error { var err error for l := range s.listeners { if cerr := s.closeListenerLocked(l); cerr != nil && err == nil { err = cerr } } return err } func (s *Server) addConnection(c *serverConn) { s.mu.Lock() defer s.mu.Unlock() s.connections[c] = struct{}{} } func (s *Server) closeIdleConns() bool { s.mu.Lock() defer s.mu.Unlock() quiescent := true for c := range s.connections { st, ok := c.getState() if !ok || st != connStateIdle { quiescent = false continue } c.close() delete(s.connections, c) } return quiescent } type connState int const ( connStateActive = iota + 1 // outstanding requests connStateIdle // no requests connStateClosed // closed connection ) func (cs connState) String() string { switch cs { case connStateActive: return "active" case connStateIdle: return "idle" case connStateClosed: return "closed" default: return "unknown" } } func (s *Server) newConn(conn net.Conn) *serverConn { c := &serverConn{ server: s, conn: conn, shutdown: make(chan struct{}), } c.setState(connStateIdle) s.addConnection(c) return c } type serverConn struct { server *Server conn net.Conn state atomic.Value shutdownOnce sync.Once shutdown chan struct{} // forced shutdown, used by close } func (c *serverConn) getState() (connState, bool) { cs, ok := c.state.Load().(connState) return cs, ok } func (c *serverConn) setState(newstate connState) { c.state.Store(newstate) } func (c *serverConn) close() error { c.shutdownOnce.Do(func() { close(c.shutdown) }) return nil } func (c *serverConn) run(sctx context.Context) { type ( request struct { id uint32 req *Request } response struct { id uint32 resp *Response } ) var ( ch = newChannel(c.conn, c.conn) ctx, cancel = context.WithCancel(sctx) active int state connState = connStateIdle responses = make(chan response) requests = make(chan request) recvErr = make(chan error, 1) shutdown = c.shutdown done = make(chan struct{}) ) defer c.conn.Close() defer cancel() defer close(done) go func(recvErr chan error) { defer close(recvErr) sendImmediate := func(id uint32, st *status.Status) bool { select { case responses <- response{ // even though we've had an invalid stream id, we send it // back on the same stream id so the client knows which // stream id was bad. id: id, resp: &Response{ Status: st.Proto(), }, }: return true case <-c.shutdown: return false case <-done: return false } } for { select { case <-c.shutdown: return case <-done: return default: // proceed } mh, p, err := ch.recv(ctx) if err != nil { status, ok := status.FromError(err) if !ok { recvErr <- err return } // in this case, we send an error for that particular message // when the status is defined. if !sendImmediate(mh.StreamID, status) { return } continue } if mh.Type != messageTypeRequest { // we must ignore this for future compat. continue } var req Request if err := c.server.codec.Unmarshal(p, &req); err != nil { ch.putmbuf(p) if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) { return } continue } ch.putmbuf(p) if mh.StreamID%2 != 1 { // enforce odd client initiated identifiers. if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) { return } continue } // Forward the request to the main loop. We don't wait on s.done // because we have already accepted the client request. select { case requests <- request{ id: mh.StreamID, req: &req, }: case <-done: return } } }(recvErr) for { newstate := state switch { case active > 0: newstate = connStateActive shutdown = nil case active == 0: newstate = connStateIdle shutdown = c.shutdown // only enable this branch in idle mode } if newstate != state { c.setState(newstate) state = newstate } select { case request := <-requests: active++ go func(id uint32) { p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload) resp := &Response{ Status: status.Proto(), Payload: p, } select { case responses <- response{ id: id, resp: resp, }: case <-done: } }(request.id) case response := <-responses: p, err := c.server.codec.Marshal(response.resp) if err != nil { log.L.WithError(err).Error("failed marshaling response") return } if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil { log.L.WithError(err).Error("failed sending message on channel") return } active-- case err := <-recvErr: // TODO(stevvooe): Not wildly clear what we should do in this // branch. Basically, it means that we are no longer receiving // requests due to a terminal error. recvErr = nil // connection is now "closing" if err != nil { log.L.WithError(err).Error("error receiving message") } case <-shutdown: return } } }