Browse code

Fix non thread-safe Iteration around go-patricia

* Unlike other methods in truncindex, Iterate was not locking before
using the Trie, making it potentially race e.g. Delete could result in
setting a child to nil, while Iterate dereferenced that node
while walking the Trie.

Signed-off-by: Petar Petrov <pppepito86@gmail.com>

William Martin authored on 2017/01/06 23:14:10
Showing 2 changed files
... ...
@@ -125,8 +125,13 @@ func (idx *TruncIndex) Get(s string) (string, error) {
125 125
 	return "", ErrNotExist
126 126
 }
127 127
 
128
-// Iterate iterates over all stored IDs, and passes each of them to the given handler.
128
+// Iterate iterates over all stored IDs and passes each of them to the given
129
+// handler. Take care that the handler method does not call any public
130
+// method on truncindex as the internal locking is not reentrant/recursive
131
+// and will result in deadlock.
129 132
 func (idx *TruncIndex) Iterate(handler func(id string)) {
133
+	idx.Lock()
134
+	defer idx.Unlock()
130 135
 	idx.trie.Visit(func(prefix patricia.Prefix, item patricia.Item) error {
131 136
 		handler(string(prefix))
132 137
 		return nil
... ...
@@ -3,6 +3,7 @@ package truncindex
3 3
 import (
4 4
 	"math/rand"
5 5
 	"testing"
6
+	"time"
6 7
 
7 8
 	"github.com/docker/docker/pkg/stringid"
8 9
 )
... ...
@@ -98,6 +99,7 @@ func TestTruncIndex(t *testing.T) {
98 98
 	assertIndexGet(t, index, id, id, false)
99 99
 
100 100
 	assertIndexIterate(t)
101
+	assertIndexIterateDoNotPanic(t)
101 102
 }
102 103
 
103 104
 func assertIndexIterate(t *testing.T) {
... ...
@@ -121,6 +123,28 @@ func assertIndexIterate(t *testing.T) {
121 121
 	})
122 122
 }
123 123
 
124
+func assertIndexIterateDoNotPanic(t *testing.T) {
125
+	ids := []string{
126
+		"19b36c2c326ccc11e726eee6ee78a0baf166ef96",
127
+		"28b36c2c326ccc11e726eee6ee78a0baf166ef96",
128
+	}
129
+
130
+	index := NewTruncIndex(ids)
131
+	iterationStarted := make(chan bool, 1)
132
+
133
+	go func() {
134
+		<-iterationStarted
135
+		index.Delete("19b36c2c326ccc11e726eee6ee78a0baf166ef96")
136
+	}()
137
+
138
+	index.Iterate(func(targetId string) {
139
+		if targetId == "19b36c2c326ccc11e726eee6ee78a0baf166ef96" {
140
+			iterationStarted <- true
141
+			time.Sleep(100 * time.Millisecond)
142
+		}
143
+	})
144
+}
145
+
124 146
 func assertIndexGet(t *testing.T, index *TruncIndex, input, expectedResult string, expectError bool) {
125 147
 	if result, err := index.Get(input); err != nil && !expectError {
126 148
 		t.Fatalf("Unexpected error getting '%s': %s", input, err)