Browse code

Overlay driver to support network layer encryption

Signed-off-by: Alessandro Boch <aboch@docker.com>

Alessandro Boch authored on 2016/06/07 10:17:10
Showing 7 changed files
... ...
@@ -3,10 +3,12 @@ package libnetwork
3 3
 //go:generate protoc -I.:Godeps/_workspace/src/github.com/gogo/protobuf  --gogo_out=import_path=github.com/docker/libnetwork,Mgogoproto/gogo.proto=github.com/gogo/protobuf/gogoproto:. agent.proto
4 4
 
5 5
 import (
6
+	"encoding/hex"
6 7
 	"fmt"
7 8
 	"net"
8 9
 	"os"
9 10
 	"sort"
11
+	"strconv"
10 12
 
11 13
 	"github.com/Sirupsen/logrus"
12 14
 	"github.com/docker/go-events"
... ...
@@ -72,6 +74,8 @@ func resolveAddr(addrOrInterface string) (string, error) {
72 72
 }
73 73
 
74 74
 func (c *controller) handleKeyChange(keys []*types.EncryptionKey) error {
75
+	drvEnc := discoverapi.DriverEncryptionUpdate{}
76
+
75 77
 	// Find the new key and add it to the key ring
76 78
 	a := c.agent
77 79
 	for _, key := range keys {
... ...
@@ -86,6 +90,10 @@ func (c *controller) handleKeyChange(keys []*types.EncryptionKey) error {
86 86
 			if key.Subsystem == "networking:gossip" {
87 87
 				a.networkDB.SetKey(key.Key)
88 88
 			}
89
+			if key.Subsystem == "networking:gossip" /*"networking:ipsec"*/ {
90
+				drvEnc.Key = hex.EncodeToString(key.Key)
91
+				drvEnc.Tag = strconv.FormatUint(key.LamportTime, 10)
92
+			}
89 93
 			break
90 94
 		}
91 95
 	}
... ...
@@ -103,6 +111,10 @@ func (c *controller) handleKeyChange(keys []*types.EncryptionKey) error {
103 103
 			if cKey.Subsystem == "networking:gossip" {
104 104
 				deleted = cKey.Key
105 105
 			}
106
+			if cKey.Subsystem == "networking:gossip" /*"networking:ipsec"*/ {
107
+				drvEnc.Prune = hex.EncodeToString(cKey.Key)
108
+				drvEnc.PruneTag = strconv.FormatUint(cKey.LamportTime, 10)
109
+			}
106 110
 			c.keys = append(c.keys[:i], c.keys[i+1:]...)
107 111
 			break
108 112
 		}
... ...
@@ -115,9 +127,25 @@ func (c *controller) handleKeyChange(keys []*types.EncryptionKey) error {
115 115
 			break
116 116
 		}
117 117
 	}
118
+	for _, key := range c.keys {
119
+		if key.Subsystem == "networking:gossip" /*"networking:ipsec"*/ {
120
+			drvEnc.Primary = hex.EncodeToString(key.Key)
121
+			drvEnc.PrimaryTag = strconv.FormatUint(key.LamportTime, 10)
122
+			break
123
+		}
124
+	}
118 125
 	if len(deleted) > 0 {
119 126
 		a.networkDB.RemoveKey(deleted)
120 127
 	}
128
+
129
+	c.drvRegistry.WalkDrivers(func(name string, driver driverapi.Driver, capability driverapi.Capability) bool {
130
+		err := driver.DiscoverNew(discoverapi.EncryptionKeysUpdate, drvEnc)
131
+		if err != nil {
132
+			logrus.Warnf("Failed to update datapath keys in driver %s: %v", name, err)
133
+		}
134
+		return false
135
+	})
136
+
121 137
 	return nil
122 138
 }
123 139
 
... ...
@@ -170,6 +198,8 @@ func (c *controller) agentInit(bindAddrOrInterface string) error {
170 170
 		return nil
171 171
 	}
172 172
 
173
+	drvEnc := discoverapi.DriverEncryptionConfig{}
174
+
173 175
 	// sort the keys by lamport time
174 176
 	sort.Sort(ByTime(c.keys))
175 177
 
... ...
@@ -178,6 +208,10 @@ func (c *controller) agentInit(bindAddrOrInterface string) error {
178 178
 		if key.Subsystem == "networking:gossip" {
179 179
 			gossipkey = append(gossipkey, key.Key)
180 180
 		}
181
+		if key.Subsystem == "networking:gossip" /*"networking:ipsec"*/ {
182
+			drvEnc.Keys = append(drvEnc.Keys, hex.EncodeToString(key.Key))
183
+			drvEnc.Tags = append(drvEnc.Tags, strconv.FormatUint(key.LamportTime, 10))
184
+		}
181 185
 	}
182 186
 
183 187
 	bindAddr, err := resolveAddr(bindAddrOrInterface)
... ...
@@ -206,6 +240,15 @@ func (c *controller) agentInit(bindAddrOrInterface string) error {
206 206
 	}
207 207
 
208 208
 	go c.handleTableEvents(ch, c.handleEpTableEvent)
209
+
210
+	c.drvRegistry.WalkDrivers(func(name string, driver driverapi.Driver, capability driverapi.Capability) bool {
211
+		err := driver.DiscoverNew(discoverapi.EncryptionKeysConfig, drvEnc)
212
+		if err != nil {
213
+			logrus.Warnf("Failed to set datapath keys in driver %s: %v", name, err)
214
+		}
215
+		return false
216
+	})
217
+
209 218
 	return nil
210 219
 }
211 220
 
... ...
@@ -226,6 +269,22 @@ func (c *controller) agentDriverNotify(d driverapi.Driver) {
226 226
 		Address: c.agent.bindAddr,
227 227
 		Self:    true,
228 228
 	})
229
+
230
+	drvEnc := discoverapi.DriverEncryptionConfig{}
231
+	for _, key := range c.keys {
232
+		if key.Subsystem == "networking:gossip" /*"networking:ipsec"*/ {
233
+			drvEnc.Keys = append(drvEnc.Keys, hex.EncodeToString(key.Key))
234
+			drvEnc.Tags = append(drvEnc.Tags, strconv.FormatUint(key.LamportTime, 10))
235
+		}
236
+	}
237
+	c.drvRegistry.WalkDrivers(func(name string, driver driverapi.Driver, capability driverapi.Capability) bool {
238
+		err := driver.DiscoverNew(discoverapi.EncryptionKeysConfig, drvEnc)
239
+		if err != nil {
240
+			logrus.Warnf("Failed to set datapath keys in driver %s: %v", name, err)
241
+		}
242
+		return false
243
+	})
244
+
229 245
 }
230 246
 
231 247
 func (c *controller) agentClose() {
... ...
@@ -18,6 +18,10 @@ const (
18 18
 	NodeDiscovery = iota + 1
19 19
 	// DatastoreConfig represents an add/remove datastore event
20 20
 	DatastoreConfig
21
+	// EncryptionKeysConfig represents the initial key(s) for performing datapath encryption
22
+	EncryptionKeysConfig
23
+	// EncryptionKeysUpdate represents an update to the datapath encryption key(s)
24
+	EncryptionKeysUpdate
21 25
 )
22 26
 
23 27
 // NodeDiscoveryData represents the structure backing the node discovery data json string
... ...
@@ -33,3 +37,23 @@ type DatastoreConfigData struct {
33 33
 	Address  string
34 34
 	Config   interface{}
35 35
 }
36
+
37
+// DriverEncryptionConfig contains the initial datapath encryption key(s)
38
+// Key in first position is the primary key, the one to be used in tx.
39
+// Original key and tag types are []byte and uint64
40
+type DriverEncryptionConfig struct {
41
+	Keys []string
42
+	Tags []string
43
+}
44
+
45
+// DriverEncryptionUpdate carries an update to the encryption key(s) as:
46
+// a new key and/or set a primary key and/or a removal of an existing key.
47
+// Original key and tag types are []byte and uint64
48
+type DriverEncryptionUpdate struct {
49
+	Key        string
50
+	Tag        string
51
+	Primary    string
52
+	PrimaryTag string
53
+	Prune      string
54
+	PruneTag   string
55
+}
36 56
new file mode 100644
... ...
@@ -0,0 +1,578 @@
0
+package overlay
1
+
2
+import (
3
+	"bytes"
4
+	"encoding/hex"
5
+	"fmt"
6
+	"net"
7
+	"sync"
8
+	"syscall"
9
+
10
+	log "github.com/Sirupsen/logrus"
11
+	"github.com/docker/libnetwork/iptables"
12
+	"github.com/docker/libnetwork/types"
13
+	"github.com/vishvananda/netlink"
14
+	"strconv"
15
+)
16
+
17
+const (
18
+	mark    = uint32(0xD0C4E3)
19
+	timeout = 30
20
+)
21
+
22
+const (
23
+	forward = iota + 1
24
+	reverse
25
+	bidir
26
+)
27
+
28
+type key struct {
29
+	value []byte
30
+	tag   uint32
31
+}
32
+
33
+func (k *key) String() string {
34
+	return fmt.Sprintf("(key: %s, tag: 0x%x)", hex.EncodeToString(k.value)[0:5], k.tag)
35
+}
36
+
37
+type spi struct {
38
+	forward int
39
+	reverse int
40
+}
41
+
42
+func (s *spi) String() string {
43
+	return fmt.Sprintf("SPI(FWD: 0x%x, REV: 0x%x)", uint32(s.forward), uint32(s.reverse))
44
+}
45
+
46
+type encrMap struct {
47
+	nodes map[string][]*spi
48
+	sync.Mutex
49
+}
50
+
51
+func (e *encrMap) String() string {
52
+	e.Lock()
53
+	defer e.Unlock()
54
+	b := new(bytes.Buffer)
55
+	for k, v := range e.nodes {
56
+		b.WriteString("\n")
57
+		b.WriteString(k)
58
+		b.WriteString(":")
59
+		b.WriteString("[")
60
+		for _, s := range v {
61
+			b.WriteString(s.String())
62
+			b.WriteString(",")
63
+		}
64
+		b.WriteString("]")
65
+
66
+	}
67
+	return b.String()
68
+}
69
+
70
+func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal, add bool) error {
71
+	log.Infof("checkEncryption(%s, %v, %d, %t)", nid[0:7], rIP, vxlanID, isLocal)
72
+
73
+	n := d.network(nid)
74
+	if n == nil || !n.secure {
75
+		return nil
76
+	}
77
+
78
+	if len(d.keys) == 0 {
79
+		return types.ForbiddenErrorf("encryption key is not present")
80
+	}
81
+
82
+	lIP := types.GetMinimalIP(net.ParseIP(d.bindAddress))
83
+	nodes := map[string]net.IP{}
84
+
85
+	switch {
86
+	case isLocal:
87
+		if err := d.peerDbNetworkWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool {
88
+			if !lIP.Equal(pEntry.vtep) {
89
+				nodes[pEntry.vtep.String()] = types.GetMinimalIP(pEntry.vtep)
90
+			}
91
+			return false
92
+		}); err != nil {
93
+			log.Warnf("Failed to retrieve list of participating nodes in overlay network %s: %v", nid[0:5], err)
94
+		}
95
+	default:
96
+		if len(d.network(nid).endpoints) > 0 {
97
+			nodes[rIP.String()] = types.GetMinimalIP(rIP)
98
+		}
99
+	}
100
+
101
+	log.Debugf("List of nodes: %s", nodes)
102
+
103
+	if add {
104
+		for _, rIP := range nodes {
105
+			if err := setupEncryption(lIP, rIP, vxlanID, d.secMap, d.keys); err != nil {
106
+				log.Warnf("Failed to program network encryption between %s and %s: %v", lIP, rIP, err)
107
+			}
108
+		}
109
+	} else {
110
+		if len(nodes) == 0 {
111
+			if err := removeEncryption(lIP, rIP, d.secMap); err != nil {
112
+				log.Warnf("Failed to remove network encryption between %s and %s: %v", lIP, rIP, err)
113
+			}
114
+		}
115
+	}
116
+
117
+	return nil
118
+}
119
+
120
+func setupEncryption(localIP, remoteIP net.IP, vni uint32, em *encrMap, keys []*key) error {
121
+	log.Infof("Programming encryption for vxlan %d between %s and %s", vni, localIP, remoteIP)
122
+	rIPs := remoteIP.String()
123
+
124
+	indices := make([]*spi, 0, len(keys))
125
+
126
+	err := programMangle(vni, true)
127
+	if err != nil {
128
+		log.Warn(err)
129
+	}
130
+
131
+	for i, k := range keys {
132
+		spis := &spi{buildSPI(localIP, remoteIP, k.tag), buildSPI(remoteIP, localIP, k.tag)}
133
+		dir := reverse
134
+		if i == 0 {
135
+			dir = bidir
136
+		}
137
+		fSA, rSA, err := programSA(localIP, remoteIP, spis, k, dir, true)
138
+		if err != nil {
139
+			log.Warn(err)
140
+		}
141
+		indices = append(indices, spis)
142
+		if i != 0 {
143
+			continue
144
+		}
145
+		err = programSP(fSA, rSA, true)
146
+		if err != nil {
147
+			log.Warn(err)
148
+		}
149
+	}
150
+
151
+	em.Lock()
152
+	em.nodes[rIPs] = indices
153
+	em.Unlock()
154
+
155
+	return nil
156
+}
157
+
158
+func removeEncryption(localIP, remoteIP net.IP, em *encrMap) error {
159
+	em.Lock()
160
+	indices, ok := em.nodes[remoteIP.String()]
161
+	em.Unlock()
162
+	if !ok {
163
+		return nil
164
+	}
165
+	for i, idxs := range indices {
166
+		dir := reverse
167
+		if i == 0 {
168
+			dir = bidir
169
+		}
170
+		fSA, rSA, err := programSA(localIP, remoteIP, idxs, nil, dir, false)
171
+		if err != nil {
172
+			log.Warn(err)
173
+		}
174
+		if i != 0 {
175
+			continue
176
+		}
177
+		err = programSP(fSA, rSA, false)
178
+		if err != nil {
179
+			log.Warn(err)
180
+		}
181
+	}
182
+	return nil
183
+}
184
+
185
+func programMangle(vni uint32, add bool) (err error) {
186
+	var (
187
+		p      = strconv.FormatUint(uint64(vxlanPort), 10)
188
+		c      = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8)
189
+		m      = strconv.FormatUint(uint64(mark), 10)
190
+		chain  = "OUTPUT"
191
+		rule   = []string{"-p", "udp", "--dport", p, "-m", "u32", "--u32", c, "-j", "MARK", "--set-mark", m}
192
+		a      = "-A"
193
+		action = "install"
194
+	)
195
+
196
+	if add == iptables.Exists(iptables.Mangle, chain, rule...) {
197
+		return
198
+	}
199
+
200
+	if !add {
201
+		a = "-D"
202
+		action = "remove"
203
+	}
204
+
205
+	if err = iptables.RawCombinedOutput(append([]string{"-t", string(iptables.Mangle), a, chain}, rule...)...); err != nil {
206
+		log.Warnf("could not %s mangle rule: %v", action, err)
207
+	}
208
+
209
+	return
210
+}
211
+
212
+func programSA(localIP, remoteIP net.IP, spi *spi, k *key, dir int, add bool) (fSA *netlink.XfrmState, rSA *netlink.XfrmState, err error) {
213
+	var (
214
+		crypt       *netlink.XfrmStateAlgo
215
+		action      = "Removing"
216
+		xfrmProgram = netlink.XfrmStateDel
217
+	)
218
+
219
+	if add {
220
+		action = "Adding"
221
+		xfrmProgram = netlink.XfrmStateAdd
222
+		crypt = &netlink.XfrmStateAlgo{Name: "cbc(aes)", Key: k.value}
223
+	}
224
+
225
+	if dir&reverse > 0 {
226
+		rSA = &netlink.XfrmState{
227
+			Src:   remoteIP,
228
+			Dst:   localIP,
229
+			Proto: netlink.XFRM_PROTO_ESP,
230
+			Spi:   spi.reverse,
231
+			Mode:  netlink.XFRM_MODE_TRANSPORT,
232
+		}
233
+		if add {
234
+			rSA.Crypt = crypt
235
+		}
236
+
237
+		exists, err := saExists(rSA)
238
+		if err != nil {
239
+			exists = !add
240
+		}
241
+
242
+		if add != exists {
243
+			log.Infof("%s: rSA{%s}", action, rSA)
244
+			if err := xfrmProgram(rSA); err != nil {
245
+				log.Warnf("Failed %s rSA{%s}: %v", action, rSA, err)
246
+			}
247
+		}
248
+	}
249
+
250
+	if dir&forward > 0 {
251
+		fSA = &netlink.XfrmState{
252
+			Src:   localIP,
253
+			Dst:   remoteIP,
254
+			Proto: netlink.XFRM_PROTO_ESP,
255
+			Spi:   spi.forward,
256
+			Mode:  netlink.XFRM_MODE_TRANSPORT,
257
+		}
258
+		if add {
259
+			fSA.Crypt = crypt
260
+		}
261
+
262
+		exists, err := saExists(fSA)
263
+		if err != nil {
264
+			exists = !add
265
+		}
266
+
267
+		if add != exists {
268
+			log.Infof("%s fSA{%s}", action, fSA)
269
+			if err := xfrmProgram(fSA); err != nil {
270
+				log.Warnf("Failed %s fSA{%s}: %v.", action, fSA, err)
271
+			}
272
+		}
273
+	}
274
+
275
+	return
276
+}
277
+
278
+func programSP(fSA *netlink.XfrmState, rSA *netlink.XfrmState, add bool) error {
279
+	action := "Removing"
280
+	xfrmProgram := netlink.XfrmPolicyDel
281
+	if add {
282
+		action = "Adding"
283
+		xfrmProgram = netlink.XfrmPolicyAdd
284
+	}
285
+
286
+	fullMask := net.CIDRMask(8*len(fSA.Src), 8*len(fSA.Src))
287
+
288
+	fPol := &netlink.XfrmPolicy{
289
+		Src:     &net.IPNet{IP: fSA.Src, Mask: fullMask},
290
+		Dst:     &net.IPNet{IP: fSA.Dst, Mask: fullMask},
291
+		Dir:     netlink.XFRM_DIR_OUT,
292
+		Proto:   17,
293
+		DstPort: 4789,
294
+		Mark: &netlink.XfrmMark{
295
+			Value: mark,
296
+		},
297
+		Tmpls: []netlink.XfrmPolicyTmpl{
298
+			{
299
+				Src:   fSA.Src,
300
+				Dst:   fSA.Dst,
301
+				Proto: netlink.XFRM_PROTO_ESP,
302
+				Mode:  netlink.XFRM_MODE_TRANSPORT,
303
+				Spi:   fSA.Spi,
304
+			},
305
+		},
306
+	}
307
+
308
+	exists, err := spExists(fPol)
309
+	if err != nil {
310
+		exists = !add
311
+	}
312
+
313
+	if add != exists {
314
+		log.Infof("%s fSP{%s}", action, fPol)
315
+		if err := xfrmProgram(fPol); err != nil {
316
+			log.Warnf("%s fSP{%s}: %v", action, fPol, err)
317
+		}
318
+	}
319
+
320
+	return nil
321
+}
322
+
323
+func saExists(sa *netlink.XfrmState) (bool, error) {
324
+	_, err := netlink.XfrmStateGet(sa)
325
+	switch err {
326
+	case nil:
327
+		return true, nil
328
+	case syscall.ESRCH:
329
+		return false, nil
330
+	default:
331
+		err = fmt.Errorf("Error while checking for SA existence: %v", err)
332
+		log.Debug(err)
333
+		return false, err
334
+	}
335
+}
336
+
337
+func spExists(sp *netlink.XfrmPolicy) (bool, error) {
338
+	_, err := netlink.XfrmPolicyGet(sp)
339
+	switch err {
340
+	case nil:
341
+		return true, nil
342
+	case syscall.ENOENT:
343
+		return false, nil
344
+	default:
345
+		err = fmt.Errorf("Error while checking for SP existence: %v", err)
346
+		log.Debug(err)
347
+		return false, err
348
+	}
349
+}
350
+
351
+func buildSPI(src, dst net.IP, st uint32) int {
352
+	spi := int(st)
353
+	f := src[len(src)-4:]
354
+	t := dst[len(dst)-4:]
355
+	for i := 0; i < 4; i++ {
356
+		spi = spi ^ (int(f[i])^int(t[3-i]))<<uint32(8*i)
357
+	}
358
+	return spi
359
+}
360
+
361
+func (d *driver) secMapWalk(f func(string, []*spi) ([]*spi, bool)) error {
362
+	d.secMap.Lock()
363
+	for node, indices := range d.secMap.nodes {
364
+		idxs, stop := f(node, indices)
365
+		if idxs != nil {
366
+			d.secMap.nodes[node] = idxs
367
+		}
368
+		if stop {
369
+			break
370
+		}
371
+	}
372
+	d.secMap.Unlock()
373
+	return nil
374
+}
375
+
376
+func (d *driver) setKeys(keys []*key) error {
377
+	if d.keys != nil {
378
+		return types.ForbiddenErrorf("initial keys are already present")
379
+	}
380
+	d.keys = keys
381
+	log.Infof("Initial encryption keys: %v", d.keys)
382
+	return nil
383
+}
384
+
385
+// updateKeys allows to add a new key and/or change the primary key and/or prune an existing key
386
+// The primary key is the key used in transmission and will go in first position in the list.
387
+func (d *driver) updateKeys(newKey, primary, pruneKey *key) error {
388
+	log.Infof("Updating Keys. New: %v, Primary: %v, Pruned: %v", newKey, primary, pruneKey)
389
+
390
+	log.Infof("Current: %v", d.keys)
391
+
392
+	var (
393
+		newIdx = -1
394
+		priIdx = -1
395
+		delIdx = -1
396
+		lIP    = types.GetMinimalIP(net.ParseIP(d.bindAddress))
397
+	)
398
+
399
+	d.Lock()
400
+	// add new
401
+	if newKey != nil {
402
+		d.keys = append(d.keys, newKey)
403
+		newIdx += len(d.keys)
404
+	}
405
+	for i, k := range d.keys {
406
+		if primary != nil && k.tag == primary.tag {
407
+			priIdx = i
408
+		}
409
+		if pruneKey != nil && k.tag == pruneKey.tag {
410
+			delIdx = i
411
+		}
412
+	}
413
+	d.Unlock()
414
+
415
+	if (newKey != nil && newIdx == -1) ||
416
+		(primary != nil && priIdx == -1) ||
417
+		(pruneKey != nil && delIdx == -1) {
418
+		err := types.BadRequestErrorf("cannot find proper key indices while processing key update:"+
419
+			"(newIdx,priIdx,delIdx):(%d, %d, %d)", newIdx, priIdx, delIdx)
420
+		log.Warn(err)
421
+		return err
422
+	}
423
+
424
+	d.secMapWalk(func(rIPs string, spis []*spi) ([]*spi, bool) {
425
+		rIP := types.GetMinimalIP(net.ParseIP(rIPs))
426
+		return updateNodeKey(lIP, rIP, spis, d.keys, newIdx, priIdx, delIdx), false
427
+	})
428
+
429
+	d.Lock()
430
+	// swap primary
431
+	if priIdx != -1 {
432
+		swp := d.keys[0]
433
+		d.keys[0] = d.keys[priIdx]
434
+		d.keys[priIdx] = swp
435
+	}
436
+	// prune
437
+	if delIdx != -1 {
438
+		if delIdx == 0 {
439
+			delIdx = priIdx
440
+		}
441
+		d.keys = append(d.keys[:delIdx], d.keys[delIdx+1:]...)
442
+	}
443
+	d.Unlock()
444
+
445
+	log.Infof("Updated: %v", d.keys)
446
+
447
+	return nil
448
+}
449
+
450
+/********************************************************
451
+ * Steady state: rSA0, rSA1, fSA0, fSP0
452
+ * Rotation --> %rSA0, +rSA2, +fSA1, +fSP1/-fSP0, -fSA0,
453
+ * Half state:   rSA0, rSA1, rSA2, fSA1, fSP1
454
+ * Steady state: rSA1, rSA2, fSA1, fSP1
455
+ *********************************************************/
456
+
457
+// Spis and keys are sorted in such away the one in position 0 is the primary
458
+func updateNodeKey(lIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx, delIdx int) []*spi {
459
+	log.Infof("Updating keys for node: %s (%d,%d,%d)", rIP, newIdx, priIdx, delIdx)
460
+
461
+	spis := idxs
462
+	log.Infof("Current: %v", spis)
463
+
464
+	// add new
465
+	if newIdx != -1 {
466
+		spis = append(spis, &spi{
467
+			forward: buildSPI(lIP, rIP, curKeys[newIdx].tag),
468
+			reverse: buildSPI(rIP, lIP, curKeys[newIdx].tag),
469
+		})
470
+	}
471
+
472
+	if delIdx != -1 {
473
+		// %rSA0
474
+		rSA0 := &netlink.XfrmState{
475
+			Src:    rIP,
476
+			Dst:    lIP,
477
+			Proto:  netlink.XFRM_PROTO_ESP,
478
+			Spi:    spis[delIdx].reverse,
479
+			Mode:   netlink.XFRM_MODE_TRANSPORT,
480
+			Crypt:  &netlink.XfrmStateAlgo{Name: "cbc(aes)", Key: curKeys[delIdx].value},
481
+			Limits: netlink.XfrmStateLimits{TimeSoft: timeout},
482
+		}
483
+		log.Infof("Updating rSA0{%s}", rSA0)
484
+		if err := netlink.XfrmStateUpdate(rSA0); err != nil {
485
+			log.Warnf("Failed to update rSA0{%s}: %v", rSA0, err)
486
+		}
487
+	}
488
+
489
+	if newIdx > -1 {
490
+		// +RSA2
491
+		programSA(lIP, rIP, spis[newIdx], curKeys[newIdx], reverse, true)
492
+	}
493
+
494
+	if priIdx > 0 {
495
+		// +fSA1
496
+		fSA1, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true)
497
+
498
+		// +fSP1, -fSP0
499
+		fullMask := net.CIDRMask(8*len(fSA1.Src), 8*len(fSA1.Src))
500
+		fSP1 := &netlink.XfrmPolicy{
501
+			Src:     &net.IPNet{IP: fSA1.Src, Mask: fullMask},
502
+			Dst:     &net.IPNet{IP: fSA1.Dst, Mask: fullMask},
503
+			Dir:     netlink.XFRM_DIR_OUT,
504
+			Proto:   17,
505
+			DstPort: 4789,
506
+			Mark: &netlink.XfrmMark{
507
+				Value: mark,
508
+			},
509
+			Tmpls: []netlink.XfrmPolicyTmpl{
510
+				{
511
+					Src:   fSA1.Src,
512
+					Dst:   fSA1.Dst,
513
+					Proto: netlink.XFRM_PROTO_ESP,
514
+					Mode:  netlink.XFRM_MODE_TRANSPORT,
515
+					Spi:   fSA1.Spi,
516
+				},
517
+			},
518
+		}
519
+		log.Infof("Updating fSP{%s}", fSP1)
520
+		if err := netlink.XfrmPolicyUpdate(fSP1); err != nil {
521
+			log.Warnf("Failed to update fSP{%s}: %v", fSP1, err)
522
+		}
523
+
524
+		// -fSA0
525
+		fSA0 := &netlink.XfrmState{
526
+			Src:    lIP,
527
+			Dst:    rIP,
528
+			Proto:  netlink.XFRM_PROTO_ESP,
529
+			Spi:    spis[0].forward,
530
+			Mode:   netlink.XFRM_MODE_TRANSPORT,
531
+			Crypt:  &netlink.XfrmStateAlgo{Name: "cbc(aes)", Key: curKeys[0].value},
532
+			Limits: netlink.XfrmStateLimits{TimeHard: timeout},
533
+		}
534
+		log.Infof("Removing fSA0{%s}", fSA0)
535
+		if err := netlink.XfrmStateUpdate(fSA0); err != nil {
536
+			log.Warnf("Failed to remove fSA0{%s}: %v", fSA0, err)
537
+		}
538
+	}
539
+
540
+	// swap
541
+	if priIdx > 0 {
542
+		swp := spis[0]
543
+		spis[0] = spis[priIdx]
544
+		spis[priIdx] = swp
545
+	}
546
+	// prune
547
+	if delIdx != -1 {
548
+		if delIdx == 0 {
549
+			delIdx = priIdx
550
+		}
551
+		spis = append(spis[:delIdx], spis[delIdx+1:]...)
552
+	}
553
+
554
+	log.Infof("Updated: %v", spis)
555
+
556
+	return spis
557
+}
558
+
559
+func parseEncryptionKey(value, tag string) (*key, error) {
560
+	var (
561
+		k   *key
562
+		err error
563
+	)
564
+	if value == "" {
565
+		return nil, nil
566
+	}
567
+	k = &key{}
568
+	if k.value, err = hex.DecodeString(value); err != nil {
569
+		return nil, types.BadRequestErrorf("failed to decode key (%s): %v", value, err)
570
+	}
571
+	t, err := strconv.ParseUint(tag, 10, 64)
572
+	if err != nil {
573
+		return nil, types.BadRequestErrorf("failed to decode tag (%s): %v", tag, err)
574
+	}
575
+	k.tag = uint32(t)
576
+	return k, nil
577
+}
... ...
@@ -27,6 +27,10 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo,
27 27
 		return fmt.Errorf("could not find endpoint with id %s", eid)
28 28
 	}
29 29
 
30
+	if n.secure && len(d.keys) == 0 {
31
+		return fmt.Errorf("cannot join secure network: encryption keys not present")
32
+	}
33
+
30 34
 	s := n.getSubnetforIP(ep.addr)
31 35
 	if s == nil {
32 36
 		return fmt.Errorf("could not find subnet for endpoint %s", eid)
... ...
@@ -106,6 +110,10 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo,
106 106
 	d.peerDbAdd(nid, eid, ep.addr.IP, ep.addr.Mask, ep.mac,
107 107
 		net.ParseIP(d.bindAddress), true)
108 108
 
109
+	if err := d.checkEncryption(nid, nil, n.vxlanID(s), true, true); err != nil {
110
+		log.Warn(err)
111
+	}
112
+
109 113
 	buf, err := proto.Marshal(&PeerRecord{
110 114
 		EndpointIP:       ep.addr.String(),
111 115
 		EndpointMAC:      ep.mac.String(),
... ...
@@ -197,5 +205,9 @@ func (d *driver) Leave(nid, eid string) error {
197 197
 
198 198
 	n.leaveSandbox()
199 199
 
200
+	if err := d.checkEncryption(nid, nil, 0, true, false); err != nil {
201
+		log.Warn(err)
202
+	}
203
+
200 204
 	return nil
201 205
 }
... ...
@@ -61,6 +61,7 @@ type network struct {
61 61
 	initEpoch int
62 62
 	initErr   error
63 63
 	subnets   []*subnet
64
+	secure    bool
64 65
 	sync.Mutex
65 66
 }
66 67
 
... ...
@@ -109,6 +110,9 @@ func (d *driver) CreateNetwork(id string, option map[string]interface{}, nInfo d
109 109
 				vnis = append(vnis, uint32(vni))
110 110
 			}
111 111
 		}
112
+		if _, ok := optMap["secure"]; ok {
113
+			n.secure = true
114
+		}
112 115
 	}
113 116
 
114 117
 	// If we are getting vnis from libnetwork, either we get for
... ...
@@ -162,7 +166,18 @@ func (d *driver) DeleteNetwork(nid string) error {
162 162
 
163 163
 	d.deleteNetwork(nid)
164 164
 
165
-	return n.releaseVxlanID()
165
+	vnis, err := n.releaseVxlanID()
166
+	if err != nil {
167
+		return err
168
+	}
169
+
170
+	if n.secure {
171
+		for _, vni := range vnis {
172
+			programMangle(vni, false)
173
+		}
174
+	}
175
+
176
+	return nil
166 177
 }
167 178
 
168 179
 func (d *driver) ProgramExternalConnectivity(nid, eid string, options map[string]interface{}) error {
... ...
@@ -618,6 +633,8 @@ func (n *network) KeyPrefix() []string {
618 618
 }
619 619
 
620 620
 func (n *network) Value() []byte {
621
+	m := map[string]interface{}{}
622
+
621 623
 	netJSON := []*subnetJSON{}
622 624
 
623 625
 	for _, s := range n.subnets {
... ...
@@ -630,10 +647,17 @@ func (n *network) Value() []byte {
630 630
 	}
631 631
 
632 632
 	b, err := json.Marshal(netJSON)
633
+	if err != nil {
634
+		return []byte{}
635
+	}
633 636
 
637
+	m["secure"] = n.secure
638
+	m["subnets"] = netJSON
639
+	b, err = json.Marshal(m)
634 640
 	if err != nil {
635 641
 		return []byte{}
636 642
 	}
643
+
637 644
 	return b
638 645
 }
639 646
 
... ...
@@ -655,18 +679,38 @@ func (n *network) Skip() bool {
655 655
 }
656 656
 
657 657
 func (n *network) SetValue(value []byte) error {
658
-	var newNet bool
659
-	netJSON := []*subnetJSON{}
660
-
661
-	err := json.Unmarshal(value, &netJSON)
662
-	if err != nil {
663
-		return err
658
+	var (
659
+		m       map[string]interface{}
660
+		newNet  bool
661
+		isMap   = true
662
+		netJSON = []*subnetJSON{}
663
+	)
664
+
665
+	if err := json.Unmarshal(value, &m); err != nil {
666
+		err := json.Unmarshal(value, &netJSON)
667
+		if err != nil {
668
+			return err
669
+		}
670
+		isMap = false
664 671
 	}
665 672
 
666 673
 	if len(n.subnets) == 0 {
667 674
 		newNet = true
668 675
 	}
669 676
 
677
+	if isMap {
678
+		if val, ok := m["secure"]; ok {
679
+			n.secure = val.(bool)
680
+		}
681
+		bytes, err := json.Marshal(m["subnets"])
682
+		if err != nil {
683
+			return err
684
+		}
685
+		if err := json.Unmarshal(bytes, &netJSON); err != nil {
686
+			return err
687
+		}
688
+	}
689
+
670 690
 	for _, sj := range netJSON {
671 691
 		subnetIPstr := sj.SubnetIP
672 692
 		gwIPstr := sj.GwIP
... ...
@@ -705,9 +749,9 @@ func (n *network) writeToStore() error {
705 705
 	return n.driver.store.PutObjectAtomic(n)
706 706
 }
707 707
 
708
-func (n *network) releaseVxlanID() error {
708
+func (n *network) releaseVxlanID() ([]uint32, error) {
709 709
 	if len(n.subnets) == 0 {
710
-		return nil
710
+		return nil, nil
711 711
 	}
712 712
 
713 713
 	if n.driver.store != nil {
... ...
@@ -715,22 +759,24 @@ func (n *network) releaseVxlanID() error {
715 715
 			if err == datastore.ErrKeyModified || err == datastore.ErrKeyNotFound {
716 716
 				// In both the above cases we can safely assume that the key has been removed by some other
717 717
 				// instance and so simply get out of here
718
-				return nil
718
+				return nil, nil
719 719
 			}
720 720
 
721
-			return fmt.Errorf("failed to delete network to vxlan id map: %v", err)
721
+			return nil, fmt.Errorf("failed to delete network to vxlan id map: %v", err)
722 722
 		}
723 723
 	}
724
-
724
+	var vnis []uint32
725 725
 	for _, s := range n.subnets {
726 726
 		if n.driver.vxlanIdm != nil {
727
-			n.driver.vxlanIdm.Release(uint64(n.vxlanID(s)))
727
+			vni := n.vxlanID(s)
728
+			vnis = append(vnis, vni)
729
+			n.driver.vxlanIdm.Release(uint64(vni))
728 730
 		}
729 731
 
730 732
 		n.setVxlanID(s, 0)
731 733
 	}
732 734
 
733
-	return nil
735
+	return vnis, nil
734 736
 }
735 737
 
736 738
 func (n *network) obtainVxlanID(s *subnet) error {
... ...
@@ -37,12 +37,14 @@ type driver struct {
37 37
 	neighIP      string
38 38
 	config       map[string]interface{}
39 39
 	peerDb       peerNetworkMap
40
+	secMap       *encrMap
40 41
 	serfInstance *serf.Serf
41 42
 	networks     networkTable
42 43
 	store        datastore.DataStore
43 44
 	vxlanIdm     *idm.Idm
44 45
 	once         sync.Once
45 46
 	joinOnce     sync.Once
47
+	keys         []*key
46 48
 	sync.Mutex
47 49
 }
48 50
 
... ...
@@ -51,12 +53,12 @@ func Init(dc driverapi.DriverCallback, config map[string]interface{}) error {
51 51
 	c := driverapi.Capability{
52 52
 		DataScope: datastore.GlobalScope,
53 53
 	}
54
-
55 54
 	d := &driver{
56 55
 		networks: networkTable{},
57 56
 		peerDb: peerNetworkMap{
58 57
 			mp: map[string]*peerMap{},
59 58
 		},
59
+		secMap: &encrMap{nodes: map[string][]*spi{}},
60 60
 		config: config,
61 61
 	}
62 62
 
... ...
@@ -209,6 +211,7 @@ func (d *driver) pushLocalEndpointEvent(action, nid, eid string) {
209 209
 
210 210
 // DiscoverNew is a notification for a new discovery event, such as a new node joining a cluster
211 211
 func (d *driver) DiscoverNew(dType discoverapi.DiscoveryType, data interface{}) error {
212
+	var err error
212 213
 	switch dType {
213 214
 	case discoverapi.NodeDiscovery:
214 215
 		nodeData, ok := data.(discoverapi.NodeDiscoveryData)
... ...
@@ -217,7 +220,6 @@ func (d *driver) DiscoverNew(dType discoverapi.DiscoveryType, data interface{})
217 217
 		}
218 218
 		d.nodeJoin(nodeData.Address, nodeData.Self)
219 219
 	case discoverapi.DatastoreConfig:
220
-		var err error
221 220
 		if d.store != nil {
222 221
 			return types.ForbiddenErrorf("cannot accept datastore configuration: Overlay driver has a datastore configured already")
223 222
 		}
... ...
@@ -229,6 +231,39 @@ func (d *driver) DiscoverNew(dType discoverapi.DiscoveryType, data interface{})
229 229
 		if err != nil {
230 230
 			return types.InternalErrorf("failed to initialize data store: %v", err)
231 231
 		}
232
+	case discoverapi.EncryptionKeysConfig:
233
+		encrData, ok := data.(discoverapi.DriverEncryptionConfig)
234
+		if !ok {
235
+			return fmt.Errorf("invalid encryption key notification data")
236
+		}
237
+		keys := make([]*key, 0, len(encrData.Keys))
238
+		for i := 0; i < len(encrData.Keys); i++ {
239
+			k, err := parseEncryptionKey(encrData.Keys[i], encrData.Tags[i])
240
+			if err != nil {
241
+				return err
242
+			}
243
+			keys = append(keys, k)
244
+		}
245
+		d.setKeys(keys)
246
+	case discoverapi.EncryptionKeysUpdate:
247
+		var newKey, delKey, priKey *key
248
+		encrData, ok := data.(discoverapi.DriverEncryptionUpdate)
249
+		if !ok {
250
+			return fmt.Errorf("invalid encryption key notification data")
251
+		}
252
+		newKey, err = parseEncryptionKey(encrData.Key, encrData.Tag)
253
+		if err != nil {
254
+			return err
255
+		}
256
+		priKey, err = parseEncryptionKey(encrData.Primary, encrData.PrimaryTag)
257
+		if err != nil {
258
+			return err
259
+		}
260
+		delKey, err = parseEncryptionKey(encrData.Prune, encrData.PruneTag)
261
+		if err != nil {
262
+			return err
263
+		}
264
+		d.updateKeys(newKey, priKey, delKey)
232 265
 	default:
233 266
 	}
234 267
 	return nil
... ...
@@ -5,6 +5,8 @@ import (
5 5
 	"net"
6 6
 	"sync"
7 7
 	"syscall"
8
+
9
+	log "github.com/Sirupsen/logrus"
8 10
 )
9 11
 
10 12
 const ovPeerTable = "overlay_peer_table"
... ...
@@ -88,7 +90,7 @@ func (d *driver) peerDbNetworkWalk(nid string, f func(*peerKey, *peerEntry) bool
88 88
 	for pKeyStr, pEntry := range pMap.mp {
89 89
 		var pKey peerKey
90 90
 		if _, err := fmt.Sscan(pKeyStr, &pKey); err != nil {
91
-			fmt.Printf("peer key scan failed: %v", err)
91
+			log.Warnf("Peer key scan on network %s failed: %v", nid, err)
92 92
 		}
93 93
 
94 94
 		if f(&pKey, &pEntry) {
... ...
@@ -273,6 +275,10 @@ func (d *driver) peerAdd(nid, eid string, peerIP net.IP, peerIPMask net.IPMask,
273 273
 		return fmt.Errorf("subnet sandbox join failed for %q: %v", s.subnetIP.String(), err)
274 274
 	}
275 275
 
276
+	if err := d.checkEncryption(nid, vtep, n.vxlanID(s), false, true); err != nil {
277
+		log.Warn(err)
278
+	}
279
+
276 280
 	// Add neighbor entry for the peer IP
277 281
 	if err := sbox.AddNeighbor(peerIP, peerMac, sbox.NeighborOptions().LinkName(s.vxlanName)); err != nil {
278 282
 		return fmt.Errorf("could not add neigbor entry into the sandbox: %v", err)
... ...
@@ -318,6 +324,10 @@ func (d *driver) peerDelete(nid, eid string, peerIP net.IP, peerIPMask net.IPMas
318 318
 		return fmt.Errorf("could not delete neigbor entry into the sandbox: %v", err)
319 319
 	}
320 320
 
321
+	if err := d.checkEncryption(nid, vtep, 0, false, false); err != nil {
322
+		log.Warn(err)
323
+	}
324
+
321 325
 	return nil
322 326
 }
323 327