Browse code

Fix potential races in the volume store

Uses finer grained locking so that each volume name gets its own lock
rather than only being protected by the global lock, which itself needs
to be unlocked during cetain operations (`create` especially`)

Signed-off-by: Brian Goff <cpuguy83@gmail.com>

Brian Goff authored on 2015/10/20 05:43:56
Showing 4 changed files
1 1
new file mode 100644
... ...
@@ -0,0 +1,65 @@
0
+Locker
1
+=====
2
+
3
+locker provides a mechanism for creating finer-grained locking to help
4
+free up more global locks to handle other tasks.
5
+
6
+The implementation looks close to a sync.Mutex, however the user must provide a
7
+reference to use to refer to the underlying lock when locking and unlocking,
8
+and unlock may generate an error.
9
+
10
+If a lock with a given name does not exist when `Lock` is called, one is
11
+created.
12
+Lock references are automatically cleaned up on `Unlock` if nothing else is
13
+waiting for the lock.
14
+
15
+
16
+## Usage
17
+
18
+```go
19
+package important
20
+
21
+import (
22
+	"sync"
23
+	"time"
24
+
25
+	"github.com/docker/docker/pkg/locker"
26
+)
27
+
28
+type important struct {
29
+	locks *locker.Locker
30
+	data  map[string]interface{}
31
+	mu    sync.Mutex
32
+}
33
+
34
+func (i *important) Get(name string) interface{} {
35
+	i.locks.Lock(name)
36
+	defer i.locks.Unlock(name)
37
+	return data[name]
38
+}
39
+
40
+func (i *important) Create(name string, data interface{}) {
41
+	i.locks.Lock(name)
42
+	defer i.locks.Unlock(name)
43
+
44
+	i.createImporatant(data)
45
+
46
+	s.mu.Lock()
47
+	i.data[name] = data
48
+	s.mu.Unlock()
49
+}
50
+
51
+func (i *important) createImportant(data interface{}) {
52
+	time.Sleep(10 * time.Second)
53
+}
54
+```
55
+
56
+For functions dealing with a given name, always lock at the beginning of the
57
+function (or before doing anything with the underlying state), this ensures any
58
+other function that is dealing with the same name will block.
59
+
60
+When needing to modify the underlying data, use the global lock to ensure nothing
61
+else is modfying it at the same time.
62
+Since name lock is already in place, no reads will occur while the modification
63
+is being performed.
64
+
0 65
new file mode 100644
... ...
@@ -0,0 +1,111 @@
0
+/*
1
+Package locker provides a mechanism for creating finer-grained locking to help
2
+free up more global locks to handle other tasks.
3
+
4
+The implementation looks close to a sync.Mutex, however the user must provide a
5
+reference to use to refer to the underlying lock when locking and unlocking,
6
+and unlock may generate an error.
7
+
8
+If a lock with a given name does not exist when `Lock` is called, one is
9
+created.
10
+Lock references are automatically cleaned up on `Unlock` if nothing else is
11
+waiting for the lock.
12
+*/
13
+package locker
14
+
15
+import (
16
+	"errors"
17
+	"sync"
18
+	"sync/atomic"
19
+)
20
+
21
+// ErrNoSuchLock is returned when the requested lock does not exist
22
+var ErrNoSuchLock = errors.New("no such lock")
23
+
24
+// Locker provides a locking mechanism based on the passed in reference name
25
+type Locker struct {
26
+	mu    sync.Mutex
27
+	locks map[string]*lockCtr
28
+}
29
+
30
+// lockCtr is used by Locker to represent a lock with a given name.
31
+type lockCtr struct {
32
+	mu sync.Mutex
33
+	// waiters is the number of waiters waiting to acquire the lock
34
+	waiters uint32
35
+}
36
+
37
+// inc increments the number of waiters waiting for the lock
38
+func (l *lockCtr) inc() {
39
+	atomic.AddUint32(&l.waiters, 1)
40
+}
41
+
42
+// dec decrements the number of waiters wating on the lock
43
+func (l *lockCtr) dec() {
44
+	atomic.AddUint32(&l.waiters, ^uint32(l.waiters-1))
45
+}
46
+
47
+// count gets the current number of waiters
48
+func (l *lockCtr) count() uint32 {
49
+	return atomic.LoadUint32(&l.waiters)
50
+}
51
+
52
+// Lock locks the mutex
53
+func (l *lockCtr) Lock() {
54
+	l.mu.Lock()
55
+}
56
+
57
+// Unlock unlocks the mutex
58
+func (l *lockCtr) Unlock() {
59
+	l.mu.Unlock()
60
+}
61
+
62
+// New creates a new Locker
63
+func New() *Locker {
64
+	return &Locker{
65
+		locks: make(map[string]*lockCtr),
66
+	}
67
+}
68
+
69
+// Lock locks a mutex with the given name. If it doesn't exist, one is created
70
+func (l *Locker) Lock(name string) {
71
+	l.mu.Lock()
72
+	if l.locks == nil {
73
+		l.locks = make(map[string]*lockCtr)
74
+	}
75
+
76
+	nameLock, exists := l.locks[name]
77
+	if !exists {
78
+		nameLock = &lockCtr{}
79
+		l.locks[name] = nameLock
80
+	}
81
+
82
+	// increment the nameLock waiters while inside the main mutex
83
+	// this makes sure that the lock isn't deleted if `Lock` and `Unlock` are called concurrently
84
+	nameLock.inc()
85
+	l.mu.Unlock()
86
+
87
+	// Lock the nameLock outside the main mutex so we don't block other operations
88
+	// once locked then we can decrement the number of waiters for this lock
89
+	nameLock.Lock()
90
+	nameLock.dec()
91
+}
92
+
93
+// Unlock unlocks the mutex with the given name
94
+// If the given lock is not being waited on by any other callers, it is deleted
95
+func (l *Locker) Unlock(name string) error {
96
+	l.mu.Lock()
97
+	nameLock, exists := l.locks[name]
98
+	if !exists {
99
+		l.mu.Unlock()
100
+		return ErrNoSuchLock
101
+	}
102
+
103
+	if nameLock.count() == 0 {
104
+		delete(l.locks, name)
105
+	}
106
+	nameLock.Unlock()
107
+
108
+	l.mu.Unlock()
109
+	return nil
110
+}
0 111
new file mode 100644
... ...
@@ -0,0 +1,90 @@
0
+package locker
1
+
2
+import (
3
+	"runtime"
4
+	"testing"
5
+)
6
+
7
+func TestLockCounter(t *testing.T) {
8
+	l := &lockCtr{}
9
+	l.inc()
10
+
11
+	if l.waiters != 1 {
12
+		t.Fatal("counter inc failed")
13
+	}
14
+
15
+	l.dec()
16
+	if l.waiters != 0 {
17
+		t.Fatal("counter dec failed")
18
+	}
19
+}
20
+
21
+func TestLockerLock(t *testing.T) {
22
+	l := New()
23
+	l.Lock("test")
24
+	ctr := l.locks["test"]
25
+
26
+	if ctr.count() != 0 {
27
+		t.Fatalf("expected waiters to be 0, got :%d", ctr.waiters)
28
+	}
29
+
30
+	chDone := make(chan struct{})
31
+	go func() {
32
+		l.Lock("test")
33
+		close(chDone)
34
+	}()
35
+
36
+	runtime.Gosched()
37
+
38
+	select {
39
+	case <-chDone:
40
+		t.Fatal("lock should not have returned while it was still held")
41
+	default:
42
+	}
43
+
44
+	if ctr.count() != 1 {
45
+		t.Fatalf("expected waiters to be 1, got: %d", ctr.count())
46
+	}
47
+
48
+	if err := l.Unlock("test"); err != nil {
49
+		t.Fatal(err)
50
+	}
51
+	runtime.Gosched()
52
+
53
+	select {
54
+	case <-chDone:
55
+	default:
56
+		// one more time just to be sure
57
+		runtime.Gosched()
58
+		select {
59
+		case <-chDone:
60
+		default:
61
+			t.Fatalf("lock should have completed")
62
+		}
63
+	}
64
+
65
+	if ctr.count() != 0 {
66
+		t.Fatalf("expected waiters to be 0, got: %d", ctr.count())
67
+	}
68
+}
69
+
70
+func TestLockerUnlock(t *testing.T) {
71
+	l := New()
72
+
73
+	l.Lock("test")
74
+	l.Unlock("test")
75
+
76
+	chDone := make(chan struct{})
77
+	go func() {
78
+		l.Lock("test")
79
+		close(chDone)
80
+	}()
81
+
82
+	runtime.Gosched()
83
+
84
+	select {
85
+	case <-chDone:
86
+	default:
87
+		t.Fatalf("lock should not be blocked")
88
+	}
89
+}
... ...
@@ -5,6 +5,7 @@ import (
5 5
 	"sync"
6 6
 
7 7
 	"github.com/Sirupsen/logrus"
8
+	"github.com/docker/docker/pkg/locker"
8 9
 	"github.com/docker/docker/volume"
9 10
 	"github.com/docker/docker/volume/drivers"
10 11
 )
... ...
@@ -22,14 +23,35 @@ var (
22 22
 // reference counting of volumes in the system.
23 23
 func New() *VolumeStore {
24 24
 	return &VolumeStore{
25
-		vols: make(map[string]*volumeCounter),
25
+		vols:  make(map[string]*volumeCounter),
26
+		locks: &locker.Locker{},
26 27
 	}
27 28
 }
28 29
 
30
+func (s *VolumeStore) get(name string) (*volumeCounter, bool) {
31
+	s.globalLock.Lock()
32
+	vc, exists := s.vols[name]
33
+	s.globalLock.Unlock()
34
+	return vc, exists
35
+}
36
+
37
+func (s *VolumeStore) set(name string, vc *volumeCounter) {
38
+	s.globalLock.Lock()
39
+	s.vols[name] = vc
40
+	s.globalLock.Unlock()
41
+}
42
+
43
+func (s *VolumeStore) remove(name string) {
44
+	s.globalLock.Lock()
45
+	delete(s.vols, name)
46
+	s.globalLock.Unlock()
47
+}
48
+
29 49
 // VolumeStore is a struct that stores the list of volumes available and keeps track of their usage counts
30 50
 type VolumeStore struct {
31
-	vols map[string]*volumeCounter
32
-	mu   sync.Mutex
51
+	vols       map[string]*volumeCounter
52
+	locks      *locker.Locker
53
+	globalLock sync.Mutex
33 54
 }
34 55
 
35 56
 // volumeCounter keeps track of references to a volume
... ...
@@ -47,14 +69,14 @@ func (s *VolumeStore) AddAll(vols []volume.Volume) {
47 47
 
48 48
 // Create tries to find an existing volume with the given name or create a new one from the passed in driver
49 49
 func (s *VolumeStore) Create(name, driverName string, opts map[string]string) (volume.Volume, error) {
50
-	s.mu.Lock()
51 50
 	name = normaliseVolumeName(name)
52
-	if vc, exists := s.vols[name]; exists {
51
+	s.locks.Lock(name)
52
+	defer s.locks.Unlock(name)
53
+
54
+	if vc, exists := s.get(name); exists {
53 55
 		v := vc.Volume
54
-		s.mu.Unlock()
55 56
 		return v, nil
56 57
 	}
57
-	s.mu.Unlock()
58 58
 	logrus.Debugf("Registering new volume reference: driver %s, name %s", driverName, name)
59 59
 
60 60
 	vd, err := volumedrivers.GetDriver(driverName)
... ...
@@ -76,19 +98,17 @@ func (s *VolumeStore) Create(name, driverName string, opts map[string]string) (v
76 76
 		return nil, err
77 77
 	}
78 78
 
79
-	s.mu.Lock()
80
-	s.vols[normaliseVolumeName(v.Name())] = &volumeCounter{v, 0}
81
-	s.mu.Unlock()
82
-
79
+	s.set(name, &volumeCounter{v, 0})
83 80
 	return v, nil
84 81
 }
85 82
 
86 83
 // Get looks if a volume with the given name exists and returns it if so
87 84
 func (s *VolumeStore) Get(name string) (volume.Volume, error) {
88 85
 	name = normaliseVolumeName(name)
89
-	s.mu.Lock()
90
-	defer s.mu.Unlock()
91
-	vc, exists := s.vols[name]
86
+	s.locks.Lock(name)
87
+	defer s.locks.Unlock(name)
88
+
89
+	vc, exists := s.get(name)
92 90
 	if !exists {
93 91
 		return nil, ErrNoSuchVolume
94 92
 	}
... ...
@@ -97,11 +117,12 @@ func (s *VolumeStore) Get(name string) (volume.Volume, error) {
97 97
 
98 98
 // Remove removes the requested volume. A volume is not removed if the usage count is > 0
99 99
 func (s *VolumeStore) Remove(v volume.Volume) error {
100
-	s.mu.Lock()
101
-	defer s.mu.Unlock()
102 100
 	name := normaliseVolumeName(v.Name())
101
+	s.locks.Lock(name)
102
+	defer s.locks.Unlock(name)
103
+
103 104
 	logrus.Debugf("Removing volume reference: driver %s, name %s", v.DriverName(), name)
104
-	vc, exists := s.vols[name]
105
+	vc, exists := s.get(name)
105 106
 	if !exists {
106 107
 		return ErrNoSuchVolume
107 108
 	}
... ...
@@ -117,20 +138,21 @@ func (s *VolumeStore) Remove(v volume.Volume) error {
117 117
 	if err := vd.Remove(vc.Volume); err != nil {
118 118
 		return err
119 119
 	}
120
-	delete(s.vols, name)
120
+
121
+	s.remove(name)
121 122
 	return nil
122 123
 }
123 124
 
124 125
 // Increment increments the usage count of the passed in volume by 1
125 126
 func (s *VolumeStore) Increment(v volume.Volume) {
126
-	s.mu.Lock()
127
-	defer s.mu.Unlock()
128 127
 	name := normaliseVolumeName(v.Name())
129
-	logrus.Debugf("Incrementing volume reference: driver %s, name %s", v.DriverName(), name)
128
+	s.locks.Lock(name)
129
+	defer s.locks.Unlock(name)
130 130
 
131
-	vc, exists := s.vols[name]
131
+	logrus.Debugf("Incrementing volume reference: driver %s, name %s", v.DriverName(), v.Name())
132
+	vc, exists := s.get(name)
132 133
 	if !exists {
133
-		s.vols[name] = &volumeCounter{v, 1}
134
+		s.set(name, &volumeCounter{v, 1})
134 135
 		return
135 136
 	}
136 137
 	vc.count++
... ...
@@ -138,12 +160,12 @@ func (s *VolumeStore) Increment(v volume.Volume) {
138 138
 
139 139
 // Decrement decrements the usage count of the passed in volume by 1
140 140
 func (s *VolumeStore) Decrement(v volume.Volume) {
141
-	s.mu.Lock()
142
-	defer s.mu.Unlock()
143 141
 	name := normaliseVolumeName(v.Name())
144
-	logrus.Debugf("Decrementing volume reference: driver %s, name %s", v.DriverName(), name)
142
+	s.locks.Lock(name)
143
+	defer s.locks.Unlock(name)
144
+	logrus.Debugf("Decrementing volume reference: driver %s, name %s", v.DriverName(), v.Name())
145 145
 
146
-	vc, exists := s.vols[name]
146
+	vc, exists := s.get(name)
147 147
 	if !exists {
148 148
 		return
149 149
 	}
... ...
@@ -155,9 +177,11 @@ func (s *VolumeStore) Decrement(v volume.Volume) {
155 155
 
156 156
 // Count returns the usage count of the passed in volume
157 157
 func (s *VolumeStore) Count(v volume.Volume) uint {
158
-	s.mu.Lock()
159
-	defer s.mu.Unlock()
160
-	vc, exists := s.vols[normaliseVolumeName(v.Name())]
158
+	name := normaliseVolumeName(v.Name())
159
+	s.locks.Lock(name)
160
+	defer s.locks.Unlock(name)
161
+
162
+	vc, exists := s.get(name)
161 163
 	if !exists {
162 164
 		return 0
163 165
 	}
... ...
@@ -166,8 +190,8 @@ func (s *VolumeStore) Count(v volume.Volume) uint {
166 166
 
167 167
 // List returns all the available volumes
168 168
 func (s *VolumeStore) List() []volume.Volume {
169
-	s.mu.Lock()
170
-	defer s.mu.Unlock()
169
+	s.globalLock.Lock()
170
+	defer s.globalLock.Unlock()
171 171
 	var ls []volume.Volume
172 172
 	for _, vc := range s.vols {
173 173
 		ls = append(ls, vc.Volume)
... ...
@@ -192,8 +216,8 @@ func byDriver(name string) filterFunc {
192 192
 
193 193
 // filter returns the available volumes filtered by a filterFunc function
194 194
 func (s *VolumeStore) filter(f filterFunc) []volume.Volume {
195
-	s.mu.Lock()
196
-	defer s.mu.Unlock()
195
+	s.globalLock.Lock()
196
+	defer s.globalLock.Unlock()
197 197
 	var ls []volume.Volume
198 198
 	for _, vc := range s.vols {
199 199
 		if f(vc.Volume) {