full diff: https://github.com/containerd/ttrpc/compare/699c4e40d1e7416e08bf7019c7ce2e9beced4636...92c8520ef9f86600c650dd540266a007bf03670f
changes:
- containerd/ttrpc#37 Handle EOF to prevent file descriptor leak
- containerd/ttrpc#38 Improve connection error handling
- containerd/ttrpc#40 Support headers
- containerd/ttrpc#41 Add client and server unary interceptors
- containerd/ttrpc#43 metadata as KeyValue type
- containerd/ttrpc#42 Refactor close handling for ttrpc clients
- containerd/ttrpc#44 Fix method full name generation
- containerd/ttrpc#46 Client.Call(): do not return error if no Status is set (gRPC v1.23 and up)
- containerd/ttrpc#49 Handle ok status
Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
(cherry picked from commit 8769255d1bb9c469d4f2966e7e9869a9f126f9e9)
Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
| ... | ... |
@@ -126,7 +126,7 @@ github.com/containerd/cgroups 4994991857f9b0ae8dc439551e8b |
| 126 | 126 |
github.com/containerd/console 0650fd9eeb50bab4fc99dceb9f2e14cf58f36e7f |
| 127 | 127 |
github.com/containerd/go-runc 7d11b49dc0769f6dbb0d1b19f3d48524d1bad9ad |
| 128 | 128 |
github.com/containerd/typeurl 2a93cfde8c20b23de8eb84a5adbc234ddf7a9e8d |
| 129 |
-github.com/containerd/ttrpc 699c4e40d1e7416e08bf7019c7ce2e9beced4636 |
|
| 129 |
+github.com/containerd/ttrpc 92c8520ef9f86600c650dd540266a007bf03670f |
|
| 130 | 130 |
github.com/gogo/googleapis d31c731455cb061f42baff3bda55bad0118b126b # v1.2.0 |
| 131 | 131 |
|
| 132 | 132 |
# cluster |
| ... | ... |
@@ -18,7 +18,6 @@ package ttrpc |
| 18 | 18 |
|
| 19 | 19 |
import ( |
| 20 | 20 |
"bufio" |
| 21 |
- "context" |
|
| 22 | 21 |
"encoding/binary" |
| 23 | 22 |
"io" |
| 24 | 23 |
"net" |
| ... | ... |
@@ -98,7 +97,7 @@ func newChannel(conn net.Conn) *channel {
|
| 98 | 98 |
// returned will be valid and caller should send that along to |
| 99 | 99 |
// the correct consumer. The bytes on the underlying channel |
| 100 | 100 |
// will be discarded. |
| 101 |
-func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
|
|
| 101 |
+func (ch *channel) recv() (messageHeader, []byte, error) {
|
|
| 102 | 102 |
mh, err := readMessageHeader(ch.hrbuf[:], ch.br) |
| 103 | 103 |
if err != nil {
|
| 104 | 104 |
return messageHeader{}, nil, err
|
| ... | ... |
@@ -120,7 +119,7 @@ func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
|
| 120 | 120 |
return mh, p, nil |
| 121 | 121 |
} |
| 122 | 122 |
|
| 123 |
-func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error {
|
|
| 123 |
+func (ch *channel) send(streamID uint32, t messageType, p []byte) error {
|
|
| 124 | 124 |
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
|
| 125 | 125 |
return err |
| 126 | 126 |
} |
| ... | ... |
@@ -29,6 +29,7 @@ import ( |
| 29 | 29 |
"github.com/gogo/protobuf/proto" |
| 30 | 30 |
"github.com/pkg/errors" |
| 31 | 31 |
"github.com/sirupsen/logrus" |
| 32 |
+ "google.golang.org/grpc/codes" |
|
| 32 | 33 |
"google.golang.org/grpc/status" |
| 33 | 34 |
) |
| 34 | 35 |
|
| ... | ... |
@@ -36,36 +37,52 @@ import ( |
| 36 | 36 |
// closed. |
| 37 | 37 |
var ErrClosed = errors.New("ttrpc: closed")
|
| 38 | 38 |
|
| 39 |
+// Client for a ttrpc server |
|
| 39 | 40 |
type Client struct {
|
| 40 | 41 |
codec codec |
| 41 | 42 |
conn net.Conn |
| 42 | 43 |
channel *channel |
| 43 | 44 |
calls chan *callRequest |
| 44 | 45 |
|
| 45 |
- closed chan struct{}
|
|
| 46 |
- closeOnce sync.Once |
|
| 47 |
- closeFunc func() |
|
| 48 |
- done chan struct{}
|
|
| 49 |
- err error |
|
| 46 |
+ ctx context.Context |
|
| 47 |
+ closed func() |
|
| 48 |
+ |
|
| 49 |
+ closeOnce sync.Once |
|
| 50 |
+ userCloseFunc func() |
|
| 51 |
+ |
|
| 52 |
+ errOnce sync.Once |
|
| 53 |
+ err error |
|
| 54 |
+ interceptor UnaryClientInterceptor |
|
| 50 | 55 |
} |
| 51 | 56 |
|
| 57 |
+// ClientOpts configures a client |
|
| 52 | 58 |
type ClientOpts func(c *Client) |
| 53 | 59 |
|
| 60 |
+// WithOnClose sets the close func whenever the client's Close() method is called |
|
| 54 | 61 |
func WithOnClose(onClose func()) ClientOpts {
|
| 55 | 62 |
return func(c *Client) {
|
| 56 |
- c.closeFunc = onClose |
|
| 63 |
+ c.userCloseFunc = onClose |
|
| 64 |
+ } |
|
| 65 |
+} |
|
| 66 |
+ |
|
| 67 |
+// WithUnaryClientInterceptor sets the provided client interceptor |
|
| 68 |
+func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
|
|
| 69 |
+ return func(c *Client) {
|
|
| 70 |
+ c.interceptor = i |
|
| 57 | 71 |
} |
| 58 | 72 |
} |
| 59 | 73 |
|
| 60 | 74 |
func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
|
| 75 |
+ ctx, cancel := context.WithCancel(context.Background()) |
|
| 61 | 76 |
c := &Client{
|
| 62 |
- codec: codec{},
|
|
| 63 |
- conn: conn, |
|
| 64 |
- channel: newChannel(conn), |
|
| 65 |
- calls: make(chan *callRequest), |
|
| 66 |
- closed: make(chan struct{}),
|
|
| 67 |
- done: make(chan struct{}),
|
|
| 68 |
- closeFunc: func() {},
|
|
| 77 |
+ codec: codec{},
|
|
| 78 |
+ conn: conn, |
|
| 79 |
+ channel: newChannel(conn), |
|
| 80 |
+ calls: make(chan *callRequest), |
|
| 81 |
+ closed: cancel, |
|
| 82 |
+ ctx: ctx, |
|
| 83 |
+ userCloseFunc: func() {},
|
|
| 84 |
+ interceptor: defaultClientInterceptor, |
|
| 69 | 85 |
} |
| 70 | 86 |
|
| 71 | 87 |
for _, o := range opts {
|
| ... | ... |
@@ -99,11 +116,18 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int |
| 99 | 99 |
cresp = &Response{}
|
| 100 | 100 |
) |
| 101 | 101 |
|
| 102 |
+ if metadata, ok := GetMetadata(ctx); ok {
|
|
| 103 |
+ metadata.setRequest(creq) |
|
| 104 |
+ } |
|
| 105 |
+ |
|
| 102 | 106 |
if dl, ok := ctx.Deadline(); ok {
|
| 103 | 107 |
creq.TimeoutNano = dl.Sub(time.Now()).Nanoseconds() |
| 104 | 108 |
} |
| 105 | 109 |
|
| 106 |
- if err := c.dispatch(ctx, creq, cresp); err != nil {
|
|
| 110 |
+ info := &UnaryClientInfo{
|
|
| 111 |
+ FullMethod: fullPath(service, method), |
|
| 112 |
+ } |
|
| 113 |
+ if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil {
|
|
| 107 | 114 |
return err |
| 108 | 115 |
} |
| 109 | 116 |
|
| ... | ... |
@@ -111,11 +135,10 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int |
| 111 | 111 |
return err |
| 112 | 112 |
} |
| 113 | 113 |
|
| 114 |
- if cresp.Status == nil {
|
|
| 115 |
- return errors.New("no status provided on response")
|
|
| 114 |
+ if cresp.Status != nil && cresp.Status.Code != int32(codes.OK) {
|
|
| 115 |
+ return status.ErrorProto(cresp.Status) |
|
| 116 | 116 |
} |
| 117 |
- |
|
| 118 |
- return status.ErrorProto(cresp.Status) |
|
| 117 |
+ return nil |
|
| 119 | 118 |
} |
| 120 | 119 |
|
| 121 | 120 |
func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
|
| ... | ... |
@@ -131,8 +154,8 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err |
| 131 | 131 |
case <-ctx.Done(): |
| 132 | 132 |
return ctx.Err() |
| 133 | 133 |
case c.calls <- call: |
| 134 |
- case <-c.done: |
|
| 135 |
- return c.err |
|
| 134 |
+ case <-c.ctx.Done(): |
|
| 135 |
+ return c.error() |
|
| 136 | 136 |
} |
| 137 | 137 |
|
| 138 | 138 |
select {
|
| ... | ... |
@@ -140,16 +163,15 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err |
| 140 | 140 |
return ctx.Err() |
| 141 | 141 |
case err := <-errs: |
| 142 | 142 |
return filterCloseErr(err) |
| 143 |
- case <-c.done: |
|
| 144 |
- return c.err |
|
| 143 |
+ case <-c.ctx.Done(): |
|
| 144 |
+ return c.error() |
|
| 145 | 145 |
} |
| 146 | 146 |
} |
| 147 | 147 |
|
| 148 | 148 |
func (c *Client) Close() error {
|
| 149 | 149 |
c.closeOnce.Do(func() {
|
| 150 |
- close(c.closed) |
|
| 150 |
+ c.closed() |
|
| 151 | 151 |
}) |
| 152 |
- |
|
| 153 | 152 |
return nil |
| 154 | 153 |
} |
| 155 | 154 |
|
| ... | ... |
@@ -159,51 +181,82 @@ type message struct {
|
| 159 | 159 |
err error |
| 160 | 160 |
} |
| 161 | 161 |
|
| 162 |
-func (c *Client) run() {
|
|
| 163 |
- var ( |
|
| 164 |
- streamID uint32 = 1 |
|
| 165 |
- waiters = make(map[uint32]*callRequest) |
|
| 166 |
- calls = c.calls |
|
| 167 |
- incoming = make(chan *message) |
|
| 168 |
- shutdown = make(chan struct{})
|
|
| 169 |
- shutdownErr error |
|
| 170 |
- ) |
|
| 162 |
+type receiver struct {
|
|
| 163 |
+ wg *sync.WaitGroup |
|
| 164 |
+ messages chan *message |
|
| 165 |
+ err error |
|
| 166 |
+} |
|
| 171 | 167 |
|
| 172 |
- go func() {
|
|
| 173 |
- defer close(shutdown) |
|
| 168 |
+func (r *receiver) run(ctx context.Context, c *channel) {
|
|
| 169 |
+ defer r.wg.Done() |
|
| 174 | 170 |
|
| 175 |
- // start one more goroutine to recv messages without blocking. |
|
| 176 |
- for {
|
|
| 177 |
- mh, p, err := c.channel.recv(context.TODO()) |
|
| 171 |
+ for {
|
|
| 172 |
+ select {
|
|
| 173 |
+ case <-ctx.Done(): |
|
| 174 |
+ r.err = ctx.Err() |
|
| 175 |
+ return |
|
| 176 |
+ default: |
|
| 177 |
+ mh, p, err := c.recv() |
|
| 178 | 178 |
if err != nil {
|
| 179 | 179 |
_, ok := status.FromError(err) |
| 180 | 180 |
if !ok {
|
| 181 | 181 |
// treat all errors that are not an rpc status as terminal. |
| 182 | 182 |
// all others poison the connection. |
| 183 |
- shutdownErr = err |
|
| 183 |
+ r.err = filterCloseErr(err) |
|
| 184 | 184 |
return |
| 185 | 185 |
} |
| 186 | 186 |
} |
| 187 | 187 |
select {
|
| 188 |
- case incoming <- &message{
|
|
| 188 |
+ case r.messages <- &message{
|
|
| 189 | 189 |
messageHeader: mh, |
| 190 | 190 |
p: p[:mh.Length], |
| 191 | 191 |
err: err, |
| 192 | 192 |
}: |
| 193 |
- case <-c.done: |
|
| 193 |
+ case <-ctx.Done(): |
|
| 194 |
+ r.err = ctx.Err() |
|
| 194 | 195 |
return |
| 195 | 196 |
} |
| 196 | 197 |
} |
| 198 |
+ } |
|
| 199 |
+} |
|
| 200 |
+ |
|
| 201 |
+func (c *Client) run() {
|
|
| 202 |
+ var ( |
|
| 203 |
+ streamID uint32 = 1 |
|
| 204 |
+ waiters = make(map[uint32]*callRequest) |
|
| 205 |
+ calls = c.calls |
|
| 206 |
+ incoming = make(chan *message) |
|
| 207 |
+ receiversDone = make(chan struct{})
|
|
| 208 |
+ wg sync.WaitGroup |
|
| 209 |
+ ) |
|
| 210 |
+ |
|
| 211 |
+ // broadcast the shutdown error to the remaining waiters. |
|
| 212 |
+ abortWaiters := func(wErr error) {
|
|
| 213 |
+ for _, waiter := range waiters {
|
|
| 214 |
+ waiter.errs <- wErr |
|
| 215 |
+ } |
|
| 216 |
+ } |
|
| 217 |
+ recv := &receiver{
|
|
| 218 |
+ wg: &wg, |
|
| 219 |
+ messages: incoming, |
|
| 220 |
+ } |
|
| 221 |
+ wg.Add(1) |
|
| 222 |
+ |
|
| 223 |
+ go func() {
|
|
| 224 |
+ wg.Wait() |
|
| 225 |
+ close(receiversDone) |
|
| 197 | 226 |
}() |
| 227 |
+ go recv.run(c.ctx, c.channel) |
|
| 198 | 228 |
|
| 199 |
- defer c.conn.Close() |
|
| 200 |
- defer close(c.done) |
|
| 201 |
- defer c.closeFunc() |
|
| 229 |
+ defer func() {
|
|
| 230 |
+ c.conn.Close() |
|
| 231 |
+ c.userCloseFunc() |
|
| 232 |
+ }() |
|
| 202 | 233 |
|
| 203 | 234 |
for {
|
| 204 | 235 |
select {
|
| 205 | 236 |
case call := <-calls: |
| 206 |
- if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil {
|
|
| 237 |
+ if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
|
|
| 207 | 238 |
call.errs <- err |
| 208 | 239 |
continue |
| 209 | 240 |
} |
| ... | ... |
@@ -219,41 +272,42 @@ func (c *Client) run() {
|
| 219 | 219 |
|
| 220 | 220 |
call.errs <- c.recv(call.resp, msg) |
| 221 | 221 |
delete(waiters, msg.StreamID) |
| 222 |
- case <-shutdown: |
|
| 223 |
- if shutdownErr != nil {
|
|
| 224 |
- shutdownErr = filterCloseErr(shutdownErr) |
|
| 225 |
- } else {
|
|
| 226 |
- shutdownErr = ErrClosed |
|
| 227 |
- } |
|
| 228 |
- |
|
| 229 |
- shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down") |
|
| 230 |
- |
|
| 231 |
- c.err = shutdownErr |
|
| 232 |
- for _, waiter := range waiters {
|
|
| 233 |
- waiter.errs <- shutdownErr |
|
| 222 |
+ case <-receiversDone: |
|
| 223 |
+ // all the receivers have exited |
|
| 224 |
+ if recv.err != nil {
|
|
| 225 |
+ c.setError(recv.err) |
|
| 234 | 226 |
} |
| 227 |
+ // don't return out, let the close of the context trigger the abort of waiters |
|
| 235 | 228 |
c.Close() |
| 236 |
- return |
|
| 237 |
- case <-c.closed: |
|
| 238 |
- if c.err == nil {
|
|
| 239 |
- c.err = ErrClosed |
|
| 240 |
- } |
|
| 241 |
- // broadcast the shutdown error to the remaining waiters. |
|
| 242 |
- for _, waiter := range waiters {
|
|
| 243 |
- waiter.errs <- c.err |
|
| 244 |
- } |
|
| 229 |
+ case <-c.ctx.Done(): |
|
| 230 |
+ abortWaiters(c.error()) |
|
| 245 | 231 |
return |
| 246 | 232 |
} |
| 247 | 233 |
} |
| 248 | 234 |
} |
| 249 | 235 |
|
| 250 |
-func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error {
|
|
| 236 |
+func (c *Client) error() error {
|
|
| 237 |
+ c.errOnce.Do(func() {
|
|
| 238 |
+ if c.err == nil {
|
|
| 239 |
+ c.err = ErrClosed |
|
| 240 |
+ } |
|
| 241 |
+ }) |
|
| 242 |
+ return c.err |
|
| 243 |
+} |
|
| 244 |
+ |
|
| 245 |
+func (c *Client) setError(err error) {
|
|
| 246 |
+ c.errOnce.Do(func() {
|
|
| 247 |
+ c.err = err |
|
| 248 |
+ }) |
|
| 249 |
+} |
|
| 250 |
+ |
|
| 251 |
+func (c *Client) send(streamID uint32, mtype messageType, msg interface{}) error {
|
|
| 251 | 252 |
p, err := c.codec.Marshal(msg) |
| 252 | 253 |
if err != nil {
|
| 253 | 254 |
return err |
| 254 | 255 |
} |
| 255 | 256 |
|
| 256 |
- return c.channel.send(ctx, streamID, mtype, p) |
|
| 257 |
+ return c.channel.send(streamID, mtype, p) |
|
| 257 | 258 |
} |
| 258 | 259 |
|
| 259 | 260 |
func (c *Client) recv(resp *Response, msg *message) error {
|
| ... | ... |
@@ -274,22 +328,21 @@ func (c *Client) recv(resp *Response, msg *message) error {
|
| 274 | 274 |
// |
| 275 | 275 |
// This purposely ignores errors with a wrapped cause. |
| 276 | 276 |
func filterCloseErr(err error) error {
|
| 277 |
- if err == nil {
|
|
| 277 |
+ switch {
|
|
| 278 |
+ case err == nil: |
|
| 278 | 279 |
return nil |
| 279 |
- } |
|
| 280 |
- |
|
| 281 |
- if err == io.EOF {
|
|
| 280 |
+ case err == io.EOF: |
|
| 282 | 281 |
return ErrClosed |
| 283 |
- } |
|
| 284 |
- |
|
| 285 |
- if strings.Contains(err.Error(), "use of closed network connection") {
|
|
| 282 |
+ case errors.Cause(err) == io.EOF: |
|
| 286 | 283 |
return ErrClosed |
| 287 |
- } |
|
| 288 |
- |
|
| 289 |
- // if we have an epipe on a write, we cast to errclosed |
|
| 290 |
- if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
|
|
| 291 |
- if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
|
|
| 292 |
- return ErrClosed |
|
| 284 |
+ case strings.Contains(err.Error(), "use of closed network connection"): |
|
| 285 |
+ return ErrClosed |
|
| 286 |
+ default: |
|
| 287 |
+ // if we have an epipe on a write, we cast to errclosed |
|
| 288 |
+ if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
|
|
| 289 |
+ if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
|
|
| 290 |
+ return ErrClosed |
|
| 291 |
+ } |
|
| 293 | 292 |
} |
| 294 | 293 |
} |
| 295 | 294 |
|
| ... | ... |
@@ -19,9 +19,11 @@ package ttrpc |
| 19 | 19 |
import "github.com/pkg/errors" |
| 20 | 20 |
|
| 21 | 21 |
type serverConfig struct {
|
| 22 |
- handshaker Handshaker |
|
| 22 |
+ handshaker Handshaker |
|
| 23 |
+ interceptor UnaryServerInterceptor |
|
| 23 | 24 |
} |
| 24 | 25 |
|
| 26 |
+// ServerOpt for configuring a ttrpc server |
|
| 25 | 27 |
type ServerOpt func(*serverConfig) error |
| 26 | 28 |
|
| 27 | 29 |
// WithServerHandshaker can be passed to NewServer to ensure that the |
| ... | ... |
@@ -37,3 +39,14 @@ func WithServerHandshaker(handshaker Handshaker) ServerOpt {
|
| 37 | 37 |
return nil |
| 38 | 38 |
} |
| 39 | 39 |
} |
| 40 |
+ |
|
| 41 |
+// WithUnaryServerInterceptor sets the provided interceptor on the server |
|
| 42 |
+func WithUnaryServerInterceptor(i UnaryServerInterceptor) ServerOpt {
|
|
| 43 |
+ return func(c *serverConfig) error {
|
|
| 44 |
+ if c.interceptor != nil {
|
|
| 45 |
+ return errors.New("only one interceptor allowed per server")
|
|
| 46 |
+ } |
|
| 47 |
+ c.interceptor = i |
|
| 48 |
+ return nil |
|
| 49 |
+ } |
|
| 50 |
+} |
| 40 | 51 |
new file mode 100644 |
| ... | ... |
@@ -0,0 +1,50 @@ |
| 0 |
+/* |
|
| 1 |
+ Copyright The containerd Authors. |
|
| 2 |
+ |
|
| 3 |
+ Licensed under the Apache License, Version 2.0 (the "License"); |
|
| 4 |
+ you may not use this file except in compliance with the License. |
|
| 5 |
+ You may obtain a copy of the License at |
|
| 6 |
+ |
|
| 7 |
+ http://www.apache.org/licenses/LICENSE-2.0 |
|
| 8 |
+ |
|
| 9 |
+ Unless required by applicable law or agreed to in writing, software |
|
| 10 |
+ distributed under the License is distributed on an "AS IS" BASIS, |
|
| 11 |
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
| 12 |
+ See the License for the specific language governing permissions and |
|
| 13 |
+ limitations under the License. |
|
| 14 |
+*/ |
|
| 15 |
+ |
|
| 16 |
+package ttrpc |
|
| 17 |
+ |
|
| 18 |
+import "context" |
|
| 19 |
+ |
|
| 20 |
+// UnaryServerInfo provides information about the server request |
|
| 21 |
+type UnaryServerInfo struct {
|
|
| 22 |
+ FullMethod string |
|
| 23 |
+} |
|
| 24 |
+ |
|
| 25 |
+// UnaryClientInfo provides information about the client request |
|
| 26 |
+type UnaryClientInfo struct {
|
|
| 27 |
+ FullMethod string |
|
| 28 |
+} |
|
| 29 |
+ |
|
| 30 |
+// Unmarshaler contains the server request data and allows it to be unmarshaled |
|
| 31 |
+// into a concrete type |
|
| 32 |
+type Unmarshaler func(interface{}) error
|
|
| 33 |
+ |
|
| 34 |
+// Invoker invokes the client's request and response from the ttrpc server |
|
| 35 |
+type Invoker func(context.Context, *Request, *Response) error |
|
| 36 |
+ |
|
| 37 |
+// UnaryServerInterceptor specifies the interceptor function for server request/response |
|
| 38 |
+type UnaryServerInterceptor func(context.Context, Unmarshaler, *UnaryServerInfo, Method) (interface{}, error)
|
|
| 39 |
+ |
|
| 40 |
+// UnaryClientInterceptor specifies the interceptor function for client request/response |
|
| 41 |
+type UnaryClientInterceptor func(context.Context, *Request, *Response, *UnaryClientInfo, Invoker) error |
|
| 42 |
+ |
|
| 43 |
+func defaultServerInterceptor(ctx context.Context, unmarshal Unmarshaler, info *UnaryServerInfo, method Method) (interface{}, error) {
|
|
| 44 |
+ return method(ctx, unmarshal) |
|
| 45 |
+} |
|
| 46 |
+ |
|
| 47 |
+func defaultClientInterceptor(ctx context.Context, req *Request, resp *Response, _ *UnaryClientInfo, invoker Invoker) error {
|
|
| 48 |
+ return invoker(ctx, req, resp) |
|
| 49 |
+} |
| 0 | 50 |
new file mode 100644 |
| ... | ... |
@@ -0,0 +1,107 @@ |
| 0 |
+/* |
|
| 1 |
+ Copyright The containerd Authors. |
|
| 2 |
+ |
|
| 3 |
+ Licensed under the Apache License, Version 2.0 (the "License"); |
|
| 4 |
+ you may not use this file except in compliance with the License. |
|
| 5 |
+ You may obtain a copy of the License at |
|
| 6 |
+ |
|
| 7 |
+ http://www.apache.org/licenses/LICENSE-2.0 |
|
| 8 |
+ |
|
| 9 |
+ Unless required by applicable law or agreed to in writing, software |
|
| 10 |
+ distributed under the License is distributed on an "AS IS" BASIS, |
|
| 11 |
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
| 12 |
+ See the License for the specific language governing permissions and |
|
| 13 |
+ limitations under the License. |
|
| 14 |
+*/ |
|
| 15 |
+ |
|
| 16 |
+package ttrpc |
|
| 17 |
+ |
|
| 18 |
+import ( |
|
| 19 |
+ "context" |
|
| 20 |
+ "strings" |
|
| 21 |
+) |
|
| 22 |
+ |
|
| 23 |
+// MD is the user type for ttrpc metadata |
|
| 24 |
+type MD map[string][]string |
|
| 25 |
+ |
|
| 26 |
+// Get returns the metadata for a given key when they exist. |
|
| 27 |
+// If there is no metadata, a nil slice and false are returned. |
|
| 28 |
+func (m MD) Get(key string) ([]string, bool) {
|
|
| 29 |
+ key = strings.ToLower(key) |
|
| 30 |
+ list, ok := m[key] |
|
| 31 |
+ if !ok || len(list) == 0 {
|
|
| 32 |
+ return nil, false |
|
| 33 |
+ } |
|
| 34 |
+ |
|
| 35 |
+ return list, true |
|
| 36 |
+} |
|
| 37 |
+ |
|
| 38 |
+// Set sets the provided values for a given key. |
|
| 39 |
+// The values will overwrite any existing values. |
|
| 40 |
+// If no values provided, a key will be deleted. |
|
| 41 |
+func (m MD) Set(key string, values ...string) {
|
|
| 42 |
+ key = strings.ToLower(key) |
|
| 43 |
+ if len(values) == 0 {
|
|
| 44 |
+ delete(m, key) |
|
| 45 |
+ return |
|
| 46 |
+ } |
|
| 47 |
+ m[key] = values |
|
| 48 |
+} |
|
| 49 |
+ |
|
| 50 |
+// Append appends additional values to the given key. |
|
| 51 |
+func (m MD) Append(key string, values ...string) {
|
|
| 52 |
+ key = strings.ToLower(key) |
|
| 53 |
+ if len(values) == 0 {
|
|
| 54 |
+ return |
|
| 55 |
+ } |
|
| 56 |
+ current, ok := m[key] |
|
| 57 |
+ if ok {
|
|
| 58 |
+ m.Set(key, append(current, values...)...) |
|
| 59 |
+ } else {
|
|
| 60 |
+ m.Set(key, values...) |
|
| 61 |
+ } |
|
| 62 |
+} |
|
| 63 |
+ |
|
| 64 |
+func (m MD) setRequest(r *Request) {
|
|
| 65 |
+ for k, values := range m {
|
|
| 66 |
+ for _, v := range values {
|
|
| 67 |
+ r.Metadata = append(r.Metadata, &KeyValue{
|
|
| 68 |
+ Key: k, |
|
| 69 |
+ Value: v, |
|
| 70 |
+ }) |
|
| 71 |
+ } |
|
| 72 |
+ } |
|
| 73 |
+} |
|
| 74 |
+ |
|
| 75 |
+func (m MD) fromRequest(r *Request) {
|
|
| 76 |
+ for _, kv := range r.Metadata {
|
|
| 77 |
+ m[kv.Key] = append(m[kv.Key], kv.Value) |
|
| 78 |
+ } |
|
| 79 |
+} |
|
| 80 |
+ |
|
| 81 |
+type metadataKey struct{}
|
|
| 82 |
+ |
|
| 83 |
+// GetMetadata retrieves metadata from context.Context (previously attached with WithMetadata) |
|
| 84 |
+func GetMetadata(ctx context.Context) (MD, bool) {
|
|
| 85 |
+ metadata, ok := ctx.Value(metadataKey{}).(MD)
|
|
| 86 |
+ return metadata, ok |
|
| 87 |
+} |
|
| 88 |
+ |
|
| 89 |
+// GetMetadataValue gets a specific metadata value by name from context.Context |
|
| 90 |
+func GetMetadataValue(ctx context.Context, name string) (string, bool) {
|
|
| 91 |
+ metadata, ok := GetMetadata(ctx) |
|
| 92 |
+ if !ok {
|
|
| 93 |
+ return "", false |
|
| 94 |
+ } |
|
| 95 |
+ |
|
| 96 |
+ if list, ok := metadata.Get(name); ok {
|
|
| 97 |
+ return list[0], true |
|
| 98 |
+ } |
|
| 99 |
+ |
|
| 100 |
+ return "", false |
|
| 101 |
+} |
|
| 102 |
+ |
|
| 103 |
+// WithMetadata attaches metadata map to a context.Context |
|
| 104 |
+func WithMetadata(ctx context.Context, md MD) context.Context {
|
|
| 105 |
+ return context.WithValue(ctx, metadataKey{}, md)
|
|
| 106 |
+} |
| ... | ... |
@@ -53,10 +53,13 @@ func NewServer(opts ...ServerOpt) (*Server, error) {
|
| 53 | 53 |
return nil, err |
| 54 | 54 |
} |
| 55 | 55 |
} |
| 56 |
+ if config.interceptor == nil {
|
|
| 57 |
+ config.interceptor = defaultServerInterceptor |
|
| 58 |
+ } |
|
| 56 | 59 |
|
| 57 | 60 |
return &Server{
|
| 58 | 61 |
config: config, |
| 59 |
- services: newServiceSet(), |
|
| 62 |
+ services: newServiceSet(config.interceptor), |
|
| 60 | 63 |
done: make(chan struct{}),
|
| 61 | 64 |
listeners: make(map[net.Listener]struct{}),
|
| 62 | 65 |
connections: make(map[*serverConn]struct{}),
|
| ... | ... |
@@ -341,7 +344,7 @@ func (c *serverConn) run(sctx context.Context) {
|
| 341 | 341 |
default: // proceed |
| 342 | 342 |
} |
| 343 | 343 |
|
| 344 |
- mh, p, err := ch.recv(ctx) |
|
| 344 |
+ mh, p, err := ch.recv() |
|
| 345 | 345 |
if err != nil {
|
| 346 | 346 |
status, ok := status.FromError(err) |
| 347 | 347 |
if !ok {
|
| ... | ... |
@@ -438,7 +441,7 @@ func (c *serverConn) run(sctx context.Context) {
|
| 438 | 438 |
return |
| 439 | 439 |
} |
| 440 | 440 |
|
| 441 |
- if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
|
|
| 441 |
+ if err := ch.send(response.id, messageTypeResponse, p); err != nil {
|
|
| 442 | 442 |
logrus.WithError(err).Error("failed sending message on channel")
|
| 443 | 443 |
return |
| 444 | 444 |
} |
| ... | ... |
@@ -449,7 +452,12 @@ func (c *serverConn) run(sctx context.Context) {
|
| 449 | 449 |
// branch. Basically, it means that we are no longer receiving |
| 450 | 450 |
// requests due to a terminal error. |
| 451 | 451 |
recvErr = nil // connection is now "closing" |
| 452 |
- if err != nil && err != io.EOF {
|
|
| 452 |
+ if err == io.EOF || err == io.ErrUnexpectedEOF {
|
|
| 453 |
+ // The client went away and we should stop processing |
|
| 454 |
+ // requests, so that the client connection is closed |
|
| 455 |
+ return |
|
| 456 |
+ } |
|
| 457 |
+ if err != nil {
|
|
| 453 | 458 |
logrus.WithError(err).Error("error receiving message")
|
| 454 | 459 |
} |
| 455 | 460 |
case <-shutdown: |
| ... | ... |
@@ -461,6 +469,12 @@ func (c *serverConn) run(sctx context.Context) {
|
| 461 | 461 |
var noopFunc = func() {}
|
| 462 | 462 |
|
| 463 | 463 |
func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) {
|
| 464 |
+ if len(req.Metadata) > 0 {
|
|
| 465 |
+ md := MD{}
|
|
| 466 |
+ md.fromRequest(req) |
|
| 467 |
+ ctx = WithMetadata(ctx, md) |
|
| 468 |
+ } |
|
| 469 |
+ |
|
| 464 | 470 |
cancel = noopFunc |
| 465 | 471 |
if req.TimeoutNano == 0 {
|
| 466 | 472 |
return ctx, cancel |
| ... | ... |
@@ -37,12 +37,14 @@ type ServiceDesc struct {
|
| 37 | 37 |
} |
| 38 | 38 |
|
| 39 | 39 |
type serviceSet struct {
|
| 40 |
- services map[string]ServiceDesc |
|
| 40 |
+ services map[string]ServiceDesc |
|
| 41 |
+ interceptor UnaryServerInterceptor |
|
| 41 | 42 |
} |
| 42 | 43 |
|
| 43 |
-func newServiceSet() *serviceSet {
|
|
| 44 |
+func newServiceSet(interceptor UnaryServerInterceptor) *serviceSet {
|
|
| 44 | 45 |
return &serviceSet{
|
| 45 |
- services: make(map[string]ServiceDesc), |
|
| 46 |
+ services: make(map[string]ServiceDesc), |
|
| 47 |
+ interceptor: interceptor, |
|
| 46 | 48 |
} |
| 47 | 49 |
} |
| 48 | 50 |
|
| ... | ... |
@@ -84,7 +86,11 @@ func (s *serviceSet) dispatch(ctx context.Context, serviceName, methodName strin |
| 84 | 84 |
return nil |
| 85 | 85 |
} |
| 86 | 86 |
|
| 87 |
- resp, err := method(ctx, unmarshal) |
|
| 87 |
+ info := &UnaryServerInfo{
|
|
| 88 |
+ FullMethod: fullPath(serviceName, methodName), |
|
| 89 |
+ } |
|
| 90 |
+ |
|
| 91 |
+ resp, err := s.interceptor(ctx, unmarshal, info, method) |
|
| 88 | 92 |
if err != nil {
|
| 89 | 93 |
return nil, err |
| 90 | 94 |
} |
| ... | ... |
@@ -146,5 +152,5 @@ func convertCode(err error) codes.Code {
|
| 146 | 146 |
} |
| 147 | 147 |
|
| 148 | 148 |
func fullPath(service, method string) string {
|
| 149 |
- return "/" + path.Join("/", service, method)
|
|
| 149 |
+ return "/" + path.Join(service, method) |
|
| 150 | 150 |
} |
| ... | ... |
@@ -23,10 +23,11 @@ import ( |
| 23 | 23 |
) |
| 24 | 24 |
|
| 25 | 25 |
type Request struct {
|
| 26 |
- Service string `protobuf:"bytes,1,opt,name=service,proto3"` |
|
| 27 |
- Method string `protobuf:"bytes,2,opt,name=method,proto3"` |
|
| 28 |
- Payload []byte `protobuf:"bytes,3,opt,name=payload,proto3"` |
|
| 29 |
- TimeoutNano int64 `protobuf:"varint,4,opt,name=timeout_nano,proto3"` |
|
| 26 |
+ Service string `protobuf:"bytes,1,opt,name=service,proto3"` |
|
| 27 |
+ Method string `protobuf:"bytes,2,opt,name=method,proto3"` |
|
| 28 |
+ Payload []byte `protobuf:"bytes,3,opt,name=payload,proto3"` |
|
| 29 |
+ TimeoutNano int64 `protobuf:"varint,4,opt,name=timeout_nano,proto3"` |
|
| 30 |
+ Metadata []*KeyValue `protobuf:"bytes,5,rep,name=metadata,proto3"` |
|
| 30 | 31 |
} |
| 31 | 32 |
|
| 32 | 33 |
func (r *Request) Reset() { *r = Request{} }
|
| ... | ... |
@@ -41,3 +42,22 @@ type Response struct {
|
| 41 | 41 |
func (r *Response) Reset() { *r = Response{} }
|
| 42 | 42 |
func (r *Response) String() string { return fmt.Sprintf("%+#v", r) }
|
| 43 | 43 |
func (r *Response) ProtoMessage() {}
|
| 44 |
+ |
|
| 45 |
+type StringList struct {
|
|
| 46 |
+ List []string `protobuf:"bytes,1,rep,name=list,proto3"` |
|
| 47 |
+} |
|
| 48 |
+ |
|
| 49 |
+func (r *StringList) Reset() { *r = StringList{} }
|
|
| 50 |
+func (r *StringList) String() string { return fmt.Sprintf("%+#v", r) }
|
|
| 51 |
+func (r *StringList) ProtoMessage() {}
|
|
| 52 |
+ |
|
| 53 |
+func makeStringList(item ...string) StringList { return StringList{List: item} }
|
|
| 54 |
+ |
|
| 55 |
+type KeyValue struct {
|
|
| 56 |
+ Key string `protobuf:"bytes,1,opt,name=key,proto3"` |
|
| 57 |
+ Value string `protobuf:"bytes,2,opt,name=value,proto3"` |
|
| 58 |
+} |
|
| 59 |
+ |
|
| 60 |
+func (m *KeyValue) Reset() { *m = KeyValue{} }
|
|
| 61 |
+func (*KeyValue) ProtoMessage() {}
|
|
| 62 |
+func (m *KeyValue) String() string { return fmt.Sprintf("%+#v", m) }
|