Browse code

Add tests of tcp port allocator

Solomon Hykes authored on 2013/04/06 05:03:24
Showing 2 changed files
... ...
@@ -176,6 +176,7 @@ func (alloc *PortAllocator) runFountain() {
176 176
 
177 177
 // FIXME: Release can no longer fail, change its prototype to reflect that.
178 178
 func (alloc *PortAllocator) Release(port int) error {
179
+	Debugf("Releasing %d", port)
179 180
 	alloc.lock.Lock()
180 181
 	delete(alloc.inUse, port)
181 182
 	alloc.lock.Unlock()
... ...
@@ -183,6 +184,7 @@ func (alloc *PortAllocator) Release(port int) error {
183 183
 }
184 184
 
185 185
 func (alloc *PortAllocator) Acquire(port int) (int, error) {
186
+	Debugf("Acquiring %d", port)
186 187
 	if port == 0 {
187 188
 		// Allocate a port from the fountain
188 189
 		for port := range alloc.fountain {
... ...
@@ -18,6 +18,42 @@ func TestIptables(t *testing.T) {
18 18
 	}
19 19
 }
20 20
 
21
+func TestPortAllocation(t *testing.T) {
22
+	allocator, err := newPortAllocator()
23
+	if err != nil {
24
+		t.Fatal(err)
25
+	}
26
+	if port, err := allocator.Acquire(80); err != nil {
27
+		t.Fatal(err)
28
+	} else if port != 80 {
29
+		t.Fatalf("Acquire(80) should return 80, not %d", port)
30
+	}
31
+	port, err := allocator.Acquire(0)
32
+	if err != nil {
33
+		t.Fatal(err)
34
+	}
35
+	if port <= 0 {
36
+		t.Fatalf("Acquire(0) should return a non-zero port")
37
+	}
38
+	if _, err := allocator.Acquire(port); err == nil {
39
+		t.Fatalf("Acquiring a port already in use should return an error")
40
+	}
41
+	if newPort, err := allocator.Acquire(0); err != nil {
42
+		t.Fatal(err)
43
+	} else if newPort == port {
44
+		t.Fatalf("Acquire(0) allocated the same port twice: %d", port)
45
+	}
46
+	if _, err := allocator.Acquire(80); err == nil {
47
+		t.Fatalf("Acquiring a port already in use should return an error")
48
+	}
49
+	if err := allocator.Release(80); err != nil {
50
+		t.Fatal(err)
51
+	}
52
+	if _, err := allocator.Acquire(80); err != nil {
53
+		t.Fatal(err)
54
+	}
55
+}
56
+
21 57
 func TestNetworkRange(t *testing.T) {
22 58
 	// Simple class C test
23 59
 	_, network, _ := net.ParseCIDR("192.168.0.1/24")