Browse code

Refactor global portallocator and portmapper state

Continuation of: #11660, working on issue #11626.

Wrapped portmapper global state into a struct. Now portallocator and
portmapper have no global state (except configuration, and a default
instance).

Unfortunately, removing the global default instances will break
```api/server/server.go:1539```, and ```daemon/daemon.go:832```, which
both call the global portallocator directly. Fixing that would be a much
bigger change, so for now, have postponed that.

Signed-off-by: Paul Bellamy <paul.a.bellamy@gmail.com>

Paul Bellamy authored on 2015/03/24 19:29:30
Showing 3 changed files
... ...
@@ -50,8 +50,12 @@ var (
50 50
 )
51 51
 
52 52
 var (
53
-	defaultIP            = net.ParseIP("0.0.0.0")
54
-	defaultPortAllocator = New()
53
+	defaultIP = net.ParseIP("0.0.0.0")
54
+
55
+	DefaultPortAllocator = New()
56
+	RequestPort          = DefaultPortAllocator.RequestPort
57
+	ReleasePort          = DefaultPortAllocator.ReleasePort
58
+	ReleaseAll           = DefaultPortAllocator.ReleaseAll
55 59
 )
56 60
 
57 61
 type PortAllocator struct {
... ...
@@ -119,6 +123,9 @@ func (e ErrPortAlreadyAllocated) Error() string {
119 119
 	return fmt.Sprintf("Bind for %s:%d failed: port is already allocated", e.ip, e.port)
120 120
 }
121 121
 
122
+// RequestPort requests new port from global ports pool for specified ip and proto.
123
+// If port is 0 it returns first free port. Otherwise it cheks port availability
124
+// in pool and return that port or error if port is already busy.
122 125
 func (p *PortAllocator) RequestPort(ip net.IP, proto string, port int) (int, error) {
123 126
 	p.mutex.Lock()
124 127
 	defer p.mutex.Unlock()
... ...
@@ -152,13 +159,6 @@ func (p *PortAllocator) RequestPort(ip net.IP, proto string, port int) (int, err
152 152
 	return port, nil
153 153
 }
154 154
 
155
-// RequestPort requests new port from global ports pool for specified ip and proto.
156
-// If port is 0 it returns first free port. Otherwise it cheks port availability
157
-// in pool and return that port or error if port is already busy.
158
-func RequestPort(ip net.IP, proto string, port int) (int, error) {
159
-	return defaultPortAllocator.RequestPort(ip, proto, port)
160
-}
161
-
162 155
 // ReleasePort releases port from global ports pool for specified ip and proto.
163 156
 func (p *PortAllocator) ReleasePort(ip net.IP, proto string, port int) error {
164 157
 	p.mutex.Lock()
... ...
@@ -175,10 +175,6 @@ func (p *PortAllocator) ReleasePort(ip net.IP, proto string, port int) error {
175 175
 	return nil
176 176
 }
177 177
 
178
-func ReleasePort(ip net.IP, proto string, port int) error {
179
-	return defaultPortAllocator.ReleasePort(ip, proto, port)
180
-}
181
-
182 178
 // ReleaseAll releases all ports for all ips.
183 179
 func (p *PortAllocator) ReleaseAll() error {
184 180
 	p.mutex.Lock()
... ...
@@ -187,10 +183,6 @@ func (p *PortAllocator) ReleaseAll() error {
187 187
 	return nil
188 188
 }
189 189
 
190
-func ReleaseAll() error {
191
-	return defaultPortAllocator.ReleaseAll()
192
-}
193
-
194 190
 func (pm *portMap) findPort() (int, error) {
195 191
 	port := pm.last
196 192
 	for i := 0; i <= endPortRange-beginPortRange; i++ {
... ...
@@ -19,13 +19,12 @@ type mapping struct {
19 19
 }
20 20
 
21 21
 var (
22
-	chain *iptables.Chain
23
-	lock  sync.Mutex
24
-
25
-	// udp:ip:port
26
-	currentMappings = make(map[string]*mapping)
27
-
28 22
 	NewProxy = NewProxyCommand
23
+
24
+	DefaultPortMapper = NewWithPortAllocator(portallocator.DefaultPortAllocator)
25
+	SetIptablesChain  = DefaultPortMapper.SetIptablesChain
26
+	Map               = DefaultPortMapper.Map
27
+	Unmap             = DefaultPortMapper.Unmap
29 28
 )
30 29
 
31 30
 var (
... ...
@@ -34,13 +33,34 @@ var (
34 34
 	ErrPortNotMapped             = errors.New("port is not mapped")
35 35
 )
36 36
 
37
-func SetIptablesChain(c *iptables.Chain) {
38
-	chain = c
37
+type PortMapper struct {
38
+	chain *iptables.Chain
39
+
40
+	// udp:ip:port
41
+	currentMappings map[string]*mapping
42
+	lock            sync.Mutex
43
+
44
+	allocator *portallocator.PortAllocator
45
+}
46
+
47
+func New() *PortMapper {
48
+	return NewWithPortAllocator(portallocator.New())
49
+}
50
+
51
+func NewWithPortAllocator(allocator *portallocator.PortAllocator) *PortMapper {
52
+	return &PortMapper{
53
+		currentMappings: make(map[string]*mapping),
54
+		allocator:       allocator,
55
+	}
56
+}
57
+
58
+func (pm *PortMapper) SetIptablesChain(c *iptables.Chain) {
59
+	pm.chain = c
39 60
 }
40 61
 
41
-func Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err error) {
42
-	lock.Lock()
43
-	defer lock.Unlock()
62
+func (pm *PortMapper) Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err error) {
63
+	pm.lock.Lock()
64
+	defer pm.lock.Unlock()
44 65
 
45 66
 	var (
46 67
 		m                 *mapping
... ...
@@ -52,7 +72,7 @@ func Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err er
52 52
 	switch container.(type) {
53 53
 	case *net.TCPAddr:
54 54
 		proto = "tcp"
55
-		if allocatedHostPort, err = portallocator.RequestPort(hostIP, proto, hostPort); err != nil {
55
+		if allocatedHostPort, err = pm.allocator.RequestPort(hostIP, proto, hostPort); err != nil {
56 56
 			return nil, err
57 57
 		}
58 58
 
... ...
@@ -65,7 +85,7 @@ func Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err er
65 65
 		proxy = NewProxy(proto, hostIP, allocatedHostPort, container.(*net.TCPAddr).IP, container.(*net.TCPAddr).Port)
66 66
 	case *net.UDPAddr:
67 67
 		proto = "udp"
68
-		if allocatedHostPort, err = portallocator.RequestPort(hostIP, proto, hostPort); err != nil {
68
+		if allocatedHostPort, err = pm.allocator.RequestPort(hostIP, proto, hostPort); err != nil {
69 69
 			return nil, err
70 70
 		}
71 71
 
... ...
@@ -83,25 +103,25 @@ func Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err er
83 83
 	// release the allocated port on any further error during return.
84 84
 	defer func() {
85 85
 		if err != nil {
86
-			portallocator.ReleasePort(hostIP, proto, allocatedHostPort)
86
+			pm.allocator.ReleasePort(hostIP, proto, allocatedHostPort)
87 87
 		}
88 88
 	}()
89 89
 
90 90
 	key := getKey(m.host)
91
-	if _, exists := currentMappings[key]; exists {
91
+	if _, exists := pm.currentMappings[key]; exists {
92 92
 		return nil, ErrPortMappedForIP
93 93
 	}
94 94
 
95 95
 	containerIP, containerPort := getIPAndPort(m.container)
96
-	if err := forward(iptables.Append, m.proto, hostIP, allocatedHostPort, containerIP.String(), containerPort); err != nil {
96
+	if err := pm.forward(iptables.Append, m.proto, hostIP, allocatedHostPort, containerIP.String(), containerPort); err != nil {
97 97
 		return nil, err
98 98
 	}
99 99
 
100 100
 	cleanup := func() error {
101 101
 		// need to undo the iptables rules before we return
102 102
 		proxy.Stop()
103
-		forward(iptables.Delete, m.proto, hostIP, allocatedHostPort, containerIP.String(), containerPort)
104
-		if err := portallocator.ReleasePort(hostIP, m.proto, allocatedHostPort); err != nil {
103
+		pm.forward(iptables.Delete, m.proto, hostIP, allocatedHostPort, containerIP.String(), containerPort)
104
+		if err := pm.allocator.ReleasePort(hostIP, m.proto, allocatedHostPort); err != nil {
105 105
 			return err
106 106
 		}
107 107
 
... ...
@@ -115,35 +135,35 @@ func Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err er
115 115
 		return nil, err
116 116
 	}
117 117
 	m.userlandProxy = proxy
118
-	currentMappings[key] = m
118
+	pm.currentMappings[key] = m
119 119
 	return m.host, nil
120 120
 }
121 121
 
122
-func Unmap(host net.Addr) error {
123
-	lock.Lock()
124
-	defer lock.Unlock()
122
+func (pm *PortMapper) Unmap(host net.Addr) error {
123
+	pm.lock.Lock()
124
+	defer pm.lock.Unlock()
125 125
 
126 126
 	key := getKey(host)
127
-	data, exists := currentMappings[key]
127
+	data, exists := pm.currentMappings[key]
128 128
 	if !exists {
129 129
 		return ErrPortNotMapped
130 130
 	}
131 131
 
132 132
 	data.userlandProxy.Stop()
133 133
 
134
-	delete(currentMappings, key)
134
+	delete(pm.currentMappings, key)
135 135
 
136 136
 	containerIP, containerPort := getIPAndPort(data.container)
137 137
 	hostIP, hostPort := getIPAndPort(data.host)
138
-	if err := forward(iptables.Delete, data.proto, hostIP, hostPort, containerIP.String(), containerPort); err != nil {
138
+	if err := pm.forward(iptables.Delete, data.proto, hostIP, hostPort, containerIP.String(), containerPort); err != nil {
139 139
 		log.Errorf("Error on iptables delete: %s", err)
140 140
 	}
141 141
 
142 142
 	switch a := host.(type) {
143 143
 	case *net.TCPAddr:
144
-		return portallocator.ReleasePort(a.IP, "tcp", a.Port)
144
+		return pm.allocator.ReleasePort(a.IP, "tcp", a.Port)
145 145
 	case *net.UDPAddr:
146
-		return portallocator.ReleasePort(a.IP, "udp", a.Port)
146
+		return pm.allocator.ReleasePort(a.IP, "udp", a.Port)
147 147
 	}
148 148
 	return nil
149 149
 }
... ...
@@ -168,9 +188,9 @@ func getIPAndPort(a net.Addr) (net.IP, int) {
168 168
 	return nil, 0
169 169
 }
170 170
 
171
-func forward(action iptables.Action, proto string, sourceIP net.IP, sourcePort int, containerIP string, containerPort int) error {
172
-	if chain == nil {
171
+func (pm *PortMapper) forward(action iptables.Action, proto string, sourceIP net.IP, sourcePort int, containerIP string, containerPort int) error {
172
+	if pm.chain == nil {
173 173
 		return nil
174 174
 	}
175
-	return chain.Forward(action, sourceIP, sourcePort, proto, containerIP, containerPort)
175
+	return pm.chain.Forward(action, sourceIP, sourcePort, proto, containerIP, containerPort)
176 176
 }
... ...
@@ -13,30 +13,26 @@ func init() {
13 13
 	NewProxy = NewMockProxyCommand
14 14
 }
15 15
 
16
-func reset() {
17
-	chain = nil
18
-	currentMappings = make(map[string]*mapping)
19
-}
20
-
21 16
 func TestSetIptablesChain(t *testing.T) {
22
-	defer reset()
17
+	pm := New()
23 18
 
24 19
 	c := &iptables.Chain{
25 20
 		Name:   "TEST",
26 21
 		Bridge: "192.168.1.1",
27 22
 	}
28 23
 
29
-	if chain != nil {
24
+	if pm.chain != nil {
30 25
 		t.Fatal("chain should be nil at init")
31 26
 	}
32 27
 
33
-	SetIptablesChain(c)
34
-	if chain == nil {
28
+	pm.SetIptablesChain(c)
29
+	if pm.chain == nil {
35 30
 		t.Fatal("chain should not be nil after set")
36 31
 	}
37 32
 }
38 33
 
39 34
 func TestMapPorts(t *testing.T) {
35
+	pm := New()
40 36
 	dstIp1 := net.ParseIP("192.168.0.1")
41 37
 	dstIp2 := net.ParseIP("192.168.0.2")
42 38
 	dstAddr1 := &net.TCPAddr{IP: dstIp1, Port: 80}
... ...
@@ -49,34 +45,34 @@ func TestMapPorts(t *testing.T) {
49 49
 		return (addr1.Network() == addr2.Network()) && (addr1.String() == addr2.String())
50 50
 	}
51 51
 
52
-	if host, err := Map(srcAddr1, dstIp1, 80); err != nil {
52
+	if host, err := pm.Map(srcAddr1, dstIp1, 80); err != nil {
53 53
 		t.Fatalf("Failed to allocate port: %s", err)
54 54
 	} else if !addrEqual(dstAddr1, host) {
55 55
 		t.Fatalf("Incorrect mapping result: expected %s:%s, got %s:%s",
56 56
 			dstAddr1.String(), dstAddr1.Network(), host.String(), host.Network())
57 57
 	}
58 58
 
59
-	if _, err := Map(srcAddr1, dstIp1, 80); err == nil {
59
+	if _, err := pm.Map(srcAddr1, dstIp1, 80); err == nil {
60 60
 		t.Fatalf("Port is in use - mapping should have failed")
61 61
 	}
62 62
 
63
-	if _, err := Map(srcAddr2, dstIp1, 80); err == nil {
63
+	if _, err := pm.Map(srcAddr2, dstIp1, 80); err == nil {
64 64
 		t.Fatalf("Port is in use - mapping should have failed")
65 65
 	}
66 66
 
67
-	if _, err := Map(srcAddr2, dstIp2, 80); err != nil {
67
+	if _, err := pm.Map(srcAddr2, dstIp2, 80); err != nil {
68 68
 		t.Fatalf("Failed to allocate port: %s", err)
69 69
 	}
70 70
 
71
-	if Unmap(dstAddr1) != nil {
71
+	if pm.Unmap(dstAddr1) != nil {
72 72
 		t.Fatalf("Failed to release port")
73 73
 	}
74 74
 
75
-	if Unmap(dstAddr2) != nil {
75
+	if pm.Unmap(dstAddr2) != nil {
76 76
 		t.Fatalf("Failed to release port")
77 77
 	}
78 78
 
79
-	if Unmap(dstAddr2) == nil {
79
+	if pm.Unmap(dstAddr2) == nil {
80 80
 		t.Fatalf("Port already released, but no error reported")
81 81
 	}
82 82
 }
... ...
@@ -115,6 +111,7 @@ func TestGetUDPIPAndPort(t *testing.T) {
115 115
 }
116 116
 
117 117
 func TestMapAllPortsSingleInterface(t *testing.T) {
118
+	pm := New()
118 119
 	dstIp1 := net.ParseIP("0.0.0.0")
119 120
 	srcAddr1 := &net.TCPAddr{Port: 1080, IP: net.ParseIP("172.16.0.1")}
120 121
 
... ...
@@ -124,26 +121,26 @@ func TestMapAllPortsSingleInterface(t *testing.T) {
124 124
 
125 125
 	defer func() {
126 126
 		for _, val := range hosts {
127
-			Unmap(val)
127
+			pm.Unmap(val)
128 128
 		}
129 129
 	}()
130 130
 
131 131
 	for i := 0; i < 10; i++ {
132 132
 		start, end := portallocator.PortRange()
133 133
 		for i := start; i < end; i++ {
134
-			if host, err = Map(srcAddr1, dstIp1, 0); err != nil {
134
+			if host, err = pm.Map(srcAddr1, dstIp1, 0); err != nil {
135 135
 				t.Fatal(err)
136 136
 			}
137 137
 
138 138
 			hosts = append(hosts, host)
139 139
 		}
140 140
 
141
-		if _, err := Map(srcAddr1, dstIp1, start); err == nil {
141
+		if _, err := pm.Map(srcAddr1, dstIp1, start); err == nil {
142 142
 			t.Fatalf("Port %d should be bound but is not", start)
143 143
 		}
144 144
 
145 145
 		for _, val := range hosts {
146
-			if err := Unmap(val); err != nil {
146
+			if err := pm.Unmap(val); err != nil {
147 147
 				t.Fatal(err)
148 148
 			}
149 149
 		}