Browse code

Defer PTR queries to external servers based on A/AAAA response

Signed-off-by: Santhosh Manohar <santhosh@docker.com>

Santhosh Manohar authored on 2016/12/07 07:56:24
Showing 4 changed files
... ...
@@ -380,7 +380,7 @@ func TestSRVServiceQuery(t *testing.T) {
380 380
 	sr := svcInfo{
381 381
 		svcMap:     make(map[string][]net.IP),
382 382
 		svcIPv6Map: make(map[string][]net.IP),
383
-		ipMap:      make(map[string]string),
383
+		ipMap:      make(map[string]*ipInfo),
384 384
 		service:    make(map[string][]servicePorts),
385 385
 	}
386 386
 	// backing container for the service
... ...
@@ -80,10 +80,18 @@ type NetworkInfo interface {
80 80
 // When the function returns true, the walk will stop.
81 81
 type EndpointWalker func(ep Endpoint) bool
82 82
 
83
+// ipInfo is the reverse mapping from IP to service name to serve the PTR query.
84
+// extResolver is set if an externl server resolves a service name to this IP.
85
+// Its an indication to defer PTR queries also to that external server.
86
+type ipInfo struct {
87
+	name        string
88
+	extResolver bool
89
+}
90
+
83 91
 type svcInfo struct {
84 92
 	svcMap     map[string][]net.IP
85 93
 	svcIPv6Map map[string][]net.IP
86
-	ipMap      map[string]string
94
+	ipMap      map[string]*ipInfo
87 95
 	service    map[string][]servicePorts
88 96
 }
89 97
 
... ...
@@ -1070,10 +1078,12 @@ func (n *network) updateSvcRecord(ep *endpoint, localEps []*endpoint, isAdd bool
1070 1070
 	}
1071 1071
 }
1072 1072
 
1073
-func addIPToName(ipMap map[string]string, name string, ip net.IP) {
1073
+func addIPToName(ipMap map[string]*ipInfo, name string, ip net.IP) {
1074 1074
 	reverseIP := netutils.ReverseIP(ip.String())
1075 1075
 	if _, ok := ipMap[reverseIP]; !ok {
1076
-		ipMap[reverseIP] = name
1076
+		ipMap[reverseIP] = &ipInfo{
1077
+			name: name,
1078
+		}
1077 1079
 	}
1078 1080
 }
1079 1081
 
... ...
@@ -1117,7 +1127,7 @@ func (n *network) addSvcRecords(name string, epIP net.IP, epIPv6 net.IP, ipMapUp
1117 1117
 		sr = svcInfo{
1118 1118
 			svcMap:     make(map[string][]net.IP),
1119 1119
 			svcIPv6Map: make(map[string][]net.IP),
1120
-			ipMap:      make(map[string]string),
1120
+			ipMap:      make(map[string]*ipInfo),
1121 1121
 		}
1122 1122
 		c.svcRecords[n.ID()] = sr
1123 1123
 	}
... ...
@@ -1612,8 +1622,8 @@ func (n *network) ResolveName(req string, ipType int) ([]net.IP, bool) {
1612 1612
 
1613 1613
 	c := n.getController()
1614 1614
 	c.Lock()
1615
+	defer c.Unlock()
1615 1616
 	sr, ok := c.svcRecords[n.ID()]
1616
-	c.Unlock()
1617 1617
 
1618 1618
 	if !ok {
1619 1619
 		return nil, false
... ...
@@ -1621,7 +1631,6 @@ func (n *network) ResolveName(req string, ipType int) ([]net.IP, bool) {
1621 1621
 
1622 1622
 	req = strings.TrimSuffix(req, ".")
1623 1623
 	var ip []net.IP
1624
-	n.Lock()
1625 1624
 	ip, ok = sr.svcMap[req]
1626 1625
 
1627 1626
 	if ipType == types.IPv6 {
... ...
@@ -1634,7 +1643,6 @@ func (n *network) ResolveName(req string, ipType int) ([]net.IP, bool) {
1634 1634
 		}
1635 1635
 		ip = sr.svcIPv6Map[req]
1636 1636
 	}
1637
-	n.Unlock()
1638 1637
 
1639 1638
 	if ip != nil {
1640 1639
 		return ip, false
... ...
@@ -1643,13 +1651,28 @@ func (n *network) ResolveName(req string, ipType int) ([]net.IP, bool) {
1643 1643
 	return nil, ipv6Miss
1644 1644
 }
1645 1645
 
1646
-func (n *network) ResolveIP(ip string) string {
1647
-	var svc string
1646
+func (n *network) HandleQueryResp(name string, ip net.IP) {
1647
+	c := n.getController()
1648
+	c.Lock()
1649
+	defer c.Unlock()
1650
+	sr, ok := c.svcRecords[n.ID()]
1648 1651
 
1652
+	if !ok {
1653
+		return
1654
+	}
1655
+
1656
+	ipStr := netutils.ReverseIP(ip.String())
1657
+
1658
+	if ipInfo, ok := sr.ipMap[ipStr]; ok {
1659
+		ipInfo.extResolver = true
1660
+	}
1661
+}
1662
+
1663
+func (n *network) ResolveIP(ip string) string {
1649 1664
 	c := n.getController()
1650 1665
 	c.Lock()
1666
+	defer c.Unlock()
1651 1667
 	sr, ok := c.svcRecords[n.ID()]
1652
-	c.Unlock()
1653 1668
 
1654 1669
 	if !ok {
1655 1670
 		return ""
... ...
@@ -1657,15 +1680,13 @@ func (n *network) ResolveIP(ip string) string {
1657 1657
 
1658 1658
 	nwName := n.Name()
1659 1659
 
1660
-	n.Lock()
1661
-	defer n.Unlock()
1662
-	svc, ok = sr.ipMap[ip]
1660
+	ipInfo, ok := sr.ipMap[ip]
1663 1661
 
1664
-	if ok {
1665
-		return svc + "." + nwName
1662
+	if !ok || ipInfo.extResolver {
1663
+		return ""
1666 1664
 	}
1667 1665
 
1668
-	return svc
1666
+	return ipInfo.name + "." + nwName
1669 1667
 }
1670 1668
 
1671 1669
 func (n *network) ResolveService(name string) ([]*net.SRV, []net.IP) {
... ...
@@ -1689,8 +1710,8 @@ func (n *network) ResolveService(name string) ([]*net.SRV, []net.IP) {
1689 1689
 	svcName := strings.Join(parts[2:], ".")
1690 1690
 
1691 1691
 	c.Lock()
1692
+	defer c.Unlock()
1692 1693
 	sr, ok := c.svcRecords[n.ID()]
1693
-	c.Unlock()
1694 1694
 
1695 1695
 	if !ok {
1696 1696
 		return nil, nil
... ...
@@ -54,6 +54,9 @@ type DNSBackend interface {
54 54
 	ExecFunc(f func()) error
55 55
 	//NdotsSet queries the backends ndots dns option settings
56 56
 	NdotsSet() bool
57
+	// HandleQueryResp passes the name & IP from a response to the backend. backend
58
+	// can use it to maintain any required state about the resolution
59
+	HandleQueryResp(name string, ip net.IP)
57 60
 }
58 61
 
59 62
 const (
... ...
@@ -462,9 +465,20 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
462 462
 				logrus.Debugf("Read from DNS server failed, %s", err)
463 463
 				continue
464 464
 			}
465
-
466 465
 			r.forwardQueryEnd()
467
-
466
+			if resp != nil {
467
+				for _, rr := range resp.Answer {
468
+					h := rr.Header()
469
+					switch h.Rrtype {
470
+					case dns.TypeA:
471
+						ip := rr.(*dns.A).A
472
+						r.backend.HandleQueryResp(h.Name, ip)
473
+					case dns.TypeAAAA:
474
+						ip := rr.(*dns.AAAA).AAAA
475
+						r.backend.HandleQueryResp(h.Name, ip)
476
+					}
477
+				}
478
+			}
468 479
 			resp.Compress = true
469 480
 			break
470 481
 		}
... ...
@@ -411,6 +411,13 @@ func (sb *sandbox) updateGateway(ep *endpoint) error {
411 411
 	return nil
412 412
 }
413 413
 
414
+func (sb *sandbox) HandleQueryResp(name string, ip net.IP) {
415
+	for _, ep := range sb.getConnectedEndpoints() {
416
+		n := ep.getNetwork()
417
+		n.HandleQueryResp(name, ip)
418
+	}
419
+}
420
+
414 421
 func (sb *sandbox) ResolveIP(ip string) string {
415 422
 	var svc string
416 423
 	logrus.Debugf("IP To resolve %v", ip)