Browse code

Merge pull request #340 from thaJeztah/19.03_backport_bump_grpc

[19.03 backport] bump google.golang.org/grpc v1.23.0 (CVE-2019-9512, CVE-2019-9514, CVE-2019-9515)

Andrew Hsu authored on 2019/09/24 01:32:43
Showing 44 changed files
... ...
@@ -74,7 +74,7 @@ github.com/opencontainers/go-digest                 279bed98673dd5bef374d3b6e4b0
74 74
 # get go-zfs packages
75 75
 github.com/mistifyio/go-zfs                         f784269be439d704d3dfa1906f45dd848fed2beb
76 76
 
77
-google.golang.org/grpc                              25c4f928eaa6d96443009bd842389fb4fa48664e # v1.20.1
77
+google.golang.org/grpc                              6eaf6f47437a6b4e2153a190160ef39a92c7eceb # v1.23.0
78 78
 
79 79
 # The version of runc should match the version that is used by the containerd
80 80
 # version that is used. If you need to update runc, open a pull request in
... ...
@@ -127,7 +127,7 @@ github.com/containerd/cgroups                       4994991857f9b0ae8dc439551e8b
127 127
 github.com/containerd/console                       0650fd9eeb50bab4fc99dceb9f2e14cf58f36e7f
128 128
 github.com/containerd/go-runc                       7d11b49dc0769f6dbb0d1b19f3d48524d1bad9ad
129 129
 github.com/containerd/typeurl                       2a93cfde8c20b23de8eb84a5adbc234ddf7a9e8d
130
-github.com/containerd/ttrpc                         699c4e40d1e7416e08bf7019c7ce2e9beced4636
130
+github.com/containerd/ttrpc                         92c8520ef9f86600c650dd540266a007bf03670f
131 131
 github.com/gogo/googleapis                          d31c731455cb061f42baff3bda55bad0118b126b # v1.2.0
132 132
 
133 133
 # cluster
... ...
@@ -18,7 +18,6 @@ package ttrpc
18 18
 
19 19
 import (
20 20
 	"bufio"
21
-	"context"
22 21
 	"encoding/binary"
23 22
 	"io"
24 23
 	"net"
... ...
@@ -98,7 +97,7 @@ func newChannel(conn net.Conn) *channel {
98 98
 // returned will be valid and caller should send that along to
99 99
 // the correct consumer. The bytes on the underlying channel
100 100
 // will be discarded.
101
-func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
101
+func (ch *channel) recv() (messageHeader, []byte, error) {
102 102
 	mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
103 103
 	if err != nil {
104 104
 		return messageHeader{}, nil, err
... ...
@@ -120,7 +119,7 @@ func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
120 120
 	return mh, p, nil
121 121
 }
122 122
 
123
-func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error {
123
+func (ch *channel) send(streamID uint32, t messageType, p []byte) error {
124 124
 	if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
125 125
 		return err
126 126
 	}
... ...
@@ -29,6 +29,7 @@ import (
29 29
 	"github.com/gogo/protobuf/proto"
30 30
 	"github.com/pkg/errors"
31 31
 	"github.com/sirupsen/logrus"
32
+	"google.golang.org/grpc/codes"
32 33
 	"google.golang.org/grpc/status"
33 34
 )
34 35
 
... ...
@@ -36,36 +37,52 @@ import (
36 36
 // closed.
37 37
 var ErrClosed = errors.New("ttrpc: closed")
38 38
 
39
+// Client for a ttrpc server
39 40
 type Client struct {
40 41
 	codec   codec
41 42
 	conn    net.Conn
42 43
 	channel *channel
43 44
 	calls   chan *callRequest
44 45
 
45
-	closed    chan struct{}
46
-	closeOnce sync.Once
47
-	closeFunc func()
48
-	done      chan struct{}
49
-	err       error
46
+	ctx    context.Context
47
+	closed func()
48
+
49
+	closeOnce     sync.Once
50
+	userCloseFunc func()
51
+
52
+	errOnce     sync.Once
53
+	err         error
54
+	interceptor UnaryClientInterceptor
50 55
 }
51 56
 
57
+// ClientOpts configures a client
52 58
 type ClientOpts func(c *Client)
53 59
 
60
+// WithOnClose sets the close func whenever the client's Close() method is called
54 61
 func WithOnClose(onClose func()) ClientOpts {
55 62
 	return func(c *Client) {
56
-		c.closeFunc = onClose
63
+		c.userCloseFunc = onClose
64
+	}
65
+}
66
+
67
+// WithUnaryClientInterceptor sets the provided client interceptor
68
+func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
69
+	return func(c *Client) {
70
+		c.interceptor = i
57 71
 	}
58 72
 }
59 73
 
60 74
 func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
75
+	ctx, cancel := context.WithCancel(context.Background())
61 76
 	c := &Client{
62
-		codec:     codec{},
63
-		conn:      conn,
64
-		channel:   newChannel(conn),
65
-		calls:     make(chan *callRequest),
66
-		closed:    make(chan struct{}),
67
-		done:      make(chan struct{}),
68
-		closeFunc: func() {},
77
+		codec:         codec{},
78
+		conn:          conn,
79
+		channel:       newChannel(conn),
80
+		calls:         make(chan *callRequest),
81
+		closed:        cancel,
82
+		ctx:           ctx,
83
+		userCloseFunc: func() {},
84
+		interceptor:   defaultClientInterceptor,
69 85
 	}
70 86
 
71 87
 	for _, o := range opts {
... ...
@@ -99,11 +116,18 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int
99 99
 		cresp = &Response{}
100 100
 	)
101 101
 
102
+	if metadata, ok := GetMetadata(ctx); ok {
103
+		metadata.setRequest(creq)
104
+	}
105
+
102 106
 	if dl, ok := ctx.Deadline(); ok {
103 107
 		creq.TimeoutNano = dl.Sub(time.Now()).Nanoseconds()
104 108
 	}
105 109
 
106
-	if err := c.dispatch(ctx, creq, cresp); err != nil {
110
+	info := &UnaryClientInfo{
111
+		FullMethod: fullPath(service, method),
112
+	}
113
+	if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil {
107 114
 		return err
108 115
 	}
109 116
 
... ...
@@ -111,11 +135,10 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int
111 111
 		return err
112 112
 	}
113 113
 
114
-	if cresp.Status == nil {
115
-		return errors.New("no status provided on response")
114
+	if cresp.Status != nil && cresp.Status.Code != int32(codes.OK) {
115
+		return status.ErrorProto(cresp.Status)
116 116
 	}
117
-
118
-	return status.ErrorProto(cresp.Status)
117
+	return nil
119 118
 }
120 119
 
121 120
 func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
... ...
@@ -131,8 +154,8 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
131 131
 	case <-ctx.Done():
132 132
 		return ctx.Err()
133 133
 	case c.calls <- call:
134
-	case <-c.done:
135
-		return c.err
134
+	case <-c.ctx.Done():
135
+		return c.error()
136 136
 	}
137 137
 
138 138
 	select {
... ...
@@ -140,16 +163,15 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
140 140
 		return ctx.Err()
141 141
 	case err := <-errs:
142 142
 		return filterCloseErr(err)
143
-	case <-c.done:
144
-		return c.err
143
+	case <-c.ctx.Done():
144
+		return c.error()
145 145
 	}
146 146
 }
147 147
 
148 148
 func (c *Client) Close() error {
149 149
 	c.closeOnce.Do(func() {
150
-		close(c.closed)
150
+		c.closed()
151 151
 	})
152
-
153 152
 	return nil
154 153
 }
155 154
 
... ...
@@ -159,51 +181,82 @@ type message struct {
159 159
 	err error
160 160
 }
161 161
 
162
-func (c *Client) run() {
163
-	var (
164
-		streamID    uint32 = 1
165
-		waiters            = make(map[uint32]*callRequest)
166
-		calls              = c.calls
167
-		incoming           = make(chan *message)
168
-		shutdown           = make(chan struct{})
169
-		shutdownErr error
170
-	)
162
+type receiver struct {
163
+	wg       *sync.WaitGroup
164
+	messages chan *message
165
+	err      error
166
+}
171 167
 
172
-	go func() {
173
-		defer close(shutdown)
168
+func (r *receiver) run(ctx context.Context, c *channel) {
169
+	defer r.wg.Done()
174 170
 
175
-		// start one more goroutine to recv messages without blocking.
176
-		for {
177
-			mh, p, err := c.channel.recv(context.TODO())
171
+	for {
172
+		select {
173
+		case <-ctx.Done():
174
+			r.err = ctx.Err()
175
+			return
176
+		default:
177
+			mh, p, err := c.recv()
178 178
 			if err != nil {
179 179
 				_, ok := status.FromError(err)
180 180
 				if !ok {
181 181
 					// treat all errors that are not an rpc status as terminal.
182 182
 					// all others poison the connection.
183
-					shutdownErr = err
183
+					r.err = filterCloseErr(err)
184 184
 					return
185 185
 				}
186 186
 			}
187 187
 			select {
188
-			case incoming <- &message{
188
+			case r.messages <- &message{
189 189
 				messageHeader: mh,
190 190
 				p:             p[:mh.Length],
191 191
 				err:           err,
192 192
 			}:
193
-			case <-c.done:
193
+			case <-ctx.Done():
194
+				r.err = ctx.Err()
194 195
 				return
195 196
 			}
196 197
 		}
198
+	}
199
+}
200
+
201
+func (c *Client) run() {
202
+	var (
203
+		streamID      uint32 = 1
204
+		waiters              = make(map[uint32]*callRequest)
205
+		calls                = c.calls
206
+		incoming             = make(chan *message)
207
+		receiversDone        = make(chan struct{})
208
+		wg            sync.WaitGroup
209
+	)
210
+
211
+	// broadcast the shutdown error to the remaining waiters.
212
+	abortWaiters := func(wErr error) {
213
+		for _, waiter := range waiters {
214
+			waiter.errs <- wErr
215
+		}
216
+	}
217
+	recv := &receiver{
218
+		wg:       &wg,
219
+		messages: incoming,
220
+	}
221
+	wg.Add(1)
222
+
223
+	go func() {
224
+		wg.Wait()
225
+		close(receiversDone)
197 226
 	}()
227
+	go recv.run(c.ctx, c.channel)
198 228
 
199
-	defer c.conn.Close()
200
-	defer close(c.done)
201
-	defer c.closeFunc()
229
+	defer func() {
230
+		c.conn.Close()
231
+		c.userCloseFunc()
232
+	}()
202 233
 
203 234
 	for {
204 235
 		select {
205 236
 		case call := <-calls:
206
-			if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil {
237
+			if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
207 238
 				call.errs <- err
208 239
 				continue
209 240
 			}
... ...
@@ -219,41 +272,42 @@ func (c *Client) run() {
219 219
 
220 220
 			call.errs <- c.recv(call.resp, msg)
221 221
 			delete(waiters, msg.StreamID)
222
-		case <-shutdown:
223
-			if shutdownErr != nil {
224
-				shutdownErr = filterCloseErr(shutdownErr)
225
-			} else {
226
-				shutdownErr = ErrClosed
227
-			}
228
-
229
-			shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down")
230
-
231
-			c.err = shutdownErr
232
-			for _, waiter := range waiters {
233
-				waiter.errs <- shutdownErr
222
+		case <-receiversDone:
223
+			// all the receivers have exited
224
+			if recv.err != nil {
225
+				c.setError(recv.err)
234 226
 			}
227
+			// don't return out, let the close of the context trigger the abort of waiters
235 228
 			c.Close()
236
-			return
237
-		case <-c.closed:
238
-			if c.err == nil {
239
-				c.err = ErrClosed
240
-			}
241
-			// broadcast the shutdown error to the remaining waiters.
242
-			for _, waiter := range waiters {
243
-				waiter.errs <- c.err
244
-			}
229
+		case <-c.ctx.Done():
230
+			abortWaiters(c.error())
245 231
 			return
246 232
 		}
247 233
 	}
248 234
 }
249 235
 
250
-func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error {
236
+func (c *Client) error() error {
237
+	c.errOnce.Do(func() {
238
+		if c.err == nil {
239
+			c.err = ErrClosed
240
+		}
241
+	})
242
+	return c.err
243
+}
244
+
245
+func (c *Client) setError(err error) {
246
+	c.errOnce.Do(func() {
247
+		c.err = err
248
+	})
249
+}
250
+
251
+func (c *Client) send(streamID uint32, mtype messageType, msg interface{}) error {
251 252
 	p, err := c.codec.Marshal(msg)
252 253
 	if err != nil {
253 254
 		return err
254 255
 	}
255 256
 
256
-	return c.channel.send(ctx, streamID, mtype, p)
257
+	return c.channel.send(streamID, mtype, p)
257 258
 }
258 259
 
259 260
 func (c *Client) recv(resp *Response, msg *message) error {
... ...
@@ -274,22 +328,21 @@ func (c *Client) recv(resp *Response, msg *message) error {
274 274
 //
275 275
 // This purposely ignores errors with a wrapped cause.
276 276
 func filterCloseErr(err error) error {
277
-	if err == nil {
277
+	switch {
278
+	case err == nil:
278 279
 		return nil
279
-	}
280
-
281
-	if err == io.EOF {
280
+	case err == io.EOF:
282 281
 		return ErrClosed
283
-	}
284
-
285
-	if strings.Contains(err.Error(), "use of closed network connection") {
282
+	case errors.Cause(err) == io.EOF:
286 283
 		return ErrClosed
287
-	}
288
-
289
-	// if we have an epipe on a write, we cast to errclosed
290
-	if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
291
-		if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
292
-			return ErrClosed
284
+	case strings.Contains(err.Error(), "use of closed network connection"):
285
+		return ErrClosed
286
+	default:
287
+		// if we have an epipe on a write, we cast to errclosed
288
+		if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
289
+			if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
290
+				return ErrClosed
291
+			}
293 292
 		}
294 293
 	}
295 294
 
... ...
@@ -19,9 +19,11 @@ package ttrpc
19 19
 import "github.com/pkg/errors"
20 20
 
21 21
 type serverConfig struct {
22
-	handshaker Handshaker
22
+	handshaker  Handshaker
23
+	interceptor UnaryServerInterceptor
23 24
 }
24 25
 
26
+// ServerOpt for configuring a ttrpc server
25 27
 type ServerOpt func(*serverConfig) error
26 28
 
27 29
 // WithServerHandshaker can be passed to NewServer to ensure that the
... ...
@@ -37,3 +39,14 @@ func WithServerHandshaker(handshaker Handshaker) ServerOpt {
37 37
 		return nil
38 38
 	}
39 39
 }
40
+
41
+// WithUnaryServerInterceptor sets the provided interceptor on the server
42
+func WithUnaryServerInterceptor(i UnaryServerInterceptor) ServerOpt {
43
+	return func(c *serverConfig) error {
44
+		if c.interceptor != nil {
45
+			return errors.New("only one interceptor allowed per server")
46
+		}
47
+		c.interceptor = i
48
+		return nil
49
+	}
50
+}
40 51
new file mode 100644
... ...
@@ -0,0 +1,50 @@
0
+/*
1
+   Copyright The containerd Authors.
2
+
3
+   Licensed under the Apache License, Version 2.0 (the "License");
4
+   you may not use this file except in compliance with the License.
5
+   You may obtain a copy of the License at
6
+
7
+       http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+   Unless required by applicable law or agreed to in writing, software
10
+   distributed under the License is distributed on an "AS IS" BASIS,
11
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+   See the License for the specific language governing permissions and
13
+   limitations under the License.
14
+*/
15
+
16
+package ttrpc
17
+
18
+import "context"
19
+
20
+// UnaryServerInfo provides information about the server request
21
+type UnaryServerInfo struct {
22
+	FullMethod string
23
+}
24
+
25
+// UnaryClientInfo provides information about the client request
26
+type UnaryClientInfo struct {
27
+	FullMethod string
28
+}
29
+
30
+// Unmarshaler contains the server request data and allows it to be unmarshaled
31
+// into a concrete type
32
+type Unmarshaler func(interface{}) error
33
+
34
+// Invoker invokes the client's request and response from the ttrpc server
35
+type Invoker func(context.Context, *Request, *Response) error
36
+
37
+// UnaryServerInterceptor specifies the interceptor function for server request/response
38
+type UnaryServerInterceptor func(context.Context, Unmarshaler, *UnaryServerInfo, Method) (interface{}, error)
39
+
40
+// UnaryClientInterceptor specifies the interceptor function for client request/response
41
+type UnaryClientInterceptor func(context.Context, *Request, *Response, *UnaryClientInfo, Invoker) error
42
+
43
+func defaultServerInterceptor(ctx context.Context, unmarshal Unmarshaler, info *UnaryServerInfo, method Method) (interface{}, error) {
44
+	return method(ctx, unmarshal)
45
+}
46
+
47
+func defaultClientInterceptor(ctx context.Context, req *Request, resp *Response, _ *UnaryClientInfo, invoker Invoker) error {
48
+	return invoker(ctx, req, resp)
49
+}
0 50
new file mode 100644
... ...
@@ -0,0 +1,107 @@
0
+/*
1
+   Copyright The containerd Authors.
2
+
3
+   Licensed under the Apache License, Version 2.0 (the "License");
4
+   you may not use this file except in compliance with the License.
5
+   You may obtain a copy of the License at
6
+
7
+       http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+   Unless required by applicable law or agreed to in writing, software
10
+   distributed under the License is distributed on an "AS IS" BASIS,
11
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+   See the License for the specific language governing permissions and
13
+   limitations under the License.
14
+*/
15
+
16
+package ttrpc
17
+
18
+import (
19
+	"context"
20
+	"strings"
21
+)
22
+
23
+// MD is the user type for ttrpc metadata
24
+type MD map[string][]string
25
+
26
+// Get returns the metadata for a given key when they exist.
27
+// If there is no metadata, a nil slice and false are returned.
28
+func (m MD) Get(key string) ([]string, bool) {
29
+	key = strings.ToLower(key)
30
+	list, ok := m[key]
31
+	if !ok || len(list) == 0 {
32
+		return nil, false
33
+	}
34
+
35
+	return list, true
36
+}
37
+
38
+// Set sets the provided values for a given key.
39
+// The values will overwrite any existing values.
40
+// If no values provided, a key will be deleted.
41
+func (m MD) Set(key string, values ...string) {
42
+	key = strings.ToLower(key)
43
+	if len(values) == 0 {
44
+		delete(m, key)
45
+		return
46
+	}
47
+	m[key] = values
48
+}
49
+
50
+// Append appends additional values to the given key.
51
+func (m MD) Append(key string, values ...string) {
52
+	key = strings.ToLower(key)
53
+	if len(values) == 0 {
54
+		return
55
+	}
56
+	current, ok := m[key]
57
+	if ok {
58
+		m.Set(key, append(current, values...)...)
59
+	} else {
60
+		m.Set(key, values...)
61
+	}
62
+}
63
+
64
+func (m MD) setRequest(r *Request) {
65
+	for k, values := range m {
66
+		for _, v := range values {
67
+			r.Metadata = append(r.Metadata, &KeyValue{
68
+				Key:   k,
69
+				Value: v,
70
+			})
71
+		}
72
+	}
73
+}
74
+
75
+func (m MD) fromRequest(r *Request) {
76
+	for _, kv := range r.Metadata {
77
+		m[kv.Key] = append(m[kv.Key], kv.Value)
78
+	}
79
+}
80
+
81
+type metadataKey struct{}
82
+
83
+// GetMetadata retrieves metadata from context.Context (previously attached with WithMetadata)
84
+func GetMetadata(ctx context.Context) (MD, bool) {
85
+	metadata, ok := ctx.Value(metadataKey{}).(MD)
86
+	return metadata, ok
87
+}
88
+
89
+// GetMetadataValue gets a specific metadata value by name from context.Context
90
+func GetMetadataValue(ctx context.Context, name string) (string, bool) {
91
+	metadata, ok := GetMetadata(ctx)
92
+	if !ok {
93
+		return "", false
94
+	}
95
+
96
+	if list, ok := metadata.Get(name); ok {
97
+		return list[0], true
98
+	}
99
+
100
+	return "", false
101
+}
102
+
103
+// WithMetadata attaches metadata map to a context.Context
104
+func WithMetadata(ctx context.Context, md MD) context.Context {
105
+	return context.WithValue(ctx, metadataKey{}, md)
106
+}
... ...
@@ -53,10 +53,13 @@ func NewServer(opts ...ServerOpt) (*Server, error) {
53 53
 			return nil, err
54 54
 		}
55 55
 	}
56
+	if config.interceptor == nil {
57
+		config.interceptor = defaultServerInterceptor
58
+	}
56 59
 
57 60
 	return &Server{
58 61
 		config:      config,
59
-		services:    newServiceSet(),
62
+		services:    newServiceSet(config.interceptor),
60 63
 		done:        make(chan struct{}),
61 64
 		listeners:   make(map[net.Listener]struct{}),
62 65
 		connections: make(map[*serverConn]struct{}),
... ...
@@ -341,7 +344,7 @@ func (c *serverConn) run(sctx context.Context) {
341 341
 			default: // proceed
342 342
 			}
343 343
 
344
-			mh, p, err := ch.recv(ctx)
344
+			mh, p, err := ch.recv()
345 345
 			if err != nil {
346 346
 				status, ok := status.FromError(err)
347 347
 				if !ok {
... ...
@@ -438,7 +441,7 @@ func (c *serverConn) run(sctx context.Context) {
438 438
 				return
439 439
 			}
440 440
 
441
-			if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
441
+			if err := ch.send(response.id, messageTypeResponse, p); err != nil {
442 442
 				logrus.WithError(err).Error("failed sending message on channel")
443 443
 				return
444 444
 			}
... ...
@@ -449,7 +452,12 @@ func (c *serverConn) run(sctx context.Context) {
449 449
 			// branch. Basically, it means that we are no longer receiving
450 450
 			// requests due to a terminal error.
451 451
 			recvErr = nil // connection is now "closing"
452
-			if err != nil && err != io.EOF {
452
+			if err == io.EOF || err == io.ErrUnexpectedEOF {
453
+				// The client went away and we should stop processing
454
+				// requests, so that the client connection is closed
455
+				return
456
+			}
457
+			if err != nil {
453 458
 				logrus.WithError(err).Error("error receiving message")
454 459
 			}
455 460
 		case <-shutdown:
... ...
@@ -461,6 +469,12 @@ func (c *serverConn) run(sctx context.Context) {
461 461
 var noopFunc = func() {}
462 462
 
463 463
 func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) {
464
+	if len(req.Metadata) > 0 {
465
+		md := MD{}
466
+		md.fromRequest(req)
467
+		ctx = WithMetadata(ctx, md)
468
+	}
469
+
464 470
 	cancel = noopFunc
465 471
 	if req.TimeoutNano == 0 {
466 472
 		return ctx, cancel
... ...
@@ -37,12 +37,14 @@ type ServiceDesc struct {
37 37
 }
38 38
 
39 39
 type serviceSet struct {
40
-	services map[string]ServiceDesc
40
+	services    map[string]ServiceDesc
41
+	interceptor UnaryServerInterceptor
41 42
 }
42 43
 
43
-func newServiceSet() *serviceSet {
44
+func newServiceSet(interceptor UnaryServerInterceptor) *serviceSet {
44 45
 	return &serviceSet{
45
-		services: make(map[string]ServiceDesc),
46
+		services:    make(map[string]ServiceDesc),
47
+		interceptor: interceptor,
46 48
 	}
47 49
 }
48 50
 
... ...
@@ -84,7 +86,11 @@ func (s *serviceSet) dispatch(ctx context.Context, serviceName, methodName strin
84 84
 		return nil
85 85
 	}
86 86
 
87
-	resp, err := method(ctx, unmarshal)
87
+	info := &UnaryServerInfo{
88
+		FullMethod: fullPath(serviceName, methodName),
89
+	}
90
+
91
+	resp, err := s.interceptor(ctx, unmarshal, info, method)
88 92
 	if err != nil {
89 93
 		return nil, err
90 94
 	}
... ...
@@ -146,5 +152,5 @@ func convertCode(err error) codes.Code {
146 146
 }
147 147
 
148 148
 func fullPath(service, method string) string {
149
-	return "/" + path.Join("/", service, method)
149
+	return "/" + path.Join(service, method)
150 150
 }
... ...
@@ -23,10 +23,11 @@ import (
23 23
 )
24 24
 
25 25
 type Request struct {
26
-	Service     string `protobuf:"bytes,1,opt,name=service,proto3"`
27
-	Method      string `protobuf:"bytes,2,opt,name=method,proto3"`
28
-	Payload     []byte `protobuf:"bytes,3,opt,name=payload,proto3"`
29
-	TimeoutNano int64  `protobuf:"varint,4,opt,name=timeout_nano,proto3"`
26
+	Service     string      `protobuf:"bytes,1,opt,name=service,proto3"`
27
+	Method      string      `protobuf:"bytes,2,opt,name=method,proto3"`
28
+	Payload     []byte      `protobuf:"bytes,3,opt,name=payload,proto3"`
29
+	TimeoutNano int64       `protobuf:"varint,4,opt,name=timeout_nano,proto3"`
30
+	Metadata    []*KeyValue `protobuf:"bytes,5,rep,name=metadata,proto3"`
30 31
 }
31 32
 
32 33
 func (r *Request) Reset()         { *r = Request{} }
... ...
@@ -41,3 +42,22 @@ type Response struct {
41 41
 func (r *Response) Reset()         { *r = Response{} }
42 42
 func (r *Response) String() string { return fmt.Sprintf("%+#v", r) }
43 43
 func (r *Response) ProtoMessage()  {}
44
+
45
+type StringList struct {
46
+	List []string `protobuf:"bytes,1,rep,name=list,proto3"`
47
+}
48
+
49
+func (r *StringList) Reset()         { *r = StringList{} }
50
+func (r *StringList) String() string { return fmt.Sprintf("%+#v", r) }
51
+func (r *StringList) ProtoMessage()  {}
52
+
53
+func makeStringList(item ...string) StringList { return StringList{List: item} }
54
+
55
+type KeyValue struct {
56
+	Key   string `protobuf:"bytes,1,opt,name=key,proto3"`
57
+	Value string `protobuf:"bytes,2,opt,name=value,proto3"`
58
+}
59
+
60
+func (m *KeyValue) Reset()         { *m = KeyValue{} }
61
+func (*KeyValue) ProtoMessage()    {}
62
+func (m *KeyValue) String() string { return fmt.Sprintf("%+#v", m) }
... ...
@@ -1,42 +1,96 @@
1 1
 # gRPC-Go
2 2
 
3
-[![Build Status](https://travis-ci.org/grpc/grpc-go.svg)](https://travis-ci.org/grpc/grpc-go) [![GoDoc](https://godoc.org/google.golang.org/grpc?status.svg)](https://godoc.org/google.golang.org/grpc) [![GoReportCard](https://goreportcard.com/badge/grpc/grpc-go)](https://goreportcard.com/report/github.com/grpc/grpc-go)
3
+[![Build Status](https://travis-ci.org/grpc/grpc-go.svg)](https://travis-ci.org/grpc/grpc-go)
4
+[![GoDoc](https://godoc.org/google.golang.org/grpc?status.svg)](https://godoc.org/google.golang.org/grpc)
5
+[![GoReportCard](https://goreportcard.com/badge/grpc/grpc-go)](https://goreportcard.com/report/github.com/grpc/grpc-go)
4 6
 
5
-The Go implementation of [gRPC](https://grpc.io/): A high performance, open source, general RPC framework that puts mobile and HTTP/2 first. For more information see the [gRPC Quick Start: Go](https://grpc.io/docs/quickstart/go.html) guide.
7
+The Go implementation of [gRPC](https://grpc.io/): A high performance, open
8
+source, general RPC framework that puts mobile and HTTP/2 first. For more
9
+information see the [gRPC Quick Start:
10
+Go](https://grpc.io/docs/quickstart/go.html) guide.
6 11
 
7 12
 Installation
8 13
 ------------
9 14
 
10
-To install this package, you need to install Go and setup your Go workspace on your computer. The simplest way to install the library is to run:
15
+To install this package, you need to install Go and setup your Go workspace on
16
+your computer. The simplest way to install the library is to run:
11 17
 
12 18
 ```
13 19
 $ go get -u google.golang.org/grpc
14 20
 ```
15 21
 
22
+With Go module support (Go 1.11+), simply `import "google.golang.org/grpc"` in
23
+your source code and `go [build|run|test]` will automatically download the
24
+necessary dependencies ([Go modules
25
+ref](https://github.com/golang/go/wiki/Modules)).
26
+
27
+If you are trying to access grpc-go from within China, please see the
28
+[FAQ](#FAQ) below.
29
+
16 30
 Prerequisites
17 31
 -------------
18
-
19 32
 gRPC-Go requires Go 1.9 or later.
20 33
 
21
-Constraints
22
-The grpc package should only depend on standard Go packages and a small number of exceptions. If your contribution introduces new dependencies which are NOT in the [list](https://godoc.org/google.golang.org/grpc?imports), you need a discussion with gRPC-Go authors and consultants.
23
-
24 34
 Documentation
25 35
 -------------
26
-See [API documentation](https://godoc.org/google.golang.org/grpc) for package and API descriptions and find examples in the [examples directory](examples/).
36
+- See [godoc](https://godoc.org/google.golang.org/grpc) for package and API
37
+  descriptions.
38
+- Documentation on specific topics can be found in the [Documentation
39
+  directory](Documentation/).
40
+- Examples can be found in the [examples directory](examples/).
27 41
 
28 42
 Performance
29 43
 -----------
30
-See the current benchmarks for some of the languages supported in [this dashboard](https://performance-dot-grpc-testing.appspot.com/explore?dashboard=5652536396611584&widget=490377658&container=1286539696).
44
+Performance benchmark data for grpc-go and other languages is maintained in
45
+[this
46
+dashboard](https://performance-dot-grpc-testing.appspot.com/explore?dashboard=5652536396611584&widget=490377658&container=1286539696).
31 47
 
32 48
 Status
33 49
 ------
34
-General Availability [Google Cloud Platform Launch Stages](https://cloud.google.com/terms/launch-stages).
50
+General Availability [Google Cloud Platform Launch
51
+Stages](https://cloud.google.com/terms/launch-stages).
35 52
 
36 53
 FAQ
37 54
 ---
38 55
 
56
+#### I/O Timeout Errors
57
+
58
+The `golang.org` domain may be blocked from some countries.  `go get` usually
59
+produces an error like the following when this happens:
60
+
61
+```
62
+$ go get -u google.golang.org/grpc
63
+package google.golang.org/grpc: unrecognized import path "google.golang.org/grpc" (https fetch: Get https://google.golang.org/grpc?go-get=1: dial tcp 216.239.37.1:443: i/o timeout)
64
+```
65
+
66
+To build Go code, there are several options:
67
+
68
+- Set up a VPN and access google.golang.org through that.
69
+
70
+- Without Go module support: `git clone` the repo manually:
71
+
72
+  ```
73
+  git clone https://github.com/grpc/grpc-go.git $GOPATH/src/google.golang.org/grpc
74
+  ```
75
+
76
+  You will need to do the same for all of grpc's dependencies in `golang.org`,
77
+  e.g. `golang.org/x/net`.
78
+
79
+- With Go module support: it is possible to use the `replace` feature of `go
80
+  mod` to create aliases for golang.org packages.  In your project's directory:
81
+
82
+  ```
83
+  go mod edit -replace=google.golang.org/grpc=github.com/grpc/grpc-go@latest
84
+  go mod tidy
85
+  go mod vendor
86
+  go build -mod=vendor
87
+  ```
88
+
89
+  Again, this will need to be done for all transitive dependencies hosted on
90
+  golang.org as well.  Please refer to [this
91
+  issue](https://github.com/golang/go/issues/28652) in the golang repo regarding
92
+  this concern.
93
+
39 94
 #### Compiling error, undefined: grpc.SupportPackageIsVersion
40 95
 
41 96
 Please update proto package, gRPC package and rebuild the proto files:
... ...
@@ -43,7 +43,7 @@ type Address struct {
43 43
 
44 44
 // BalancerConfig specifies the configurations for Balancer.
45 45
 //
46
-// Deprecated: please use package balancer.
46
+// Deprecated: please use package balancer.  May be removed in a future 1.x release.
47 47
 type BalancerConfig struct {
48 48
 	// DialCreds is the transport credential the Balancer implementation can
49 49
 	// use to dial to a remote load balancer server. The Balancer implementations
... ...
@@ -57,7 +57,7 @@ type BalancerConfig struct {
57 57
 
58 58
 // BalancerGetOptions configures a Get call.
59 59
 //
60
-// Deprecated: please use package balancer.
60
+// Deprecated: please use package balancer.  May be removed in a future 1.x release.
61 61
 type BalancerGetOptions struct {
62 62
 	// BlockingWait specifies whether Get should block when there is no
63 63
 	// connected address.
... ...
@@ -66,7 +66,7 @@ type BalancerGetOptions struct {
66 66
 
67 67
 // Balancer chooses network addresses for RPCs.
68 68
 //
69
-// Deprecated: please use package balancer.
69
+// Deprecated: please use package balancer.  May be removed in a future 1.x release.
70 70
 type Balancer interface {
71 71
 	// Start does the initialization work to bootstrap a Balancer. For example,
72 72
 	// this function may start the name resolution and watch the updates. It will
... ...
@@ -120,7 +120,7 @@ type Balancer interface {
120 120
 // RoundRobin returns a Balancer that selects addresses round-robin. It uses r to watch
121 121
 // the name resolution updates and updates the addresses available correspondingly.
122 122
 //
123
-// Deprecated: please use package balancer/roundrobin.
123
+// Deprecated: please use package balancer/roundrobin. May be removed in a future 1.x release.
124 124
 func RoundRobin(r naming.Resolver) Balancer {
125 125
 	return &roundRobin{r: r}
126 126
 }
... ...
@@ -22,6 +22,7 @@ package balancer
22 22
 
23 23
 import (
24 24
 	"context"
25
+	"encoding/json"
25 26
 	"errors"
26 27
 	"net"
27 28
 	"strings"
... ...
@@ -31,6 +32,7 @@ import (
31 31
 	"google.golang.org/grpc/internal"
32 32
 	"google.golang.org/grpc/metadata"
33 33
 	"google.golang.org/grpc/resolver"
34
+	"google.golang.org/grpc/serviceconfig"
34 35
 )
35 36
 
36 37
 var (
... ...
@@ -39,7 +41,10 @@ var (
39 39
 )
40 40
 
41 41
 // Register registers the balancer builder to the balancer map. b.Name
42
-// (lowercased) will be used as the name registered with this builder.
42
+// (lowercased) will be used as the name registered with this builder.  If the
43
+// Builder implements ConfigParser, ParseConfig will be called when new service
44
+// configs are received by the resolver, and the result will be provided to the
45
+// Balancer in UpdateClientConnState.
43 46
 //
44 47
 // NOTE: this function must only be called during initialization time (i.e. in
45 48
 // an init() function), and is not thread-safe. If multiple Balancers are
... ...
@@ -138,6 +143,8 @@ type ClientConn interface {
138 138
 	ResolveNow(resolver.ResolveNowOption)
139 139
 
140 140
 	// Target returns the dial target for this ClientConn.
141
+	//
142
+	// Deprecated: Use the Target field in the BuildOptions instead.
141 143
 	Target() string
142 144
 }
143 145
 
... ...
@@ -155,6 +162,10 @@ type BuildOptions struct {
155 155
 	Dialer func(context.Context, string) (net.Conn, error)
156 156
 	// ChannelzParentID is the entity parent's channelz unique identification number.
157 157
 	ChannelzParentID int64
158
+	// Target contains the parsed address info of the dial target. It is the same resolver.Target as
159
+	// passed to the resolver.
160
+	// See the documentation for the resolver.Target type for details about what it contains.
161
+	Target resolver.Target
158 162
 }
159 163
 
160 164
 // Builder creates a balancer.
... ...
@@ -166,6 +177,14 @@ type Builder interface {
166 166
 	Name() string
167 167
 }
168 168
 
169
+// ConfigParser parses load balancer configs.
170
+type ConfigParser interface {
171
+	// ParseConfig parses the JSON load balancer config provided into an
172
+	// internal form or returns an error if the config is invalid.  For future
173
+	// compatibility reasons, unknown fields in the config should be ignored.
174
+	ParseConfig(LoadBalancingConfigJSON json.RawMessage) (serviceconfig.LoadBalancingConfig, error)
175
+}
176
+
169 177
 // PickOptions contains addition information for the Pick operation.
170 178
 type PickOptions struct {
171 179
 	// FullMethodName is the method name that NewClientStream() is called
... ...
@@ -264,7 +283,7 @@ type Balancer interface {
264 264
 	// non-nil error to gRPC.
265 265
 	//
266 266
 	// Deprecated: if V2Balancer is implemented by the Balancer,
267
-	// UpdateResolverState will be called instead.
267
+	// UpdateClientConnState will be called instead.
268 268
 	HandleResolvedAddrs([]resolver.Address, error)
269 269
 	// Close closes the balancer. The balancer is not required to call
270 270
 	// ClientConn.RemoveSubConn for its existing SubConns.
... ...
@@ -277,14 +296,23 @@ type SubConnState struct {
277 277
 	// TODO: add last connection error
278 278
 }
279 279
 
280
+// ClientConnState describes the state of a ClientConn relevant to the
281
+// balancer.
282
+type ClientConnState struct {
283
+	ResolverState resolver.State
284
+	// The parsed load balancing configuration returned by the builder's
285
+	// ParseConfig method, if implemented.
286
+	BalancerConfig serviceconfig.LoadBalancingConfig
287
+}
288
+
280 289
 // V2Balancer is defined for documentation purposes.  If a Balancer also
281
-// implements V2Balancer, its UpdateResolverState method will be called instead
282
-// of HandleResolvedAddrs and its UpdateSubConnState will be called instead of
283
-// HandleSubConnStateChange.
290
+// implements V2Balancer, its UpdateClientConnState method will be called
291
+// instead of HandleResolvedAddrs and its UpdateSubConnState will be called
292
+// instead of HandleSubConnStateChange.
284 293
 type V2Balancer interface {
285
-	// UpdateResolverState is called by gRPC when the state of the resolver
294
+	// UpdateClientConnState is called by gRPC when the state of the ClientConn
286 295
 	// changes.
287
-	UpdateResolverState(resolver.State)
296
+	UpdateClientConnState(ClientConnState)
288 297
 	// UpdateSubConnState is called by gRPC when the state of a SubConn
289 298
 	// changes.
290 299
 	UpdateSubConnState(SubConn, SubConnState)
... ...
@@ -70,13 +70,15 @@ func (b *baseBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error)
70 70
 	panic("not implemented")
71 71
 }
72 72
 
73
-func (b *baseBalancer) UpdateResolverState(s resolver.State) {
74
-	// TODO: handle s.Err (log if not nil) once implemented.
75
-	// TODO: handle s.ServiceConfig?
76
-	grpclog.Infoln("base.baseBalancer: got new resolver state: ", s)
73
+func (b *baseBalancer) UpdateClientConnState(s balancer.ClientConnState) {
74
+	// TODO: handle s.ResolverState.Err (log if not nil) once implemented.
75
+	// TODO: handle s.ResolverState.ServiceConfig?
76
+	if grpclog.V(2) {
77
+		grpclog.Infoln("base.baseBalancer: got new ClientConn state: ", s)
78
+	}
77 79
 	// addrsSet is the set converted from addrs, it's used for quick lookup of an address.
78 80
 	addrsSet := make(map[resolver.Address]struct{})
79
-	for _, a := range s.Addresses {
81
+	for _, a := range s.ResolverState.Addresses {
80 82
 		addrsSet[a] = struct{}{}
81 83
 		if _, ok := b.subConns[a]; !ok {
82 84
 			// a is a new address (not existing in b.subConns).
... ...
@@ -127,10 +129,14 @@ func (b *baseBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectiv
127 127
 
128 128
 func (b *baseBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
129 129
 	s := state.ConnectivityState
130
-	grpclog.Infof("base.baseBalancer: handle SubConn state change: %p, %v", sc, s)
130
+	if grpclog.V(2) {
131
+		grpclog.Infof("base.baseBalancer: handle SubConn state change: %p, %v", sc, s)
132
+	}
131 133
 	oldS, ok := b.scStates[sc]
132 134
 	if !ok {
133
-		grpclog.Infof("base.baseBalancer: got state changes for an unknown SubConn: %p, %v", sc, s)
135
+		if grpclog.V(2) {
136
+			grpclog.Infof("base.baseBalancer: got state changes for an unknown SubConn: %p, %v", sc, s)
137
+		}
134 138
 		return
135 139
 	}
136 140
 	b.scStates[sc] = s
... ...
@@ -88,7 +88,7 @@ type ccBalancerWrapper struct {
88 88
 	cc               *ClientConn
89 89
 	balancer         balancer.Balancer
90 90
 	stateChangeQueue *scStateUpdateBuffer
91
-	resolverUpdateCh chan *resolver.State
91
+	ccUpdateCh       chan *balancer.ClientConnState
92 92
 	done             chan struct{}
93 93
 
94 94
 	mu       sync.Mutex
... ...
@@ -99,7 +99,7 @@ func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.Bui
99 99
 	ccb := &ccBalancerWrapper{
100 100
 		cc:               cc,
101 101
 		stateChangeQueue: newSCStateUpdateBuffer(),
102
-		resolverUpdateCh: make(chan *resolver.State, 1),
102
+		ccUpdateCh:       make(chan *balancer.ClientConnState, 1),
103 103
 		done:             make(chan struct{}),
104 104
 		subConns:         make(map[*acBalancerWrapper]struct{}),
105 105
 	}
... ...
@@ -126,7 +126,7 @@ func (ccb *ccBalancerWrapper) watcher() {
126 126
 			} else {
127 127
 				ccb.balancer.HandleSubConnStateChange(t.sc, t.state)
128 128
 			}
129
-		case s := <-ccb.resolverUpdateCh:
129
+		case s := <-ccb.ccUpdateCh:
130 130
 			select {
131 131
 			case <-ccb.done:
132 132
 				ccb.balancer.Close()
... ...
@@ -134,9 +134,9 @@ func (ccb *ccBalancerWrapper) watcher() {
134 134
 			default:
135 135
 			}
136 136
 			if ub, ok := ccb.balancer.(balancer.V2Balancer); ok {
137
-				ub.UpdateResolverState(*s)
137
+				ub.UpdateClientConnState(*s)
138 138
 			} else {
139
-				ccb.balancer.HandleResolvedAddrs(s.Addresses, nil)
139
+				ccb.balancer.HandleResolvedAddrs(s.ResolverState.Addresses, nil)
140 140
 			}
141 141
 		case <-ccb.done:
142 142
 		}
... ...
@@ -151,9 +151,11 @@ func (ccb *ccBalancerWrapper) watcher() {
151 151
 			for acbw := range scs {
152 152
 				ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain)
153 153
 			}
154
+			ccb.UpdateBalancerState(connectivity.Connecting, nil)
154 155
 			return
155 156
 		default:
156 157
 		}
158
+		ccb.cc.firstResolveEvent.Fire()
157 159
 	}
158 160
 }
159 161
 
... ...
@@ -178,9 +180,10 @@ func (ccb *ccBalancerWrapper) handleSubConnStateChange(sc balancer.SubConn, s co
178 178
 	})
179 179
 }
180 180
 
181
-func (ccb *ccBalancerWrapper) updateResolverState(s resolver.State) {
181
+func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnState) {
182 182
 	if ccb.cc.curBalancerName != grpclbName {
183 183
 		// Filter any grpclb addresses since we don't have the grpclb balancer.
184
+		s := &ccs.ResolverState
184 185
 		for i := 0; i < len(s.Addresses); {
185 186
 			if s.Addresses[i].Type == resolver.GRPCLB {
186 187
 				copy(s.Addresses[i:], s.Addresses[i+1:])
... ...
@@ -191,10 +194,10 @@ func (ccb *ccBalancerWrapper) updateResolverState(s resolver.State) {
191 191
 		}
192 192
 	}
193 193
 	select {
194
-	case <-ccb.resolverUpdateCh:
194
+	case <-ccb.ccUpdateCh:
195 195
 	default:
196 196
 	}
197
-	ccb.resolverUpdateCh <- &s
197
+	ccb.ccUpdateCh <- ccs
198 198
 }
199 199
 
200 200
 func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
... ...
@@ -20,7 +20,6 @@ package grpc
20 20
 
21 21
 import (
22 22
 	"context"
23
-	"strings"
24 23
 	"sync"
25 24
 
26 25
 	"google.golang.org/grpc/balancer"
... ...
@@ -34,13 +33,7 @@ type balancerWrapperBuilder struct {
34 34
 }
35 35
 
36 36
 func (bwb *balancerWrapperBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
37
-	targetAddr := cc.Target()
38
-	targetSplitted := strings.Split(targetAddr, ":///")
39
-	if len(targetSplitted) >= 2 {
40
-		targetAddr = targetSplitted[1]
41
-	}
42
-
43
-	bwb.b.Start(targetAddr, BalancerConfig{
37
+	bwb.b.Start(opts.Target.Endpoint, BalancerConfig{
44 38
 		DialCreds: opts.DialCreds,
45 39
 		Dialer:    opts.Dialer,
46 40
 	})
... ...
@@ -49,7 +42,7 @@ func (bwb *balancerWrapperBuilder) Build(cc balancer.ClientConn, opts balancer.B
49 49
 		balancer:   bwb.b,
50 50
 		pickfirst:  pickfirst,
51 51
 		cc:         cc,
52
-		targetAddr: targetAddr,
52
+		targetAddr: opts.Target.Endpoint,
53 53
 		startCh:    make(chan struct{}),
54 54
 		conns:      make(map[resolver.Address]balancer.SubConn),
55 55
 		connSt:     make(map[balancer.SubConn]*scState),
... ...
@@ -120,7 +113,7 @@ func (bw *balancerWrapper) lbWatcher() {
120 120
 	}
121 121
 
122 122
 	for addrs := range notifyCh {
123
-		grpclog.Infof("balancerWrapper: got update addr from Notify: %v\n", addrs)
123
+		grpclog.Infof("balancerWrapper: got update addr from Notify: %v", addrs)
124 124
 		if bw.pickfirst {
125 125
 			var (
126 126
 				oldA  resolver.Address
... ...
@@ -38,13 +38,13 @@ import (
38 38
 	"google.golang.org/grpc/grpclog"
39 39
 	"google.golang.org/grpc/internal/backoff"
40 40
 	"google.golang.org/grpc/internal/channelz"
41
-	"google.golang.org/grpc/internal/envconfig"
42 41
 	"google.golang.org/grpc/internal/grpcsync"
43 42
 	"google.golang.org/grpc/internal/transport"
44 43
 	"google.golang.org/grpc/keepalive"
45 44
 	"google.golang.org/grpc/resolver"
46 45
 	_ "google.golang.org/grpc/resolver/dns"         // To register dns resolver.
47 46
 	_ "google.golang.org/grpc/resolver/passthrough" // To register passthrough resolver.
47
+	"google.golang.org/grpc/serviceconfig"
48 48
 	"google.golang.org/grpc/status"
49 49
 )
50 50
 
... ...
@@ -137,6 +137,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
137 137
 		opt.apply(&cc.dopts)
138 138
 	}
139 139
 
140
+	chainUnaryClientInterceptors(cc)
141
+	chainStreamClientInterceptors(cc)
142
+
140 143
 	defer func() {
141 144
 		if err != nil {
142 145
 			cc.Close()
... ...
@@ -290,6 +293,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
290 290
 		CredsBundle:      cc.dopts.copts.CredsBundle,
291 291
 		Dialer:           cc.dopts.copts.Dialer,
292 292
 		ChannelzParentID: cc.channelzID,
293
+		Target:           cc.parsedTarget,
293 294
 	}
294 295
 
295 296
 	// Build the resolver.
... ...
@@ -327,6 +331,68 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
327 327
 	return cc, nil
328 328
 }
329 329
 
330
+// chainUnaryClientInterceptors chains all unary client interceptors into one.
331
+func chainUnaryClientInterceptors(cc *ClientConn) {
332
+	interceptors := cc.dopts.chainUnaryInts
333
+	// Prepend dopts.unaryInt to the chaining interceptors if it exists, since unaryInt will
334
+	// be executed before any other chained interceptors.
335
+	if cc.dopts.unaryInt != nil {
336
+		interceptors = append([]UnaryClientInterceptor{cc.dopts.unaryInt}, interceptors...)
337
+	}
338
+	var chainedInt UnaryClientInterceptor
339
+	if len(interceptors) == 0 {
340
+		chainedInt = nil
341
+	} else if len(interceptors) == 1 {
342
+		chainedInt = interceptors[0]
343
+	} else {
344
+		chainedInt = func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
345
+			return interceptors[0](ctx, method, req, reply, cc, getChainUnaryInvoker(interceptors, 0, invoker), opts...)
346
+		}
347
+	}
348
+	cc.dopts.unaryInt = chainedInt
349
+}
350
+
351
+// getChainUnaryInvoker recursively generate the chained unary invoker.
352
+func getChainUnaryInvoker(interceptors []UnaryClientInterceptor, curr int, finalInvoker UnaryInvoker) UnaryInvoker {
353
+	if curr == len(interceptors)-1 {
354
+		return finalInvoker
355
+	}
356
+	return func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error {
357
+		return interceptors[curr+1](ctx, method, req, reply, cc, getChainUnaryInvoker(interceptors, curr+1, finalInvoker), opts...)
358
+	}
359
+}
360
+
361
+// chainStreamClientInterceptors chains all stream client interceptors into one.
362
+func chainStreamClientInterceptors(cc *ClientConn) {
363
+	interceptors := cc.dopts.chainStreamInts
364
+	// Prepend dopts.streamInt to the chaining interceptors if it exists, since streamInt will
365
+	// be executed before any other chained interceptors.
366
+	if cc.dopts.streamInt != nil {
367
+		interceptors = append([]StreamClientInterceptor{cc.dopts.streamInt}, interceptors...)
368
+	}
369
+	var chainedInt StreamClientInterceptor
370
+	if len(interceptors) == 0 {
371
+		chainedInt = nil
372
+	} else if len(interceptors) == 1 {
373
+		chainedInt = interceptors[0]
374
+	} else {
375
+		chainedInt = func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
376
+			return interceptors[0](ctx, desc, cc, method, getChainStreamer(interceptors, 0, streamer), opts...)
377
+		}
378
+	}
379
+	cc.dopts.streamInt = chainedInt
380
+}
381
+
382
+// getChainStreamer recursively generate the chained client stream constructor.
383
+func getChainStreamer(interceptors []StreamClientInterceptor, curr int, finalStreamer Streamer) Streamer {
384
+	if curr == len(interceptors)-1 {
385
+		return finalStreamer
386
+	}
387
+	return func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
388
+		return interceptors[curr+1](ctx, desc, cc, method, getChainStreamer(interceptors, curr+1, finalStreamer), opts...)
389
+	}
390
+}
391
+
330 392
 // connectivityStateManager keeps the connectivity.State of ClientConn.
331 393
 // This struct will eventually be exported so the balancers can access it.
332 394
 type connectivityStateManager struct {
... ...
@@ -466,24 +532,6 @@ func (cc *ClientConn) waitForResolvedAddrs(ctx context.Context) error {
466 466
 	}
467 467
 }
468 468
 
469
-// gRPC should resort to default service config when:
470
-// * resolver service config is disabled
471
-// * or, resolver does not return a service config or returns an invalid one.
472
-func (cc *ClientConn) fallbackToDefaultServiceConfig(sc string) bool {
473
-	if cc.dopts.disableServiceConfig {
474
-		return true
475
-	}
476
-	// The logic below is temporary, will be removed once we change the resolver.State ServiceConfig field type.
477
-	// Right now, we assume that empty service config string means resolver does not return a config.
478
-	if sc == "" {
479
-		return true
480
-	}
481
-	// TODO: the logic below is temporary. Once we finish the logic to validate service config
482
-	// in resolver, we will replace the logic below.
483
-	_, err := parseServiceConfig(sc)
484
-	return err != nil
485
-}
486
-
487 469
 func (cc *ClientConn) updateResolverState(s resolver.State) error {
488 470
 	cc.mu.Lock()
489 471
 	defer cc.mu.Unlock()
... ...
@@ -494,54 +542,47 @@ func (cc *ClientConn) updateResolverState(s resolver.State) error {
494 494
 		return nil
495 495
 	}
496 496
 
497
-	if cc.fallbackToDefaultServiceConfig(s.ServiceConfig) {
497
+	if cc.dopts.disableServiceConfig || s.ServiceConfig == nil {
498 498
 		if cc.dopts.defaultServiceConfig != nil && cc.sc == nil {
499 499
 			cc.applyServiceConfig(cc.dopts.defaultServiceConfig)
500 500
 		}
501
-	} else {
502
-		// TODO: the parsing logic below will be moved inside resolver.
503
-		sc, err := parseServiceConfig(s.ServiceConfig)
504
-		if err != nil {
505
-			return err
506
-		}
507
-		if cc.sc == nil || cc.sc.rawJSONString != s.ServiceConfig {
508
-			cc.applyServiceConfig(sc)
509
-		}
510
-	}
511
-
512
-	// update the service config that will be sent to balancer.
513
-	if cc.sc != nil {
514
-		s.ServiceConfig = cc.sc.rawJSONString
501
+	} else if sc, ok := s.ServiceConfig.(*ServiceConfig); ok {
502
+		cc.applyServiceConfig(sc)
515 503
 	}
516 504
 
505
+	var balCfg serviceconfig.LoadBalancingConfig
517 506
 	if cc.dopts.balancerBuilder == nil {
518 507
 		// Only look at balancer types and switch balancer if balancer dial
519 508
 		// option is not set.
520
-		var isGRPCLB bool
521
-		for _, a := range s.Addresses {
522
-			if a.Type == resolver.GRPCLB {
523
-				isGRPCLB = true
524
-				break
525
-			}
526
-		}
527 509
 		var newBalancerName string
528
-		// TODO: use new loadBalancerConfig field with appropriate priority.
529
-		if isGRPCLB {
530
-			newBalancerName = grpclbName
531
-		} else if cc.sc != nil && cc.sc.LB != nil {
532
-			newBalancerName = *cc.sc.LB
510
+		if cc.sc != nil && cc.sc.lbConfig != nil {
511
+			newBalancerName = cc.sc.lbConfig.name
512
+			balCfg = cc.sc.lbConfig.cfg
533 513
 		} else {
534
-			newBalancerName = PickFirstBalancerName
514
+			var isGRPCLB bool
515
+			for _, a := range s.Addresses {
516
+				if a.Type == resolver.GRPCLB {
517
+					isGRPCLB = true
518
+					break
519
+				}
520
+			}
521
+			if isGRPCLB {
522
+				newBalancerName = grpclbName
523
+			} else if cc.sc != nil && cc.sc.LB != nil {
524
+				newBalancerName = *cc.sc.LB
525
+			} else {
526
+				newBalancerName = PickFirstBalancerName
527
+			}
535 528
 		}
536 529
 		cc.switchBalancer(newBalancerName)
537 530
 	} else if cc.balancerWrapper == nil {
538 531
 		// Balancer dial option was set, and this is the first time handling
539 532
 		// resolved addresses. Build a balancer with dopts.balancerBuilder.
533
+		cc.curBalancerName = cc.dopts.balancerBuilder.Name()
540 534
 		cc.balancerWrapper = newCCBalancerWrapper(cc, cc.dopts.balancerBuilder, cc.balancerBuildOpts)
541 535
 	}
542 536
 
543
-	cc.balancerWrapper.updateResolverState(s)
544
-	cc.firstResolveEvent.Fire()
537
+	cc.balancerWrapper.updateClientConnState(&balancer.ClientConnState{ResolverState: s, BalancerConfig: balCfg})
545 538
 	return nil
546 539
 }
547 540
 
... ...
@@ -554,7 +595,7 @@ func (cc *ClientConn) updateResolverState(s resolver.State) error {
554 554
 //
555 555
 // Caller must hold cc.mu.
556 556
 func (cc *ClientConn) switchBalancer(name string) {
557
-	if strings.ToLower(cc.curBalancerName) == strings.ToLower(name) {
557
+	if strings.EqualFold(cc.curBalancerName, name) {
558 558
 		return
559 559
 	}
560 560
 
... ...
@@ -693,6 +734,8 @@ func (ac *addrConn) connect() error {
693 693
 		ac.mu.Unlock()
694 694
 		return nil
695 695
 	}
696
+	// Update connectivity state within the lock to prevent subsequent or
697
+	// concurrent calls from resetting the transport more than once.
696 698
 	ac.updateConnectivityState(connectivity.Connecting)
697 699
 	ac.mu.Unlock()
698 700
 
... ...
@@ -703,7 +746,16 @@ func (ac *addrConn) connect() error {
703 703
 
704 704
 // tryUpdateAddrs tries to update ac.addrs with the new addresses list.
705 705
 //
706
-// It checks whether current connected address of ac is in the new addrs list.
706
+// If ac is Connecting, it returns false. The caller should tear down the ac and
707
+// create a new one. Note that the backoff will be reset when this happens.
708
+//
709
+// If ac is TransientFailure, it updates ac.addrs and returns true. The updated
710
+// addresses will be picked up by retry in the next iteration after backoff.
711
+//
712
+// If ac is Shutdown or Idle, it updates ac.addrs and returns true.
713
+//
714
+// If ac is Ready, it checks whether current connected address of ac is in the
715
+// new addrs list.
707 716
 //  - If true, it updates ac.addrs and returns true. The ac will keep using
708 717
 //    the existing connection.
709 718
 //  - If false, it does nothing and returns false.
... ...
@@ -711,17 +763,18 @@ func (ac *addrConn) tryUpdateAddrs(addrs []resolver.Address) bool {
711 711
 	ac.mu.Lock()
712 712
 	defer ac.mu.Unlock()
713 713
 	grpclog.Infof("addrConn: tryUpdateAddrs curAddr: %v, addrs: %v", ac.curAddr, addrs)
714
-	if ac.state == connectivity.Shutdown {
714
+	if ac.state == connectivity.Shutdown ||
715
+		ac.state == connectivity.TransientFailure ||
716
+		ac.state == connectivity.Idle {
715 717
 		ac.addrs = addrs
716 718
 		return true
717 719
 	}
718 720
 
719
-	// Unless we're busy reconnecting already, let's reconnect from the top of
720
-	// the list.
721
-	if ac.state != connectivity.Ready {
721
+	if ac.state == connectivity.Connecting {
722 722
 		return false
723 723
 	}
724 724
 
725
+	// ac.state is Ready, try to find the connected address.
725 726
 	var curAddrFound bool
726 727
 	for _, a := range addrs {
727 728
 		if reflect.DeepEqual(ac.curAddr, a) {
... ...
@@ -970,6 +1023,9 @@ func (ac *addrConn) resetTransport() {
970 970
 		// The spec doesn't mention what should be done for multiple addresses.
971 971
 		// https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md#proposed-backoff-algorithm
972 972
 		connectDeadline := time.Now().Add(dialDuration)
973
+
974
+		ac.updateConnectivityState(connectivity.Connecting)
975
+		ac.transport = nil
973 976
 		ac.mu.Unlock()
974 977
 
975 978
 		newTr, addr, reconnect, err := ac.tryAllAddrs(addrs, connectDeadline)
... ...
@@ -1004,55 +1060,32 @@ func (ac *addrConn) resetTransport() {
1004 1004
 
1005 1005
 		ac.mu.Lock()
1006 1006
 		if ac.state == connectivity.Shutdown {
1007
-			newTr.Close()
1008 1007
 			ac.mu.Unlock()
1008
+			newTr.Close()
1009 1009
 			return
1010 1010
 		}
1011 1011
 		ac.curAddr = addr
1012 1012
 		ac.transport = newTr
1013 1013
 		ac.backoffIdx = 0
1014 1014
 
1015
-		healthCheckConfig := ac.cc.healthCheckConfig()
1016
-		// LB channel health checking is only enabled when all the four requirements below are met:
1017
-		// 1. it is not disabled by the user with the WithDisableHealthCheck DialOption,
1018
-		// 2. the internal.HealthCheckFunc is set by importing the grpc/healthcheck package,
1019
-		// 3. a service config with non-empty healthCheckConfig field is provided,
1020
-		// 4. the current load balancer allows it.
1021 1015
 		hctx, hcancel := context.WithCancel(ac.ctx)
1022
-		healthcheckManagingState := false
1023
-		if !ac.cc.dopts.disableHealthCheck && healthCheckConfig != nil && ac.scopts.HealthCheckEnabled {
1024
-			if ac.cc.dopts.healthCheckFunc == nil {
1025
-				// TODO: add a link to the health check doc in the error message.
1026
-				grpclog.Error("the client side LB channel health check function has not been set.")
1027
-			} else {
1028
-				// TODO(deklerk) refactor to just return transport
1029
-				go ac.startHealthCheck(hctx, newTr, addr, healthCheckConfig.ServiceName)
1030
-				healthcheckManagingState = true
1031
-			}
1032
-		}
1033
-		if !healthcheckManagingState {
1034
-			ac.updateConnectivityState(connectivity.Ready)
1035
-		}
1016
+		ac.startHealthCheck(hctx)
1036 1017
 		ac.mu.Unlock()
1037 1018
 
1038 1019
 		// Block until the created transport is down. And when this happens,
1039 1020
 		// we restart from the top of the addr list.
1040 1021
 		<-reconnect.Done()
1041 1022
 		hcancel()
1042
-
1043
-		// Need to reconnect after a READY, the addrConn enters
1044
-		// TRANSIENT_FAILURE.
1023
+		// restart connecting - the top of the loop will set state to
1024
+		// CONNECTING.  This is against the current connectivity semantics doc,
1025
+		// however it allows for graceful behavior for RPCs not yet dispatched
1026
+		// - unfortunate timing would otherwise lead to the RPC failing even
1027
+		// though the TRANSIENT_FAILURE state (called for by the doc) would be
1028
+		// instantaneous.
1045 1029
 		//
1046
-		// This will set addrConn to TRANSIENT_FAILURE for a very short period
1047
-		// of time, and turns CONNECTING. It seems reasonable to skip this, but
1048
-		// READY-CONNECTING is not a valid transition.
1049
-		ac.mu.Lock()
1050
-		if ac.state == connectivity.Shutdown {
1051
-			ac.mu.Unlock()
1052
-			return
1053
-		}
1054
-		ac.updateConnectivityState(connectivity.TransientFailure)
1055
-		ac.mu.Unlock()
1030
+		// Ideally we should transition to Idle here and block until there is
1031
+		// RPC activity that leads to the balancer requesting a reconnect of
1032
+		// the associated SubConn.
1056 1033
 	}
1057 1034
 }
1058 1035
 
... ...
@@ -1066,8 +1099,6 @@ func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.T
1066 1066
 			ac.mu.Unlock()
1067 1067
 			return nil, resolver.Address{}, nil, errConnClosing
1068 1068
 		}
1069
-		ac.updateConnectivityState(connectivity.Connecting)
1070
-		ac.transport = nil
1071 1069
 
1072 1070
 		ac.cc.mu.RLock()
1073 1071
 		ac.dopts.copts.KeepaliveParams = ac.cc.mkp
... ...
@@ -1111,14 +1142,35 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne
1111 1111
 		Authority: ac.cc.authority,
1112 1112
 	}
1113 1113
 
1114
+	once := sync.Once{}
1114 1115
 	onGoAway := func(r transport.GoAwayReason) {
1115 1116
 		ac.mu.Lock()
1116 1117
 		ac.adjustParams(r)
1118
+		once.Do(func() {
1119
+			if ac.state == connectivity.Ready {
1120
+				// Prevent this SubConn from being used for new RPCs by setting its
1121
+				// state to Connecting.
1122
+				//
1123
+				// TODO: this should be Idle when grpc-go properly supports it.
1124
+				ac.updateConnectivityState(connectivity.Connecting)
1125
+			}
1126
+		})
1117 1127
 		ac.mu.Unlock()
1118 1128
 		reconnect.Fire()
1119 1129
 	}
1120 1130
 
1121 1131
 	onClose := func() {
1132
+		ac.mu.Lock()
1133
+		once.Do(func() {
1134
+			if ac.state == connectivity.Ready {
1135
+				// Prevent this SubConn from being used for new RPCs by setting its
1136
+				// state to Connecting.
1137
+				//
1138
+				// TODO: this should be Idle when grpc-go properly supports it.
1139
+				ac.updateConnectivityState(connectivity.Connecting)
1140
+			}
1141
+		})
1142
+		ac.mu.Unlock()
1122 1143
 		close(onCloseCalled)
1123 1144
 		reconnect.Fire()
1124 1145
 	}
... ...
@@ -1140,60 +1192,99 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne
1140 1140
 		return nil, nil, err
1141 1141
 	}
1142 1142
 
1143
-	if ac.dopts.reqHandshake == envconfig.RequireHandshakeOn {
1144
-		select {
1145
-		case <-time.After(connectDeadline.Sub(time.Now())):
1146
-			// We didn't get the preface in time.
1147
-			newTr.Close()
1148
-			grpclog.Warningf("grpc: addrConn.createTransport failed to connect to %v: didn't receive server preface in time. Reconnecting...", addr)
1149
-			return nil, nil, errors.New("timed out waiting for server handshake")
1150
-		case <-prefaceReceived:
1151
-			// We got the preface - huzzah! things are good.
1152
-		case <-onCloseCalled:
1153
-			// The transport has already closed - noop.
1154
-			return nil, nil, errors.New("connection closed")
1155
-			// TODO(deklerk) this should bail on ac.ctx.Done(). Add a test and fix.
1156
-		}
1143
+	select {
1144
+	case <-time.After(connectDeadline.Sub(time.Now())):
1145
+		// We didn't get the preface in time.
1146
+		newTr.Close()
1147
+		grpclog.Warningf("grpc: addrConn.createTransport failed to connect to %v: didn't receive server preface in time. Reconnecting...", addr)
1148
+		return nil, nil, errors.New("timed out waiting for server handshake")
1149
+	case <-prefaceReceived:
1150
+		// We got the preface - huzzah! things are good.
1151
+	case <-onCloseCalled:
1152
+		// The transport has already closed - noop.
1153
+		return nil, nil, errors.New("connection closed")
1154
+		// TODO(deklerk) this should bail on ac.ctx.Done(). Add a test and fix.
1157 1155
 	}
1158 1156
 	return newTr, reconnect, nil
1159 1157
 }
1160 1158
 
1161
-func (ac *addrConn) startHealthCheck(ctx context.Context, newTr transport.ClientTransport, addr resolver.Address, serviceName string) {
1162
-	// Set up the health check helper functions
1163
-	newStream := func() (interface{}, error) {
1164
-		return ac.newClientStream(ctx, &StreamDesc{ServerStreams: true}, "/grpc.health.v1.Health/Watch", newTr)
1159
+// startHealthCheck starts the health checking stream (RPC) to watch the health
1160
+// stats of this connection if health checking is requested and configured.
1161
+//
1162
+// LB channel health checking is enabled when all requirements below are met:
1163
+// 1. it is not disabled by the user with the WithDisableHealthCheck DialOption
1164
+// 2. internal.HealthCheckFunc is set by importing the grpc/healthcheck package
1165
+// 3. a service config with non-empty healthCheckConfig field is provided
1166
+// 4. the load balancer requests it
1167
+//
1168
+// It sets addrConn to READY if the health checking stream is not started.
1169
+//
1170
+// Caller must hold ac.mu.
1171
+func (ac *addrConn) startHealthCheck(ctx context.Context) {
1172
+	var healthcheckManagingState bool
1173
+	defer func() {
1174
+		if !healthcheckManagingState {
1175
+			ac.updateConnectivityState(connectivity.Ready)
1176
+		}
1177
+	}()
1178
+
1179
+	if ac.cc.dopts.disableHealthCheck {
1180
+		return
1181
+	}
1182
+	healthCheckConfig := ac.cc.healthCheckConfig()
1183
+	if healthCheckConfig == nil {
1184
+		return
1185
+	}
1186
+	if !ac.scopts.HealthCheckEnabled {
1187
+		return
1188
+	}
1189
+	healthCheckFunc := ac.cc.dopts.healthCheckFunc
1190
+	if healthCheckFunc == nil {
1191
+		// The health package is not imported to set health check function.
1192
+		//
1193
+		// TODO: add a link to the health check doc in the error message.
1194
+		grpclog.Error("Health check is requested but health check function is not set.")
1195
+		return
1196
+	}
1197
+
1198
+	healthcheckManagingState = true
1199
+
1200
+	// Set up the health check helper functions.
1201
+	currentTr := ac.transport
1202
+	newStream := func(method string) (interface{}, error) {
1203
+		ac.mu.Lock()
1204
+		if ac.transport != currentTr {
1205
+			ac.mu.Unlock()
1206
+			return nil, status.Error(codes.Canceled, "the provided transport is no longer valid to use")
1207
+		}
1208
+		ac.mu.Unlock()
1209
+		return newNonRetryClientStream(ctx, &StreamDesc{ServerStreams: true}, method, currentTr, ac)
1165 1210
 	}
1166
-	firstReady := true
1167
-	reportHealth := func(ok bool) {
1211
+	setConnectivityState := func(s connectivity.State) {
1168 1212
 		ac.mu.Lock()
1169 1213
 		defer ac.mu.Unlock()
1170
-		if ac.transport != newTr {
1214
+		if ac.transport != currentTr {
1171 1215
 			return
1172 1216
 		}
1173
-		if ok {
1174
-			if firstReady {
1175
-				firstReady = false
1176
-				ac.curAddr = addr
1177
-			}
1178
-			ac.updateConnectivityState(connectivity.Ready)
1179
-		} else {
1180
-			ac.updateConnectivityState(connectivity.TransientFailure)
1181
-		}
1217
+		ac.updateConnectivityState(s)
1182 1218
 	}
1183
-	err := ac.cc.dopts.healthCheckFunc(ctx, newStream, reportHealth, serviceName)
1184
-	if err != nil {
1185
-		if status.Code(err) == codes.Unimplemented {
1186
-			if channelz.IsOn() {
1187
-				channelz.AddTraceEvent(ac.channelzID, &channelz.TraceEventDesc{
1188
-					Desc:     "Subchannel health check is unimplemented at server side, thus health check is disabled",
1189
-					Severity: channelz.CtError,
1190
-				})
1219
+	// Start the health checking stream.
1220
+	go func() {
1221
+		err := ac.cc.dopts.healthCheckFunc(ctx, newStream, setConnectivityState, healthCheckConfig.ServiceName)
1222
+		if err != nil {
1223
+			if status.Code(err) == codes.Unimplemented {
1224
+				if channelz.IsOn() {
1225
+					channelz.AddTraceEvent(ac.channelzID, &channelz.TraceEventDesc{
1226
+						Desc:     "Subchannel health check is unimplemented at server side, thus health check is disabled",
1227
+						Severity: channelz.CtError,
1228
+					})
1229
+				}
1230
+				grpclog.Error("Subchannel health check is unimplemented at server side, thus health check is disabled")
1231
+			} else {
1232
+				grpclog.Errorf("HealthCheckFunc exits with unexpected error %v", err)
1191 1233
 			}
1192
-			grpclog.Error("Subchannel health check is unimplemented at server side, thus health check is disabled")
1193
-		} else {
1194
-			grpclog.Errorf("HealthCheckFunc exits with unexpected error %v", err)
1195 1234
 		}
1196
-	}
1235
+	}()
1197 1236
 }
1198 1237
 
1199 1238
 func (ac *addrConn) resetConnectBackoff() {
... ...
@@ -132,7 +132,8 @@ const (
132 132
 
133 133
 	// Unavailable indicates the service is currently unavailable.
134 134
 	// This is a most likely a transient condition and may be corrected
135
-	// by retrying with a backoff.
135
+	// by retrying with a backoff. Note that it is not always safe to retry
136
+	// non-idempotent operations.
136 137
 	//
137 138
 	// See litmus test above for deciding between FailedPrecondition,
138 139
 	// Aborted, and Unavailable.
... ...
@@ -278,24 +278,22 @@ type ChannelzSecurityValue interface {
278 278
 // TLSChannelzSecurityValue defines the struct that TLS protocol should return
279 279
 // from GetSecurityValue(), containing security info like cipher and certificate used.
280 280
 type TLSChannelzSecurityValue struct {
281
+	ChannelzSecurityValue
281 282
 	StandardName      string
282 283
 	LocalCertificate  []byte
283 284
 	RemoteCertificate []byte
284 285
 }
285 286
 
286
-func (*TLSChannelzSecurityValue) isChannelzSecurityValue() {}
287
-
288 287
 // OtherChannelzSecurityValue defines the struct that non-TLS protocol should return
289 288
 // from GetSecurityValue(), which contains protocol specific security info. Note
290 289
 // the Value field will be sent to users of channelz requesting channel info, and
291 290
 // thus sensitive info should better be avoided.
292 291
 type OtherChannelzSecurityValue struct {
292
+	ChannelzSecurityValue
293 293
 	Name  string
294 294
 	Value proto.Message
295 295
 }
296 296
 
297
-func (*OtherChannelzSecurityValue) isChannelzSecurityValue() {}
298
-
299 297
 var cipherSuiteLookup = map[uint16]string{
300 298
 	tls.TLS_RSA_WITH_RC4_128_SHA:                "TLS_RSA_WITH_RC4_128_SHA",
301 299
 	tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA:           "TLS_RSA_WITH_3DES_EDE_CBC_SHA",
... ...
@@ -39,8 +39,12 @@ import (
39 39
 // dialOptions configure a Dial call. dialOptions are set by the DialOption
40 40
 // values passed to Dial.
41 41
 type dialOptions struct {
42
-	unaryInt    UnaryClientInterceptor
43
-	streamInt   StreamClientInterceptor
42
+	unaryInt  UnaryClientInterceptor
43
+	streamInt StreamClientInterceptor
44
+
45
+	chainUnaryInts  []UnaryClientInterceptor
46
+	chainStreamInts []StreamClientInterceptor
47
+
44 48
 	cp          Compressor
45 49
 	dc          Decompressor
46 50
 	bs          backoff.Strategy
... ...
@@ -56,7 +60,6 @@ type dialOptions struct {
56 56
 	balancerBuilder balancer.Builder
57 57
 	// This is to support grpclb.
58 58
 	resolverBuilder             resolver.Builder
59
-	reqHandshake                envconfig.RequireHandshakeSetting
60 59
 	channelzParentID            int64
61 60
 	disableServiceConfig        bool
62 61
 	disableRetry                bool
... ...
@@ -96,17 +99,6 @@ func newFuncDialOption(f func(*dialOptions)) *funcDialOption {
96 96
 	}
97 97
 }
98 98
 
99
-// WithWaitForHandshake blocks until the initial settings frame is received from
100
-// the server before assigning RPCs to the connection.
101
-//
102
-// Deprecated: this is the default behavior, and this option will be removed
103
-// after the 1.18 release.
104
-func WithWaitForHandshake() DialOption {
105
-	return newFuncDialOption(func(o *dialOptions) {
106
-		o.reqHandshake = envconfig.RequireHandshakeOn
107
-	})
108
-}
109
-
110 99
 // WithWriteBufferSize determines how much data can be batched before doing a
111 100
 // write on the wire. The corresponding memory allocation for this buffer will
112 101
 // be twice the size to keep syscalls low. The default value for this buffer is
... ...
@@ -152,7 +144,8 @@ func WithInitialConnWindowSize(s int32) DialOption {
152 152
 // WithMaxMsgSize returns a DialOption which sets the maximum message size the
153 153
 // client can receive.
154 154
 //
155
-// Deprecated: use WithDefaultCallOptions(MaxCallRecvMsgSize(s)) instead.
155
+// Deprecated: use WithDefaultCallOptions(MaxCallRecvMsgSize(s)) instead.  Will
156
+// be supported throughout 1.x.
156 157
 func WithMaxMsgSize(s int) DialOption {
157 158
 	return WithDefaultCallOptions(MaxCallRecvMsgSize(s))
158 159
 }
... ...
@@ -168,7 +161,8 @@ func WithDefaultCallOptions(cos ...CallOption) DialOption {
168 168
 // WithCodec returns a DialOption which sets a codec for message marshaling and
169 169
 // unmarshaling.
170 170
 //
171
-// Deprecated: use WithDefaultCallOptions(ForceCodec(_)) instead.
171
+// Deprecated: use WithDefaultCallOptions(ForceCodec(_)) instead.  Will be
172
+// supported throughout 1.x.
172 173
 func WithCodec(c Codec) DialOption {
173 174
 	return WithDefaultCallOptions(CallCustomCodec(c))
174 175
 }
... ...
@@ -177,7 +171,7 @@ func WithCodec(c Codec) DialOption {
177 177
 // message compression. It has lower priority than the compressor set by the
178 178
 // UseCompressor CallOption.
179 179
 //
180
-// Deprecated: use UseCompressor instead.
180
+// Deprecated: use UseCompressor instead.  Will be supported throughout 1.x.
181 181
 func WithCompressor(cp Compressor) DialOption {
182 182
 	return newFuncDialOption(func(o *dialOptions) {
183 183
 		o.cp = cp
... ...
@@ -192,7 +186,8 @@ func WithCompressor(cp Compressor) DialOption {
192 192
 // message.  If no compressor is registered for the encoding, an Unimplemented
193 193
 // status error will be returned.
194 194
 //
195
-// Deprecated: use encoding.RegisterCompressor instead.
195
+// Deprecated: use encoding.RegisterCompressor instead.  Will be supported
196
+// throughout 1.x.
196 197
 func WithDecompressor(dc Decompressor) DialOption {
197 198
 	return newFuncDialOption(func(o *dialOptions) {
198 199
 		o.dc = dc
... ...
@@ -203,7 +198,7 @@ func WithDecompressor(dc Decompressor) DialOption {
203 203
 // Name resolver will be ignored if this DialOption is specified.
204 204
 //
205 205
 // Deprecated: use the new balancer APIs in balancer package and
206
-// WithBalancerName.
206
+// WithBalancerName.  Will be removed in a future 1.x release.
207 207
 func WithBalancer(b Balancer) DialOption {
208 208
 	return newFuncDialOption(func(o *dialOptions) {
209 209
 		o.balancerBuilder = &balancerWrapperBuilder{
... ...
@@ -219,7 +214,8 @@ func WithBalancer(b Balancer) DialOption {
219 219
 // The balancer cannot be overridden by balancer option specified by service
220 220
 // config.
221 221
 //
222
-// This is an EXPERIMENTAL API.
222
+// Deprecated: use WithDefaultServiceConfig and WithDisableServiceConfig
223
+// instead.  Will be removed in a future 1.x release.
223 224
 func WithBalancerName(balancerName string) DialOption {
224 225
 	builder := balancer.Get(balancerName)
225 226
 	if builder == nil {
... ...
@@ -240,9 +236,10 @@ func withResolverBuilder(b resolver.Builder) DialOption {
240 240
 // WithServiceConfig returns a DialOption which has a channel to read the
241 241
 // service configuration.
242 242
 //
243
-// Deprecated: service config should be received through name resolver, as
244
-// specified here.
245
-// https://github.com/grpc/grpc/blob/master/doc/service_config.md
243
+// Deprecated: service config should be received through name resolver or via
244
+// WithDefaultServiceConfig, as specified at
245
+// https://github.com/grpc/grpc/blob/master/doc/service_config.md.  Will be
246
+// removed in a future 1.x release.
246 247
 func WithServiceConfig(c <-chan ServiceConfig) DialOption {
247 248
 	return newFuncDialOption(func(o *dialOptions) {
248 249
 		o.scChan = c
... ...
@@ -325,7 +322,8 @@ func WithCredentialsBundle(b credentials.Bundle) DialOption {
325 325
 // WithTimeout returns a DialOption that configures a timeout for dialing a
326 326
 // ClientConn initially. This is valid if and only if WithBlock() is present.
327 327
 //
328
-// Deprecated: use DialContext and context.WithTimeout instead.
328
+// Deprecated: use DialContext and context.WithTimeout instead.  Will be
329
+// supported throughout 1.x.
329 330
 func WithTimeout(d time.Duration) DialOption {
330 331
 	return newFuncDialOption(func(o *dialOptions) {
331 332
 		o.timeout = d
... ...
@@ -352,7 +350,8 @@ func init() {
352 352
 // is returned by f, gRPC checks the error's Temporary() method to decide if it
353 353
 // should try to reconnect to the network address.
354 354
 //
355
-// Deprecated: use WithContextDialer instead
355
+// Deprecated: use WithContextDialer instead.  Will be supported throughout
356
+// 1.x.
356 357
 func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption {
357 358
 	return WithContextDialer(
358 359
 		func(ctx context.Context, addr string) (net.Conn, error) {
... ...
@@ -414,6 +413,17 @@ func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption {
414 414
 	})
415 415
 }
416 416
 
417
+// WithChainUnaryInterceptor returns a DialOption that specifies the chained
418
+// interceptor for unary RPCs. The first interceptor will be the outer most,
419
+// while the last interceptor will be the inner most wrapper around the real call.
420
+// All interceptors added by this method will be chained, and the interceptor
421
+// defined by WithUnaryInterceptor will always be prepended to the chain.
422
+func WithChainUnaryInterceptor(interceptors ...UnaryClientInterceptor) DialOption {
423
+	return newFuncDialOption(func(o *dialOptions) {
424
+		o.chainUnaryInts = append(o.chainUnaryInts, interceptors...)
425
+	})
426
+}
427
+
417 428
 // WithStreamInterceptor returns a DialOption that specifies the interceptor for
418 429
 // streaming RPCs.
419 430
 func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
... ...
@@ -422,6 +432,17 @@ func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
422 422
 	})
423 423
 }
424 424
 
425
+// WithChainStreamInterceptor returns a DialOption that specifies the chained
426
+// interceptor for unary RPCs. The first interceptor will be the outer most,
427
+// while the last interceptor will be the inner most wrapper around the real call.
428
+// All interceptors added by this method will be chained, and the interceptor
429
+// defined by WithStreamInterceptor will always be prepended to the chain.
430
+func WithChainStreamInterceptor(interceptors ...StreamClientInterceptor) DialOption {
431
+	return newFuncDialOption(func(o *dialOptions) {
432
+		o.chainStreamInts = append(o.chainStreamInts, interceptors...)
433
+	})
434
+}
435
+
425 436
 // WithAuthority returns a DialOption that specifies the value to be used as the
426 437
 // :authority pseudo-header. This value only works with WithInsecure and has no
427 438
 // effect if TransportCredentials are present.
... ...
@@ -440,12 +461,12 @@ func WithChannelzParentID(id int64) DialOption {
440 440
 	})
441 441
 }
442 442
 
443
-// WithDisableServiceConfig returns a DialOption that causes grpc to ignore any
443
+// WithDisableServiceConfig returns a DialOption that causes gRPC to ignore any
444 444
 // service config provided by the resolver and provides a hint to the resolver
445 445
 // to not fetch service configs.
446 446
 //
447
-// Note that, this dial option only disables service config from resolver. If
448
-// default service config is provided, grpc will use the default service config.
447
+// Note that this dial option only disables service config from resolver. If
448
+// default service config is provided, gRPC will use the default service config.
449 449
 func WithDisableServiceConfig() DialOption {
450 450
 	return newFuncDialOption(func(o *dialOptions) {
451 451
 		o.disableServiceConfig = true
... ...
@@ -454,8 +475,10 @@ func WithDisableServiceConfig() DialOption {
454 454
 
455 455
 // WithDefaultServiceConfig returns a DialOption that configures the default
456 456
 // service config, which will be used in cases where:
457
-// 1. WithDisableServiceConfig is called.
458
-// 2. Resolver does not return service config or if the resolver gets and invalid config.
457
+//
458
+// 1. WithDisableServiceConfig is also used.
459
+// 2. Resolver does not return a service config or if the resolver returns an
460
+//    invalid service config.
459 461
 //
460 462
 // This API is EXPERIMENTAL.
461 463
 func WithDefaultServiceConfig(s string) DialOption {
... ...
@@ -511,7 +534,6 @@ func withHealthCheckFunc(f internal.HealthChecker) DialOption {
511 511
 func defaultDialOptions() dialOptions {
512 512
 	return dialOptions{
513 513
 		disableRetry:    !envconfig.Retry,
514
-		reqHandshake:    envconfig.RequireHandshake,
515 514
 		healthCheckFunc: internal.HealthCheckFunc,
516 515
 		copts: transport.ConnectOptions{
517 516
 			WriteBufferSize: defaultWriteBufSize,
... ...
@@ -7,13 +7,13 @@ require (
7 7
 	github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b
8 8
 	github.com/golang/mock v1.1.1
9 9
 	github.com/golang/protobuf v1.2.0
10
+	github.com/google/go-cmp v0.2.0
10 11
 	golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3
11 12
 	golang.org/x/net v0.0.0-20190311183353-d8887717615a
12 13
 	golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be
13
-	golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f // indirect
14 14
 	golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a
15
-	golang.org/x/tools v0.0.0-20190311212946-11955173bddd
15
+	golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135
16 16
 	google.golang.org/appengine v1.1.0 // indirect
17 17
 	google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8
18
-	honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099
18
+	honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc
19 19
 )
... ...
@@ -26,6 +26,7 @@ import (
26 26
 
27 27
 	"google.golang.org/grpc"
28 28
 	"google.golang.org/grpc/codes"
29
+	"google.golang.org/grpc/connectivity"
29 30
 	healthpb "google.golang.org/grpc/health/grpc_health_v1"
30 31
 	"google.golang.org/grpc/internal"
31 32
 	"google.golang.org/grpc/internal/backoff"
... ...
@@ -51,7 +52,11 @@ func init() {
51 51
 	internal.HealthCheckFunc = clientHealthCheck
52 52
 }
53 53
 
54
-func clientHealthCheck(ctx context.Context, newStream func() (interface{}, error), reportHealth func(bool), service string) error {
54
+const healthCheckMethod = "/grpc.health.v1.Health/Watch"
55
+
56
+// This function implements the protocol defined at:
57
+// https://github.com/grpc/grpc/blob/master/doc/health-checking.md
58
+func clientHealthCheck(ctx context.Context, newStream func(string) (interface{}, error), setConnectivityState func(connectivity.State), service string) error {
55 59
 	tryCnt := 0
56 60
 
57 61
 retryConnection:
... ...
@@ -65,7 +70,8 @@ retryConnection:
65 65
 		if ctx.Err() != nil {
66 66
 			return nil
67 67
 		}
68
-		rawS, err := newStream()
68
+		setConnectivityState(connectivity.Connecting)
69
+		rawS, err := newStream(healthCheckMethod)
69 70
 		if err != nil {
70 71
 			continue retryConnection
71 72
 		}
... ...
@@ -73,7 +79,7 @@ retryConnection:
73 73
 		s, ok := rawS.(grpc.ClientStream)
74 74
 		// Ideally, this should never happen. But if it happens, the server is marked as healthy for LBing purposes.
75 75
 		if !ok {
76
-			reportHealth(true)
76
+			setConnectivityState(connectivity.Ready)
77 77
 			return fmt.Errorf("newStream returned %v (type %T); want grpc.ClientStream", rawS, rawS)
78 78
 		}
79 79
 
... ...
@@ -89,19 +95,23 @@ retryConnection:
89 89
 
90 90
 			// Reports healthy for the LBing purposes if health check is not implemented in the server.
91 91
 			if status.Code(err) == codes.Unimplemented {
92
-				reportHealth(true)
92
+				setConnectivityState(connectivity.Ready)
93 93
 				return err
94 94
 			}
95 95
 
96 96
 			// Reports unhealthy if server's Watch method gives an error other than UNIMPLEMENTED.
97 97
 			if err != nil {
98
-				reportHealth(false)
98
+				setConnectivityState(connectivity.TransientFailure)
99 99
 				continue retryConnection
100 100
 			}
101 101
 
102 102
 			// As a message has been received, removes the need for backoff for the next retry by reseting the try count.
103 103
 			tryCnt = 0
104
-			reportHealth(resp.Status == healthpb.HealthCheckResponse_SERVING)
104
+			if resp.Status == healthpb.HealthCheckResponse_SERVING {
105
+				setConnectivityState(connectivity.Ready)
106
+			} else {
107
+				setConnectivityState(connectivity.TransientFailure)
108
+			}
105 109
 		}
106 110
 	}
107 111
 }
... ...
@@ -24,6 +24,7 @@
24 24
 package channelz
25 25
 
26 26
 import (
27
+	"fmt"
27 28
 	"sort"
28 29
 	"sync"
29 30
 	"sync/atomic"
... ...
@@ -95,9 +96,14 @@ func (d *dbWrapper) get() *channelMap {
95 95
 
96 96
 // NewChannelzStorage initializes channelz data storage and id generator.
97 97
 //
98
+// This function returns a cleanup function to wait for all channelz state to be reset by the
99
+// grpc goroutines when those entities get closed. By using this cleanup function, we make sure tests
100
+// don't mess up each other, i.e. lingering goroutine from previous test doing entity removal happen
101
+// to remove some entity just register by the new test, since the id space is the same.
102
+//
98 103
 // Note: This function is exported for testing purpose only. User should not call
99 104
 // it in most cases.
100
-func NewChannelzStorage() {
105
+func NewChannelzStorage() (cleanup func() error) {
101 106
 	db.set(&channelMap{
102 107
 		topLevelChannels: make(map[int64]struct{}),
103 108
 		channels:         make(map[int64]*channel),
... ...
@@ -107,6 +113,28 @@ func NewChannelzStorage() {
107 107
 		subChannels:      make(map[int64]*subChannel),
108 108
 	})
109 109
 	idGen.reset()
110
+	return func() error {
111
+		var err error
112
+		cm := db.get()
113
+		if cm == nil {
114
+			return nil
115
+		}
116
+		for i := 0; i < 1000; i++ {
117
+			cm.mu.Lock()
118
+			if len(cm.topLevelChannels) == 0 && len(cm.servers) == 0 && len(cm.channels) == 0 && len(cm.subChannels) == 0 && len(cm.listenSockets) == 0 && len(cm.normalSockets) == 0 {
119
+				cm.mu.Unlock()
120
+				// all things stored in the channelz map have been cleared.
121
+				return nil
122
+			}
123
+			cm.mu.Unlock()
124
+			time.Sleep(10 * time.Millisecond)
125
+		}
126
+
127
+		cm.mu.Lock()
128
+		err = fmt.Errorf("after 10s the channelz map has not been cleaned up yet, topchannels: %d, servers: %d, channels: %d, subchannels: %d, listen sockets: %d, normal sockets: %d", len(cm.topLevelChannels), len(cm.servers), len(cm.channels), len(cm.subChannels), len(cm.listenSockets), len(cm.normalSockets))
129
+		cm.mu.Unlock()
130
+		return err
131
+	}
110 132
 }
111 133
 
112 134
 // GetTopChannels returns a slice of top channel's ChannelMetric, along with a
... ...
@@ -25,40 +25,11 @@ import (
25 25
 )
26 26
 
27 27
 const (
28
-	prefix              = "GRPC_GO_"
29
-	retryStr            = prefix + "RETRY"
30
-	requireHandshakeStr = prefix + "REQUIRE_HANDSHAKE"
31
-)
32
-
33
-// RequireHandshakeSetting describes the settings for handshaking.
34
-type RequireHandshakeSetting int
35
-
36
-const (
37
-	// RequireHandshakeOn indicates to wait for handshake before considering a
38
-	// connection ready/successful.
39
-	RequireHandshakeOn RequireHandshakeSetting = iota
40
-	// RequireHandshakeOff indicates to not wait for handshake before
41
-	// considering a connection ready/successful.
42
-	RequireHandshakeOff
28
+	prefix   = "GRPC_GO_"
29
+	retryStr = prefix + "RETRY"
43 30
 )
44 31
 
45 32
 var (
46 33
 	// Retry is set if retry is explicitly enabled via "GRPC_GO_RETRY=on".
47 34
 	Retry = strings.EqualFold(os.Getenv(retryStr), "on")
48
-	// RequireHandshake is set based upon the GRPC_GO_REQUIRE_HANDSHAKE
49
-	// environment variable.
50
-	//
51
-	// Will be removed after the 1.18 release.
52
-	RequireHandshake = RequireHandshakeOn
53 35
 )
54
-
55
-func init() {
56
-	switch strings.ToLower(os.Getenv(requireHandshakeStr)) {
57
-	case "on":
58
-		fallthrough
59
-	default:
60
-		RequireHandshake = RequireHandshakeOn
61
-	case "off":
62
-		RequireHandshake = RequireHandshakeOff
63
-	}
64
-}
... ...
@@ -23,6 +23,8 @@ package internal
23 23
 import (
24 24
 	"context"
25 25
 	"time"
26
+
27
+	"google.golang.org/grpc/connectivity"
26 28
 )
27 29
 
28 30
 var (
... ...
@@ -37,10 +39,25 @@ var (
37 37
 	// KeepaliveMinPingTime is the minimum ping interval.  This must be 10s by
38 38
 	// default, but tests may wish to set it lower for convenience.
39 39
 	KeepaliveMinPingTime = 10 * time.Second
40
+	// ParseServiceConfig is a function to parse JSON service configs into
41
+	// opaque data structures.
42
+	ParseServiceConfig func(sc string) (interface{}, error)
43
+	// StatusRawProto is exported by status/status.go. This func returns a
44
+	// pointer to the wrapped Status proto for a given status.Status without a
45
+	// call to proto.Clone(). The returned Status proto should not be mutated by
46
+	// the caller.
47
+	StatusRawProto interface{} // func (*status.Status) *spb.Status
40 48
 )
41 49
 
42 50
 // HealthChecker defines the signature of the client-side LB channel health checking function.
43
-type HealthChecker func(ctx context.Context, newStream func() (interface{}, error), reportHealth func(bool), serviceName string) error
51
+//
52
+// The implementation is expected to create a health checking RPC stream by
53
+// calling newStream(), watch for the health status of serviceName, and report
54
+// it's health back by calling setConnectivityState().
55
+//
56
+// The health checking protocol is defined at:
57
+// https://github.com/grpc/grpc/blob/master/doc/health-checking.md
58
+type HealthChecker func(ctx context.Context, newStream func(string) (interface{}, error), setConnectivityState func(connectivity.State), serviceName string) error
44 59
 
45 60
 const (
46 61
 	// CredsBundleModeFallback switches GoogleDefaultCreds to fallback mode.
... ...
@@ -23,6 +23,7 @@ import (
23 23
 	"fmt"
24 24
 	"runtime"
25 25
 	"sync"
26
+	"sync/atomic"
26 27
 
27 28
 	"golang.org/x/net/http2"
28 29
 	"golang.org/x/net/http2/hpack"
... ...
@@ -84,12 +85,24 @@ func (il *itemList) isEmpty() bool {
84 84
 // the control buffer of transport. They represent different aspects of
85 85
 // control tasks, e.g., flow control, settings, streaming resetting, etc.
86 86
 
87
+// maxQueuedTransportResponseFrames is the most queued "transport response"
88
+// frames we will buffer before preventing new reads from occurring on the
89
+// transport.  These are control frames sent in response to client requests,
90
+// such as RST_STREAM due to bad headers or settings acks.
91
+const maxQueuedTransportResponseFrames = 50
92
+
93
+type cbItem interface {
94
+	isTransportResponseFrame() bool
95
+}
96
+
87 97
 // registerStream is used to register an incoming stream with loopy writer.
88 98
 type registerStream struct {
89 99
 	streamID uint32
90 100
 	wq       *writeQuota
91 101
 }
92 102
 
103
+func (*registerStream) isTransportResponseFrame() bool { return false }
104
+
93 105
 // headerFrame is also used to register stream on the client-side.
94 106
 type headerFrame struct {
95 107
 	streamID   uint32
... ...
@@ -102,6 +115,10 @@ type headerFrame struct {
102 102
 	onOrphaned func(error)    // Valid on client-side
103 103
 }
104 104
 
105
+func (h *headerFrame) isTransportResponseFrame() bool {
106
+	return h.cleanup != nil && h.cleanup.rst // Results in a RST_STREAM
107
+}
108
+
105 109
 type cleanupStream struct {
106 110
 	streamID uint32
107 111
 	rst      bool
... ...
@@ -109,6 +126,8 @@ type cleanupStream struct {
109 109
 	onWrite  func()
110 110
 }
111 111
 
112
+func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM
113
+
112 114
 type dataFrame struct {
113 115
 	streamID  uint32
114 116
 	endStream bool
... ...
@@ -119,27 +138,41 @@ type dataFrame struct {
119 119
 	onEachWrite func()
120 120
 }
121 121
 
122
+func (*dataFrame) isTransportResponseFrame() bool { return false }
123
+
122 124
 type incomingWindowUpdate struct {
123 125
 	streamID  uint32
124 126
 	increment uint32
125 127
 }
126 128
 
129
+func (*incomingWindowUpdate) isTransportResponseFrame() bool { return false }
130
+
127 131
 type outgoingWindowUpdate struct {
128 132
 	streamID  uint32
129 133
 	increment uint32
130 134
 }
131 135
 
136
+func (*outgoingWindowUpdate) isTransportResponseFrame() bool {
137
+	return false // window updates are throttled by thresholds
138
+}
139
+
132 140
 type incomingSettings struct {
133 141
 	ss []http2.Setting
134 142
 }
135 143
 
144
+func (*incomingSettings) isTransportResponseFrame() bool { return true } // Results in a settings ACK
145
+
136 146
 type outgoingSettings struct {
137 147
 	ss []http2.Setting
138 148
 }
139 149
 
150
+func (*outgoingSettings) isTransportResponseFrame() bool { return false }
151
+
140 152
 type incomingGoAway struct {
141 153
 }
142 154
 
155
+func (*incomingGoAway) isTransportResponseFrame() bool { return false }
156
+
143 157
 type goAway struct {
144 158
 	code      http2.ErrCode
145 159
 	debugData []byte
... ...
@@ -147,15 +180,21 @@ type goAway struct {
147 147
 	closeConn bool
148 148
 }
149 149
 
150
+func (*goAway) isTransportResponseFrame() bool { return false }
151
+
150 152
 type ping struct {
151 153
 	ack  bool
152 154
 	data [8]byte
153 155
 }
154 156
 
157
+func (*ping) isTransportResponseFrame() bool { return true }
158
+
155 159
 type outFlowControlSizeRequest struct {
156 160
 	resp chan uint32
157 161
 }
158 162
 
163
+func (*outFlowControlSizeRequest) isTransportResponseFrame() bool { return false }
164
+
159 165
 type outStreamState int
160 166
 
161 167
 const (
... ...
@@ -238,6 +277,14 @@ type controlBuffer struct {
238 238
 	consumerWaiting bool
239 239
 	list            *itemList
240 240
 	err             error
241
+
242
+	// transportResponseFrames counts the number of queued items that represent
243
+	// the response of an action initiated by the peer.  trfChan is created
244
+	// when transportResponseFrames >= maxQueuedTransportResponseFrames and is
245
+	// closed and nilled when transportResponseFrames drops below the
246
+	// threshold.  Both fields are protected by mu.
247
+	transportResponseFrames int
248
+	trfChan                 atomic.Value // *chan struct{}
241 249
 }
242 250
 
243 251
 func newControlBuffer(done <-chan struct{}) *controlBuffer {
... ...
@@ -248,12 +295,24 @@ func newControlBuffer(done <-chan struct{}) *controlBuffer {
248 248
 	}
249 249
 }
250 250
 
251
-func (c *controlBuffer) put(it interface{}) error {
251
+// throttle blocks if there are too many incomingSettings/cleanupStreams in the
252
+// controlbuf.
253
+func (c *controlBuffer) throttle() {
254
+	ch, _ := c.trfChan.Load().(*chan struct{})
255
+	if ch != nil {
256
+		select {
257
+		case <-*ch:
258
+		case <-c.done:
259
+		}
260
+	}
261
+}
262
+
263
+func (c *controlBuffer) put(it cbItem) error {
252 264
 	_, err := c.executeAndPut(nil, it)
253 265
 	return err
254 266
 }
255 267
 
256
-func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it interface{}) (bool, error) {
268
+func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it cbItem) (bool, error) {
257 269
 	var wakeUp bool
258 270
 	c.mu.Lock()
259 271
 	if c.err != nil {
... ...
@@ -271,6 +330,15 @@ func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it interface{
271 271
 		c.consumerWaiting = false
272 272
 	}
273 273
 	c.list.enqueue(it)
274
+	if it.isTransportResponseFrame() {
275
+		c.transportResponseFrames++
276
+		if c.transportResponseFrames == maxQueuedTransportResponseFrames {
277
+			// We are adding the frame that puts us over the threshold; create
278
+			// a throttling channel.
279
+			ch := make(chan struct{})
280
+			c.trfChan.Store(&ch)
281
+		}
282
+	}
274 283
 	c.mu.Unlock()
275 284
 	if wakeUp {
276 285
 		select {
... ...
@@ -304,7 +372,17 @@ func (c *controlBuffer) get(block bool) (interface{}, error) {
304 304
 			return nil, c.err
305 305
 		}
306 306
 		if !c.list.isEmpty() {
307
-			h := c.list.dequeue()
307
+			h := c.list.dequeue().(cbItem)
308
+			if h.isTransportResponseFrame() {
309
+				if c.transportResponseFrames == maxQueuedTransportResponseFrames {
310
+					// We are removing the frame that put us over the
311
+					// threshold; close and clear the throttling channel.
312
+					ch := c.trfChan.Load().(*chan struct{})
313
+					close(*ch)
314
+					c.trfChan.Store((*chan struct{})(nil))
315
+				}
316
+				c.transportResponseFrames--
317
+			}
308 318
 			c.mu.Unlock()
309 319
 			return h, nil
310 320
 		}
... ...
@@ -149,6 +149,7 @@ func (f *inFlow) maybeAdjust(n uint32) uint32 {
149 149
 		n = uint32(math.MaxInt32)
150 150
 	}
151 151
 	f.mu.Lock()
152
+	defer f.mu.Unlock()
152 153
 	// estSenderQuota is the receiver's view of the maximum number of bytes the sender
153 154
 	// can send without a window update.
154 155
 	estSenderQuota := int32(f.limit - (f.pendingData + f.pendingUpdate))
... ...
@@ -169,10 +170,8 @@ func (f *inFlow) maybeAdjust(n uint32) uint32 {
169 169
 			// is padded; We will fallback on the current available window(at least a 1/4th of the limit).
170 170
 			f.delta = n
171 171
 		}
172
-		f.mu.Unlock()
173 172
 		return f.delta
174 173
 	}
175
-	f.mu.Unlock()
176 174
 	return 0
177 175
 }
178 176
 
... ...
@@ -24,6 +24,7 @@
24 24
 package transport
25 25
 
26 26
 import (
27
+	"bytes"
27 28
 	"context"
28 29
 	"errors"
29 30
 	"fmt"
... ...
@@ -347,7 +348,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
347 347
 		ht.stats.HandleRPC(s.ctx, inHeader)
348 348
 	}
349 349
 	s.trReader = &transportReader{
350
-		reader:        &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf},
350
+		reader:        &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}},
351 351
 		windowHandler: func(int) {},
352 352
 	}
353 353
 
... ...
@@ -361,7 +362,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
361 361
 		for buf := make([]byte, readSize); ; {
362 362
 			n, err := req.Body.Read(buf)
363 363
 			if n > 0 {
364
-				s.buf.put(recvMsg{data: buf[:n:n]})
364
+				s.buf.put(recvMsg{buffer: bytes.NewBuffer(buf[:n:n])})
365 365
 				buf = buf[n:]
366 366
 			}
367 367
 			if err != nil {
... ...
@@ -117,6 +117,8 @@ type http2Client struct {
117 117
 
118 118
 	onGoAway func(GoAwayReason)
119 119
 	onClose  func()
120
+
121
+	bufferPool *bufferPool
120 122
 }
121 123
 
122 124
 func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) {
... ...
@@ -249,6 +251,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
249 249
 		onGoAway:              onGoAway,
250 250
 		onClose:               onClose,
251 251
 		keepaliveEnabled:      keepaliveEnabled,
252
+		bufferPool:            newBufferPool(),
252 253
 	}
253 254
 	t.controlBuf = newControlBuffer(t.ctxDone)
254 255
 	if opts.InitialWindowSize >= defaultWindowSize {
... ...
@@ -367,6 +370,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
367 367
 			closeStream: func(err error) {
368 368
 				t.CloseStream(s, err)
369 369
 			},
370
+			freeBuffer: t.bufferPool.put,
370 371
 		},
371 372
 		windowHandler: func(n int) {
372 373
 			t.updateWindow(s, uint32(n))
... ...
@@ -437,6 +441,15 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
437 437
 
438 438
 	if md, added, ok := metadata.FromOutgoingContextRaw(ctx); ok {
439 439
 		var k string
440
+		for k, vv := range md {
441
+			// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
442
+			if isReservedHeader(k) {
443
+				continue
444
+			}
445
+			for _, v := range vv {
446
+				headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
447
+			}
448
+		}
440 449
 		for _, vv := range added {
441 450
 			for i, v := range vv {
442 451
 				if i%2 == 0 {
... ...
@@ -450,15 +463,6 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
450 450
 				headerFields = append(headerFields, hpack.HeaderField{Name: strings.ToLower(k), Value: encodeMetadataHeader(k, v)})
451 451
 			}
452 452
 		}
453
-		for k, vv := range md {
454
-			// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
455
-			if isReservedHeader(k) {
456
-				continue
457
-			}
458
-			for _, v := range vv {
459
-				headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
460
-			}
461
-		}
462 453
 	}
463 454
 	if md, ok := t.md.(*metadata.MD); ok {
464 455
 		for k, vv := range *md {
... ...
@@ -489,6 +493,9 @@ func (t *http2Client) createAudience(callHdr *CallHdr) string {
489 489
 }
490 490
 
491 491
 func (t *http2Client) getTrAuthData(ctx context.Context, audience string) (map[string]string, error) {
492
+	if len(t.perRPCCreds) == 0 {
493
+		return nil, nil
494
+	}
492 495
 	authData := map[string]string{}
493 496
 	for _, c := range t.perRPCCreds {
494 497
 		data, err := c.GetRequestMetadata(ctx, audience)
... ...
@@ -509,7 +516,7 @@ func (t *http2Client) getTrAuthData(ctx context.Context, audience string) (map[s
509 509
 }
510 510
 
511 511
 func (t *http2Client) getCallAuthData(ctx context.Context, audience string, callHdr *CallHdr) (map[string]string, error) {
512
-	callAuthData := map[string]string{}
512
+	var callAuthData map[string]string
513 513
 	// Check if credentials.PerRPCCredentials were provided via call options.
514 514
 	// Note: if these credentials are provided both via dial options and call
515 515
 	// options, then both sets of credentials will be applied.
... ...
@@ -521,6 +528,7 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call
521 521
 		if err != nil {
522 522
 			return nil, status.Errorf(codes.Internal, "transport: %v", err)
523 523
 		}
524
+		callAuthData = make(map[string]string, len(data))
524 525
 		for k, v := range data {
525 526
 			// Capital header names are illegal in HTTP/2
526 527
 			k = strings.ToLower(k)
... ...
@@ -549,10 +557,9 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
549 549
 		s.write(recvMsg{err: err})
550 550
 		close(s.done)
551 551
 		// If headerChan isn't closed, then close it.
552
-		if atomic.SwapUint32(&s.headerDone, 1) == 0 {
552
+		if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
553 553
 			close(s.headerChan)
554 554
 		}
555
-
556 555
 	}
557 556
 	hdr := &headerFrame{
558 557
 		hf:        headerFields,
... ...
@@ -713,7 +720,7 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.
713 713
 		s.write(recvMsg{err: err})
714 714
 	}
715 715
 	// If headerChan isn't closed, then close it.
716
-	if atomic.SwapUint32(&s.headerDone, 1) == 0 {
716
+	if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
717 717
 		s.noHeaders = true
718 718
 		close(s.headerChan)
719 719
 	}
... ...
@@ -765,6 +772,9 @@ func (t *http2Client) Close() error {
765 765
 		t.mu.Unlock()
766 766
 		return nil
767 767
 	}
768
+	// Call t.onClose before setting the state to closing to prevent the client
769
+	// from attempting to create new streams ASAP.
770
+	t.onClose()
768 771
 	t.state = closing
769 772
 	streams := t.activeStreams
770 773
 	t.activeStreams = nil
... ...
@@ -785,7 +795,6 @@ func (t *http2Client) Close() error {
785 785
 		}
786 786
 		t.statsHandler.HandleConn(t.ctx, connEnd)
787 787
 	}
788
-	t.onClose()
789 788
 	return err
790 789
 }
791 790
 
... ...
@@ -794,21 +803,21 @@ func (t *http2Client) Close() error {
794 794
 // stream is closed.  If there are no active streams, the transport is closed
795 795
 // immediately.  This does nothing if the transport is already draining or
796 796
 // closing.
797
-func (t *http2Client) GracefulClose() error {
797
+func (t *http2Client) GracefulClose() {
798 798
 	t.mu.Lock()
799 799
 	// Make sure we move to draining only from active.
800 800
 	if t.state == draining || t.state == closing {
801 801
 		t.mu.Unlock()
802
-		return nil
802
+		return
803 803
 	}
804 804
 	t.state = draining
805 805
 	active := len(t.activeStreams)
806 806
 	t.mu.Unlock()
807 807
 	if active == 0 {
808
-		return t.Close()
808
+		t.Close()
809
+		return
809 810
 	}
810 811
 	t.controlBuf.put(&incomingGoAway{})
811
-	return nil
812 812
 }
813 813
 
814 814
 // Write formats the data into HTTP2 data frame(s) and sends it out. The caller
... ...
@@ -946,9 +955,10 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
946 946
 		// guarantee f.Data() is consumed before the arrival of next frame.
947 947
 		// Can this copy be eliminated?
948 948
 		if len(f.Data()) > 0 {
949
-			data := make([]byte, len(f.Data()))
950
-			copy(data, f.Data())
951
-			s.write(recvMsg{data: data})
949
+			buffer := t.bufferPool.get()
950
+			buffer.Reset()
951
+			buffer.Write(f.Data())
952
+			s.write(recvMsg{buffer: buffer})
952 953
 		}
953 954
 	}
954 955
 	// The server has closed the stream without sending trailers.  Record that
... ...
@@ -973,9 +983,9 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
973 973
 		statusCode = codes.Unknown
974 974
 	}
975 975
 	if statusCode == codes.Canceled {
976
-		// Our deadline was already exceeded, and that was likely the cause of
977
-		// this cancelation.  Alter the status code accordingly.
978
-		if d, ok := s.ctx.Deadline(); ok && d.After(time.Now()) {
976
+		if d, ok := s.ctx.Deadline(); ok && !d.After(time.Now()) {
977
+			// Our deadline was already exceeded, and that was likely the cause
978
+			// of this cancelation.  Alter the status code accordingly.
979 979
 			statusCode = codes.DeadlineExceeded
980 980
 		}
981 981
 	}
... ...
@@ -1080,11 +1090,12 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
1080 1080
 	default:
1081 1081
 		t.setGoAwayReason(f)
1082 1082
 		close(t.goAway)
1083
-		t.state = draining
1084 1083
 		t.controlBuf.put(&incomingGoAway{})
1085
-
1086
-		// This has to be a new goroutine because we're still using the current goroutine to read in the transport.
1084
+		// Notify the clientconn about the GOAWAY before we set the state to
1085
+		// draining, to allow the client to stop attempting to create streams
1086
+		// before disallowing new streams on this connection.
1087 1087
 		t.onGoAway(t.goAwayReason)
1088
+		t.state = draining
1088 1089
 	}
1089 1090
 	// All streams with IDs greater than the GoAwayId
1090 1091
 	// and smaller than the previous GoAway ID should be killed.
... ...
@@ -1142,26 +1153,24 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
1142 1142
 	}
1143 1143
 	endStream := frame.StreamEnded()
1144 1144
 	atomic.StoreUint32(&s.bytesReceived, 1)
1145
-	initialHeader := atomic.SwapUint32(&s.headerDone, 1) == 0
1145
+	initialHeader := atomic.LoadUint32(&s.headerChanClosed) == 0
1146 1146
 
1147 1147
 	if !initialHeader && !endStream {
1148
-		// As specified by RFC 7540, a HEADERS frame (and associated CONTINUATION frames) can only appear
1149
-		// at the start or end of a stream. Therefore, second HEADERS frame must have EOS bit set.
1148
+		// As specified by gRPC over HTTP2, a HEADERS frame (and associated CONTINUATION frames) can only appear at the start or end of a stream. Therefore, second HEADERS frame must have EOS bit set.
1150 1149
 		st := status.New(codes.Internal, "a HEADERS frame cannot appear in the middle of a stream")
1151 1150
 		t.closeStream(s, st.Err(), true, http2.ErrCodeProtocol, st, nil, false)
1152 1151
 		return
1153 1152
 	}
1154 1153
 
1155 1154
 	state := &decodeState{}
1156
-	// Initialize isGRPC value to be !initialHeader, since if a gRPC ResponseHeader has been received
1157
-	// which indicates peer speaking gRPC, we are in gRPC mode.
1155
+	// Initialize isGRPC value to be !initialHeader, since if a gRPC Response-Headers has already been received, then it means that the peer is speaking gRPC and we are in gRPC mode.
1158 1156
 	state.data.isGRPC = !initialHeader
1159 1157
 	if err := state.decodeHeader(frame); err != nil {
1160 1158
 		t.closeStream(s, err, true, http2.ErrCodeProtocol, status.Convert(err), nil, endStream)
1161 1159
 		return
1162 1160
 	}
1163 1161
 
1164
-	var isHeader bool
1162
+	isHeader := false
1165 1163
 	defer func() {
1166 1164
 		if t.statsHandler != nil {
1167 1165
 			if isHeader {
... ...
@@ -1180,10 +1189,10 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
1180 1180
 		}
1181 1181
 	}()
1182 1182
 
1183
-	// If headers haven't been received yet.
1184
-	if initialHeader {
1183
+	// If headerChan hasn't been closed yet
1184
+	if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
1185 1185
 		if !endStream {
1186
-			// Headers frame is ResponseHeader.
1186
+			// HEADERS frame block carries a Response-Headers.
1187 1187
 			isHeader = true
1188 1188
 			// These values can be set without any synchronization because
1189 1189
 			// stream goroutine will read it only after seeing a closed
... ...
@@ -1192,14 +1201,17 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
1192 1192
 			if len(state.data.mdata) > 0 {
1193 1193
 				s.header = state.data.mdata
1194 1194
 			}
1195
-			close(s.headerChan)
1196
-			return
1195
+		} else {
1196
+			// HEADERS frame block carries a Trailers-Only.
1197
+			s.noHeaders = true
1197 1198
 		}
1198
-		// Headers frame is Trailers-only.
1199
-		s.noHeaders = true
1200 1199
 		close(s.headerChan)
1201 1200
 	}
1202 1201
 
1202
+	if !endStream {
1203
+		return
1204
+	}
1205
+
1203 1206
 	// if client received END_STREAM from server while stream was still active, send RST_STREAM
1204 1207
 	rst := s.getState() == streamActive
1205 1208
 	t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, state.status(), state.data.mdata, true)
... ...
@@ -1233,6 +1245,7 @@ func (t *http2Client) reader() {
1233 1233
 
1234 1234
 	// loop to keep reading incoming messages on this transport.
1235 1235
 	for {
1236
+		t.controlBuf.throttle()
1236 1237
 		frame, err := t.framer.fr.ReadFrame()
1237 1238
 		if t.keepaliveEnabled {
1238 1239
 			atomic.CompareAndSwapUint32(&t.activity, 0, 1)
... ...
@@ -1320,6 +1333,7 @@ func (t *http2Client) keepalive() {
1320 1320
 					timer.Reset(t.kp.Time)
1321 1321
 					continue
1322 1322
 				}
1323
+				infof("transport: closing client transport due to idleness.")
1323 1324
 				t.Close()
1324 1325
 				return
1325 1326
 			case <-t.ctx.Done():
... ...
@@ -35,9 +35,11 @@ import (
35 35
 	"golang.org/x/net/http2"
36 36
 	"golang.org/x/net/http2/hpack"
37 37
 
38
+	spb "google.golang.org/genproto/googleapis/rpc/status"
38 39
 	"google.golang.org/grpc/codes"
39 40
 	"google.golang.org/grpc/credentials"
40 41
 	"google.golang.org/grpc/grpclog"
42
+	"google.golang.org/grpc/internal"
41 43
 	"google.golang.org/grpc/internal/channelz"
42 44
 	"google.golang.org/grpc/internal/grpcrand"
43 45
 	"google.golang.org/grpc/keepalive"
... ...
@@ -55,6 +57,9 @@ var (
55 55
 	// ErrHeaderListSizeLimitViolation indicates that the header list size is larger
56 56
 	// than the limit set by peer.
57 57
 	ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer")
58
+	// statusRawProto is a function to get to the raw status proto wrapped in a
59
+	// status.Status without a proto.Clone().
60
+	statusRawProto = internal.StatusRawProto.(func(*status.Status) *spb.Status)
58 61
 )
59 62
 
60 63
 // http2Server implements the ServerTransport interface with HTTP2.
... ...
@@ -119,6 +124,7 @@ type http2Server struct {
119 119
 	// Fields below are for channelz metric collection.
120 120
 	channelzID int64 // channelz unique identification number
121 121
 	czData     *channelzData
122
+	bufferPool *bufferPool
122 123
 }
123 124
 
124 125
 // newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is
... ...
@@ -220,6 +226,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
220 220
 		kep:               kep,
221 221
 		initialWindowSize: iwz,
222 222
 		czData:            new(channelzData),
223
+		bufferPool:        newBufferPool(),
223 224
 	}
224 225
 	t.controlBuf = newControlBuffer(t.ctxDone)
225 226
 	if dynamicWindow {
... ...
@@ -405,9 +412,10 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
405 405
 	s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
406 406
 	s.trReader = &transportReader{
407 407
 		reader: &recvBufferReader{
408
-			ctx:     s.ctx,
409
-			ctxDone: s.ctxDone,
410
-			recv:    s.buf,
408
+			ctx:        s.ctx,
409
+			ctxDone:    s.ctxDone,
410
+			recv:       s.buf,
411
+			freeBuffer: t.bufferPool.put,
411 412
 		},
412 413
 		windowHandler: func(n int) {
413 414
 			t.updateWindow(s, uint32(n))
... ...
@@ -428,6 +436,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
428 428
 func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) {
429 429
 	defer close(t.readerDone)
430 430
 	for {
431
+		t.controlBuf.throttle()
431 432
 		frame, err := t.framer.fr.ReadFrame()
432 433
 		atomic.StoreUint32(&t.activity, 1)
433 434
 		if err != nil {
... ...
@@ -591,9 +600,10 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
591 591
 		// guarantee f.Data() is consumed before the arrival of next frame.
592 592
 		// Can this copy be eliminated?
593 593
 		if len(f.Data()) > 0 {
594
-			data := make([]byte, len(f.Data()))
595
-			copy(data, f.Data())
596
-			s.write(recvMsg{data: data})
594
+			buffer := t.bufferPool.get()
595
+			buffer.Reset()
596
+			buffer.Write(f.Data())
597
+			s.write(recvMsg{buffer: buffer})
597 598
 		}
598 599
 	}
599 600
 	if f.Header().Flags.Has(http2.FlagDataEndStream) {
... ...
@@ -757,6 +767,10 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
757 757
 	return nil
758 758
 }
759 759
 
760
+func (t *http2Server) setResetPingStrikes() {
761
+	atomic.StoreUint32(&t.resetPingStrikes, 1)
762
+}
763
+
760 764
 func (t *http2Server) writeHeaderLocked(s *Stream) error {
761 765
 	// TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields
762 766
 	// first and create a slice of that exact size.
... ...
@@ -771,9 +785,7 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error {
771 771
 		streamID:  s.id,
772 772
 		hf:        headerFields,
773 773
 		endStream: false,
774
-		onWrite: func() {
775
-			atomic.StoreUint32(&t.resetPingStrikes, 1)
776
-		},
774
+		onWrite:   t.setResetPingStrikes,
777 775
 	})
778 776
 	if !success {
779 777
 		if err != nil {
... ...
@@ -817,7 +829,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
817 817
 	headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))})
818 818
 	headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())})
819 819
 
820
-	if p := st.Proto(); p != nil && len(p.Details) > 0 {
820
+	if p := statusRawProto(st); p != nil && len(p.Details) > 0 {
821 821
 		stBytes, err := proto.Marshal(p)
822 822
 		if err != nil {
823 823
 			// TODO: return error instead, when callers are able to handle it.
... ...
@@ -833,9 +845,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
833 833
 		streamID:  s.id,
834 834
 		hf:        headerFields,
835 835
 		endStream: true,
836
-		onWrite: func() {
837
-			atomic.StoreUint32(&t.resetPingStrikes, 1)
838
-		},
836
+		onWrite:   t.setResetPingStrikes,
839 837
 	}
840 838
 	s.hdrMu.Unlock()
841 839
 	success, err := t.controlBuf.execute(t.checkForHeaderListSize, trailingHeader)
... ...
@@ -887,12 +897,10 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
887 887
 	hdr = append(hdr, data[:emptyLen]...)
888 888
 	data = data[emptyLen:]
889 889
 	df := &dataFrame{
890
-		streamID: s.id,
891
-		h:        hdr,
892
-		d:        data,
893
-		onEachWrite: func() {
894
-			atomic.StoreUint32(&t.resetPingStrikes, 1)
895
-		},
890
+		streamID:    s.id,
891
+		h:           hdr,
892
+		d:           data,
893
+		onEachWrite: t.setResetPingStrikes,
896 894
 	}
897 895
 	if err := s.wq.get(int32(len(hdr) + len(data))); err != nil {
898 896
 		select {
... ...
@@ -958,6 +966,7 @@ func (t *http2Server) keepalive() {
958 958
 			select {
959 959
 			case <-maxAge.C:
960 960
 				// Close the connection after grace period.
961
+				infof("transport: closing server transport due to maximum connection age.")
961 962
 				t.Close()
962 963
 				// Resetting the timer so that the clean-up doesn't deadlock.
963 964
 				maxAge.Reset(infinity)
... ...
@@ -971,6 +980,7 @@ func (t *http2Server) keepalive() {
971 971
 				continue
972 972
 			}
973 973
 			if pingSent {
974
+				infof("transport: closing server transport due to idleness.")
974 975
 				t.Close()
975 976
 				// Resetting the timer so that the clean-up doesn't deadlock.
976 977
 				keepalive.Reset(infinity)
... ...
@@ -1019,13 +1029,7 @@ func (t *http2Server) Close() error {
1019 1019
 }
1020 1020
 
1021 1021
 // deleteStream deletes the stream s from transport's active streams.
1022
-func (t *http2Server) deleteStream(s *Stream, eosReceived bool) (oldState streamState) {
1023
-	oldState = s.swapState(streamDone)
1024
-	if oldState == streamDone {
1025
-		// If the stream was already done, return.
1026
-		return oldState
1027
-	}
1028
-
1022
+func (t *http2Server) deleteStream(s *Stream, eosReceived bool) {
1029 1023
 	// In case stream sending and receiving are invoked in separate
1030 1024
 	// goroutines (e.g., bi-directional streaming), cancel needs to be
1031 1025
 	// called to interrupt the potential blocking on other goroutines.
... ...
@@ -1047,15 +1051,13 @@ func (t *http2Server) deleteStream(s *Stream, eosReceived bool) (oldState stream
1047 1047
 			atomic.AddInt64(&t.czData.streamsFailed, 1)
1048 1048
 		}
1049 1049
 	}
1050
-
1051
-	return oldState
1052 1050
 }
1053 1051
 
1054 1052
 // finishStream closes the stream and puts the trailing headerFrame into controlbuf.
1055 1053
 func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, hdr *headerFrame, eosReceived bool) {
1056
-	oldState := t.deleteStream(s, eosReceived)
1057
-	// If the stream is already closed, then don't put trailing header to controlbuf.
1054
+	oldState := s.swapState(streamDone)
1058 1055
 	if oldState == streamDone {
1056
+		// If the stream was already done, return.
1059 1057
 		return
1060 1058
 	}
1061 1059
 
... ...
@@ -1063,14 +1065,18 @@ func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, h
1063 1063
 		streamID: s.id,
1064 1064
 		rst:      rst,
1065 1065
 		rstCode:  rstCode,
1066
-		onWrite:  func() {},
1066
+		onWrite: func() {
1067
+			t.deleteStream(s, eosReceived)
1068
+		},
1067 1069
 	}
1068 1070
 	t.controlBuf.put(hdr)
1069 1071
 }
1070 1072
 
1071 1073
 // closeStream clears the footprint of a stream when the stream is not needed any more.
1072 1074
 func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eosReceived bool) {
1075
+	s.swapState(streamDone)
1073 1076
 	t.deleteStream(s, eosReceived)
1077
+
1074 1078
 	t.controlBuf.put(&cleanupStream{
1075 1079
 		streamID: s.id,
1076 1080
 		rst:      rst,
... ...
@@ -22,6 +22,7 @@
22 22
 package transport
23 23
 
24 24
 import (
25
+	"bytes"
25 26
 	"context"
26 27
 	"errors"
27 28
 	"fmt"
... ...
@@ -39,10 +40,32 @@ import (
39 39
 	"google.golang.org/grpc/tap"
40 40
 )
41 41
 
42
+type bufferPool struct {
43
+	pool sync.Pool
44
+}
45
+
46
+func newBufferPool() *bufferPool {
47
+	return &bufferPool{
48
+		pool: sync.Pool{
49
+			New: func() interface{} {
50
+				return new(bytes.Buffer)
51
+			},
52
+		},
53
+	}
54
+}
55
+
56
+func (p *bufferPool) get() *bytes.Buffer {
57
+	return p.pool.Get().(*bytes.Buffer)
58
+}
59
+
60
+func (p *bufferPool) put(b *bytes.Buffer) {
61
+	p.pool.Put(b)
62
+}
63
+
42 64
 // recvMsg represents the received msg from the transport. All transport
43 65
 // protocol specific info has been removed.
44 66
 type recvMsg struct {
45
-	data []byte
67
+	buffer *bytes.Buffer
46 68
 	// nil: received some data
47 69
 	// io.EOF: stream is completed. data is nil.
48 70
 	// other non-nil error: transport failure. data is nil.
... ...
@@ -117,8 +140,9 @@ type recvBufferReader struct {
117 117
 	ctx         context.Context
118 118
 	ctxDone     <-chan struct{} // cache of ctx.Done() (for performance).
119 119
 	recv        *recvBuffer
120
-	last        []byte // Stores the remaining data in the previous calls.
120
+	last        *bytes.Buffer // Stores the remaining data in the previous calls.
121 121
 	err         error
122
+	freeBuffer  func(*bytes.Buffer)
122 123
 }
123 124
 
124 125
 // Read reads the next len(p) bytes from last. If last is drained, it tries to
... ...
@@ -128,10 +152,13 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) {
128 128
 	if r.err != nil {
129 129
 		return 0, r.err
130 130
 	}
131
-	if r.last != nil && len(r.last) > 0 {
131
+	if r.last != nil {
132 132
 		// Read remaining data left in last call.
133
-		copied := copy(p, r.last)
134
-		r.last = r.last[copied:]
133
+		copied, _ := r.last.Read(p)
134
+		if r.last.Len() == 0 {
135
+			r.freeBuffer(r.last)
136
+			r.last = nil
137
+		}
135 138
 		return copied, nil
136 139
 	}
137 140
 	if r.closeStream != nil {
... ...
@@ -157,6 +184,19 @@ func (r *recvBufferReader) readClient(p []byte) (n int, err error) {
157 157
 	// r.readAdditional acts on that message and returns the necessary error.
158 158
 	select {
159 159
 	case <-r.ctxDone:
160
+		// Note that this adds the ctx error to the end of recv buffer, and
161
+		// reads from the head. This will delay the error until recv buffer is
162
+		// empty, thus will delay ctx cancellation in Recv().
163
+		//
164
+		// It's done this way to fix a race between ctx cancel and trailer. The
165
+		// race was, stream.Recv() may return ctx error if ctxDone wins the
166
+		// race, but stream.Trailer() may return a non-nil md because the stream
167
+		// was not marked as done when trailer is received. This closeStream
168
+		// call will mark stream as done, thus fix the race.
169
+		//
170
+		// TODO: delaying ctx error seems like a unnecessary side effect. What
171
+		// we really want is to mark the stream as done, and return ctx error
172
+		// faster.
160 173
 		r.closeStream(ContextErr(r.ctx.Err()))
161 174
 		m := <-r.recv.get()
162 175
 		return r.readAdditional(m, p)
... ...
@@ -170,8 +210,13 @@ func (r *recvBufferReader) readAdditional(m recvMsg, p []byte) (n int, err error
170 170
 	if m.err != nil {
171 171
 		return 0, m.err
172 172
 	}
173
-	copied := copy(p, m.data)
174
-	r.last = m.data[copied:]
173
+	copied, _ := m.buffer.Read(p)
174
+	if m.buffer.Len() == 0 {
175
+		r.freeBuffer(m.buffer)
176
+		r.last = nil
177
+	} else {
178
+		r.last = m.buffer
179
+	}
175 180
 	return copied, nil
176 181
 }
177 182
 
... ...
@@ -204,8 +249,8 @@ type Stream struct {
204 204
 	// is used to adjust flow control, if needed.
205 205
 	requestRead func(int)
206 206
 
207
-	headerChan chan struct{} // closed to indicate the end of header metadata.
208
-	headerDone uint32        // set when headerChan is closed. Used to avoid closing headerChan multiple times.
207
+	headerChan       chan struct{} // closed to indicate the end of header metadata.
208
+	headerChanClosed uint32        // set when headerChan is closed. Used to avoid closing headerChan multiple times.
209 209
 
210 210
 	// hdrMu protects header and trailer metadata on the server-side.
211 211
 	hdrMu sync.Mutex
... ...
@@ -266,6 +311,14 @@ func (s *Stream) waitOnHeader() error {
266 266
 	}
267 267
 	select {
268 268
 	case <-s.ctx.Done():
269
+		// We prefer success over failure when reading messages because we delay
270
+		// context error in stream.Read(). To keep behavior consistent, we also
271
+		// prefer success here.
272
+		select {
273
+		case <-s.headerChan:
274
+			return nil
275
+		default:
276
+		}
269 277
 		return ContextErr(s.ctx.Err())
270 278
 	case <-s.headerChan:
271 279
 		return nil
... ...
@@ -578,9 +631,12 @@ type ClientTransport interface {
578 578
 	// is called only once.
579 579
 	Close() error
580 580
 
581
-	// GracefulClose starts to tear down the transport. It stops accepting
582
-	// new RPCs and wait the completion of the pending RPCs.
583
-	GracefulClose() error
581
+	// GracefulClose starts to tear down the transport: the transport will stop
582
+	// accepting new RPCs and NewStream will return error. Once all streams are
583
+	// finished, the transport will close.
584
+	//
585
+	// It does not block.
586
+	GracefulClose()
584 587
 
585 588
 	// Write sends the data for the given stream. A nil stream indicates
586 589
 	// the write is to be performed on the transport as a whole.
... ...
@@ -17,9 +17,8 @@
17 17
  */
18 18
 
19 19
 // Package naming defines the naming API and related data structures for gRPC.
20
-// The interface is EXPERIMENTAL and may be subject to change.
21 20
 //
22
-// Deprecated: please use package resolver.
21
+// This package is deprecated: please use package resolver instead.
23 22
 package naming
24 23
 
25 24
 // Operation defines the corresponding operations for a name resolution change.
... ...
@@ -120,6 +120,14 @@ func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer.
120 120
 			bp.mu.Unlock()
121 121
 			select {
122 122
 			case <-ctx.Done():
123
+				if connectionErr := bp.connectionError(); connectionErr != nil {
124
+					switch ctx.Err() {
125
+					case context.DeadlineExceeded:
126
+						return nil, nil, status.Errorf(codes.DeadlineExceeded, "latest connection error: %v", connectionErr)
127
+					case context.Canceled:
128
+						return nil, nil, status.Errorf(codes.Canceled, "latest connection error: %v", connectionErr)
129
+					}
130
+				}
123 131
 				return nil, nil, ctx.Err()
124 132
 			case <-ch:
125 133
 			}
... ...
@@ -51,14 +51,18 @@ type pickfirstBalancer struct {
51 51
 
52 52
 func (b *pickfirstBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) {
53 53
 	if err != nil {
54
-		grpclog.Infof("pickfirstBalancer: HandleResolvedAddrs called with error %v", err)
54
+		if grpclog.V(2) {
55
+			grpclog.Infof("pickfirstBalancer: HandleResolvedAddrs called with error %v", err)
56
+		}
55 57
 		return
56 58
 	}
57 59
 	if b.sc == nil {
58 60
 		b.sc, err = b.cc.NewSubConn(addrs, balancer.NewSubConnOptions{})
59 61
 		if err != nil {
60 62
 			//TODO(yuxuanli): why not change the cc state to Idle?
61
-			grpclog.Errorf("pickfirstBalancer: failed to NewSubConn: %v", err)
63
+			if grpclog.V(2) {
64
+				grpclog.Errorf("pickfirstBalancer: failed to NewSubConn: %v", err)
65
+			}
62 66
 			return
63 67
 		}
64 68
 		b.cc.UpdateBalancerState(connectivity.Idle, &picker{sc: b.sc})
... ...
@@ -70,9 +74,13 @@ func (b *pickfirstBalancer) HandleResolvedAddrs(addrs []resolver.Address, err er
70 70
 }
71 71
 
72 72
 func (b *pickfirstBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) {
73
-	grpclog.Infof("pickfirstBalancer: HandleSubConnStateChange: %p, %v", sc, s)
73
+	if grpclog.V(2) {
74
+		grpclog.Infof("pickfirstBalancer: HandleSubConnStateChange: %p, %v", sc, s)
75
+	}
74 76
 	if b.sc != sc {
75
-		grpclog.Infof("pickfirstBalancer: ignored state change because sc is not recognized")
77
+		if grpclog.V(2) {
78
+			grpclog.Infof("pickfirstBalancer: ignored state change because sc is not recognized")
79
+		}
76 80
 		return
77 81
 	}
78 82
 	if s == connectivity.Shutdown {
79 83
new file mode 100644
... ...
@@ -0,0 +1,64 @@
0
+/*
1
+ *
2
+ * Copyright 2019 gRPC authors.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ *     http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ *
16
+ */
17
+
18
+package grpc
19
+
20
+import (
21
+	"google.golang.org/grpc/codes"
22
+	"google.golang.org/grpc/status"
23
+)
24
+
25
+// PreparedMsg is responsible for creating a Marshalled and Compressed object.
26
+//
27
+// This API is EXPERIMENTAL.
28
+type PreparedMsg struct {
29
+	// Struct for preparing msg before sending them
30
+	encodedData []byte
31
+	hdr         []byte
32
+	payload     []byte
33
+}
34
+
35
+// Encode marshalls and compresses the message using the codec and compressor for the stream.
36
+func (p *PreparedMsg) Encode(s Stream, msg interface{}) error {
37
+	ctx := s.Context()
38
+	rpcInfo, ok := rpcInfoFromContext(ctx)
39
+	if !ok {
40
+		return status.Errorf(codes.Internal, "grpc: unable to get rpcInfo")
41
+	}
42
+
43
+	// check if the context has the relevant information to prepareMsg
44
+	if rpcInfo.preloaderInfo == nil {
45
+		return status.Errorf(codes.Internal, "grpc: rpcInfo.preloaderInfo is nil")
46
+	}
47
+	if rpcInfo.preloaderInfo.codec == nil {
48
+		return status.Errorf(codes.Internal, "grpc: rpcInfo.preloaderInfo.codec is nil")
49
+	}
50
+
51
+	// prepare the msg
52
+	data, err := encode(rpcInfo.preloaderInfo.codec, msg)
53
+	if err != nil {
54
+		return err
55
+	}
56
+	p.encodedData = data
57
+	compData, err := compress(data, rpcInfo.preloaderInfo.cp, rpcInfo.preloaderInfo.comp)
58
+	if err != nil {
59
+		return err
60
+	}
61
+	p.hdr, p.payload = msgHeader(data, compData)
62
+	return nil
63
+}
... ...
@@ -66,6 +66,9 @@ var (
66 66
 
67 67
 var (
68 68
 	defaultResolver netResolver = net.DefaultResolver
69
+	// To prevent excessive re-resolution, we enforce a rate limit on DNS
70
+	// resolution requests.
71
+	minDNSResRate = 30 * time.Second
69 72
 )
70 73
 
71 74
 var customAuthorityDialler = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) {
... ...
@@ -241,7 +244,13 @@ func (d *dnsResolver) watcher() {
241 241
 			return
242 242
 		case <-d.t.C:
243 243
 		case <-d.rn:
244
+			if !d.t.Stop() {
245
+				// Before resetting a timer, it should be stopped to prevent racing with
246
+				// reads on it's channel.
247
+				<-d.t.C
248
+			}
244 249
 		}
250
+
245 251
 		result, sc := d.lookup()
246 252
 		// Next lookup should happen within an interval defined by d.freq. It may be
247 253
 		// more often due to exponential retry on empty address list.
... ...
@@ -254,6 +263,16 @@ func (d *dnsResolver) watcher() {
254 254
 		}
255 255
 		d.cc.NewServiceConfig(sc)
256 256
 		d.cc.NewAddress(result)
257
+
258
+		// Sleep to prevent excessive re-resolutions. Incoming resolution requests
259
+		// will be queued in d.rn.
260
+		t := time.NewTimer(minDNSResRate)
261
+		select {
262
+		case <-t.C:
263
+		case <-d.ctx.Done():
264
+			t.Stop()
265
+			return
266
+		}
257 267
 	}
258 268
 }
259 269
 
... ...
@@ -20,6 +20,10 @@
20 20
 // All APIs in this package are experimental.
21 21
 package resolver
22 22
 
23
+import (
24
+	"google.golang.org/grpc/serviceconfig"
25
+)
26
+
23 27
 var (
24 28
 	// m is a map from scheme to resolver builder.
25 29
 	m = make(map[string]Builder)
... ...
@@ -100,11 +104,12 @@ type BuildOption struct {
100 100
 
101 101
 // State contains the current Resolver state relevant to the ClientConn.
102 102
 type State struct {
103
-	Addresses     []Address // Resolved addresses for the target
104
-	ServiceConfig string    // JSON representation of the service config
103
+	Addresses []Address // Resolved addresses for the target
104
+	// ServiceConfig is the parsed service config; obtained from
105
+	// serviceconfig.Parse.
106
+	ServiceConfig serviceconfig.Config
105 107
 
106 108
 	// TODO: add Err error
107
-	// TODO: add ParsedServiceConfig interface{}
108 109
 }
109 110
 
110 111
 // ClientConn contains the callbacks for resolver to notify any updates
... ...
@@ -132,6 +137,21 @@ type ClientConn interface {
132 132
 
133 133
 // Target represents a target for gRPC, as specified in:
134 134
 // https://github.com/grpc/grpc/blob/master/doc/naming.md.
135
+// It is parsed from the target string that gets passed into Dial or DialContext by the user. And
136
+// grpc passes it to the resolver and the balancer.
137
+//
138
+// If the target follows the naming spec, and the parsed scheme is registered with grpc, we will
139
+// parse the target string according to the spec. e.g. "dns://some_authority/foo.bar" will be parsed
140
+// into &Target{Scheme: "dns", Authority: "some_authority", Endpoint: "foo.bar"}
141
+//
142
+// If the target does not contain a scheme, we will apply the default scheme, and set the Target to
143
+// be the full target string. e.g. "foo.bar" will be parsed into
144
+// &Target{Scheme: resolver.GetDefaultScheme(), Endpoint: "foo.bar"}.
145
+//
146
+// If the parsed scheme is not registered (i.e. no corresponding resolver available to resolve the
147
+// endpoint), we set the Scheme to be the default scheme, and set the Endpoint to be the full target
148
+// string. e.g. target string "unknown_scheme://authority/endpoint" will be parsed into
149
+// &Target{Scheme: resolver.GetDefaultScheme(), Endpoint: "unknown_scheme://authority/endpoint"}.
135 150
 type Target struct {
136 151
 	Scheme    string
137 152
 	Authority string
... ...
@@ -138,19 +138,22 @@ func (ccr *ccResolverWrapper) NewServiceConfig(sc string) {
138 138
 		return
139 139
 	}
140 140
 	grpclog.Infof("ccResolverWrapper: got new service config: %v", sc)
141
+	c, err := parseServiceConfig(sc)
142
+	if err != nil {
143
+		return
144
+	}
141 145
 	if channelz.IsOn() {
142
-		ccr.addChannelzTraceEvent(resolver.State{Addresses: ccr.curState.Addresses, ServiceConfig: sc})
146
+		ccr.addChannelzTraceEvent(resolver.State{Addresses: ccr.curState.Addresses, ServiceConfig: c})
143 147
 	}
144
-	ccr.curState.ServiceConfig = sc
148
+	ccr.curState.ServiceConfig = c
145 149
 	ccr.cc.updateResolverState(ccr.curState)
146 150
 }
147 151
 
148 152
 func (ccr *ccResolverWrapper) addChannelzTraceEvent(s resolver.State) {
149
-	if s.ServiceConfig == ccr.curState.ServiceConfig && (len(ccr.curState.Addresses) == 0) == (len(s.Addresses) == 0) {
150
-		return
151
-	}
152 153
 	var updates []string
153
-	if s.ServiceConfig != ccr.curState.ServiceConfig {
154
+	oldSC, oldOK := ccr.curState.ServiceConfig.(*ServiceConfig)
155
+	newSC, newOK := s.ServiceConfig.(*ServiceConfig)
156
+	if oldOK != newOK || (oldOK && newOK && oldSC.rawJSONString != newSC.rawJSONString) {
154 157
 		updates = append(updates, "service config updated")
155 158
 	}
156 159
 	if len(ccr.curState.Addresses) > 0 && len(s.Addresses) == 0 {
... ...
@@ -694,14 +694,34 @@ func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interf
694 694
 	return nil
695 695
 }
696 696
 
697
+// Information about RPC
697 698
 type rpcInfo struct {
698
-	failfast bool
699
+	failfast      bool
700
+	preloaderInfo *compressorInfo
701
+}
702
+
703
+// Information about Preloader
704
+// Responsible for storing codec, and compressors
705
+// If stream (s) has  context s.Context which stores rpcInfo that has non nil
706
+// pointers to codec, and compressors, then we can use preparedMsg for Async message prep
707
+// and reuse marshalled bytes
708
+type compressorInfo struct {
709
+	codec baseCodec
710
+	cp    Compressor
711
+	comp  encoding.Compressor
699 712
 }
700 713
 
701 714
 type rpcInfoContextKey struct{}
702 715
 
703
-func newContextWithRPCInfo(ctx context.Context, failfast bool) context.Context {
704
-	return context.WithValue(ctx, rpcInfoContextKey{}, &rpcInfo{failfast: failfast})
716
+func newContextWithRPCInfo(ctx context.Context, failfast bool, codec baseCodec, cp Compressor, comp encoding.Compressor) context.Context {
717
+	return context.WithValue(ctx, rpcInfoContextKey{}, &rpcInfo{
718
+		failfast: failfast,
719
+		preloaderInfo: &compressorInfo{
720
+			codec: codec,
721
+			cp:    cp,
722
+			comp:  comp,
723
+		},
724
+	})
705 725
 }
706 726
 
707 727
 func rpcInfoFromContext(ctx context.Context) (s *rpcInfo, ok bool) {
... ...
@@ -42,6 +42,7 @@ import (
42 42
 	"google.golang.org/grpc/grpclog"
43 43
 	"google.golang.org/grpc/internal/binarylog"
44 44
 	"google.golang.org/grpc/internal/channelz"
45
+	"google.golang.org/grpc/internal/grpcsync"
45 46
 	"google.golang.org/grpc/internal/transport"
46 47
 	"google.golang.org/grpc/keepalive"
47 48
 	"google.golang.org/grpc/metadata"
... ...
@@ -56,6 +57,8 @@ const (
56 56
 	defaultServerMaxSendMessageSize    = math.MaxInt32
57 57
 )
58 58
 
59
+var statusOK = status.New(codes.OK, "")
60
+
59 61
 type methodHandler func(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor UnaryServerInterceptor) (interface{}, error)
60 62
 
61 63
 // MethodDesc represents an RPC service's method specification.
... ...
@@ -86,21 +89,19 @@ type service struct {
86 86
 
87 87
 // Server is a gRPC server to serve RPC requests.
88 88
 type Server struct {
89
-	opts options
89
+	opts serverOptions
90 90
 
91 91
 	mu     sync.Mutex // guards following
92 92
 	lis    map[net.Listener]bool
93
-	conns  map[io.Closer]bool
93
+	conns  map[transport.ServerTransport]bool
94 94
 	serve  bool
95 95
 	drain  bool
96 96
 	cv     *sync.Cond          // signaled when connections close for GracefulStop
97 97
 	m      map[string]*service // service name -> service info
98 98
 	events trace.EventLog
99 99
 
100
-	quit               chan struct{}
101
-	done               chan struct{}
102
-	quitOnce           sync.Once
103
-	doneOnce           sync.Once
100
+	quit               *grpcsync.Event
101
+	done               *grpcsync.Event
104 102
 	channelzRemoveOnce sync.Once
105 103
 	serveWG            sync.WaitGroup // counts active Serve goroutines for GracefulStop
106 104
 
... ...
@@ -108,7 +109,7 @@ type Server struct {
108 108
 	czData     *channelzData
109 109
 }
110 110
 
111
-type options struct {
111
+type serverOptions struct {
112 112
 	creds                 credentials.TransportCredentials
113 113
 	codec                 baseCodec
114 114
 	cp                    Compressor
... ...
@@ -131,7 +132,7 @@ type options struct {
131 131
 	maxHeaderListSize     *uint32
132 132
 }
133 133
 
134
-var defaultServerOptions = options{
134
+var defaultServerOptions = serverOptions{
135 135
 	maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
136 136
 	maxSendMessageSize:    defaultServerMaxSendMessageSize,
137 137
 	connectionTimeout:     120 * time.Second,
... ...
@@ -140,7 +141,33 @@ var defaultServerOptions = options{
140 140
 }
141 141
 
142 142
 // A ServerOption sets options such as credentials, codec and keepalive parameters, etc.
143
-type ServerOption func(*options)
143
+type ServerOption interface {
144
+	apply(*serverOptions)
145
+}
146
+
147
+// EmptyServerOption does not alter the server configuration. It can be embedded
148
+// in another structure to build custom server options.
149
+//
150
+// This API is EXPERIMENTAL.
151
+type EmptyServerOption struct{}
152
+
153
+func (EmptyServerOption) apply(*serverOptions) {}
154
+
155
+// funcServerOption wraps a function that modifies serverOptions into an
156
+// implementation of the ServerOption interface.
157
+type funcServerOption struct {
158
+	f func(*serverOptions)
159
+}
160
+
161
+func (fdo *funcServerOption) apply(do *serverOptions) {
162
+	fdo.f(do)
163
+}
164
+
165
+func newFuncServerOption(f func(*serverOptions)) *funcServerOption {
166
+	return &funcServerOption{
167
+		f: f,
168
+	}
169
+}
144 170
 
145 171
 // WriteBufferSize determines how much data can be batched before doing a write on the wire.
146 172
 // The corresponding memory allocation for this buffer will be twice the size to keep syscalls low.
... ...
@@ -148,9 +175,9 @@ type ServerOption func(*options)
148 148
 // Zero will disable the write buffer such that each write will be on underlying connection.
149 149
 // Note: A Send call may not directly translate to a write.
150 150
 func WriteBufferSize(s int) ServerOption {
151
-	return func(o *options) {
151
+	return newFuncServerOption(func(o *serverOptions) {
152 152
 		o.writeBufferSize = s
153
-	}
153
+	})
154 154
 }
155 155
 
156 156
 // ReadBufferSize lets you set the size of read buffer, this determines how much data can be read at most
... ...
@@ -159,25 +186,25 @@ func WriteBufferSize(s int) ServerOption {
159 159
 // Zero will disable read buffer for a connection so data framer can access the underlying
160 160
 // conn directly.
161 161
 func ReadBufferSize(s int) ServerOption {
162
-	return func(o *options) {
162
+	return newFuncServerOption(func(o *serverOptions) {
163 163
 		o.readBufferSize = s
164
-	}
164
+	})
165 165
 }
166 166
 
167 167
 // InitialWindowSize returns a ServerOption that sets window size for stream.
168 168
 // The lower bound for window size is 64K and any value smaller than that will be ignored.
169 169
 func InitialWindowSize(s int32) ServerOption {
170
-	return func(o *options) {
170
+	return newFuncServerOption(func(o *serverOptions) {
171 171
 		o.initialWindowSize = s
172
-	}
172
+	})
173 173
 }
174 174
 
175 175
 // InitialConnWindowSize returns a ServerOption that sets window size for a connection.
176 176
 // The lower bound for window size is 64K and any value smaller than that will be ignored.
177 177
 func InitialConnWindowSize(s int32) ServerOption {
178
-	return func(o *options) {
178
+	return newFuncServerOption(func(o *serverOptions) {
179 179
 		o.initialConnWindowSize = s
180
-	}
180
+	})
181 181
 }
182 182
 
183 183
 // KeepaliveParams returns a ServerOption that sets keepalive and max-age parameters for the server.
... ...
@@ -187,25 +214,25 @@ func KeepaliveParams(kp keepalive.ServerParameters) ServerOption {
187 187
 		kp.Time = time.Second
188 188
 	}
189 189
 
190
-	return func(o *options) {
190
+	return newFuncServerOption(func(o *serverOptions) {
191 191
 		o.keepaliveParams = kp
192
-	}
192
+	})
193 193
 }
194 194
 
195 195
 // KeepaliveEnforcementPolicy returns a ServerOption that sets keepalive enforcement policy for the server.
196 196
 func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption {
197
-	return func(o *options) {
197
+	return newFuncServerOption(func(o *serverOptions) {
198 198
 		o.keepalivePolicy = kep
199
-	}
199
+	})
200 200
 }
201 201
 
202 202
 // CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling.
203 203
 //
204 204
 // This will override any lookups by content-subtype for Codecs registered with RegisterCodec.
205 205
 func CustomCodec(codec Codec) ServerOption {
206
-	return func(o *options) {
206
+	return newFuncServerOption(func(o *serverOptions) {
207 207
 		o.codec = codec
208
-	}
208
+	})
209 209
 }
210 210
 
211 211
 // RPCCompressor returns a ServerOption that sets a compressor for outbound
... ...
@@ -216,9 +243,9 @@ func CustomCodec(codec Codec) ServerOption {
216 216
 //
217 217
 // Deprecated: use encoding.RegisterCompressor instead.
218 218
 func RPCCompressor(cp Compressor) ServerOption {
219
-	return func(o *options) {
219
+	return newFuncServerOption(func(o *serverOptions) {
220 220
 		o.cp = cp
221
-	}
221
+	})
222 222
 }
223 223
 
224 224
 // RPCDecompressor returns a ServerOption that sets a decompressor for inbound
... ...
@@ -227,9 +254,9 @@ func RPCCompressor(cp Compressor) ServerOption {
227 227
 //
228 228
 // Deprecated: use encoding.RegisterCompressor instead.
229 229
 func RPCDecompressor(dc Decompressor) ServerOption {
230
-	return func(o *options) {
230
+	return newFuncServerOption(func(o *serverOptions) {
231 231
 		o.dc = dc
232
-	}
232
+	})
233 233
 }
234 234
 
235 235
 // MaxMsgSize returns a ServerOption to set the max message size in bytes the server can receive.
... ...
@@ -243,73 +270,73 @@ func MaxMsgSize(m int) ServerOption {
243 243
 // MaxRecvMsgSize returns a ServerOption to set the max message size in bytes the server can receive.
244 244
 // If this is not set, gRPC uses the default 4MB.
245 245
 func MaxRecvMsgSize(m int) ServerOption {
246
-	return func(o *options) {
246
+	return newFuncServerOption(func(o *serverOptions) {
247 247
 		o.maxReceiveMessageSize = m
248
-	}
248
+	})
249 249
 }
250 250
 
251 251
 // MaxSendMsgSize returns a ServerOption to set the max message size in bytes the server can send.
252 252
 // If this is not set, gRPC uses the default `math.MaxInt32`.
253 253
 func MaxSendMsgSize(m int) ServerOption {
254
-	return func(o *options) {
254
+	return newFuncServerOption(func(o *serverOptions) {
255 255
 		o.maxSendMessageSize = m
256
-	}
256
+	})
257 257
 }
258 258
 
259 259
 // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
260 260
 // of concurrent streams to each ServerTransport.
261 261
 func MaxConcurrentStreams(n uint32) ServerOption {
262
-	return func(o *options) {
262
+	return newFuncServerOption(func(o *serverOptions) {
263 263
 		o.maxConcurrentStreams = n
264
-	}
264
+	})
265 265
 }
266 266
 
267 267
 // Creds returns a ServerOption that sets credentials for server connections.
268 268
 func Creds(c credentials.TransportCredentials) ServerOption {
269
-	return func(o *options) {
269
+	return newFuncServerOption(func(o *serverOptions) {
270 270
 		o.creds = c
271
-	}
271
+	})
272 272
 }
273 273
 
274 274
 // UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the
275 275
 // server. Only one unary interceptor can be installed. The construction of multiple
276 276
 // interceptors (e.g., chaining) can be implemented at the caller.
277 277
 func UnaryInterceptor(i UnaryServerInterceptor) ServerOption {
278
-	return func(o *options) {
278
+	return newFuncServerOption(func(o *serverOptions) {
279 279
 		if o.unaryInt != nil {
280 280
 			panic("The unary server interceptor was already set and may not be reset.")
281 281
 		}
282 282
 		o.unaryInt = i
283
-	}
283
+	})
284 284
 }
285 285
 
286 286
 // StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the
287 287
 // server. Only one stream interceptor can be installed.
288 288
 func StreamInterceptor(i StreamServerInterceptor) ServerOption {
289
-	return func(o *options) {
289
+	return newFuncServerOption(func(o *serverOptions) {
290 290
 		if o.streamInt != nil {
291 291
 			panic("The stream server interceptor was already set and may not be reset.")
292 292
 		}
293 293
 		o.streamInt = i
294
-	}
294
+	})
295 295
 }
296 296
 
297 297
 // InTapHandle returns a ServerOption that sets the tap handle for all the server
298 298
 // transport to be created. Only one can be installed.
299 299
 func InTapHandle(h tap.ServerInHandle) ServerOption {
300
-	return func(o *options) {
300
+	return newFuncServerOption(func(o *serverOptions) {
301 301
 		if o.inTapHandle != nil {
302 302
 			panic("The tap handle was already set and may not be reset.")
303 303
 		}
304 304
 		o.inTapHandle = h
305
-	}
305
+	})
306 306
 }
307 307
 
308 308
 // StatsHandler returns a ServerOption that sets the stats handler for the server.
309 309
 func StatsHandler(h stats.Handler) ServerOption {
310
-	return func(o *options) {
310
+	return newFuncServerOption(func(o *serverOptions) {
311 311
 		o.statsHandler = h
312
-	}
312
+	})
313 313
 }
314 314
 
315 315
 // UnknownServiceHandler returns a ServerOption that allows for adding a custom
... ...
@@ -319,7 +346,7 @@ func StatsHandler(h stats.Handler) ServerOption {
319 319
 // The handling function has full access to the Context of the request and the
320 320
 // stream, and the invocation bypasses interceptors.
321 321
 func UnknownServiceHandler(streamHandler StreamHandler) ServerOption {
322
-	return func(o *options) {
322
+	return newFuncServerOption(func(o *serverOptions) {
323 323
 		o.unknownStreamDesc = &StreamDesc{
324 324
 			StreamName: "unknown_service_handler",
325 325
 			Handler:    streamHandler,
... ...
@@ -327,7 +354,7 @@ func UnknownServiceHandler(streamHandler StreamHandler) ServerOption {
327 327
 			ClientStreams: true,
328 328
 			ServerStreams: true,
329 329
 		}
330
-	}
330
+	})
331 331
 }
332 332
 
333 333
 // ConnectionTimeout returns a ServerOption that sets the timeout for
... ...
@@ -337,17 +364,17 @@ func UnknownServiceHandler(streamHandler StreamHandler) ServerOption {
337 337
 //
338 338
 // This API is EXPERIMENTAL.
339 339
 func ConnectionTimeout(d time.Duration) ServerOption {
340
-	return func(o *options) {
340
+	return newFuncServerOption(func(o *serverOptions) {
341 341
 		o.connectionTimeout = d
342
-	}
342
+	})
343 343
 }
344 344
 
345 345
 // MaxHeaderListSize returns a ServerOption that sets the max (uncompressed) size
346 346
 // of header list that the server is prepared to accept.
347 347
 func MaxHeaderListSize(s uint32) ServerOption {
348
-	return func(o *options) {
348
+	return newFuncServerOption(func(o *serverOptions) {
349 349
 		o.maxHeaderListSize = &s
350
-	}
350
+	})
351 351
 }
352 352
 
353 353
 // NewServer creates a gRPC server which has no service registered and has not
... ...
@@ -355,15 +382,15 @@ func MaxHeaderListSize(s uint32) ServerOption {
355 355
 func NewServer(opt ...ServerOption) *Server {
356 356
 	opts := defaultServerOptions
357 357
 	for _, o := range opt {
358
-		o(&opts)
358
+		o.apply(&opts)
359 359
 	}
360 360
 	s := &Server{
361 361
 		lis:    make(map[net.Listener]bool),
362 362
 		opts:   opts,
363
-		conns:  make(map[io.Closer]bool),
363
+		conns:  make(map[transport.ServerTransport]bool),
364 364
 		m:      make(map[string]*service),
365
-		quit:   make(chan struct{}),
366
-		done:   make(chan struct{}),
365
+		quit:   grpcsync.NewEvent(),
366
+		done:   grpcsync.NewEvent(),
367 367
 		czData: new(channelzData),
368 368
 	}
369 369
 	s.cv = sync.NewCond(&s.mu)
... ...
@@ -530,11 +557,9 @@ func (s *Server) Serve(lis net.Listener) error {
530 530
 	s.serveWG.Add(1)
531 531
 	defer func() {
532 532
 		s.serveWG.Done()
533
-		select {
534
-		// Stop or GracefulStop called; block until done and return nil.
535
-		case <-s.quit:
536
-			<-s.done
537
-		default:
533
+		if s.quit.HasFired() {
534
+			// Stop or GracefulStop called; block until done and return nil.
535
+			<-s.done.Done()
538 536
 		}
539 537
 	}()
540 538
 
... ...
@@ -577,7 +602,7 @@ func (s *Server) Serve(lis net.Listener) error {
577 577
 				timer := time.NewTimer(tempDelay)
578 578
 				select {
579 579
 				case <-timer.C:
580
-				case <-s.quit:
580
+				case <-s.quit.Done():
581 581
 					timer.Stop()
582 582
 					return nil
583 583
 				}
... ...
@@ -587,10 +612,8 @@ func (s *Server) Serve(lis net.Listener) error {
587 587
 			s.printf("done serving; Accept = %v", err)
588 588
 			s.mu.Unlock()
589 589
 
590
-			select {
591
-			case <-s.quit:
590
+			if s.quit.HasFired() {
592 591
 				return nil
593
-			default:
594 592
 			}
595 593
 			return err
596 594
 		}
... ...
@@ -611,6 +634,10 @@ func (s *Server) Serve(lis net.Listener) error {
611 611
 // handleRawConn forks a goroutine to handle a just-accepted connection that
612 612
 // has not had any I/O performed on it yet.
613 613
 func (s *Server) handleRawConn(rawConn net.Conn) {
614
+	if s.quit.HasFired() {
615
+		rawConn.Close()
616
+		return
617
+	}
614 618
 	rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout))
615 619
 	conn, authInfo, err := s.useTransportAuthenticator(rawConn)
616 620
 	if err != nil {
... ...
@@ -627,14 +654,6 @@ func (s *Server) handleRawConn(rawConn net.Conn) {
627 627
 		return
628 628
 	}
629 629
 
630
-	s.mu.Lock()
631
-	if s.conns == nil {
632
-		s.mu.Unlock()
633
-		conn.Close()
634
-		return
635
-	}
636
-	s.mu.Unlock()
637
-
638 630
 	// Finish handshaking (HTTP2)
639 631
 	st := s.newHTTP2Transport(conn, authInfo)
640 632
 	if st == nil {
... ...
@@ -742,6 +761,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
742 742
 // traceInfo returns a traceInfo and associates it with stream, if tracing is enabled.
743 743
 // If tracing is not enabled, it returns nil.
744 744
 func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) {
745
+	if !EnableTracing {
746
+		return nil
747
+	}
745 748
 	tr, ok := trace.FromContext(stream.Context())
746 749
 	if !ok {
747 750
 		return nil
... ...
@@ -760,27 +782,27 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
760 760
 	return trInfo
761 761
 }
762 762
 
763
-func (s *Server) addConn(c io.Closer) bool {
763
+func (s *Server) addConn(st transport.ServerTransport) bool {
764 764
 	s.mu.Lock()
765 765
 	defer s.mu.Unlock()
766 766
 	if s.conns == nil {
767
-		c.Close()
767
+		st.Close()
768 768
 		return false
769 769
 	}
770 770
 	if s.drain {
771 771
 		// Transport added after we drained our existing conns: drain it
772 772
 		// immediately.
773
-		c.(transport.ServerTransport).Drain()
773
+		st.Drain()
774 774
 	}
775
-	s.conns[c] = true
775
+	s.conns[st] = true
776 776
 	return true
777 777
 }
778 778
 
779
-func (s *Server) removeConn(c io.Closer) {
779
+func (s *Server) removeConn(st transport.ServerTransport) {
780 780
 	s.mu.Lock()
781 781
 	defer s.mu.Unlock()
782 782
 	if s.conns != nil {
783
-		delete(s.conns, c)
783
+		delete(s.conns, st)
784 784
 		s.cv.Broadcast()
785 785
 	}
786 786
 }
... ...
@@ -952,10 +974,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
952 952
 		}
953 953
 		if sh != nil {
954 954
 			sh.HandleRPC(stream.Context(), &stats.InPayload{
955
-				RecvTime: time.Now(),
956
-				Payload:  v,
957
-				Data:     d,
958
-				Length:   len(d),
955
+				RecvTime:   time.Now(),
956
+				Payload:    v,
957
+				WireLength: payInfo.wireLength,
958
+				Data:       d,
959
+				Length:     len(d),
959 960
 			})
960 961
 		}
961 962
 		if binlog != nil {
... ...
@@ -1051,7 +1074,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
1051 1051
 	// TODO: Should we be logging if writing status failed here, like above?
1052 1052
 	// Should the logging be in WriteStatus?  Should we ignore the WriteStatus
1053 1053
 	// error or allow the stats handler to see it?
1054
-	err = t.WriteStatus(stream, status.New(codes.OK, ""))
1054
+	err = t.WriteStatus(stream, statusOK)
1055 1055
 	if binlog != nil {
1056 1056
 		binlog.Log(&binarylog.ServerTrailer{
1057 1057
 			Trailer: stream.Trailer(),
... ...
@@ -1209,7 +1232,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
1209 1209
 		ss.trInfo.tr.LazyLog(stringer("OK"), false)
1210 1210
 		ss.mu.Unlock()
1211 1211
 	}
1212
-	err = t.WriteStatus(ss.s, status.New(codes.OK, ""))
1212
+	err = t.WriteStatus(ss.s, statusOK)
1213 1213
 	if ss.binlog != nil {
1214 1214
 		ss.binlog.Log(&binarylog.ServerTrailer{
1215 1215
 			Trailer: ss.s.Trailer(),
... ...
@@ -1326,15 +1349,11 @@ func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream
1326 1326
 // pending RPCs on the client side will get notified by connection
1327 1327
 // errors.
1328 1328
 func (s *Server) Stop() {
1329
-	s.quitOnce.Do(func() {
1330
-		close(s.quit)
1331
-	})
1329
+	s.quit.Fire()
1332 1330
 
1333 1331
 	defer func() {
1334 1332
 		s.serveWG.Wait()
1335
-		s.doneOnce.Do(func() {
1336
-			close(s.done)
1337
-		})
1333
+		s.done.Fire()
1338 1334
 	}()
1339 1335
 
1340 1336
 	s.channelzRemoveOnce.Do(func() {
... ...
@@ -1371,15 +1390,8 @@ func (s *Server) Stop() {
1371 1371
 // accepting new connections and RPCs and blocks until all the pending RPCs are
1372 1372
 // finished.
1373 1373
 func (s *Server) GracefulStop() {
1374
-	s.quitOnce.Do(func() {
1375
-		close(s.quit)
1376
-	})
1377
-
1378
-	defer func() {
1379
-		s.doneOnce.Do(func() {
1380
-			close(s.done)
1381
-		})
1382
-	}()
1374
+	s.quit.Fire()
1375
+	defer s.done.Fire()
1383 1376
 
1384 1377
 	s.channelzRemoveOnce.Do(func() {
1385 1378
 		if channelz.IsOn() {
... ...
@@ -1397,8 +1409,8 @@ func (s *Server) GracefulStop() {
1397 1397
 	}
1398 1398
 	s.lis = nil
1399 1399
 	if !s.drain {
1400
-		for c := range s.conns {
1401
-			c.(transport.ServerTransport).Drain()
1400
+		for st := range s.conns {
1401
+			st.Drain()
1402 1402
 		}
1403 1403
 		s.drain = true
1404 1404
 	}
... ...
@@ -25,8 +25,11 @@ import (
25 25
 	"strings"
26 26
 	"time"
27 27
 
28
+	"google.golang.org/grpc/balancer"
28 29
 	"google.golang.org/grpc/codes"
29 30
 	"google.golang.org/grpc/grpclog"
31
+	"google.golang.org/grpc/internal"
32
+	"google.golang.org/grpc/serviceconfig"
30 33
 )
31 34
 
32 35
 const maxInt = int(^uint(0) >> 1)
... ...
@@ -61,6 +64,11 @@ type MethodConfig struct {
61 61
 	retryPolicy *retryPolicy
62 62
 }
63 63
 
64
+type lbConfig struct {
65
+	name string
66
+	cfg  serviceconfig.LoadBalancingConfig
67
+}
68
+
64 69
 // ServiceConfig is provided by the service provider and contains parameters for how
65 70
 // clients that connect to the service should behave.
66 71
 //
... ...
@@ -68,10 +76,18 @@ type MethodConfig struct {
68 68
 // through name resolver, as specified here
69 69
 // https://github.com/grpc/grpc/blob/master/doc/service_config.md
70 70
 type ServiceConfig struct {
71
-	// LB is the load balancer the service providers recommends. The balancer specified
72
-	// via grpc.WithBalancer will override this.
71
+	serviceconfig.Config
72
+
73
+	// LB is the load balancer the service providers recommends. The balancer
74
+	// specified via grpc.WithBalancer will override this.  This is deprecated;
75
+	// lbConfigs is preferred.  If lbConfig and LB are both present, lbConfig
76
+	// will be used.
73 77
 	LB *string
74 78
 
79
+	// lbConfig is the service config's load balancing configuration.  If
80
+	// lbConfig and LB are both present, lbConfig will be used.
81
+	lbConfig *lbConfig
82
+
75 83
 	// Methods contains a map for the methods in this service.  If there is an
76 84
 	// exact match for a method (i.e. /service/method) in the map, use the
77 85
 	// corresponding MethodConfig.  If there's no exact match, look for the
... ...
@@ -233,15 +249,27 @@ type jsonMC struct {
233 233
 	RetryPolicy             *jsonRetryPolicy
234 234
 }
235 235
 
236
+type loadBalancingConfig map[string]json.RawMessage
237
+
236 238
 // TODO(lyuxuan): delete this struct after cleaning up old service config implementation.
237 239
 type jsonSC struct {
238 240
 	LoadBalancingPolicy *string
241
+	LoadBalancingConfig *[]loadBalancingConfig
239 242
 	MethodConfig        *[]jsonMC
240 243
 	RetryThrottling     *retryThrottlingPolicy
241 244
 	HealthCheckConfig   *healthCheckConfig
242 245
 }
243 246
 
247
+func init() {
248
+	internal.ParseServiceConfig = func(sc string) (interface{}, error) {
249
+		return parseServiceConfig(sc)
250
+	}
251
+}
252
+
244 253
 func parseServiceConfig(js string) (*ServiceConfig, error) {
254
+	if len(js) == 0 {
255
+		return nil, fmt.Errorf("no JSON service config provided")
256
+	}
245 257
 	var rsc jsonSC
246 258
 	err := json.Unmarshal([]byte(js), &rsc)
247 259
 	if err != nil {
... ...
@@ -255,10 +283,38 @@ func parseServiceConfig(js string) (*ServiceConfig, error) {
255 255
 		healthCheckConfig: rsc.HealthCheckConfig,
256 256
 		rawJSONString:     js,
257 257
 	}
258
+	if rsc.LoadBalancingConfig != nil {
259
+		for i, lbcfg := range *rsc.LoadBalancingConfig {
260
+			if len(lbcfg) != 1 {
261
+				err := fmt.Errorf("invalid loadBalancingConfig: entry %v does not contain exactly 1 policy/config pair: %q", i, lbcfg)
262
+				grpclog.Warningf(err.Error())
263
+				return nil, err
264
+			}
265
+			var name string
266
+			var jsonCfg json.RawMessage
267
+			for name, jsonCfg = range lbcfg {
268
+			}
269
+			builder := balancer.Get(name)
270
+			if builder == nil {
271
+				continue
272
+			}
273
+			sc.lbConfig = &lbConfig{name: name}
274
+			if parser, ok := builder.(balancer.ConfigParser); ok {
275
+				var err error
276
+				sc.lbConfig.cfg, err = parser.ParseConfig(jsonCfg)
277
+				if err != nil {
278
+					return nil, fmt.Errorf("error parsing loadBalancingConfig for policy %q: %v", name, err)
279
+				}
280
+			} else if string(jsonCfg) != "{}" {
281
+				grpclog.Warningf("non-empty balancer configuration %q, but balancer does not implement ParseConfig", string(jsonCfg))
282
+			}
283
+			break
284
+		}
285
+	}
286
+
258 287
 	if rsc.MethodConfig == nil {
259 288
 		return &sc, nil
260 289
 	}
261
-
262 290
 	for _, m := range *rsc.MethodConfig {
263 291
 		if m.Name == nil {
264 292
 			continue
... ...
@@ -299,11 +355,11 @@ func parseServiceConfig(js string) (*ServiceConfig, error) {
299 299
 	}
300 300
 
301 301
 	if sc.retryThrottling != nil {
302
-		if sc.retryThrottling.MaxTokens <= 0 ||
303
-			sc.retryThrottling.MaxTokens > 1000 ||
304
-			sc.retryThrottling.TokenRatio <= 0 {
305
-			// Illegal throttling config; disable throttling.
306
-			sc.retryThrottling = nil
302
+		if mt := sc.retryThrottling.MaxTokens; mt <= 0 || mt > 1000 {
303
+			return nil, fmt.Errorf("invalid retry throttling config: maxTokens (%v) out of range (0, 1000]", mt)
304
+		}
305
+		if tr := sc.retryThrottling.TokenRatio; tr <= 0 {
306
+			return nil, fmt.Errorf("invalid retry throttling config: tokenRatio (%v) may not be negative", tr)
307 307
 		}
308 308
 	}
309 309
 	return &sc, nil
310 310
new file mode 100644
... ...
@@ -0,0 +1,48 @@
0
+/*
1
+ *
2
+ * Copyright 2019 gRPC authors.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ *     http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ *
16
+ */
17
+
18
+// Package serviceconfig defines types and methods for operating on gRPC
19
+// service configs.
20
+//
21
+// This package is EXPERIMENTAL.
22
+package serviceconfig
23
+
24
+import (
25
+	"google.golang.org/grpc/internal"
26
+)
27
+
28
+// Config represents an opaque data structure holding a service config.
29
+type Config interface {
30
+	isConfig()
31
+}
32
+
33
+// LoadBalancingConfig represents an opaque data structure holding a load
34
+// balancer config.
35
+type LoadBalancingConfig interface {
36
+	isLoadBalancingConfig()
37
+}
38
+
39
+// Parse parses the JSON service config provided into an internal form or
40
+// returns an error if the config is invalid.
41
+func Parse(ServiceConfigJSON string) (Config, error) {
42
+	c, err := internal.ParseServiceConfig(ServiceConfigJSON)
43
+	if err != nil {
44
+		return nil, err
45
+	}
46
+	return c.(Config), err
47
+}
... ...
@@ -36,8 +36,15 @@ import (
36 36
 	"github.com/golang/protobuf/ptypes"
37 37
 	spb "google.golang.org/genproto/googleapis/rpc/status"
38 38
 	"google.golang.org/grpc/codes"
39
+	"google.golang.org/grpc/internal"
39 40
 )
40 41
 
42
+func init() {
43
+	internal.StatusRawProto = statusRawProto
44
+}
45
+
46
+func statusRawProto(s *Status) *spb.Status { return s.s }
47
+
41 48
 // statusError is an alias of a status proto.  It implements error and Status,
42 49
 // and a nil statusError should never be returned by this package.
43 50
 type statusError spb.Status
... ...
@@ -51,6 +58,17 @@ func (se *statusError) GRPCStatus() *Status {
51 51
 	return &Status{s: (*spb.Status)(se)}
52 52
 }
53 53
 
54
+// Is implements future error.Is functionality.
55
+// A statusError is equivalent if the code and message are identical.
56
+func (se *statusError) Is(target error) bool {
57
+	tse, ok := target.(*statusError)
58
+	if !ok {
59
+		return false
60
+	}
61
+
62
+	return proto.Equal((*spb.Status)(se), (*spb.Status)(tse))
63
+}
64
+
54 65
 // Status represents an RPC status code, message, and details.  It is immutable
55 66
 // and should be created with New, Newf, or FromProto.
56 67
 type Status struct {
... ...
@@ -125,7 +143,7 @@ func FromProto(s *spb.Status) *Status {
125 125
 // Status is returned with codes.Unknown and the original error message.
126 126
 func FromError(err error) (s *Status, ok bool) {
127 127
 	if err == nil {
128
-		return &Status{s: &spb.Status{Code: int32(codes.OK)}}, true
128
+		return nil, true
129 129
 	}
130 130
 	if se, ok := err.(interface {
131 131
 		GRPCStatus() *Status
... ...
@@ -199,7 +217,7 @@ func Code(err error) codes.Code {
199 199
 func FromContextError(err error) *Status {
200 200
 	switch err {
201 201
 	case nil:
202
-		return New(codes.OK, "")
202
+		return nil
203 203
 	case context.DeadlineExceeded:
204 204
 		return New(codes.DeadlineExceeded, err.Error())
205 205
 	case context.Canceled:
... ...
@@ -30,7 +30,6 @@ import (
30 30
 	"golang.org/x/net/trace"
31 31
 	"google.golang.org/grpc/balancer"
32 32
 	"google.golang.org/grpc/codes"
33
-	"google.golang.org/grpc/connectivity"
34 33
 	"google.golang.org/grpc/encoding"
35 34
 	"google.golang.org/grpc/grpclog"
36 35
 	"google.golang.org/grpc/internal/balancerload"
... ...
@@ -245,7 +244,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
245 245
 		trInfo.tr.LazyLog(&trInfo.firstLine, false)
246 246
 		ctx = trace.NewContext(ctx, trInfo.tr)
247 247
 	}
248
-	ctx = newContextWithRPCInfo(ctx, c.failFast)
248
+	ctx = newContextWithRPCInfo(ctx, c.failFast, c.codec, cp, comp)
249 249
 	sh := cc.dopts.copts.StatsHandler
250 250
 	var beginTime time.Time
251 251
 	if sh != nil {
... ...
@@ -328,13 +327,23 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
328 328
 	return cs, nil
329 329
 }
330 330
 
331
-func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo *traceInfo) error {
332
-	cs.attempt = &csAttempt{
331
+// newAttemptLocked creates a new attempt with a transport.
332
+// If it succeeds, then it replaces clientStream's attempt with this new attempt.
333
+func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo *traceInfo) (retErr error) {
334
+	newAttempt := &csAttempt{
333 335
 		cs:           cs,
334 336
 		dc:           cs.cc.dopts.dc,
335 337
 		statsHandler: sh,
336 338
 		trInfo:       trInfo,
337 339
 	}
340
+	defer func() {
341
+		if retErr != nil {
342
+			// This attempt is not set in the clientStream, so it's finish won't
343
+			// be called. Call it here for stats and trace in case they are not
344
+			// nil.
345
+			newAttempt.finish(retErr)
346
+		}
347
+	}()
338 348
 
339 349
 	if err := cs.ctx.Err(); err != nil {
340 350
 		return toRPCErr(err)
... ...
@@ -346,8 +355,9 @@ func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo *traceInfo) er
346 346
 	if trInfo != nil {
347 347
 		trInfo.firstLine.SetRemoteAddr(t.RemoteAddr())
348 348
 	}
349
-	cs.attempt.t = t
350
-	cs.attempt.done = done
349
+	newAttempt.t = t
350
+	newAttempt.done = done
351
+	cs.attempt = newAttempt
351 352
 	return nil
352 353
 }
353 354
 
... ...
@@ -396,11 +406,18 @@ type clientStream struct {
396 396
 	serverHeaderBinlogged bool
397 397
 
398 398
 	mu                      sync.Mutex
399
-	firstAttempt            bool       // if true, transparent retry is valid
400
-	numRetries              int        // exclusive of transparent retry attempt(s)
401
-	numRetriesSincePushback int        // retries since pushback; to reset backoff
402
-	finished                bool       // TODO: replace with atomic cmpxchg or sync.Once?
403
-	attempt                 *csAttempt // the active client stream attempt
399
+	firstAttempt            bool // if true, transparent retry is valid
400
+	numRetries              int  // exclusive of transparent retry attempt(s)
401
+	numRetriesSincePushback int  // retries since pushback; to reset backoff
402
+	finished                bool // TODO: replace with atomic cmpxchg or sync.Once?
403
+	// attempt is the active client stream attempt.
404
+	// The only place where it is written is the newAttemptLocked method and this method never writes nil.
405
+	// So, attempt can be nil only inside newClientStream function when clientStream is first created.
406
+	// One of the first things done after clientStream's creation, is to call newAttemptLocked which either
407
+	// assigns a non nil value to the attempt or returns an error. If an error is returned from newAttemptLocked,
408
+	// then newClientStream calls finish on the clientStream and returns. So, finish method is the only
409
+	// place where we need to check if the attempt is nil.
410
+	attempt *csAttempt
404 411
 	// TODO(hedging): hedging will have multiple attempts simultaneously.
405 412
 	committed  bool                       // active attempt committed for retry?
406 413
 	buffer     []func(a *csAttempt) error // operations to replay on retry
... ...
@@ -458,8 +475,8 @@ func (cs *clientStream) shouldRetry(err error) error {
458 458
 	if cs.attempt.s != nil {
459 459
 		<-cs.attempt.s.Done()
460 460
 	}
461
-	if cs.firstAttempt && !cs.callInfo.failFast && (cs.attempt.s == nil || cs.attempt.s.Unprocessed()) {
462
-		// First attempt, wait-for-ready, stream unprocessed: transparently retry.
461
+	if cs.firstAttempt && (cs.attempt.s == nil || cs.attempt.s.Unprocessed()) {
462
+		// First attempt, stream unprocessed: transparently retry.
463 463
 		cs.firstAttempt = false
464 464
 		return nil
465 465
 	}
... ...
@@ -677,15 +694,13 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
677 677
 	if !cs.desc.ClientStreams {
678 678
 		cs.sentLast = true
679 679
 	}
680
-	data, err := encode(cs.codec, m)
681
-	if err != nil {
682
-		return err
683
-	}
684
-	compData, err := compress(data, cs.cp, cs.comp)
680
+
681
+	// load hdr, payload, data
682
+	hdr, payload, data, err := prepareMsg(m, cs.codec, cs.cp, cs.comp)
685 683
 	if err != nil {
686 684
 		return err
687 685
 	}
688
-	hdr, payload := msgHeader(data, compData)
686
+
689 687
 	// TODO(dfawley): should we be checking len(data) instead?
690 688
 	if len(payload) > *cs.callInfo.maxSendMessageSize {
691 689
 		return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), *cs.callInfo.maxSendMessageSize)
... ...
@@ -808,11 +823,11 @@ func (cs *clientStream) finish(err error) {
808 808
 	}
809 809
 	if cs.attempt != nil {
810 810
 		cs.attempt.finish(err)
811
-	}
812
-	// after functions all rely upon having a stream.
813
-	if cs.attempt.s != nil {
814
-		for _, o := range cs.opts {
815
-			o.after(cs.callInfo)
811
+		// after functions all rely upon having a stream.
812
+		if cs.attempt.s != nil {
813
+			for _, o := range cs.opts {
814
+				o.after(cs.callInfo)
815
+			}
816 816
 		}
817 817
 	}
818 818
 	cs.cancel()
... ...
@@ -967,19 +982,18 @@ func (a *csAttempt) finish(err error) {
967 967
 	a.mu.Unlock()
968 968
 }
969 969
 
970
-func (ac *addrConn) newClientStream(ctx context.Context, desc *StreamDesc, method string, t transport.ClientTransport, opts ...CallOption) (_ ClientStream, err error) {
971
-	ac.mu.Lock()
972
-	if ac.transport != t {
973
-		ac.mu.Unlock()
974
-		return nil, status.Error(codes.Canceled, "the provided transport is no longer valid to use")
975
-	}
976
-	// transition to CONNECTING state when an attempt starts
977
-	if ac.state != connectivity.Connecting {
978
-		ac.updateConnectivityState(connectivity.Connecting)
979
-		ac.cc.handleSubConnStateChange(ac.acbw, ac.state)
980
-	}
981
-	ac.mu.Unlock()
982
-
970
+// newClientStream creates a ClientStream with the specified transport, on the
971
+// given addrConn.
972
+//
973
+// It's expected that the given transport is either the same one in addrConn, or
974
+// is already closed. To avoid race, transport is specified separately, instead
975
+// of using ac.transpot.
976
+//
977
+// Main difference between this and ClientConn.NewStream:
978
+// - no retry
979
+// - no service config (or wait for service config)
980
+// - no tracing or stats
981
+func newNonRetryClientStream(ctx context.Context, desc *StreamDesc, method string, t transport.ClientTransport, ac *addrConn, opts ...CallOption) (_ ClientStream, err error) {
983 982
 	if t == nil {
984 983
 		// TODO: return RPC error here?
985 984
 		return nil, errors.New("transport provided is nil")
... ...
@@ -987,14 +1001,6 @@ func (ac *addrConn) newClientStream(ctx context.Context, desc *StreamDesc, metho
987 987
 	// defaultCallInfo contains unnecessary info(i.e. failfast, maxRetryRPCBufferSize), so we just initialize an empty struct.
988 988
 	c := &callInfo{}
989 989
 
990
-	for _, o := range opts {
991
-		if err := o.before(c); err != nil {
992
-			return nil, toRPCErr(err)
993
-		}
994
-	}
995
-	c.maxReceiveMessageSize = getMaxSize(nil, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize)
996
-	c.maxSendMessageSize = getMaxSize(nil, c.maxSendMessageSize, defaultServerMaxSendMessageSize)
997
-
998 990
 	// Possible context leak:
999 991
 	// The cancel function for the child context we create will only be called
1000 992
 	// when RecvMsg returns a non-nil error, if the ClientConn is closed, or if
... ...
@@ -1007,6 +1013,13 @@ func (ac *addrConn) newClientStream(ctx context.Context, desc *StreamDesc, metho
1007 1007
 		}
1008 1008
 	}()
1009 1009
 
1010
+	for _, o := range opts {
1011
+		if err := o.before(c); err != nil {
1012
+			return nil, toRPCErr(err)
1013
+		}
1014
+	}
1015
+	c.maxReceiveMessageSize = getMaxSize(nil, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize)
1016
+	c.maxSendMessageSize = getMaxSize(nil, c.maxSendMessageSize, defaultServerMaxSendMessageSize)
1010 1017
 	if err := setCallInfoCodec(c); err != nil {
1011 1018
 		return nil, err
1012 1019
 	}
... ...
@@ -1039,6 +1052,7 @@ func (ac *addrConn) newClientStream(ctx context.Context, desc *StreamDesc, metho
1039 1039
 		callHdr.Creds = c.creds
1040 1040
 	}
1041 1041
 
1042
+	// Use a special addrConnStream to avoid retry.
1042 1043
 	as := &addrConnStream{
1043 1044
 		callHdr:  callHdr,
1044 1045
 		ac:       ac,
... ...
@@ -1150,15 +1164,13 @@ func (as *addrConnStream) SendMsg(m interface{}) (err error) {
1150 1150
 	if !as.desc.ClientStreams {
1151 1151
 		as.sentLast = true
1152 1152
 	}
1153
-	data, err := encode(as.codec, m)
1154
-	if err != nil {
1155
-		return err
1156
-	}
1157
-	compData, err := compress(data, as.cp, as.comp)
1153
+
1154
+	// load hdr, payload, data
1155
+	hdr, payld, _, err := prepareMsg(m, as.codec, as.cp, as.comp)
1158 1156
 	if err != nil {
1159 1157
 		return err
1160 1158
 	}
1161
-	hdr, payld := msgHeader(data, compData)
1159
+
1162 1160
 	// TODO(dfawley): should we be checking len(data) instead?
1163 1161
 	if len(payld) > *as.callInfo.maxSendMessageSize {
1164 1162
 		return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payld), *as.callInfo.maxSendMessageSize)
... ...
@@ -1395,15 +1407,13 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
1395 1395
 			ss.t.IncrMsgSent()
1396 1396
 		}
1397 1397
 	}()
1398
-	data, err := encode(ss.codec, m)
1399
-	if err != nil {
1400
-		return err
1401
-	}
1402
-	compData, err := compress(data, ss.cp, ss.comp)
1398
+
1399
+	// load hdr, payload, data
1400
+	hdr, payload, data, err := prepareMsg(m, ss.codec, ss.cp, ss.comp)
1403 1401
 	if err != nil {
1404 1402
 		return err
1405 1403
 	}
1406
-	hdr, payload := msgHeader(data, compData)
1404
+
1407 1405
 	// TODO(dfawley): should we be checking len(data) instead?
1408 1406
 	if len(payload) > ss.maxSendMessageSize {
1409 1407
 		return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), ss.maxSendMessageSize)
... ...
@@ -1496,3 +1506,24 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
1496 1496
 func MethodFromServerStream(stream ServerStream) (string, bool) {
1497 1497
 	return Method(stream.Context())
1498 1498
 }
1499
+
1500
+// prepareMsg returns the hdr, payload and data
1501
+// using the compressors passed or using the
1502
+// passed preparedmsg
1503
+func prepareMsg(m interface{}, codec baseCodec, cp Compressor, comp encoding.Compressor) (hdr, payload, data []byte, err error) {
1504
+	if preparedMsg, ok := m.(*PreparedMsg); ok {
1505
+		return preparedMsg.hdr, preparedMsg.payload, preparedMsg.encodedData, nil
1506
+	}
1507
+	// The input interface is not a prepared msg.
1508
+	// Marshal and Compress the data at this point
1509
+	data, err = encode(codec, m)
1510
+	if err != nil {
1511
+		return nil, nil, nil, err
1512
+	}
1513
+	compData, err := compress(data, cp, comp)
1514
+	if err != nil {
1515
+		return nil, nil, nil, err
1516
+	}
1517
+	hdr, payload = msgHeader(data, compData)
1518
+	return hdr, payload, data, nil
1519
+}
... ...
@@ -19,4 +19,4 @@
19 19
 package grpc
20 20
 
21 21
 // Version is the current grpc version.
22
-const Version = "1.20.1"
22
+const Version = "1.23.0"