Browse code

Merge pull request #29787 from yongtang/29730-multiple-published-port

Fix issues of multiple published ports mapping to the same target port

Vincent Demeester authored on 2017/01/11 02:49:16
Showing 14 changed files
... ...
@@ -1572,3 +1572,22 @@ func (s *DockerSwarmSuite) TestSwarmServicePsMultipleServiceIDs(c *check.C) {
1572 1572
 	c.Assert(out, checker.Contains, name2+".2")
1573 1573
 	c.Assert(out, checker.Contains, name2+".3")
1574 1574
 }
1575
+
1576
+func (s *DockerSwarmSuite) TestSwarmPublishDuplicatePorts(c *check.C) {
1577
+	d := s.AddDaemon(c, true, true)
1578
+
1579
+	out, err := d.Cmd("service", "create", "--publish", "5000:80", "--publish", "5001:80", "--publish", "80", "--publish", "80", "busybox", "top")
1580
+	c.Assert(err, check.IsNil, check.Commentf(out))
1581
+	id := strings.TrimSpace(out)
1582
+
1583
+	// make sure task has been deployed.
1584
+	waitAndAssert(c, defaultReconciliationTimeout, d.CheckActiveContainerCount, checker.Equals, 1)
1585
+
1586
+	// Total len = 4, with 2 dynamic ports and 2 non-dynamic ports
1587
+	// Dynamic ports are likely to be 30000 and 30001 but doesn't matter
1588
+	out, err = d.Cmd("service", "inspect", "--format", "{{.Endpoint.Ports}} len={{len .Endpoint.Ports}}", id)
1589
+	c.Assert(err, check.IsNil, check.Commentf(out))
1590
+	c.Assert(out, checker.Contains, "len=4")
1591
+	c.Assert(out, checker.Contains, "{ tcp 80 5000 ingress}")
1592
+	c.Assert(out, checker.Contains, "{ tcp 80 5001 ingress}")
1593
+}
... ...
@@ -103,7 +103,7 @@ github.com/docker/containerd 03e5862ec0d8d3b3f750e19fca3ee367e13c090e
103 103
 github.com/tonistiigi/fifo 1405643975692217d6720f8b54aeee1bf2cd5cf4
104 104
 
105 105
 # cluster
106
-github.com/docker/swarmkit 4762d92234d286ae7c9e061470485e4d34ef8ebd
106
+github.com/docker/swarmkit c97146840a26c9ce8023284d0e9c989586cc1857
107 107
 github.com/golang/mock bd3c8e81be01eef76d4b503f5e687d2d1354d2d9
108 108
 github.com/gogo/protobuf v0.3
109 109
 github.com/cloudflare/cfssl 7fb22c8cba7ecaf98e4082d22d65800cf45e042a
... ...
@@ -836,12 +836,12 @@ func encodeVarintCa(data []byte, offset int, v uint64) int {
836 836
 }
837 837
 
838 838
 type raftProxyCAServer struct {
839
-	local        CAServer
840
-	connSelector raftselector.ConnProvider
841
-	ctxMods      []func(context.Context) (context.Context, error)
839
+	local                       CAServer
840
+	connSelector                raftselector.ConnProvider
841
+	localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error)
842 842
 }
843 843
 
844
-func NewRaftProxyCAServer(local CAServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) CAServer {
844
+func NewRaftProxyCAServer(local CAServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) CAServer {
845 845
 	redirectChecker := func(ctx context.Context) (context.Context, error) {
846 846
 		s, ok := transport.StreamFromContext(ctx)
847 847
 		if !ok {
... ...
@@ -858,18 +858,24 @@ func NewRaftProxyCAServer(local CAServer, connSelector raftselector.ConnProvider
858 858
 		md["redirect"] = append(md["redirect"], addr)
859 859
 		return metadata.NewContext(ctx, md), nil
860 860
 	}
861
-	mods := []func(context.Context) (context.Context, error){redirectChecker}
862
-	mods = append(mods, ctxMod)
861
+	remoteMods := []func(context.Context) (context.Context, error){redirectChecker}
862
+	remoteMods = append(remoteMods, remoteCtxMod)
863
+
864
+	var localMods []func(context.Context) (context.Context, error)
865
+	if localCtxMod != nil {
866
+		localMods = []func(context.Context) (context.Context, error){localCtxMod}
867
+	}
863 868
 
864 869
 	return &raftProxyCAServer{
865
-		local:        local,
866
-		connSelector: connSelector,
867
-		ctxMods:      mods,
870
+		local:         local,
871
+		connSelector:  connSelector,
872
+		localCtxMods:  localMods,
873
+		remoteCtxMods: remoteMods,
868 874
 	}
869 875
 }
870
-func (p *raftProxyCAServer) runCtxMods(ctx context.Context) (context.Context, error) {
876
+func (p *raftProxyCAServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) {
871 877
 	var err error
872
-	for _, mod := range p.ctxMods {
878
+	for _, mod := range ctxMods {
873 879
 		ctx, err = mod(ctx)
874 880
 		if err != nil {
875 881
 			return ctx, err
... ...
@@ -906,11 +912,15 @@ func (p *raftProxyCAServer) GetRootCACertificate(ctx context.Context, r *GetRoot
906 906
 	conn, err := p.connSelector.LeaderConn(ctx)
907 907
 	if err != nil {
908 908
 		if err == raftselector.ErrIsLeader {
909
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
910
+			if err != nil {
911
+				return nil, err
912
+			}
909 913
 			return p.local.GetRootCACertificate(ctx, r)
910 914
 		}
911 915
 		return nil, err
912 916
 	}
913
-	modCtx, err := p.runCtxMods(ctx)
917
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
914 918
 	if err != nil {
915 919
 		return nil, err
916 920
 	}
... ...
@@ -937,11 +947,15 @@ func (p *raftProxyCAServer) GetUnlockKey(ctx context.Context, r *GetUnlockKeyReq
937 937
 	conn, err := p.connSelector.LeaderConn(ctx)
938 938
 	if err != nil {
939 939
 		if err == raftselector.ErrIsLeader {
940
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
941
+			if err != nil {
942
+				return nil, err
943
+			}
940 944
 			return p.local.GetUnlockKey(ctx, r)
941 945
 		}
942 946
 		return nil, err
943 947
 	}
944
-	modCtx, err := p.runCtxMods(ctx)
948
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
945 949
 	if err != nil {
946 950
 		return nil, err
947 951
 	}
... ...
@@ -964,12 +978,12 @@ func (p *raftProxyCAServer) GetUnlockKey(ctx context.Context, r *GetUnlockKeyReq
964 964
 }
965 965
 
966 966
 type raftProxyNodeCAServer struct {
967
-	local        NodeCAServer
968
-	connSelector raftselector.ConnProvider
969
-	ctxMods      []func(context.Context) (context.Context, error)
967
+	local                       NodeCAServer
968
+	connSelector                raftselector.ConnProvider
969
+	localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error)
970 970
 }
971 971
 
972
-func NewRaftProxyNodeCAServer(local NodeCAServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) NodeCAServer {
972
+func NewRaftProxyNodeCAServer(local NodeCAServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) NodeCAServer {
973 973
 	redirectChecker := func(ctx context.Context) (context.Context, error) {
974 974
 		s, ok := transport.StreamFromContext(ctx)
975 975
 		if !ok {
... ...
@@ -986,18 +1000,24 @@ func NewRaftProxyNodeCAServer(local NodeCAServer, connSelector raftselector.Conn
986 986
 		md["redirect"] = append(md["redirect"], addr)
987 987
 		return metadata.NewContext(ctx, md), nil
988 988
 	}
989
-	mods := []func(context.Context) (context.Context, error){redirectChecker}
990
-	mods = append(mods, ctxMod)
989
+	remoteMods := []func(context.Context) (context.Context, error){redirectChecker}
990
+	remoteMods = append(remoteMods, remoteCtxMod)
991
+
992
+	var localMods []func(context.Context) (context.Context, error)
993
+	if localCtxMod != nil {
994
+		localMods = []func(context.Context) (context.Context, error){localCtxMod}
995
+	}
991 996
 
992 997
 	return &raftProxyNodeCAServer{
993
-		local:        local,
994
-		connSelector: connSelector,
995
-		ctxMods:      mods,
998
+		local:         local,
999
+		connSelector:  connSelector,
1000
+		localCtxMods:  localMods,
1001
+		remoteCtxMods: remoteMods,
996 1002
 	}
997 1003
 }
998
-func (p *raftProxyNodeCAServer) runCtxMods(ctx context.Context) (context.Context, error) {
1004
+func (p *raftProxyNodeCAServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) {
999 1005
 	var err error
1000
-	for _, mod := range p.ctxMods {
1006
+	for _, mod := range ctxMods {
1001 1007
 		ctx, err = mod(ctx)
1002 1008
 		if err != nil {
1003 1009
 			return ctx, err
... ...
@@ -1034,11 +1054,15 @@ func (p *raftProxyNodeCAServer) IssueNodeCertificate(ctx context.Context, r *Iss
1034 1034
 	conn, err := p.connSelector.LeaderConn(ctx)
1035 1035
 	if err != nil {
1036 1036
 		if err == raftselector.ErrIsLeader {
1037
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
1038
+			if err != nil {
1039
+				return nil, err
1040
+			}
1037 1041
 			return p.local.IssueNodeCertificate(ctx, r)
1038 1042
 		}
1039 1043
 		return nil, err
1040 1044
 	}
1041
-	modCtx, err := p.runCtxMods(ctx)
1045
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
1042 1046
 	if err != nil {
1043 1047
 		return nil, err
1044 1048
 	}
... ...
@@ -1065,11 +1089,15 @@ func (p *raftProxyNodeCAServer) NodeCertificateStatus(ctx context.Context, r *No
1065 1065
 	conn, err := p.connSelector.LeaderConn(ctx)
1066 1066
 	if err != nil {
1067 1067
 		if err == raftselector.ErrIsLeader {
1068
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
1069
+			if err != nil {
1070
+				return nil, err
1071
+			}
1068 1072
 			return p.local.NodeCertificateStatus(ctx, r)
1069 1073
 		}
1070 1074
 		return nil, err
1071 1075
 	}
1072
-	modCtx, err := p.runCtxMods(ctx)
1076
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
1073 1077
 	if err != nil {
1074 1078
 		return nil, err
1075 1079
 	}
... ...
@@ -5256,12 +5256,12 @@ func encodeVarintControl(data []byte, offset int, v uint64) int {
5256 5256
 }
5257 5257
 
5258 5258
 type raftProxyControlServer struct {
5259
-	local        ControlServer
5260
-	connSelector raftselector.ConnProvider
5261
-	ctxMods      []func(context.Context) (context.Context, error)
5259
+	local                       ControlServer
5260
+	connSelector                raftselector.ConnProvider
5261
+	localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error)
5262 5262
 }
5263 5263
 
5264
-func NewRaftProxyControlServer(local ControlServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) ControlServer {
5264
+func NewRaftProxyControlServer(local ControlServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) ControlServer {
5265 5265
 	redirectChecker := func(ctx context.Context) (context.Context, error) {
5266 5266
 		s, ok := transport.StreamFromContext(ctx)
5267 5267
 		if !ok {
... ...
@@ -5278,18 +5278,24 @@ func NewRaftProxyControlServer(local ControlServer, connSelector raftselector.Co
5278 5278
 		md["redirect"] = append(md["redirect"], addr)
5279 5279
 		return metadata.NewContext(ctx, md), nil
5280 5280
 	}
5281
-	mods := []func(context.Context) (context.Context, error){redirectChecker}
5282
-	mods = append(mods, ctxMod)
5281
+	remoteMods := []func(context.Context) (context.Context, error){redirectChecker}
5282
+	remoteMods = append(remoteMods, remoteCtxMod)
5283
+
5284
+	var localMods []func(context.Context) (context.Context, error)
5285
+	if localCtxMod != nil {
5286
+		localMods = []func(context.Context) (context.Context, error){localCtxMod}
5287
+	}
5283 5288
 
5284 5289
 	return &raftProxyControlServer{
5285
-		local:        local,
5286
-		connSelector: connSelector,
5287
-		ctxMods:      mods,
5290
+		local:         local,
5291
+		connSelector:  connSelector,
5292
+		localCtxMods:  localMods,
5293
+		remoteCtxMods: remoteMods,
5288 5294
 	}
5289 5295
 }
5290
-func (p *raftProxyControlServer) runCtxMods(ctx context.Context) (context.Context, error) {
5296
+func (p *raftProxyControlServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) {
5291 5297
 	var err error
5292
-	for _, mod := range p.ctxMods {
5298
+	for _, mod := range ctxMods {
5293 5299
 		ctx, err = mod(ctx)
5294 5300
 		if err != nil {
5295 5301
 			return ctx, err
... ...
@@ -5326,11 +5332,15 @@ func (p *raftProxyControlServer) GetNode(ctx context.Context, r *GetNodeRequest)
5326 5326
 	conn, err := p.connSelector.LeaderConn(ctx)
5327 5327
 	if err != nil {
5328 5328
 		if err == raftselector.ErrIsLeader {
5329
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5330
+			if err != nil {
5331
+				return nil, err
5332
+			}
5329 5333
 			return p.local.GetNode(ctx, r)
5330 5334
 		}
5331 5335
 		return nil, err
5332 5336
 	}
5333
-	modCtx, err := p.runCtxMods(ctx)
5337
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5334 5338
 	if err != nil {
5335 5339
 		return nil, err
5336 5340
 	}
... ...
@@ -5357,11 +5367,15 @@ func (p *raftProxyControlServer) ListNodes(ctx context.Context, r *ListNodesRequ
5357 5357
 	conn, err := p.connSelector.LeaderConn(ctx)
5358 5358
 	if err != nil {
5359 5359
 		if err == raftselector.ErrIsLeader {
5360
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5361
+			if err != nil {
5362
+				return nil, err
5363
+			}
5360 5364
 			return p.local.ListNodes(ctx, r)
5361 5365
 		}
5362 5366
 		return nil, err
5363 5367
 	}
5364
-	modCtx, err := p.runCtxMods(ctx)
5368
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5365 5369
 	if err != nil {
5366 5370
 		return nil, err
5367 5371
 	}
... ...
@@ -5388,11 +5402,15 @@ func (p *raftProxyControlServer) UpdateNode(ctx context.Context, r *UpdateNodeRe
5388 5388
 	conn, err := p.connSelector.LeaderConn(ctx)
5389 5389
 	if err != nil {
5390 5390
 		if err == raftselector.ErrIsLeader {
5391
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5392
+			if err != nil {
5393
+				return nil, err
5394
+			}
5391 5395
 			return p.local.UpdateNode(ctx, r)
5392 5396
 		}
5393 5397
 		return nil, err
5394 5398
 	}
5395
-	modCtx, err := p.runCtxMods(ctx)
5399
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5396 5400
 	if err != nil {
5397 5401
 		return nil, err
5398 5402
 	}
... ...
@@ -5419,11 +5437,15 @@ func (p *raftProxyControlServer) RemoveNode(ctx context.Context, r *RemoveNodeRe
5419 5419
 	conn, err := p.connSelector.LeaderConn(ctx)
5420 5420
 	if err != nil {
5421 5421
 		if err == raftselector.ErrIsLeader {
5422
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5423
+			if err != nil {
5424
+				return nil, err
5425
+			}
5422 5426
 			return p.local.RemoveNode(ctx, r)
5423 5427
 		}
5424 5428
 		return nil, err
5425 5429
 	}
5426
-	modCtx, err := p.runCtxMods(ctx)
5430
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5427 5431
 	if err != nil {
5428 5432
 		return nil, err
5429 5433
 	}
... ...
@@ -5450,11 +5472,15 @@ func (p *raftProxyControlServer) GetTask(ctx context.Context, r *GetTaskRequest)
5450 5450
 	conn, err := p.connSelector.LeaderConn(ctx)
5451 5451
 	if err != nil {
5452 5452
 		if err == raftselector.ErrIsLeader {
5453
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5454
+			if err != nil {
5455
+				return nil, err
5456
+			}
5453 5457
 			return p.local.GetTask(ctx, r)
5454 5458
 		}
5455 5459
 		return nil, err
5456 5460
 	}
5457
-	modCtx, err := p.runCtxMods(ctx)
5461
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5458 5462
 	if err != nil {
5459 5463
 		return nil, err
5460 5464
 	}
... ...
@@ -5481,11 +5507,15 @@ func (p *raftProxyControlServer) ListTasks(ctx context.Context, r *ListTasksRequ
5481 5481
 	conn, err := p.connSelector.LeaderConn(ctx)
5482 5482
 	if err != nil {
5483 5483
 		if err == raftselector.ErrIsLeader {
5484
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5485
+			if err != nil {
5486
+				return nil, err
5487
+			}
5484 5488
 			return p.local.ListTasks(ctx, r)
5485 5489
 		}
5486 5490
 		return nil, err
5487 5491
 	}
5488
-	modCtx, err := p.runCtxMods(ctx)
5492
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5489 5493
 	if err != nil {
5490 5494
 		return nil, err
5491 5495
 	}
... ...
@@ -5512,11 +5542,15 @@ func (p *raftProxyControlServer) RemoveTask(ctx context.Context, r *RemoveTaskRe
5512 5512
 	conn, err := p.connSelector.LeaderConn(ctx)
5513 5513
 	if err != nil {
5514 5514
 		if err == raftselector.ErrIsLeader {
5515
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5516
+			if err != nil {
5517
+				return nil, err
5518
+			}
5515 5519
 			return p.local.RemoveTask(ctx, r)
5516 5520
 		}
5517 5521
 		return nil, err
5518 5522
 	}
5519
-	modCtx, err := p.runCtxMods(ctx)
5523
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5520 5524
 	if err != nil {
5521 5525
 		return nil, err
5522 5526
 	}
... ...
@@ -5543,11 +5577,15 @@ func (p *raftProxyControlServer) GetService(ctx context.Context, r *GetServiceRe
5543 5543
 	conn, err := p.connSelector.LeaderConn(ctx)
5544 5544
 	if err != nil {
5545 5545
 		if err == raftselector.ErrIsLeader {
5546
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5547
+			if err != nil {
5548
+				return nil, err
5549
+			}
5546 5550
 			return p.local.GetService(ctx, r)
5547 5551
 		}
5548 5552
 		return nil, err
5549 5553
 	}
5550
-	modCtx, err := p.runCtxMods(ctx)
5554
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5551 5555
 	if err != nil {
5552 5556
 		return nil, err
5553 5557
 	}
... ...
@@ -5574,11 +5612,15 @@ func (p *raftProxyControlServer) ListServices(ctx context.Context, r *ListServic
5574 5574
 	conn, err := p.connSelector.LeaderConn(ctx)
5575 5575
 	if err != nil {
5576 5576
 		if err == raftselector.ErrIsLeader {
5577
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5578
+			if err != nil {
5579
+				return nil, err
5580
+			}
5577 5581
 			return p.local.ListServices(ctx, r)
5578 5582
 		}
5579 5583
 		return nil, err
5580 5584
 	}
5581
-	modCtx, err := p.runCtxMods(ctx)
5585
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5582 5586
 	if err != nil {
5583 5587
 		return nil, err
5584 5588
 	}
... ...
@@ -5605,11 +5647,15 @@ func (p *raftProxyControlServer) CreateService(ctx context.Context, r *CreateSer
5605 5605
 	conn, err := p.connSelector.LeaderConn(ctx)
5606 5606
 	if err != nil {
5607 5607
 		if err == raftselector.ErrIsLeader {
5608
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5609
+			if err != nil {
5610
+				return nil, err
5611
+			}
5608 5612
 			return p.local.CreateService(ctx, r)
5609 5613
 		}
5610 5614
 		return nil, err
5611 5615
 	}
5612
-	modCtx, err := p.runCtxMods(ctx)
5616
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5613 5617
 	if err != nil {
5614 5618
 		return nil, err
5615 5619
 	}
... ...
@@ -5636,11 +5682,15 @@ func (p *raftProxyControlServer) UpdateService(ctx context.Context, r *UpdateSer
5636 5636
 	conn, err := p.connSelector.LeaderConn(ctx)
5637 5637
 	if err != nil {
5638 5638
 		if err == raftselector.ErrIsLeader {
5639
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5640
+			if err != nil {
5641
+				return nil, err
5642
+			}
5639 5643
 			return p.local.UpdateService(ctx, r)
5640 5644
 		}
5641 5645
 		return nil, err
5642 5646
 	}
5643
-	modCtx, err := p.runCtxMods(ctx)
5647
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5644 5648
 	if err != nil {
5645 5649
 		return nil, err
5646 5650
 	}
... ...
@@ -5667,11 +5717,15 @@ func (p *raftProxyControlServer) RemoveService(ctx context.Context, r *RemoveSer
5667 5667
 	conn, err := p.connSelector.LeaderConn(ctx)
5668 5668
 	if err != nil {
5669 5669
 		if err == raftselector.ErrIsLeader {
5670
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5671
+			if err != nil {
5672
+				return nil, err
5673
+			}
5670 5674
 			return p.local.RemoveService(ctx, r)
5671 5675
 		}
5672 5676
 		return nil, err
5673 5677
 	}
5674
-	modCtx, err := p.runCtxMods(ctx)
5678
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5675 5679
 	if err != nil {
5676 5680
 		return nil, err
5677 5681
 	}
... ...
@@ -5698,11 +5752,15 @@ func (p *raftProxyControlServer) GetNetwork(ctx context.Context, r *GetNetworkRe
5698 5698
 	conn, err := p.connSelector.LeaderConn(ctx)
5699 5699
 	if err != nil {
5700 5700
 		if err == raftselector.ErrIsLeader {
5701
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5702
+			if err != nil {
5703
+				return nil, err
5704
+			}
5701 5705
 			return p.local.GetNetwork(ctx, r)
5702 5706
 		}
5703 5707
 		return nil, err
5704 5708
 	}
5705
-	modCtx, err := p.runCtxMods(ctx)
5709
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5706 5710
 	if err != nil {
5707 5711
 		return nil, err
5708 5712
 	}
... ...
@@ -5729,11 +5787,15 @@ func (p *raftProxyControlServer) ListNetworks(ctx context.Context, r *ListNetwor
5729 5729
 	conn, err := p.connSelector.LeaderConn(ctx)
5730 5730
 	if err != nil {
5731 5731
 		if err == raftselector.ErrIsLeader {
5732
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5733
+			if err != nil {
5734
+				return nil, err
5735
+			}
5732 5736
 			return p.local.ListNetworks(ctx, r)
5733 5737
 		}
5734 5738
 		return nil, err
5735 5739
 	}
5736
-	modCtx, err := p.runCtxMods(ctx)
5740
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5737 5741
 	if err != nil {
5738 5742
 		return nil, err
5739 5743
 	}
... ...
@@ -5760,11 +5822,15 @@ func (p *raftProxyControlServer) CreateNetwork(ctx context.Context, r *CreateNet
5760 5760
 	conn, err := p.connSelector.LeaderConn(ctx)
5761 5761
 	if err != nil {
5762 5762
 		if err == raftselector.ErrIsLeader {
5763
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5764
+			if err != nil {
5765
+				return nil, err
5766
+			}
5763 5767
 			return p.local.CreateNetwork(ctx, r)
5764 5768
 		}
5765 5769
 		return nil, err
5766 5770
 	}
5767
-	modCtx, err := p.runCtxMods(ctx)
5771
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5768 5772
 	if err != nil {
5769 5773
 		return nil, err
5770 5774
 	}
... ...
@@ -5791,11 +5857,15 @@ func (p *raftProxyControlServer) RemoveNetwork(ctx context.Context, r *RemoveNet
5791 5791
 	conn, err := p.connSelector.LeaderConn(ctx)
5792 5792
 	if err != nil {
5793 5793
 		if err == raftselector.ErrIsLeader {
5794
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5795
+			if err != nil {
5796
+				return nil, err
5797
+			}
5794 5798
 			return p.local.RemoveNetwork(ctx, r)
5795 5799
 		}
5796 5800
 		return nil, err
5797 5801
 	}
5798
-	modCtx, err := p.runCtxMods(ctx)
5802
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5799 5803
 	if err != nil {
5800 5804
 		return nil, err
5801 5805
 	}
... ...
@@ -5822,11 +5892,15 @@ func (p *raftProxyControlServer) GetCluster(ctx context.Context, r *GetClusterRe
5822 5822
 	conn, err := p.connSelector.LeaderConn(ctx)
5823 5823
 	if err != nil {
5824 5824
 		if err == raftselector.ErrIsLeader {
5825
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5826
+			if err != nil {
5827
+				return nil, err
5828
+			}
5825 5829
 			return p.local.GetCluster(ctx, r)
5826 5830
 		}
5827 5831
 		return nil, err
5828 5832
 	}
5829
-	modCtx, err := p.runCtxMods(ctx)
5833
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5830 5834
 	if err != nil {
5831 5835
 		return nil, err
5832 5836
 	}
... ...
@@ -5853,11 +5927,15 @@ func (p *raftProxyControlServer) ListClusters(ctx context.Context, r *ListCluste
5853 5853
 	conn, err := p.connSelector.LeaderConn(ctx)
5854 5854
 	if err != nil {
5855 5855
 		if err == raftselector.ErrIsLeader {
5856
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5857
+			if err != nil {
5858
+				return nil, err
5859
+			}
5856 5860
 			return p.local.ListClusters(ctx, r)
5857 5861
 		}
5858 5862
 		return nil, err
5859 5863
 	}
5860
-	modCtx, err := p.runCtxMods(ctx)
5864
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5861 5865
 	if err != nil {
5862 5866
 		return nil, err
5863 5867
 	}
... ...
@@ -5884,11 +5962,15 @@ func (p *raftProxyControlServer) UpdateCluster(ctx context.Context, r *UpdateClu
5884 5884
 	conn, err := p.connSelector.LeaderConn(ctx)
5885 5885
 	if err != nil {
5886 5886
 		if err == raftselector.ErrIsLeader {
5887
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5888
+			if err != nil {
5889
+				return nil, err
5890
+			}
5887 5891
 			return p.local.UpdateCluster(ctx, r)
5888 5892
 		}
5889 5893
 		return nil, err
5890 5894
 	}
5891
-	modCtx, err := p.runCtxMods(ctx)
5895
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5892 5896
 	if err != nil {
5893 5897
 		return nil, err
5894 5898
 	}
... ...
@@ -5915,11 +5997,15 @@ func (p *raftProxyControlServer) GetSecret(ctx context.Context, r *GetSecretRequ
5915 5915
 	conn, err := p.connSelector.LeaderConn(ctx)
5916 5916
 	if err != nil {
5917 5917
 		if err == raftselector.ErrIsLeader {
5918
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5919
+			if err != nil {
5920
+				return nil, err
5921
+			}
5918 5922
 			return p.local.GetSecret(ctx, r)
5919 5923
 		}
5920 5924
 		return nil, err
5921 5925
 	}
5922
-	modCtx, err := p.runCtxMods(ctx)
5926
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5923 5927
 	if err != nil {
5924 5928
 		return nil, err
5925 5929
 	}
... ...
@@ -5946,11 +6032,15 @@ func (p *raftProxyControlServer) UpdateSecret(ctx context.Context, r *UpdateSecr
5946 5946
 	conn, err := p.connSelector.LeaderConn(ctx)
5947 5947
 	if err != nil {
5948 5948
 		if err == raftselector.ErrIsLeader {
5949
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5950
+			if err != nil {
5951
+				return nil, err
5952
+			}
5949 5953
 			return p.local.UpdateSecret(ctx, r)
5950 5954
 		}
5951 5955
 		return nil, err
5952 5956
 	}
5953
-	modCtx, err := p.runCtxMods(ctx)
5957
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5954 5958
 	if err != nil {
5955 5959
 		return nil, err
5956 5960
 	}
... ...
@@ -5977,11 +6067,15 @@ func (p *raftProxyControlServer) ListSecrets(ctx context.Context, r *ListSecrets
5977 5977
 	conn, err := p.connSelector.LeaderConn(ctx)
5978 5978
 	if err != nil {
5979 5979
 		if err == raftselector.ErrIsLeader {
5980
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
5981
+			if err != nil {
5982
+				return nil, err
5983
+			}
5980 5984
 			return p.local.ListSecrets(ctx, r)
5981 5985
 		}
5982 5986
 		return nil, err
5983 5987
 	}
5984
-	modCtx, err := p.runCtxMods(ctx)
5988
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
5985 5989
 	if err != nil {
5986 5990
 		return nil, err
5987 5991
 	}
... ...
@@ -6008,11 +6102,15 @@ func (p *raftProxyControlServer) CreateSecret(ctx context.Context, r *CreateSecr
6008 6008
 	conn, err := p.connSelector.LeaderConn(ctx)
6009 6009
 	if err != nil {
6010 6010
 		if err == raftselector.ErrIsLeader {
6011
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
6012
+			if err != nil {
6013
+				return nil, err
6014
+			}
6011 6015
 			return p.local.CreateSecret(ctx, r)
6012 6016
 		}
6013 6017
 		return nil, err
6014 6018
 	}
6015
-	modCtx, err := p.runCtxMods(ctx)
6019
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
6016 6020
 	if err != nil {
6017 6021
 		return nil, err
6018 6022
 	}
... ...
@@ -6039,11 +6137,15 @@ func (p *raftProxyControlServer) RemoveSecret(ctx context.Context, r *RemoveSecr
6039 6039
 	conn, err := p.connSelector.LeaderConn(ctx)
6040 6040
 	if err != nil {
6041 6041
 		if err == raftselector.ErrIsLeader {
6042
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
6043
+			if err != nil {
6044
+				return nil, err
6045
+			}
6042 6046
 			return p.local.RemoveSecret(ctx, r)
6043 6047
 		}
6044 6048
 		return nil, err
6045 6049
 	}
6046
-	modCtx, err := p.runCtxMods(ctx)
6050
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
6047 6051
 	if err != nil {
6048 6052
 		return nil, err
6049 6053
 	}
... ...
@@ -1670,12 +1670,12 @@ func encodeVarintDispatcher(data []byte, offset int, v uint64) int {
1670 1670
 }
1671 1671
 
1672 1672
 type raftProxyDispatcherServer struct {
1673
-	local        DispatcherServer
1674
-	connSelector raftselector.ConnProvider
1675
-	ctxMods      []func(context.Context) (context.Context, error)
1673
+	local                       DispatcherServer
1674
+	connSelector                raftselector.ConnProvider
1675
+	localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error)
1676 1676
 }
1677 1677
 
1678
-func NewRaftProxyDispatcherServer(local DispatcherServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) DispatcherServer {
1678
+func NewRaftProxyDispatcherServer(local DispatcherServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) DispatcherServer {
1679 1679
 	redirectChecker := func(ctx context.Context) (context.Context, error) {
1680 1680
 		s, ok := transport.StreamFromContext(ctx)
1681 1681
 		if !ok {
... ...
@@ -1692,18 +1692,24 @@ func NewRaftProxyDispatcherServer(local DispatcherServer, connSelector raftselec
1692 1692
 		md["redirect"] = append(md["redirect"], addr)
1693 1693
 		return metadata.NewContext(ctx, md), nil
1694 1694
 	}
1695
-	mods := []func(context.Context) (context.Context, error){redirectChecker}
1696
-	mods = append(mods, ctxMod)
1695
+	remoteMods := []func(context.Context) (context.Context, error){redirectChecker}
1696
+	remoteMods = append(remoteMods, remoteCtxMod)
1697
+
1698
+	var localMods []func(context.Context) (context.Context, error)
1699
+	if localCtxMod != nil {
1700
+		localMods = []func(context.Context) (context.Context, error){localCtxMod}
1701
+	}
1697 1702
 
1698 1703
 	return &raftProxyDispatcherServer{
1699
-		local:        local,
1700
-		connSelector: connSelector,
1701
-		ctxMods:      mods,
1704
+		local:         local,
1705
+		connSelector:  connSelector,
1706
+		localCtxMods:  localMods,
1707
+		remoteCtxMods: remoteMods,
1702 1708
 	}
1703 1709
 }
1704
-func (p *raftProxyDispatcherServer) runCtxMods(ctx context.Context) (context.Context, error) {
1710
+func (p *raftProxyDispatcherServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) {
1705 1711
 	var err error
1706
-	for _, mod := range p.ctxMods {
1712
+	for _, mod := range ctxMods {
1707 1713
 		ctx, err = mod(ctx)
1708 1714
 		if err != nil {
1709 1715
 			return ctx, err
... ...
@@ -1735,17 +1741,33 @@ func (p *raftProxyDispatcherServer) pollNewLeaderConn(ctx context.Context) (*grp
1735 1735
 	}
1736 1736
 }
1737 1737
 
1738
-func (p *raftProxyDispatcherServer) Session(r *SessionRequest, stream Dispatcher_SessionServer) error {
1738
+type Dispatcher_SessionServerWrapper struct {
1739
+	Dispatcher_SessionServer
1740
+	ctx context.Context
1741
+}
1739 1742
 
1743
+func (s Dispatcher_SessionServerWrapper) Context() context.Context {
1744
+	return s.ctx
1745
+}
1746
+
1747
+func (p *raftProxyDispatcherServer) Session(r *SessionRequest, stream Dispatcher_SessionServer) error {
1740 1748
 	ctx := stream.Context()
1741 1749
 	conn, err := p.connSelector.LeaderConn(ctx)
1742 1750
 	if err != nil {
1743 1751
 		if err == raftselector.ErrIsLeader {
1744
-			return p.local.Session(r, stream)
1752
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
1753
+			if err != nil {
1754
+				return err
1755
+			}
1756
+			streamWrapper := Dispatcher_SessionServerWrapper{
1757
+				Dispatcher_SessionServer: stream,
1758
+				ctx: ctx,
1759
+			}
1760
+			return p.local.Session(r, streamWrapper)
1745 1761
 		}
1746 1762
 		return err
1747 1763
 	}
1748
-	ctx, err = p.runCtxMods(ctx)
1764
+	ctx, err = p.runCtxMods(ctx, p.remoteCtxMods)
1749 1765
 	if err != nil {
1750 1766
 		return err
1751 1767
 	}
... ...
@@ -1775,11 +1797,15 @@ func (p *raftProxyDispatcherServer) Heartbeat(ctx context.Context, r *HeartbeatR
1775 1775
 	conn, err := p.connSelector.LeaderConn(ctx)
1776 1776
 	if err != nil {
1777 1777
 		if err == raftselector.ErrIsLeader {
1778
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
1779
+			if err != nil {
1780
+				return nil, err
1781
+			}
1778 1782
 			return p.local.Heartbeat(ctx, r)
1779 1783
 		}
1780 1784
 		return nil, err
1781 1785
 	}
1782
-	modCtx, err := p.runCtxMods(ctx)
1786
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
1783 1787
 	if err != nil {
1784 1788
 		return nil, err
1785 1789
 	}
... ...
@@ -1806,11 +1832,15 @@ func (p *raftProxyDispatcherServer) UpdateTaskStatus(ctx context.Context, r *Upd
1806 1806
 	conn, err := p.connSelector.LeaderConn(ctx)
1807 1807
 	if err != nil {
1808 1808
 		if err == raftselector.ErrIsLeader {
1809
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
1810
+			if err != nil {
1811
+				return nil, err
1812
+			}
1809 1813
 			return p.local.UpdateTaskStatus(ctx, r)
1810 1814
 		}
1811 1815
 		return nil, err
1812 1816
 	}
1813
-	modCtx, err := p.runCtxMods(ctx)
1817
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
1814 1818
 	if err != nil {
1815 1819
 		return nil, err
1816 1820
 	}
... ...
@@ -1832,17 +1862,33 @@ func (p *raftProxyDispatcherServer) UpdateTaskStatus(ctx context.Context, r *Upd
1832 1832
 	return resp, err
1833 1833
 }
1834 1834
 
1835
-func (p *raftProxyDispatcherServer) Tasks(r *TasksRequest, stream Dispatcher_TasksServer) error {
1835
+type Dispatcher_TasksServerWrapper struct {
1836
+	Dispatcher_TasksServer
1837
+	ctx context.Context
1838
+}
1839
+
1840
+func (s Dispatcher_TasksServerWrapper) Context() context.Context {
1841
+	return s.ctx
1842
+}
1836 1843
 
1844
+func (p *raftProxyDispatcherServer) Tasks(r *TasksRequest, stream Dispatcher_TasksServer) error {
1837 1845
 	ctx := stream.Context()
1838 1846
 	conn, err := p.connSelector.LeaderConn(ctx)
1839 1847
 	if err != nil {
1840 1848
 		if err == raftselector.ErrIsLeader {
1841
-			return p.local.Tasks(r, stream)
1849
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
1850
+			if err != nil {
1851
+				return err
1852
+			}
1853
+			streamWrapper := Dispatcher_TasksServerWrapper{
1854
+				Dispatcher_TasksServer: stream,
1855
+				ctx: ctx,
1856
+			}
1857
+			return p.local.Tasks(r, streamWrapper)
1842 1858
 		}
1843 1859
 		return err
1844 1860
 	}
1845
-	ctx, err = p.runCtxMods(ctx)
1861
+	ctx, err = p.runCtxMods(ctx, p.remoteCtxMods)
1846 1862
 	if err != nil {
1847 1863
 		return err
1848 1864
 	}
... ...
@@ -1867,17 +1913,33 @@ func (p *raftProxyDispatcherServer) Tasks(r *TasksRequest, stream Dispatcher_Tas
1867 1867
 	return nil
1868 1868
 }
1869 1869
 
1870
-func (p *raftProxyDispatcherServer) Assignments(r *AssignmentsRequest, stream Dispatcher_AssignmentsServer) error {
1870
+type Dispatcher_AssignmentsServerWrapper struct {
1871
+	Dispatcher_AssignmentsServer
1872
+	ctx context.Context
1873
+}
1874
+
1875
+func (s Dispatcher_AssignmentsServerWrapper) Context() context.Context {
1876
+	return s.ctx
1877
+}
1871 1878
 
1879
+func (p *raftProxyDispatcherServer) Assignments(r *AssignmentsRequest, stream Dispatcher_AssignmentsServer) error {
1872 1880
 	ctx := stream.Context()
1873 1881
 	conn, err := p.connSelector.LeaderConn(ctx)
1874 1882
 	if err != nil {
1875 1883
 		if err == raftselector.ErrIsLeader {
1876
-			return p.local.Assignments(r, stream)
1884
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
1885
+			if err != nil {
1886
+				return err
1887
+			}
1888
+			streamWrapper := Dispatcher_AssignmentsServerWrapper{
1889
+				Dispatcher_AssignmentsServer: stream,
1890
+				ctx: ctx,
1891
+			}
1892
+			return p.local.Assignments(r, streamWrapper)
1877 1893
 		}
1878 1894
 		return err
1879 1895
 	}
1880
-	ctx, err = p.runCtxMods(ctx)
1896
+	ctx, err = p.runCtxMods(ctx, p.remoteCtxMods)
1881 1897
 	if err != nil {
1882 1898
 		return err
1883 1899
 	}
... ...
@@ -321,12 +321,12 @@ func encodeVarintHealth(data []byte, offset int, v uint64) int {
321 321
 }
322 322
 
323 323
 type raftProxyHealthServer struct {
324
-	local        HealthServer
325
-	connSelector raftselector.ConnProvider
326
-	ctxMods      []func(context.Context) (context.Context, error)
324
+	local                       HealthServer
325
+	connSelector                raftselector.ConnProvider
326
+	localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error)
327 327
 }
328 328
 
329
-func NewRaftProxyHealthServer(local HealthServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) HealthServer {
329
+func NewRaftProxyHealthServer(local HealthServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) HealthServer {
330 330
 	redirectChecker := func(ctx context.Context) (context.Context, error) {
331 331
 		s, ok := transport.StreamFromContext(ctx)
332 332
 		if !ok {
... ...
@@ -343,18 +343,24 @@ func NewRaftProxyHealthServer(local HealthServer, connSelector raftselector.Conn
343 343
 		md["redirect"] = append(md["redirect"], addr)
344 344
 		return metadata.NewContext(ctx, md), nil
345 345
 	}
346
-	mods := []func(context.Context) (context.Context, error){redirectChecker}
347
-	mods = append(mods, ctxMod)
346
+	remoteMods := []func(context.Context) (context.Context, error){redirectChecker}
347
+	remoteMods = append(remoteMods, remoteCtxMod)
348
+
349
+	var localMods []func(context.Context) (context.Context, error)
350
+	if localCtxMod != nil {
351
+		localMods = []func(context.Context) (context.Context, error){localCtxMod}
352
+	}
348 353
 
349 354
 	return &raftProxyHealthServer{
350
-		local:        local,
351
-		connSelector: connSelector,
352
-		ctxMods:      mods,
355
+		local:         local,
356
+		connSelector:  connSelector,
357
+		localCtxMods:  localMods,
358
+		remoteCtxMods: remoteMods,
353 359
 	}
354 360
 }
355
-func (p *raftProxyHealthServer) runCtxMods(ctx context.Context) (context.Context, error) {
361
+func (p *raftProxyHealthServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) {
356 362
 	var err error
357
-	for _, mod := range p.ctxMods {
363
+	for _, mod := range ctxMods {
358 364
 		ctx, err = mod(ctx)
359 365
 		if err != nil {
360 366
 			return ctx, err
... ...
@@ -391,11 +397,15 @@ func (p *raftProxyHealthServer) Check(ctx context.Context, r *HealthCheckRequest
391 391
 	conn, err := p.connSelector.LeaderConn(ctx)
392 392
 	if err != nil {
393 393
 		if err == raftselector.ErrIsLeader {
394
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
395
+			if err != nil {
396
+				return nil, err
397
+			}
394 398
 			return p.local.Check(ctx, r)
395 399
 		}
396 400
 		return nil, err
397 401
 	}
398
-	modCtx, err := p.runCtxMods(ctx)
402
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
399 403
 	if err != nil {
400 404
 		return nil, err
401 405
 	}
... ...
@@ -1279,12 +1279,12 @@ func encodeVarintLogbroker(data []byte, offset int, v uint64) int {
1279 1279
 }
1280 1280
 
1281 1281
 type raftProxyLogsServer struct {
1282
-	local        LogsServer
1283
-	connSelector raftselector.ConnProvider
1284
-	ctxMods      []func(context.Context) (context.Context, error)
1282
+	local                       LogsServer
1283
+	connSelector                raftselector.ConnProvider
1284
+	localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error)
1285 1285
 }
1286 1286
 
1287
-func NewRaftProxyLogsServer(local LogsServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) LogsServer {
1287
+func NewRaftProxyLogsServer(local LogsServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) LogsServer {
1288 1288
 	redirectChecker := func(ctx context.Context) (context.Context, error) {
1289 1289
 		s, ok := transport.StreamFromContext(ctx)
1290 1290
 		if !ok {
... ...
@@ -1301,18 +1301,24 @@ func NewRaftProxyLogsServer(local LogsServer, connSelector raftselector.ConnProv
1301 1301
 		md["redirect"] = append(md["redirect"], addr)
1302 1302
 		return metadata.NewContext(ctx, md), nil
1303 1303
 	}
1304
-	mods := []func(context.Context) (context.Context, error){redirectChecker}
1305
-	mods = append(mods, ctxMod)
1304
+	remoteMods := []func(context.Context) (context.Context, error){redirectChecker}
1305
+	remoteMods = append(remoteMods, remoteCtxMod)
1306
+
1307
+	var localMods []func(context.Context) (context.Context, error)
1308
+	if localCtxMod != nil {
1309
+		localMods = []func(context.Context) (context.Context, error){localCtxMod}
1310
+	}
1306 1311
 
1307 1312
 	return &raftProxyLogsServer{
1308
-		local:        local,
1309
-		connSelector: connSelector,
1310
-		ctxMods:      mods,
1313
+		local:         local,
1314
+		connSelector:  connSelector,
1315
+		localCtxMods:  localMods,
1316
+		remoteCtxMods: remoteMods,
1311 1317
 	}
1312 1318
 }
1313
-func (p *raftProxyLogsServer) runCtxMods(ctx context.Context) (context.Context, error) {
1319
+func (p *raftProxyLogsServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) {
1314 1320
 	var err error
1315
-	for _, mod := range p.ctxMods {
1321
+	for _, mod := range ctxMods {
1316 1322
 		ctx, err = mod(ctx)
1317 1323
 		if err != nil {
1318 1324
 			return ctx, err
... ...
@@ -1344,17 +1350,33 @@ func (p *raftProxyLogsServer) pollNewLeaderConn(ctx context.Context) (*grpc.Clie
1344 1344
 	}
1345 1345
 }
1346 1346
 
1347
-func (p *raftProxyLogsServer) SubscribeLogs(r *SubscribeLogsRequest, stream Logs_SubscribeLogsServer) error {
1347
+type Logs_SubscribeLogsServerWrapper struct {
1348
+	Logs_SubscribeLogsServer
1349
+	ctx context.Context
1350
+}
1348 1351
 
1352
+func (s Logs_SubscribeLogsServerWrapper) Context() context.Context {
1353
+	return s.ctx
1354
+}
1355
+
1356
+func (p *raftProxyLogsServer) SubscribeLogs(r *SubscribeLogsRequest, stream Logs_SubscribeLogsServer) error {
1349 1357
 	ctx := stream.Context()
1350 1358
 	conn, err := p.connSelector.LeaderConn(ctx)
1351 1359
 	if err != nil {
1352 1360
 		if err == raftselector.ErrIsLeader {
1353
-			return p.local.SubscribeLogs(r, stream)
1361
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
1362
+			if err != nil {
1363
+				return err
1364
+			}
1365
+			streamWrapper := Logs_SubscribeLogsServerWrapper{
1366
+				Logs_SubscribeLogsServer: stream,
1367
+				ctx: ctx,
1368
+			}
1369
+			return p.local.SubscribeLogs(r, streamWrapper)
1354 1370
 		}
1355 1371
 		return err
1356 1372
 	}
1357
-	ctx, err = p.runCtxMods(ctx)
1373
+	ctx, err = p.runCtxMods(ctx, p.remoteCtxMods)
1358 1374
 	if err != nil {
1359 1375
 		return err
1360 1376
 	}
... ...
@@ -1380,12 +1402,12 @@ func (p *raftProxyLogsServer) SubscribeLogs(r *SubscribeLogsRequest, stream Logs
1380 1380
 }
1381 1381
 
1382 1382
 type raftProxyLogBrokerServer struct {
1383
-	local        LogBrokerServer
1384
-	connSelector raftselector.ConnProvider
1385
-	ctxMods      []func(context.Context) (context.Context, error)
1383
+	local                       LogBrokerServer
1384
+	connSelector                raftselector.ConnProvider
1385
+	localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error)
1386 1386
 }
1387 1387
 
1388
-func NewRaftProxyLogBrokerServer(local LogBrokerServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) LogBrokerServer {
1388
+func NewRaftProxyLogBrokerServer(local LogBrokerServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) LogBrokerServer {
1389 1389
 	redirectChecker := func(ctx context.Context) (context.Context, error) {
1390 1390
 		s, ok := transport.StreamFromContext(ctx)
1391 1391
 		if !ok {
... ...
@@ -1402,18 +1424,24 @@ func NewRaftProxyLogBrokerServer(local LogBrokerServer, connSelector raftselecto
1402 1402
 		md["redirect"] = append(md["redirect"], addr)
1403 1403
 		return metadata.NewContext(ctx, md), nil
1404 1404
 	}
1405
-	mods := []func(context.Context) (context.Context, error){redirectChecker}
1406
-	mods = append(mods, ctxMod)
1405
+	remoteMods := []func(context.Context) (context.Context, error){redirectChecker}
1406
+	remoteMods = append(remoteMods, remoteCtxMod)
1407
+
1408
+	var localMods []func(context.Context) (context.Context, error)
1409
+	if localCtxMod != nil {
1410
+		localMods = []func(context.Context) (context.Context, error){localCtxMod}
1411
+	}
1407 1412
 
1408 1413
 	return &raftProxyLogBrokerServer{
1409
-		local:        local,
1410
-		connSelector: connSelector,
1411
-		ctxMods:      mods,
1414
+		local:         local,
1415
+		connSelector:  connSelector,
1416
+		localCtxMods:  localMods,
1417
+		remoteCtxMods: remoteMods,
1412 1418
 	}
1413 1419
 }
1414
-func (p *raftProxyLogBrokerServer) runCtxMods(ctx context.Context) (context.Context, error) {
1420
+func (p *raftProxyLogBrokerServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) {
1415 1421
 	var err error
1416
-	for _, mod := range p.ctxMods {
1422
+	for _, mod := range ctxMods {
1417 1423
 		ctx, err = mod(ctx)
1418 1424
 		if err != nil {
1419 1425
 			return ctx, err
... ...
@@ -1445,17 +1473,33 @@ func (p *raftProxyLogBrokerServer) pollNewLeaderConn(ctx context.Context) (*grpc
1445 1445
 	}
1446 1446
 }
1447 1447
 
1448
-func (p *raftProxyLogBrokerServer) ListenSubscriptions(r *ListenSubscriptionsRequest, stream LogBroker_ListenSubscriptionsServer) error {
1448
+type LogBroker_ListenSubscriptionsServerWrapper struct {
1449
+	LogBroker_ListenSubscriptionsServer
1450
+	ctx context.Context
1451
+}
1449 1452
 
1453
+func (s LogBroker_ListenSubscriptionsServerWrapper) Context() context.Context {
1454
+	return s.ctx
1455
+}
1456
+
1457
+func (p *raftProxyLogBrokerServer) ListenSubscriptions(r *ListenSubscriptionsRequest, stream LogBroker_ListenSubscriptionsServer) error {
1450 1458
 	ctx := stream.Context()
1451 1459
 	conn, err := p.connSelector.LeaderConn(ctx)
1452 1460
 	if err != nil {
1453 1461
 		if err == raftselector.ErrIsLeader {
1454
-			return p.local.ListenSubscriptions(r, stream)
1462
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
1463
+			if err != nil {
1464
+				return err
1465
+			}
1466
+			streamWrapper := LogBroker_ListenSubscriptionsServerWrapper{
1467
+				LogBroker_ListenSubscriptionsServer: stream,
1468
+				ctx: ctx,
1469
+			}
1470
+			return p.local.ListenSubscriptions(r, streamWrapper)
1455 1471
 		}
1456 1472
 		return err
1457 1473
 	}
1458
-	ctx, err = p.runCtxMods(ctx)
1474
+	ctx, err = p.runCtxMods(ctx, p.remoteCtxMods)
1459 1475
 	if err != nil {
1460 1476
 		return err
1461 1477
 	}
... ...
@@ -1480,17 +1524,33 @@ func (p *raftProxyLogBrokerServer) ListenSubscriptions(r *ListenSubscriptionsReq
1480 1480
 	return nil
1481 1481
 }
1482 1482
 
1483
-func (p *raftProxyLogBrokerServer) PublishLogs(stream LogBroker_PublishLogsServer) error {
1483
+type LogBroker_PublishLogsServerWrapper struct {
1484
+	LogBroker_PublishLogsServer
1485
+	ctx context.Context
1486
+}
1484 1487
 
1488
+func (s LogBroker_PublishLogsServerWrapper) Context() context.Context {
1489
+	return s.ctx
1490
+}
1491
+
1492
+func (p *raftProxyLogBrokerServer) PublishLogs(stream LogBroker_PublishLogsServer) error {
1485 1493
 	ctx := stream.Context()
1486 1494
 	conn, err := p.connSelector.LeaderConn(ctx)
1487 1495
 	if err != nil {
1488 1496
 		if err == raftselector.ErrIsLeader {
1489
-			return p.local.PublishLogs(stream)
1497
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
1498
+			if err != nil {
1499
+				return err
1500
+			}
1501
+			streamWrapper := LogBroker_PublishLogsServerWrapper{
1502
+				LogBroker_PublishLogsServer: stream,
1503
+				ctx: ctx,
1504
+			}
1505
+			return p.local.PublishLogs(streamWrapper)
1490 1506
 		}
1491 1507
 		return err
1492 1508
 	}
1493
-	ctx, err = p.runCtxMods(ctx)
1509
+	ctx, err = p.runCtxMods(ctx, p.remoteCtxMods)
1494 1510
 	if err != nil {
1495 1511
 		return err
1496 1512
 	}
... ...
@@ -1498,12 +1498,12 @@ func encodeVarintRaft(data []byte, offset int, v uint64) int {
1498 1498
 }
1499 1499
 
1500 1500
 type raftProxyRaftServer struct {
1501
-	local        RaftServer
1502
-	connSelector raftselector.ConnProvider
1503
-	ctxMods      []func(context.Context) (context.Context, error)
1501
+	local                       RaftServer
1502
+	connSelector                raftselector.ConnProvider
1503
+	localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error)
1504 1504
 }
1505 1505
 
1506
-func NewRaftProxyRaftServer(local RaftServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) RaftServer {
1506
+func NewRaftProxyRaftServer(local RaftServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) RaftServer {
1507 1507
 	redirectChecker := func(ctx context.Context) (context.Context, error) {
1508 1508
 		s, ok := transport.StreamFromContext(ctx)
1509 1509
 		if !ok {
... ...
@@ -1520,18 +1520,24 @@ func NewRaftProxyRaftServer(local RaftServer, connSelector raftselector.ConnProv
1520 1520
 		md["redirect"] = append(md["redirect"], addr)
1521 1521
 		return metadata.NewContext(ctx, md), nil
1522 1522
 	}
1523
-	mods := []func(context.Context) (context.Context, error){redirectChecker}
1524
-	mods = append(mods, ctxMod)
1523
+	remoteMods := []func(context.Context) (context.Context, error){redirectChecker}
1524
+	remoteMods = append(remoteMods, remoteCtxMod)
1525
+
1526
+	var localMods []func(context.Context) (context.Context, error)
1527
+	if localCtxMod != nil {
1528
+		localMods = []func(context.Context) (context.Context, error){localCtxMod}
1529
+	}
1525 1530
 
1526 1531
 	return &raftProxyRaftServer{
1527
-		local:        local,
1528
-		connSelector: connSelector,
1529
-		ctxMods:      mods,
1532
+		local:         local,
1533
+		connSelector:  connSelector,
1534
+		localCtxMods:  localMods,
1535
+		remoteCtxMods: remoteMods,
1530 1536
 	}
1531 1537
 }
1532
-func (p *raftProxyRaftServer) runCtxMods(ctx context.Context) (context.Context, error) {
1538
+func (p *raftProxyRaftServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) {
1533 1539
 	var err error
1534
-	for _, mod := range p.ctxMods {
1540
+	for _, mod := range ctxMods {
1535 1541
 		ctx, err = mod(ctx)
1536 1542
 		if err != nil {
1537 1543
 			return ctx, err
... ...
@@ -1568,11 +1574,15 @@ func (p *raftProxyRaftServer) ProcessRaftMessage(ctx context.Context, r *Process
1568 1568
 	conn, err := p.connSelector.LeaderConn(ctx)
1569 1569
 	if err != nil {
1570 1570
 		if err == raftselector.ErrIsLeader {
1571
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
1572
+			if err != nil {
1573
+				return nil, err
1574
+			}
1571 1575
 			return p.local.ProcessRaftMessage(ctx, r)
1572 1576
 		}
1573 1577
 		return nil, err
1574 1578
 	}
1575
-	modCtx, err := p.runCtxMods(ctx)
1579
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
1576 1580
 	if err != nil {
1577 1581
 		return nil, err
1578 1582
 	}
... ...
@@ -1599,11 +1609,15 @@ func (p *raftProxyRaftServer) ResolveAddress(ctx context.Context, r *ResolveAddr
1599 1599
 	conn, err := p.connSelector.LeaderConn(ctx)
1600 1600
 	if err != nil {
1601 1601
 		if err == raftselector.ErrIsLeader {
1602
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
1603
+			if err != nil {
1604
+				return nil, err
1605
+			}
1602 1606
 			return p.local.ResolveAddress(ctx, r)
1603 1607
 		}
1604 1608
 		return nil, err
1605 1609
 	}
1606
-	modCtx, err := p.runCtxMods(ctx)
1610
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
1607 1611
 	if err != nil {
1608 1612
 		return nil, err
1609 1613
 	}
... ...
@@ -1626,12 +1640,12 @@ func (p *raftProxyRaftServer) ResolveAddress(ctx context.Context, r *ResolveAddr
1626 1626
 }
1627 1627
 
1628 1628
 type raftProxyRaftMembershipServer struct {
1629
-	local        RaftMembershipServer
1630
-	connSelector raftselector.ConnProvider
1631
-	ctxMods      []func(context.Context) (context.Context, error)
1629
+	local                       RaftMembershipServer
1630
+	connSelector                raftselector.ConnProvider
1631
+	localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error)
1632 1632
 }
1633 1633
 
1634
-func NewRaftProxyRaftMembershipServer(local RaftMembershipServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) RaftMembershipServer {
1634
+func NewRaftProxyRaftMembershipServer(local RaftMembershipServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) RaftMembershipServer {
1635 1635
 	redirectChecker := func(ctx context.Context) (context.Context, error) {
1636 1636
 		s, ok := transport.StreamFromContext(ctx)
1637 1637
 		if !ok {
... ...
@@ -1648,18 +1662,24 @@ func NewRaftProxyRaftMembershipServer(local RaftMembershipServer, connSelector r
1648 1648
 		md["redirect"] = append(md["redirect"], addr)
1649 1649
 		return metadata.NewContext(ctx, md), nil
1650 1650
 	}
1651
-	mods := []func(context.Context) (context.Context, error){redirectChecker}
1652
-	mods = append(mods, ctxMod)
1651
+	remoteMods := []func(context.Context) (context.Context, error){redirectChecker}
1652
+	remoteMods = append(remoteMods, remoteCtxMod)
1653
+
1654
+	var localMods []func(context.Context) (context.Context, error)
1655
+	if localCtxMod != nil {
1656
+		localMods = []func(context.Context) (context.Context, error){localCtxMod}
1657
+	}
1653 1658
 
1654 1659
 	return &raftProxyRaftMembershipServer{
1655
-		local:        local,
1656
-		connSelector: connSelector,
1657
-		ctxMods:      mods,
1660
+		local:         local,
1661
+		connSelector:  connSelector,
1662
+		localCtxMods:  localMods,
1663
+		remoteCtxMods: remoteMods,
1658 1664
 	}
1659 1665
 }
1660
-func (p *raftProxyRaftMembershipServer) runCtxMods(ctx context.Context) (context.Context, error) {
1666
+func (p *raftProxyRaftMembershipServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) {
1661 1667
 	var err error
1662
-	for _, mod := range p.ctxMods {
1668
+	for _, mod := range ctxMods {
1663 1669
 		ctx, err = mod(ctx)
1664 1670
 		if err != nil {
1665 1671
 			return ctx, err
... ...
@@ -1696,11 +1716,15 @@ func (p *raftProxyRaftMembershipServer) Join(ctx context.Context, r *JoinRequest
1696 1696
 	conn, err := p.connSelector.LeaderConn(ctx)
1697 1697
 	if err != nil {
1698 1698
 		if err == raftselector.ErrIsLeader {
1699
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
1700
+			if err != nil {
1701
+				return nil, err
1702
+			}
1699 1703
 			return p.local.Join(ctx, r)
1700 1704
 		}
1701 1705
 		return nil, err
1702 1706
 	}
1703
-	modCtx, err := p.runCtxMods(ctx)
1707
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
1704 1708
 	if err != nil {
1705 1709
 		return nil, err
1706 1710
 	}
... ...
@@ -1727,11 +1751,15 @@ func (p *raftProxyRaftMembershipServer) Leave(ctx context.Context, r *LeaveReque
1727 1727
 	conn, err := p.connSelector.LeaderConn(ctx)
1728 1728
 	if err != nil {
1729 1729
 		if err == raftselector.ErrIsLeader {
1730
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
1731
+			if err != nil {
1732
+				return nil, err
1733
+			}
1730 1734
 			return p.local.Leave(ctx, r)
1731 1735
 		}
1732 1736
 		return nil, err
1733 1737
 	}
1734
-	modCtx, err := p.runCtxMods(ctx)
1738
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
1735 1739
 	if err != nil {
1736 1740
 		return nil, err
1737 1741
 	}
... ...
@@ -451,12 +451,12 @@ func encodeVarintResource(data []byte, offset int, v uint64) int {
451 451
 }
452 452
 
453 453
 type raftProxyResourceAllocatorServer struct {
454
-	local        ResourceAllocatorServer
455
-	connSelector raftselector.ConnProvider
456
-	ctxMods      []func(context.Context) (context.Context, error)
454
+	local                       ResourceAllocatorServer
455
+	connSelector                raftselector.ConnProvider
456
+	localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error)
457 457
 }
458 458
 
459
-func NewRaftProxyResourceAllocatorServer(local ResourceAllocatorServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) ResourceAllocatorServer {
459
+func NewRaftProxyResourceAllocatorServer(local ResourceAllocatorServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) ResourceAllocatorServer {
460 460
 	redirectChecker := func(ctx context.Context) (context.Context, error) {
461 461
 		s, ok := transport.StreamFromContext(ctx)
462 462
 		if !ok {
... ...
@@ -473,18 +473,24 @@ func NewRaftProxyResourceAllocatorServer(local ResourceAllocatorServer, connSele
473 473
 		md["redirect"] = append(md["redirect"], addr)
474 474
 		return metadata.NewContext(ctx, md), nil
475 475
 	}
476
-	mods := []func(context.Context) (context.Context, error){redirectChecker}
477
-	mods = append(mods, ctxMod)
476
+	remoteMods := []func(context.Context) (context.Context, error){redirectChecker}
477
+	remoteMods = append(remoteMods, remoteCtxMod)
478
+
479
+	var localMods []func(context.Context) (context.Context, error)
480
+	if localCtxMod != nil {
481
+		localMods = []func(context.Context) (context.Context, error){localCtxMod}
482
+	}
478 483
 
479 484
 	return &raftProxyResourceAllocatorServer{
480
-		local:        local,
481
-		connSelector: connSelector,
482
-		ctxMods:      mods,
485
+		local:         local,
486
+		connSelector:  connSelector,
487
+		localCtxMods:  localMods,
488
+		remoteCtxMods: remoteMods,
483 489
 	}
484 490
 }
485
-func (p *raftProxyResourceAllocatorServer) runCtxMods(ctx context.Context) (context.Context, error) {
491
+func (p *raftProxyResourceAllocatorServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) {
486 492
 	var err error
487
-	for _, mod := range p.ctxMods {
493
+	for _, mod := range ctxMods {
488 494
 		ctx, err = mod(ctx)
489 495
 		if err != nil {
490 496
 			return ctx, err
... ...
@@ -521,11 +527,15 @@ func (p *raftProxyResourceAllocatorServer) AttachNetwork(ctx context.Context, r
521 521
 	conn, err := p.connSelector.LeaderConn(ctx)
522 522
 	if err != nil {
523 523
 		if err == raftselector.ErrIsLeader {
524
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
525
+			if err != nil {
526
+				return nil, err
527
+			}
524 528
 			return p.local.AttachNetwork(ctx, r)
525 529
 		}
526 530
 		return nil, err
527 531
 	}
528
-	modCtx, err := p.runCtxMods(ctx)
532
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
529 533
 	if err != nil {
530 534
 		return nil, err
531 535
 	}
... ...
@@ -552,11 +562,15 @@ func (p *raftProxyResourceAllocatorServer) DetachNetwork(ctx context.Context, r
552 552
 	conn, err := p.connSelector.LeaderConn(ctx)
553 553
 	if err != nil {
554 554
 		if err == raftselector.ErrIsLeader {
555
+			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
556
+			if err != nil {
557
+				return nil, err
558
+			}
555 559
 			return p.local.DetachNetwork(ctx, r)
556 560
 		}
557 561
 		return nil, err
558 562
 	}
559
-	modCtx, err := p.runCtxMods(ctx)
563
+	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
560 564
 	if err != nil {
561 565
 		return nil, err
562 566
 	}
... ...
@@ -16,6 +16,13 @@ import (
16 16
 	"google.golang.org/grpc/peer"
17 17
 )
18 18
 
19
+type localRequestKeyType struct{}
20
+
21
+// LocalRequestKey is a context key to mark a request that originating on the
22
+// local node. The assocated value is a RemoteNodeInfo structure describing the
23
+// local node.
24
+var LocalRequestKey = localRequestKeyType{}
25
+
19 26
 // LogTLSState logs information about the TLS connection and remote peers
20 27
 func LogTLSState(ctx context.Context, tlsState *tls.ConnectionState) {
21 28
 	if tlsState == nil {
... ...
@@ -189,6 +196,17 @@ type RemoteNodeInfo struct {
189 189
 // well as the forwarder's ID. This function does not do authorization checks -
190 190
 // it only looks up the node ID.
191 191
 func RemoteNode(ctx context.Context) (RemoteNodeInfo, error) {
192
+	// If we have a value on the context that marks this as a local
193
+	// request, we return the node info from the context.
194
+	localNodeInfo := ctx.Value(LocalRequestKey)
195
+
196
+	if localNodeInfo != nil {
197
+		nodeInfo, ok := localNodeInfo.(RemoteNodeInfo)
198
+		if ok {
199
+			return nodeInfo, nil
200
+		}
201
+	}
202
+
192 203
 	certSubj, err := certSubjectFromContext(ctx)
193 204
 	if err != nil {
194 205
 		return RemoteNodeInfo{}, err
... ...
@@ -28,7 +28,6 @@ const (
28 28
 // breaking it apart doesn't seem worth it.
29 29
 type Server struct {
30 30
 	mu                          sync.Mutex
31
-	wg                          sync.WaitGroup
32 31
 	ctx                         context.Context
33 32
 	cancel                      func()
34 33
 	store                       *store.MemoryStore
... ...
@@ -102,10 +101,9 @@ func (s *Server) NodeCertificateStatus(ctx context.Context, request *api.NodeCer
102 102
 		return nil, grpc.Errorf(codes.InvalidArgument, codes.InvalidArgument.String())
103 103
 	}
104 104
 
105
-	if err := s.addTask(); err != nil {
105
+	if err := s.isRunningLocked(); err != nil {
106 106
 		return nil, err
107 107
 	}
108
-	defer s.doneTask()
109 108
 
110 109
 	var node *api.Node
111 110
 
... ...
@@ -189,10 +187,9 @@ func (s *Server) IssueNodeCertificate(ctx context.Context, request *api.IssueNod
189 189
 		return nil, grpc.Errorf(codes.InvalidArgument, codes.InvalidArgument.String())
190 190
 	}
191 191
 
192
-	if err := s.addTask(); err != nil {
192
+	if err := s.isRunningLocked(); err != nil {
193 193
 		return nil, err
194 194
 	}
195
-	defer s.doneTask()
196 195
 
197 196
 	var (
198 197
 		blacklistedCerts map[string]*api.BlacklistedCertificate
... ...
@@ -211,6 +208,15 @@ func (s *Server) IssueNodeCertificate(ctx context.Context, request *api.IssueNod
211 211
 		blacklistedCerts = clusters[0].BlacklistedCertificates
212 212
 	}
213 213
 
214
+	// Renewing the cert with a local (unix socket) is always valid.
215
+	localNodeInfo := ctx.Value(LocalRequestKey)
216
+	if localNodeInfo != nil {
217
+		nodeInfo, ok := localNodeInfo.(RemoteNodeInfo)
218
+		if ok && nodeInfo.NodeID != "" {
219
+			return s.issueRenewCertificate(ctx, nodeInfo.NodeID, request.CSR)
220
+		}
221
+	}
222
+
214 223
 	// If the remote node is a worker (either forwarded by a manager, or calling directly),
215 224
 	// issue a renew worker certificate entry with the correct ID
216 225
 	nodeID, err := AuthorizeForwardedRoleAndOrg(ctx, []string{WorkerRole}, []string{ManagerRole}, s.securityConfig.ClientTLSCreds.Organization(), blacklistedCerts)
... ...
@@ -365,10 +371,8 @@ func (s *Server) Run(ctx context.Context) error {
365 365
 		s.mu.Unlock()
366 366
 		return errors.New("CA signer is already running")
367 367
 	}
368
-	s.wg.Add(1)
369 368
 	s.mu.Unlock()
370 369
 
371
-	defer s.wg.Done()
372 370
 	ctx = log.WithModule(ctx, "ca")
373 371
 
374 372
 	// Retrieve the channels to keep track of changes in the cluster
... ...
@@ -398,8 +402,8 @@ func (s *Server) Run(ctx context.Context) error {
398 398
 	// returns true without joinTokens being set correctly.
399 399
 	s.mu.Lock()
400 400
 	s.ctx, s.cancel = context.WithCancel(ctx)
401
-	s.mu.Unlock()
402 401
 	close(s.started)
402
+	s.mu.Unlock()
403 403
 
404 404
 	if err != nil {
405 405
 		log.G(ctx).WithFields(logrus.Fields{
... ...
@@ -460,38 +464,32 @@ func (s *Server) Run(ctx context.Context) error {
460 460
 // Stop stops the CA and closes all grpc streams.
461 461
 func (s *Server) Stop() error {
462 462
 	s.mu.Lock()
463
+	defer s.mu.Unlock()
463 464
 	if !s.isRunning() {
464
-		s.mu.Unlock()
465 465
 		return errors.New("CA signer is already stopped")
466 466
 	}
467 467
 	s.cancel()
468
-	s.mu.Unlock()
469
-	// wait for all handlers to finish their CA deals,
470
-	s.wg.Wait()
471 468
 	s.started = make(chan struct{})
472 469
 	return nil
473 470
 }
474 471
 
475 472
 // Ready waits on the ready channel and returns when the server is ready to serve.
476 473
 func (s *Server) Ready() <-chan struct{} {
474
+	s.mu.Lock()
475
+	defer s.mu.Unlock()
477 476
 	return s.started
478 477
 }
479 478
 
480
-func (s *Server) addTask() error {
479
+func (s *Server) isRunningLocked() error {
481 480
 	s.mu.Lock()
482 481
 	if !s.isRunning() {
483 482
 		s.mu.Unlock()
484 483
 		return grpc.Errorf(codes.Aborted, "CA signer is stopped")
485 484
 	}
486
-	s.wg.Add(1)
487 485
 	s.mu.Unlock()
488 486
 	return nil
489 487
 }
490 488
 
491
-func (s *Server) doneTask() {
492
-	s.wg.Done()
493
-}
494
-
495 489
 func (s *Server) isRunning() bool {
496 490
 	if s.ctx == nil {
497 491
 		return false
... ...
@@ -38,6 +38,71 @@ type portSpace struct {
38 38
 	dynamicPortSpace *idm.Idm
39 39
 }
40 40
 
41
+type allocatedPorts map[api.PortConfig]map[uint32]*api.PortConfig
42
+
43
+// addState add the state of an allocated port to the collection.
44
+// `allocatedPorts` is a map of portKey:publishedPort:portState.
45
+// In case the value of the portKey is missing, the map
46
+// publishedPort:portState is created automatically
47
+func (ps allocatedPorts) addState(p *api.PortConfig) {
48
+	portKey := getPortConfigKey(p)
49
+	if _, ok := ps[portKey]; !ok {
50
+		ps[portKey] = make(map[uint32]*api.PortConfig)
51
+	}
52
+	ps[portKey][p.PublishedPort] = p
53
+}
54
+
55
+// delState delete the state of an allocated port from the collection.
56
+// `allocatedPorts` is a map of portKey:publishedPort:portState.
57
+//
58
+// If publishedPort is non-zero, then it is user defined. We will try to
59
+// remove the portState from `allocatedPorts` directly and return
60
+// the portState (or nil if no portState exists)
61
+//
62
+// If publishedPort is zero, then it is dynamically allocated. We will try
63
+// to remove the portState from `allocatedPorts`, as long as there is
64
+// a portState associated with a non-zero publishedPort.
65
+// Note multiple dynamically allocated ports might exists. In this case,
66
+// we will remove only at a time so both allocated ports are tracked.
67
+//
68
+// Note becasue of the potential co-existence of user-defined and dynamically
69
+// allocated ports, delState has to be called for user-defined port first.
70
+// dynamically allocated ports should be removed later.
71
+func (ps allocatedPorts) delState(p *api.PortConfig) *api.PortConfig {
72
+	portKey := getPortConfigKey(p)
73
+
74
+	portStateMap, ok := ps[portKey]
75
+
76
+	// If name, port, protocol values don't match then we
77
+	// are not allocated.
78
+	if !ok {
79
+		return nil
80
+	}
81
+
82
+	if p.PublishedPort != 0 {
83
+		// If SwarmPort was user defined but the port state
84
+		// SwarmPort doesn't match we are not allocated.
85
+		v := portStateMap[p.PublishedPort]
86
+
87
+		// Delete state from allocatedPorts
88
+		delete(portStateMap, p.PublishedPort)
89
+
90
+		return v
91
+	}
92
+
93
+	// If PublishedPort == 0 and we don't have non-zero port
94
+	// then we are not allocated
95
+	for publishedPort, v := range portStateMap {
96
+		if publishedPort != 0 {
97
+			// Delete state from allocatedPorts
98
+			delete(portStateMap, publishedPort)
99
+			return v
100
+		}
101
+	}
102
+
103
+	return nil
104
+}
105
+
41 106
 func newPortAllocator() (*portAllocator, error) {
42 107
 	portSpaces := make(map[api.PortConfig_Protocol]*portSpace)
43 108
 	for _, protocol := range []api.PortConfig_Protocol{api.ProtocolTCP, api.ProtocolUDP} {
... ...
@@ -91,40 +156,53 @@ func reconcilePortConfigs(s *api.Service) []*api.PortConfig {
91 91
 		return s.Spec.Endpoint.Ports
92 92
 	}
93 93
 
94
-	allocatedPorts := make(map[api.PortConfig]*api.PortConfig)
94
+	portStates := allocatedPorts{}
95 95
 	for _, portState := range s.Endpoint.Ports {
96
-		if portState.PublishMode != api.PublishModeIngress {
97
-			continue
96
+		if portState.PublishMode == api.PublishModeIngress {
97
+			portStates.addState(portState)
98 98
 		}
99
-
100
-		allocatedPorts[getPortConfigKey(portState)] = portState
101 99
 	}
102 100
 
103 101
 	var portConfigs []*api.PortConfig
102
+
103
+	// Process the portConfig with portConfig.PublishMode != api.PublishModeIngress
104
+	// and PublishedPort != 0 (high priority)
104 105
 	for _, portConfig := range s.Spec.Endpoint.Ports {
105
-		// If the PublishMode is not Ingress simply pick up
106
-		// the port config.
107 106
 		if portConfig.PublishMode != api.PublishModeIngress {
107
+			// If the PublishMode is not Ingress simply pick up the port config.
108 108
 			portConfigs = append(portConfigs, portConfig)
109
-			continue
110
-		}
109
+		} else if portConfig.PublishedPort != 0 {
110
+			// Otherwise we only process PublishedPort != 0 in this round
111 111
 
112
-		portState, ok := allocatedPorts[getPortConfigKey(portConfig)]
113
-
114
-		// If the portConfig is exactly the same as portState
115
-		// except if SwarmPort is not user-define then prefer
116
-		// portState to ensure sticky allocation of the same
117
-		// port that was allocated before.
118
-		if ok && portConfig.Name == portState.Name &&
119
-			portConfig.TargetPort == portState.TargetPort &&
120
-			portConfig.Protocol == portState.Protocol &&
121
-			portConfig.PublishedPort == 0 {
122
-			portConfigs = append(portConfigs, portState)
123
-			continue
112
+			// Remove record from portState
113
+			portStates.delState(portConfig)
114
+
115
+			// For PublishedPort != 0 prefer the portConfig
116
+			portConfigs = append(portConfigs, portConfig)
124 117
 		}
118
+	}
119
+
120
+	// Iterate portConfigs with PublishedPort == 0 (low priority)
121
+	for _, portConfig := range s.Spec.Endpoint.Ports {
122
+		// Ignore ports which are not PublishModeIngress (already processed)
123
+		// And we only process PublishedPort == 0 in this round
124
+		// So the following:
125
+		//  `portConfig.PublishMode == api.PublishModeIngress && portConfig.PublishedPort == 0`
126
+		if portConfig.PublishMode == api.PublishModeIngress && portConfig.PublishedPort == 0 {
127
+			// If the portConfig is exactly the same as portState
128
+			// except if SwarmPort is not user-define then prefer
129
+			// portState to ensure sticky allocation of the same
130
+			// port that was allocated before.
131
+
132
+			// Remove record from portState
133
+			if portState := portStates.delState(portConfig); portState != nil {
134
+				portConfigs = append(portConfigs, portState)
135
+				continue
136
+			}
125 137
 
126
-		// For all other cases prefer the portConfig
127
-		portConfigs = append(portConfigs, portConfig)
138
+			// For all other cases prefer the portConfig
139
+			portConfigs = append(portConfigs, portConfig)
140
+		}
128 141
 	}
129 142
 
130 143
 	return portConfigs
... ...
@@ -213,40 +291,31 @@ func (pa *portAllocator) isPortsAllocated(s *api.Service) bool {
213 213
 		return false
214 214
 	}
215 215
 
216
-	allocatedPorts := make(map[api.PortConfig]*api.PortConfig)
216
+	portStates := allocatedPorts{}
217 217
 	for _, portState := range s.Endpoint.Ports {
218
-		if portState.PublishMode != api.PublishModeIngress {
219
-			continue
218
+		if portState.PublishMode == api.PublishModeIngress {
219
+			portStates.addState(portState)
220 220
 		}
221
-
222
-		allocatedPorts[getPortConfigKey(portState)] = portState
223 221
 	}
224 222
 
223
+	// Iterate portConfigs with PublishedPort != 0 (high priority)
225 224
 	for _, portConfig := range s.Spec.Endpoint.Ports {
226 225
 		// Ignore ports which are not PublishModeIngress
227 226
 		if portConfig.PublishMode != api.PublishModeIngress {
228 227
 			continue
229 228
 		}
230
-
231
-		portState, ok := allocatedPorts[getPortConfigKey(portConfig)]
232
-
233
-		// If name, port, protocol values don't match then we
234
-		// are not allocated.
235
-		if !ok {
229
+		if portConfig.PublishedPort != 0 && portStates.delState(portConfig) == nil {
236 230
 			return false
237 231
 		}
232
+	}
238 233
 
239
-		// If SwarmPort was user defined but the port state
240
-		// SwarmPort doesn't match we are not allocated.
241
-		if portConfig.PublishedPort != portState.PublishedPort &&
242
-			portConfig.PublishedPort != 0 {
243
-			return false
234
+	// Iterate portConfigs with PublishedPort == 0 (low priority)
235
+	for _, portConfig := range s.Spec.Endpoint.Ports {
236
+		// Ignore ports which are not PublishModeIngress
237
+		if portConfig.PublishMode != api.PublishModeIngress {
238
+			continue
244 239
 		}
245
-
246
-		// If SwarmPort was not defined by user and port state
247
-		// is not initialized with a valid SwarmPort value then
248
-		// we are not allocated.
249
-		if portConfig.PublishedPort == 0 && portState.PublishedPort == 0 {
240
+		if portConfig.PublishedPort == 0 && portStates.delState(portConfig) == nil {
250 241
 			return false
251 242
 		}
252 243
 	}
... ...
@@ -133,7 +133,6 @@ type Dispatcher struct {
133 133
 }
134 134
 
135 135
 // New returns Dispatcher with cluster interface(usually raft.Node).
136
-// NOTE: each handler which does something with raft must add to Dispatcher.wg
137 136
 func New(cluster Cluster, c *Config) *Dispatcher {
138 137
 	d := &Dispatcher{
139 138
 		nodes:                 newNodeStore(c.HeartbeatPeriod, c.HeartbeatEpsilon, c.GracePeriodMultiplier, c.RateLimitPeriod),
... ...
@@ -335,23 +335,46 @@ func (m *Manager) Run(parent context.Context) error {
335 335
 	authenticatedHealthAPI := api.NewAuthenticatedWrapperHealthServer(healthServer, authorize)
336 336
 	authenticatedRaftMembershipAPI := api.NewAuthenticatedWrapperRaftMembershipServer(m.raftNode, authorize)
337 337
 
338
-	proxyDispatcherAPI := api.NewRaftProxyDispatcherServer(authenticatedDispatcherAPI, m.raftNode, ca.WithMetadataForwardTLSInfo)
339
-	proxyCAAPI := api.NewRaftProxyCAServer(authenticatedCAAPI, m.raftNode, ca.WithMetadataForwardTLSInfo)
340
-	proxyNodeCAAPI := api.NewRaftProxyNodeCAServer(authenticatedNodeCAAPI, m.raftNode, ca.WithMetadataForwardTLSInfo)
341
-	proxyRaftMembershipAPI := api.NewRaftProxyRaftMembershipServer(authenticatedRaftMembershipAPI, m.raftNode, ca.WithMetadataForwardTLSInfo)
342
-	proxyResourceAPI := api.NewRaftProxyResourceAllocatorServer(authenticatedResourceAPI, m.raftNode, ca.WithMetadataForwardTLSInfo)
343
-	proxyLogBrokerAPI := api.NewRaftProxyLogBrokerServer(authenticatedLogBrokerAPI, m.raftNode, ca.WithMetadataForwardTLSInfo)
344
-
345
-	// localProxyControlAPI is a special kind of proxy. It is only wired up
346
-	// to receive requests from a trusted local socket, and these requests
347
-	// don't use TLS, therefore the requests it handles locally should
348
-	// bypass authorization. When it proxies, it sends them as requests from
349
-	// this manager rather than forwarded requests (it has no TLS
350
-	// information to put in the metadata map).
338
+	proxyDispatcherAPI := api.NewRaftProxyDispatcherServer(authenticatedDispatcherAPI, m.raftNode, nil, ca.WithMetadataForwardTLSInfo)
339
+	proxyCAAPI := api.NewRaftProxyCAServer(authenticatedCAAPI, m.raftNode, nil, ca.WithMetadataForwardTLSInfo)
340
+	proxyNodeCAAPI := api.NewRaftProxyNodeCAServer(authenticatedNodeCAAPI, m.raftNode, nil, ca.WithMetadataForwardTLSInfo)
341
+	proxyRaftMembershipAPI := api.NewRaftProxyRaftMembershipServer(authenticatedRaftMembershipAPI, m.raftNode, nil, ca.WithMetadataForwardTLSInfo)
342
+	proxyResourceAPI := api.NewRaftProxyResourceAllocatorServer(authenticatedResourceAPI, m.raftNode, nil, ca.WithMetadataForwardTLSInfo)
343
+	proxyLogBrokerAPI := api.NewRaftProxyLogBrokerServer(authenticatedLogBrokerAPI, m.raftNode, nil, ca.WithMetadataForwardTLSInfo)
344
+
345
+	// The following local proxies are only wired up to receive requests
346
+	// from a trusted local socket, and these requests don't use TLS,
347
+	// therefore the requests they handle locally should bypass
348
+	// authorization. When requests are proxied from these servers, they
349
+	// are sent as requests from this manager rather than forwarded
350
+	// requests (it has no TLS information to put in the metadata map).
351 351
 	forwardAsOwnRequest := func(ctx context.Context) (context.Context, error) { return ctx, nil }
352
-	localProxyControlAPI := api.NewRaftProxyControlServer(baseControlAPI, m.raftNode, forwardAsOwnRequest)
353
-	localProxyLogsAPI := api.NewRaftProxyLogsServer(m.logbroker, m.raftNode, forwardAsOwnRequest)
354
-	localCAAPI := api.NewRaftProxyCAServer(m.caserver, m.raftNode, forwardAsOwnRequest)
352
+	handleRequestLocally := func(ctx context.Context) (context.Context, error) {
353
+		var remoteAddr string
354
+		if m.config.RemoteAPI.AdvertiseAddr != "" {
355
+			remoteAddr = m.config.RemoteAPI.AdvertiseAddr
356
+		} else {
357
+			remoteAddr = m.config.RemoteAPI.ListenAddr
358
+		}
359
+
360
+		creds := m.config.SecurityConfig.ClientTLSCreds
361
+
362
+		nodeInfo := ca.RemoteNodeInfo{
363
+			Roles:        []string{creds.Role()},
364
+			Organization: creds.Organization(),
365
+			NodeID:       creds.NodeID(),
366
+			RemoteAddr:   remoteAddr,
367
+		}
368
+
369
+		return context.WithValue(ctx, ca.LocalRequestKey, nodeInfo), nil
370
+	}
371
+	localProxyControlAPI := api.NewRaftProxyControlServer(baseControlAPI, m.raftNode, handleRequestLocally, forwardAsOwnRequest)
372
+	localProxyLogsAPI := api.NewRaftProxyLogsServer(m.logbroker, m.raftNode, handleRequestLocally, forwardAsOwnRequest)
373
+	localProxyDispatcherAPI := api.NewRaftProxyDispatcherServer(m.dispatcher, m.raftNode, handleRequestLocally, forwardAsOwnRequest)
374
+	localProxyCAAPI := api.NewRaftProxyCAServer(m.caserver, m.raftNode, handleRequestLocally, forwardAsOwnRequest)
375
+	localProxyNodeCAAPI := api.NewRaftProxyNodeCAServer(m.caserver, m.raftNode, handleRequestLocally, forwardAsOwnRequest)
376
+	localProxyResourceAPI := api.NewRaftProxyResourceAllocatorServer(baseResourceAPI, m.raftNode, handleRequestLocally, forwardAsOwnRequest)
377
+	localProxyLogBrokerAPI := api.NewRaftProxyLogBrokerServer(m.logbroker, m.raftNode, handleRequestLocally, forwardAsOwnRequest)
355 378
 
356 379
 	// Everything registered on m.server should be an authenticated
357 380
 	// wrapper, or a proxy wrapping an authenticated wrapper!
... ...
@@ -369,7 +392,11 @@ func (m *Manager) Run(parent context.Context) error {
369 369
 	api.RegisterControlServer(m.localserver, localProxyControlAPI)
370 370
 	api.RegisterLogsServer(m.localserver, localProxyLogsAPI)
371 371
 	api.RegisterHealthServer(m.localserver, localHealthServer)
372
-	api.RegisterCAServer(m.localserver, localCAAPI)
372
+	api.RegisterDispatcherServer(m.localserver, localProxyDispatcherAPI)
373
+	api.RegisterCAServer(m.localserver, localProxyCAAPI)
374
+	api.RegisterNodeCAServer(m.localserver, localProxyNodeCAAPI)
375
+	api.RegisterResourceAllocatorServer(m.localserver, localProxyResourceAPI)
376
+	api.RegisterLogBrokerServer(m.localserver, localProxyLogBrokerAPI)
373 377
 
374 378
 	healthServer.SetServingStatus("Raft", api.HealthCheckResponse_NOT_SERVING)
375 379
 	localHealthServer.SetServingStatus("ControlAPI", api.HealthCheckResponse_NOT_SERVING)