Browse code

Create UID and MCS category allocators

Clayton Coleman authored on 2015/05/28 05:32:03
Showing 8 changed files
1 1
new file mode 100644
... ...
@@ -0,0 +1,229 @@
0
+package mcs
1
+
2
+import (
3
+	"bytes"
4
+	"fmt"
5
+	"sort"
6
+	"strconv"
7
+	"strings"
8
+)
9
+
10
+const maxCategories = 1024
11
+
12
+type Label struct {
13
+	Prefix string
14
+	Categories
15
+}
16
+
17
+// NewLabel creates a Label object based on the offset given by
18
+// offset with a number of labels equal to k. Prefix may be any
19
+// valid SELinux label (user:role:type:level:).
20
+func NewLabel(prefix string, offset uint64, k uint) (*Label, error) {
21
+	if len(prefix) > 0 && !(strings.HasSuffix(prefix, ":") || strings.HasSuffix(prefix, ",")) {
22
+		prefix = prefix + ":"
23
+	}
24
+	return &Label{
25
+		Prefix:     prefix,
26
+		Categories: categoriesForOffset(offset, maxCategories, k),
27
+	}, nil
28
+}
29
+
30
+// ParseLabel converts a string value representing an SELinux label
31
+// into a Label object, extracting and ordering categories.
32
+func ParseLabel(in string) (*Label, error) {
33
+	if len(in) == 0 {
34
+		return &Label{}, nil
35
+	}
36
+
37
+	prefix := strings.Split(in, ":")
38
+	segment := prefix[len(prefix)-1]
39
+	if len(prefix) > 0 {
40
+		prefix = prefix[:len(prefix)-1]
41
+	}
42
+	prefixString := strings.Join(prefix, ":")
43
+	if len(prefixString) > 0 {
44
+		prefixString += ":"
45
+	}
46
+
47
+	var categories Categories
48
+	for _, s := range strings.Split(segment, ",") {
49
+		if !strings.HasPrefix(s, "c") {
50
+			return nil, fmt.Errorf("categories must start with 'c': %s", segment)
51
+		}
52
+		i, err := strconv.Atoi(s[1:])
53
+		if err != nil {
54
+			return nil, err
55
+		}
56
+		categories = append(categories, uint16(i))
57
+	}
58
+	sort.Sort(categories)
59
+
60
+	last := -1
61
+	for _, c := range categories {
62
+		if int(c) == last {
63
+			return nil, fmt.Errorf("labels may not contain duplicate categories: %s", in)
64
+		}
65
+		last = int(c)
66
+	}
67
+
68
+	return &Label{
69
+		Prefix:     prefixString,
70
+		Categories: categories,
71
+	}, nil
72
+}
73
+
74
+func (labels *Label) String() string {
75
+	buf := bytes.Buffer{}
76
+	buf.WriteString(labels.Prefix)
77
+	for i, label := range labels.Categories {
78
+		if i != 0 {
79
+			buf.WriteRune(',')
80
+		}
81
+		buf.WriteRune('c')
82
+		buf.WriteString(strconv.Itoa(int(label)))
83
+	}
84
+	return buf.String()
85
+}
86
+
87
+// Offset returns the rank of the provided categories in the
88
+// co-lex rank operation (k is implicit)
89
+func (categories Categories) Offset() uint64 {
90
+	k := len(categories)
91
+	r := uint64(0)
92
+	for i := 0; i < k; i++ {
93
+		r += binomial(uint(categories[i]), uint(k-i))
94
+	}
95
+	return r
96
+}
97
+
98
+// categoriesForOffset calculates the co-lex unrank operation
99
+// on the combinatorial group defined by n, k, where rank is
100
+// the offset. n is typically 1024 (the SELinux max)
101
+func categoriesForOffset(offset uint64, n, k uint) Categories {
102
+	var categories Categories
103
+	for i := uint(0); i < k; i++ {
104
+		current := binomial(n, k-i)
105
+		for current > offset {
106
+			n--
107
+			current = binomial(n, k-i)
108
+		}
109
+		categories = append(categories, uint16(n))
110
+		offset = offset - current
111
+	}
112
+	sort.Sort(categories)
113
+	return categories
114
+}
115
+
116
+type Categories []uint16
117
+
118
+func (c Categories) Len() int      { return len(c) }
119
+func (c Categories) Swap(i, j int) { c[i], c[j] = c[j], c[i] }
120
+func (c Categories) Less(i, j int) bool {
121
+	return c[i] > c[j]
122
+}
123
+
124
+func binomial(n, k uint) uint64 {
125
+	if n < k {
126
+		return 0
127
+	}
128
+	if k == n {
129
+		return 1
130
+	}
131
+	r := uint64(1)
132
+	for d := uint(1); d <= k; d++ {
133
+		r *= uint64(n)
134
+		r /= uint64(d)
135
+		n--
136
+	}
137
+	return r
138
+}
139
+
140
+type Range struct {
141
+	prefix string
142
+	n      uint
143
+	k      uint
144
+}
145
+
146
+// NewRange describes an SELinux category range, where prefix may include
147
+// the user, type, role, and level of the range, and n and k represent the
148
+// highest category c0 to c(N-1) and k represents the number of labels to use.
149
+// A range can be used to check whether a given label matches the range.
150
+func NewRange(prefix string, n, k uint) (*Range, error) {
151
+	if n == 0 {
152
+		return nil, fmt.Errorf("label max value must be a positive integer")
153
+	}
154
+	if k == 0 {
155
+		return nil, fmt.Errorf("label length must be a positive integer")
156
+	}
157
+	return &Range{
158
+		prefix: prefix,
159
+		n:      n,
160
+		k:      k,
161
+	}, nil
162
+}
163
+
164
+func ParseRange(in string) (*Range, error) {
165
+	seg := strings.SplitN(in, "/", 2)
166
+	if len(seg) != 2 {
167
+		return nil, fmt.Errorf("range not in the format \"<prefix>/<numLabel>[,<maxCategory>]\"")
168
+	}
169
+	prefix := seg[0]
170
+	n := maxCategories
171
+	size := strings.SplitN(seg[1], ",", 2)
172
+	k, err := strconv.Atoi(size[0])
173
+	if err != nil {
174
+		return nil, fmt.Errorf("range not in the format \"<prefix>/<numLabel>[,<maxCategory>]\"")
175
+	}
176
+	if len(size) > 1 {
177
+		max, err := strconv.Atoi(size[1])
178
+		if err != nil {
179
+			return nil, fmt.Errorf("range not in the format \"<prefix>/<numLabel>[,<maxCategory>]\"")
180
+		}
181
+		n = max
182
+	}
183
+	if k > 5 {
184
+		return nil, fmt.Errorf("range may not include more than 5 labels")
185
+	}
186
+	if n > maxCategories {
187
+		return nil, fmt.Errorf("range may not include more than %d categories", maxCategories)
188
+	}
189
+	return NewRange(prefix, uint(n), uint(k))
190
+}
191
+
192
+func (r *Range) Size() uint64 {
193
+	return binomial(r.n, uint(r.k))
194
+}
195
+
196
+func (r *Range) String() string {
197
+	if r.n == maxCategories {
198
+		return fmt.Sprintf("%s/%d", r.prefix, r.k)
199
+	}
200
+	return fmt.Sprintf("%s/%d,%d", r.prefix, r.k, r.n)
201
+}
202
+
203
+func (r *Range) LabelAt(offset uint64) (*Label, bool) {
204
+	label, err := NewLabel(r.prefix, offset, r.k)
205
+	return label, err == nil
206
+}
207
+
208
+func (r *Range) Contains(label *Label) bool {
209
+	if label.Prefix != r.prefix {
210
+		return false
211
+	}
212
+	if len(label.Categories) != int(r.k) {
213
+		return false
214
+	}
215
+	for _, i := range label.Categories {
216
+		if i >= uint16(r.n) {
217
+			return false
218
+		}
219
+	}
220
+	return true
221
+}
222
+
223
+func (r *Range) Offset(label *Label) (bool, uint64) {
224
+	if !r.Contains(label) {
225
+		return false, 0
226
+	}
227
+	return true, label.Offset()
228
+}
0 229
new file mode 100644
... ...
@@ -0,0 +1,188 @@
0
+package mcs
1
+
2
+import (
3
+	"reflect"
4
+	"strings"
5
+	"testing"
6
+)
7
+
8
+type rangeTest struct {
9
+	label string
10
+	in    bool
11
+}
12
+
13
+func TestParseRange(t *testing.T) {
14
+	testCases := map[string]struct {
15
+		in    string
16
+		errFn func(error) bool
17
+		r     Range
18
+		total uint64
19
+		tests []rangeTest
20
+	}{
21
+		"identity range": {
22
+			in: "test,s0/1",
23
+			r: Range{
24
+				prefix: "test,s0",
25
+				n:      1024,
26
+				k:      1,
27
+			},
28
+			total: 1024,
29
+		},
30
+		"simple range": {
31
+			in: "s0:/2",
32
+			r: Range{
33
+				prefix: "s0:",
34
+				n:      1024,
35
+				k:      2,
36
+			},
37
+			total: 523776,
38
+			tests: []rangeTest{
39
+				{label: "c100,c3", in: false},
40
+				{label: "s0:c100,c3", in: true},
41
+				{label: "s0:c100,c3,c0", in: false},
42
+				{label: "s0:c3", in: false},
43
+				{label: "s0:c1024,c0", in: false},
44
+			},
45
+		},
46
+		"limited range with full prefix": {
47
+			in: "systemd_u:systemd_t:cupsd_t:s0:/2,10",
48
+			r: Range{
49
+				prefix: "systemd_u:systemd_t:cupsd_t:s0:",
50
+				n:      10,
51
+				k:      2,
52
+			},
53
+			total: 45,
54
+			tests: []rangeTest{
55
+				{label: "systemd_u:systemd_t:cupsd_t:s0:c100,c3", in: false},
56
+				{label: "systemd_u:systemd_t:cupsd_t:s0:c9,c8", in: true},
57
+			},
58
+		},
59
+		"NaN": {
60
+			in:    "/a",
61
+			errFn: func(err error) bool { return strings.Contains(err.Error(), "range not in the format") },
62
+		},
63
+		"zero k": {
64
+			in:    "/0",
65
+			errFn: func(err error) bool { return strings.Contains(err.Error(), "label length must be a positive integer") },
66
+		},
67
+	}
68
+
69
+	for s, testCase := range testCases {
70
+		r, err := ParseRange(testCase.in)
71
+		if testCase.errFn != nil && !testCase.errFn(err) {
72
+			t.Errorf("%s: unexpected error: %v", s, err)
73
+			continue
74
+		}
75
+		if err != nil {
76
+			continue
77
+		}
78
+		if r.String() != testCase.in {
79
+			t.Errorf("%s: range.String() did not match input: %s", r.String())
80
+		}
81
+		if *r != testCase.r {
82
+			t.Errorf("%s: unexpected range: %#v", s, r)
83
+		}
84
+		if r.Size() != testCase.total {
85
+			t.Errorf("%s: unexpected total: %d", s, r.Size())
86
+		}
87
+		for _, test := range testCase.tests {
88
+			l, err := ParseLabel(test.label)
89
+			if err != nil {
90
+				t.Fatal(err)
91
+			}
92
+			if r.Contains(l) != test.in {
93
+				t.Errorf("%s: range contains(%s) != %t", s, l, !test.in)
94
+			}
95
+		}
96
+	}
97
+}
98
+
99
+func TestLabel(t *testing.T) {
100
+	if _, err := ParseLabel("s0:c9,c9"); err == nil {
101
+		t.Errorf("unexpected non-error")
102
+	}
103
+	if _, err := ParseLabel("s0:ca,cb"); err == nil {
104
+		t.Errorf("unexpected non-error")
105
+	}
106
+
107
+	testCases := map[string]struct {
108
+		in     string
109
+		expect Categories
110
+		offset uint64
111
+		out    string
112
+	}{
113
+		"identity range": {
114
+			in:     "c0,c1",
115
+			expect: Categories{1, 0},
116
+			offset: 0,
117
+			out:    "c1,c0",
118
+		},
119
+		"order doesn't matter": {
120
+			in:     "c1,c0",
121
+			expect: Categories{1, 0},
122
+			offset: 0,
123
+			out:    "c1,c0",
124
+		},
125
+		"second": {
126
+			in:     "c2,c0",
127
+			expect: Categories{2, 0},
128
+			offset: 1,
129
+			out:    "c2,c0",
130
+		},
131
+		"single": {
132
+			in:     "c3",
133
+			expect: Categories{3},
134
+			offset: 3,
135
+			out:    "c3",
136
+		},
137
+		"third": {
138
+			in:     "c3,c0",
139
+			expect: Categories{3, 0},
140
+			offset: 3,
141
+			out:    "c3,c0",
142
+		},
143
+		"three labels": {
144
+			in:     "c3,c1,c0",
145
+			expect: Categories{3, 1, 0},
146
+			offset: 1,
147
+			out:    "c3,c1,c0",
148
+		},
149
+		"three labels, second": {
150
+			in:     "s0:c10,c0,c2",
151
+			expect: Categories{10, 2, 0},
152
+			offset: 121,
153
+			out:    "s0:c10,c2,c0",
154
+		},
155
+		"very large": {
156
+			in:     "c1021,c1020",
157
+			expect: Categories{1021, 1020},
158
+			offset: 521730,
159
+			out:    "c1021,c1020",
160
+		},
161
+	}
162
+
163
+	for s, testCase := range testCases {
164
+		labels, err := ParseLabel(testCase.in)
165
+		if err != nil {
166
+			t.Errorf("%s: failed to parse labels: %v", err)
167
+			continue
168
+		}
169
+		if !reflect.DeepEqual(labels.Categories, testCase.expect) {
170
+			t.Errorf("%s: unexpected categories: %v %v", s, labels.Categories, testCase.expect)
171
+			continue
172
+		}
173
+		if testCase.out != labels.String() {
174
+			t.Errorf("%s: unexpected string: %s", s, labels.String())
175
+			continue
176
+		}
177
+		v := labels.Offset()
178
+		if v != testCase.offset {
179
+			t.Errorf("%s: unexpected offset: %d", s, v)
180
+			continue
181
+		}
182
+		categories := categoriesForOffset(testCase.offset, maxCategories, uint(len(testCase.expect)))
183
+		if !reflect.DeepEqual(categories, labels.Categories) {
184
+			t.Errorf("%s: could not roundtrip categories %v: %v", s, labels.Categories, categories)
185
+		}
186
+	}
187
+}
0 188
new file mode 100644
... ...
@@ -0,0 +1,139 @@
0
+package mcsallocator
1
+
2
+import (
3
+	"errors"
4
+	"fmt"
5
+
6
+	"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
7
+	"github.com/GoogleCloudPlatform/kubernetes/pkg/registry/service/allocator"
8
+
9
+	"github.com/openshift/origin/pkg/security/mcs"
10
+)
11
+
12
+// Interface manages the allocation of ports out of a range. Interface
13
+// should be threadsafe.
14
+type Interface interface {
15
+	Allocate(*mcs.Label) error
16
+	AllocateNext() (*mcs.Label, error)
17
+	Release(*mcs.Label) error
18
+}
19
+
20
+var (
21
+	ErrFull            = errors.New("range is full")
22
+	ErrNotInRange      = errors.New("provided label is not in the valid range")
23
+	ErrAllocated       = errors.New("provided label is already allocated")
24
+	ErrMismatchedRange = errors.New("the provided label does not match the current label range")
25
+)
26
+
27
+type Allocator struct {
28
+	r     *mcs.Range
29
+	alloc allocator.Interface
30
+}
31
+
32
+// Allocator implements Interface and Snapshottable
33
+var _ Interface = &Allocator{}
34
+
35
+// New creates a Allocator over a UID range, calling factory to construct the backing store.
36
+func New(r *mcs.Range, factory allocator.AllocatorFactory) *Allocator {
37
+	return &Allocator{
38
+		r:     r,
39
+		alloc: factory(int(r.Size()), r.String()),
40
+	}
41
+}
42
+
43
+// Free returns the count of port left in the range.
44
+func (r *Allocator) Free() int {
45
+	return r.alloc.Free()
46
+}
47
+
48
+// Allocate attempts to reserve the provided label. ErrNotInRange or
49
+// ErrAllocated will be returned if the label is not valid for this range
50
+// or has already been reserved.  ErrFull will be returned if there
51
+// are no labels left.
52
+func (r *Allocator) Allocate(label *mcs.Label) error {
53
+	ok, offset := r.contains(label)
54
+	if !ok {
55
+		return ErrNotInRange
56
+	}
57
+
58
+	allocated, err := r.alloc.Allocate(int(offset))
59
+	if err != nil {
60
+		return err
61
+	}
62
+	if !allocated {
63
+		return ErrAllocated
64
+	}
65
+	return nil
66
+}
67
+
68
+// AllocateNext reserves one of the labels from the pool. ErrFull may
69
+// be returned if there are no labels left.
70
+func (r *Allocator) AllocateNext() (*mcs.Label, error) {
71
+	offset, ok, err := r.alloc.AllocateNext()
72
+	if err != nil {
73
+		return nil, err
74
+	}
75
+	if !ok {
76
+		return nil, ErrFull
77
+	}
78
+	label, ok := r.r.LabelAt(uint64(offset))
79
+	if !ok {
80
+		return nil, ErrNotInRange
81
+	}
82
+	return label, nil
83
+}
84
+
85
+// Release releases the port back to the pool. Releasing an
86
+// unallocated port or a port out of the range is a no-op and
87
+// returns no error.
88
+func (r *Allocator) Release(label *mcs.Label) error {
89
+	ok, offset := r.contains(label)
90
+	if !ok {
91
+		// TODO: log a warning
92
+		return nil
93
+	}
94
+
95
+	return r.alloc.Release(int(offset))
96
+}
97
+
98
+// Has returns true if the provided port is already allocated and a call
99
+// to Allocate(label) would fail with ErrAllocated.
100
+func (r *Allocator) Has(label *mcs.Label) bool {
101
+	ok, offset := r.contains(label)
102
+	if !ok {
103
+		return false
104
+	}
105
+
106
+	return r.alloc.Has(int(offset))
107
+}
108
+
109
+// Snapshot saves the current state of the pool.
110
+func (r *Allocator) Snapshot(dst *api.RangeAllocation) error {
111
+	snapshottable, ok := r.alloc.(allocator.Snapshottable)
112
+	if !ok {
113
+		return fmt.Errorf("not a snapshottable allocator")
114
+	}
115
+	rangeString, data := snapshottable.Snapshot()
116
+	dst.Range = rangeString
117
+	dst.Data = data
118
+	return nil
119
+}
120
+
121
+// Restore restores the pool to the previously captured state. ErrMismatchedNetwork
122
+// is returned if the provided port range doesn't exactly match the previous range.
123
+func (r *Allocator) Restore(into *mcs.Range, data []byte) error {
124
+	if into.String() != r.r.String() {
125
+		return ErrMismatchedRange
126
+	}
127
+	snapshottable, ok := r.alloc.(allocator.Snapshottable)
128
+	if !ok {
129
+		return fmt.Errorf("not a snapshottable allocator")
130
+	}
131
+	return snapshottable.Restore(into.String(), data)
132
+}
133
+
134
+// contains returns true and the offset if the label is in the range (and aligned), and false
135
+// and nil otherwise.
136
+func (r *Allocator) contains(label *mcs.Label) (bool, uint64) {
137
+	return r.r.Offset(label)
138
+}
0 139
new file mode 100644
... ...
@@ -0,0 +1,70 @@
0
+package mcsallocator
1
+
2
+import (
3
+	"reflect"
4
+	"testing"
5
+
6
+	"github.com/GoogleCloudPlatform/kubernetes/pkg/registry/service/allocator"
7
+	"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
8
+
9
+	"github.com/openshift/origin/pkg/security/mcs"
10
+)
11
+
12
+func TestAllocate(t *testing.T) {
13
+	ranger, _ := mcs.NewRange("s0:", 5, 2)
14
+	r := New(ranger, allocator.NewContiguousAllocationInterface)
15
+	if f := r.Free(); f != 10 {
16
+		t.Errorf("unexpected free %d", f)
17
+	}
18
+	found := util.NewStringSet()
19
+	count := 0
20
+	for r.Free() > 0 {
21
+		label, err := r.AllocateNext()
22
+		if err != nil {
23
+			t.Fatalf("error @ %d: %v", count, err)
24
+		}
25
+		count++
26
+		if !ranger.Contains(label) {
27
+			t.Fatalf("allocated %s which is outside of %s", label, ranger)
28
+		}
29
+		if found.Has(label.String()) {
30
+			t.Fatalf("allocated %s twice @ %d", label, count)
31
+		}
32
+		found.Insert(label.String())
33
+	}
34
+	if _, err := r.AllocateNext(); err != ErrFull {
35
+		t.Fatal(err)
36
+	}
37
+
38
+	released, _ := ranger.LabelAt(3)
39
+	if err := r.Release(released); err != nil {
40
+		t.Fatal(err)
41
+	}
42
+	if f := r.Free(); f != 1 {
43
+		t.Errorf("unexpected free %d", f)
44
+	}
45
+	label, err := r.AllocateNext()
46
+	if err != nil {
47
+		t.Fatal(err)
48
+	}
49
+	if !reflect.DeepEqual(released, label) {
50
+		t.Errorf("unexpected %s : %s", label, released)
51
+	}
52
+
53
+	if err := r.Release(released); err != nil {
54
+		t.Fatal(err)
55
+	}
56
+	badLabel, _ := ranger.LabelAt(30)
57
+	if err := r.Allocate(badLabel); err != ErrNotInRange {
58
+		t.Fatal(err)
59
+	}
60
+	if f := r.Free(); f != 1 {
61
+		t.Errorf("unexpected free %d", f)
62
+	}
63
+	if err := r.Allocate(released); err != nil {
64
+		t.Fatal(err)
65
+	}
66
+	if f := r.Free(); f != 0 {
67
+		t.Errorf("unexpected free %d", f)
68
+	}
69
+}
0 70
new file mode 100644
... ...
@@ -0,0 +1,91 @@
0
+package uid
1
+
2
+import (
3
+	"fmt"
4
+)
5
+
6
+type Block struct {
7
+	Start uint32
8
+	End   uint32
9
+}
10
+
11
+func (b Block) String() string {
12
+	return fmt.Sprintf("%d-%d", b.Start, b.End)
13
+}
14
+
15
+func (b Block) Size() uint32 {
16
+	return b.End - b.Start + 1
17
+}
18
+
19
+type Range struct {
20
+	block Block
21
+	size  uint32
22
+}
23
+
24
+func NewRange(start, end, size uint32) (*Range, error) {
25
+	if start > end {
26
+		return nil, fmt.Errorf("start %d must be less than end %d", start, end)
27
+	}
28
+	if size == 0 {
29
+		return nil, fmt.Errorf("block size must be a positive integer")
30
+	}
31
+	if (end - start) < size {
32
+		return nil, fmt.Errorf("block size must be less than or equal to the range")
33
+	}
34
+	return &Range{
35
+		block: Block{start, end},
36
+		size:  size,
37
+	}, nil
38
+}
39
+
40
+func ParseRange(in string) (*Range, error) {
41
+	var start, end, block uint32
42
+	n, err := fmt.Sscanf(in, "%d-%d/%d", &start, &end, &block)
43
+	if err != nil {
44
+		return nil, err
45
+	}
46
+	if n != 3 {
47
+		return nil, fmt.Errorf("range not in the format \"<start>-<end>/<blockSize>\"")
48
+	}
49
+	return NewRange(start, end, block)
50
+}
51
+
52
+func (r *Range) Size() uint32 {
53
+	return r.block.Size() / r.size
54
+}
55
+
56
+func (r *Range) String() string {
57
+	return fmt.Sprintf("%s/%d", r.block, r.size)
58
+}
59
+
60
+func (r *Range) BlockAt(offset uint32) (Block, bool) {
61
+	if offset > r.Size() {
62
+		return Block{}, false
63
+	}
64
+	start := r.block.Start + offset*r.size
65
+	return Block{
66
+		Start: start,
67
+		End:   start + r.size - 1,
68
+	}, true
69
+}
70
+
71
+func (r *Range) Contains(block Block) bool {
72
+	ok, _ := r.Offset(block)
73
+	return ok
74
+}
75
+
76
+func (r *Range) Offset(block Block) (bool, uint32) {
77
+	if block.Start < r.block.Start {
78
+		return false, 0
79
+	}
80
+	if block.End > r.block.End {
81
+		return false, 0
82
+	}
83
+	if block.End-block.Start+1 != r.size {
84
+		return false, 0
85
+	}
86
+	if (block.Start-r.block.Start)%r.size != 0 {
87
+		return false, 0
88
+	}
89
+	return true, (block.Start - r.block.Start) / r.size
90
+}
0 91
new file mode 100644
... ...
@@ -0,0 +1,143 @@
0
+package uid
1
+
2
+import (
3
+	"strings"
4
+	"testing"
5
+)
6
+
7
+func TestParseRange(t *testing.T) {
8
+	testCases := map[string]struct {
9
+		in    string
10
+		errFn func(error) bool
11
+		r     Range
12
+		total uint32
13
+	}{
14
+		"identity range": {
15
+			in: "1-1/1",
16
+			r: Range{
17
+				block: Block{1, 1},
18
+				size:  1,
19
+			},
20
+			total: 1,
21
+		},
22
+		"simple range": {
23
+			in: "1-2/1",
24
+			r: Range{
25
+				block: Block{1, 2},
26
+				size:  1,
27
+			},
28
+			total: 2,
29
+		},
30
+		"wide range": {
31
+			in: "10000-999999/1000",
32
+			r: Range{
33
+				block: Block{10000, 999999},
34
+				size:  1000,
35
+			},
36
+			total: 990,
37
+		},
38
+		"overflow uint": {
39
+			in:    "1000-100000000000000/1",
40
+			errFn: func(err error) bool { return strings.Contains(err.Error(), "unsigned integer overflow") },
41
+		},
42
+		"negative range": {
43
+			in:    "1000-999/1",
44
+			errFn: func(err error) bool { return strings.Contains(err.Error(), "must be less than end 999") },
45
+		},
46
+		"zero block size": {
47
+			in:    "1000-1000/0",
48
+			errFn: func(err error) bool { return strings.Contains(err.Error(), "block size must be a positive integer") },
49
+		},
50
+		"large block size": {
51
+			in:    "1000-1001/3",
52
+			errFn: func(err error) bool { return strings.Contains(err.Error(), "must be less than or equal to the range") },
53
+		},
54
+	}
55
+
56
+	for s, testCase := range testCases {
57
+		r, err := ParseRange(testCase.in)
58
+		if testCase.errFn != nil && !testCase.errFn(err) {
59
+			t.Errorf("%s: unexpected error: %v", s, err)
60
+			continue
61
+		}
62
+		if err != nil {
63
+			continue
64
+		}
65
+		if r.block.Start != testCase.r.block.Start || r.block.End != testCase.r.block.End || r.size != testCase.r.size {
66
+			t.Errorf("%s: unexpected range: %#v", s, r)
67
+		}
68
+		if r.Size() != testCase.total {
69
+			t.Errorf("%s: unexpected total: %d", s, r.Size())
70
+		}
71
+	}
72
+}
73
+
74
+func TestOffset(t *testing.T) {
75
+	testCases := map[string]struct {
76
+		r         Range
77
+		block     Block
78
+		contained bool
79
+		offset    uint32
80
+	}{
81
+		"identity range": {
82
+			r: Range{
83
+				block: Block{1, 1},
84
+				size:  1,
85
+			},
86
+			block:     Block{1, 1},
87
+			contained: true,
88
+		},
89
+		"out of identity range": {
90
+			r: Range{
91
+				block: Block{1, 1},
92
+				size:  1,
93
+			},
94
+			block: Block{2, 2},
95
+		},
96
+		"out of identity range expanded": {
97
+			r: Range{
98
+				block: Block{1, 1},
99
+				size:  1,
100
+			},
101
+			block: Block{2, 3},
102
+		},
103
+		"aligned to offset": {
104
+			r: Range{
105
+				block: Block{0, 100},
106
+				size:  10,
107
+			},
108
+			block:     Block{10, 19},
109
+			contained: true,
110
+			offset:    1,
111
+		},
112
+		"not aligned": {
113
+			r: Range{
114
+				block: Block{0, 100},
115
+				size:  10,
116
+			},
117
+			block: Block{11, 20},
118
+		},
119
+	}
120
+
121
+	for s, testCase := range testCases {
122
+		contained, offset := testCase.r.Offset(testCase.block)
123
+		if contained != testCase.contained {
124
+			t.Errorf("%s: unexpected contained: %t", s, contained)
125
+			continue
126
+		}
127
+		if offset != testCase.offset {
128
+			t.Errorf("%s: unexpected offset: %d", s, offset)
129
+			continue
130
+		}
131
+		if contained {
132
+			block, ok := testCase.r.BlockAt(offset)
133
+			if !ok {
134
+				t.Errorf("%s: should find block", s)
135
+				continue
136
+			}
137
+			if block != testCase.block {
138
+				t.Errorf("%s: blocks are not equivalent: %#v", s, block)
139
+			}
140
+		}
141
+	}
142
+}
0 143
new file mode 100644
... ...
@@ -0,0 +1,139 @@
0
+package uidallocator
1
+
2
+import (
3
+	"errors"
4
+	"fmt"
5
+
6
+	"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
7
+	"github.com/GoogleCloudPlatform/kubernetes/pkg/registry/service/allocator"
8
+
9
+	"github.com/openshift/origin/pkg/security/uid"
10
+)
11
+
12
+// Interface manages the allocation of ports out of a range. Interface
13
+// should be threadsafe.
14
+type Interface interface {
15
+	Allocate(uid.Block) error
16
+	AllocateNext() (uid.Block, error)
17
+	Release(uid.Block) error
18
+}
19
+
20
+var (
21
+	ErrFull            = errors.New("range is full")
22
+	ErrNotInRange      = errors.New("provided UID range is not in the valid range")
23
+	ErrAllocated       = errors.New("provided UID range is already allocated")
24
+	ErrMismatchedRange = errors.New("the provided UID range does not match the current UID range")
25
+)
26
+
27
+type Allocator struct {
28
+	r     *uid.Range
29
+	alloc allocator.Interface
30
+}
31
+
32
+// Allocator implements Interface and Snapshottable
33
+var _ Interface = &Allocator{}
34
+
35
+// New creates a Allocator over a UID range, calling factory to construct the backing store.
36
+func New(r *uid.Range, factory allocator.AllocatorFactory) *Allocator {
37
+	return &Allocator{
38
+		r:     r,
39
+		alloc: factory(int(r.Size()), r.String()),
40
+	}
41
+}
42
+
43
+// Free returns the count of port left in the range.
44
+func (r *Allocator) Free() int {
45
+	return r.alloc.Free()
46
+}
47
+
48
+// Allocate attempts to reserve the provided block. ErrNotInRange or
49
+// ErrAllocated will be returned if the block is not valid for this range
50
+// or has already been reserved.  ErrFull will be returned if there
51
+// are no blocks left.
52
+func (r *Allocator) Allocate(block uid.Block) error {
53
+	ok, offset := r.contains(block)
54
+	if !ok {
55
+		return ErrNotInRange
56
+	}
57
+
58
+	allocated, err := r.alloc.Allocate(int(offset))
59
+	if err != nil {
60
+		return err
61
+	}
62
+	if !allocated {
63
+		return ErrAllocated
64
+	}
65
+	return nil
66
+}
67
+
68
+// AllocateNext reserves one of the ports from the pool. ErrFull may
69
+// be returned if there are no ports left.
70
+func (r *Allocator) AllocateNext() (uid.Block, error) {
71
+	offset, ok, err := r.alloc.AllocateNext()
72
+	if err != nil {
73
+		return uid.Block{}, err
74
+	}
75
+	if !ok {
76
+		return uid.Block{}, ErrFull
77
+	}
78
+	block, ok := r.r.BlockAt(uint32(offset))
79
+	if !ok {
80
+		return uid.Block{}, ErrNotInRange
81
+	}
82
+	return block, nil
83
+}
84
+
85
+// Release releases the port back to the pool. Releasing an
86
+// unallocated port or a port out of the range is a no-op and
87
+// returns no error.
88
+func (r *Allocator) Release(block uid.Block) error {
89
+	ok, offset := r.contains(block)
90
+	if !ok {
91
+		// TODO: log a warning
92
+		return nil
93
+	}
94
+
95
+	return r.alloc.Release(int(offset))
96
+}
97
+
98
+// Has returns true if the provided port is already allocated and a call
99
+// to Allocate(block) would fail with ErrAllocated.
100
+func (r *Allocator) Has(block uid.Block) bool {
101
+	ok, offset := r.contains(block)
102
+	if !ok {
103
+		return false
104
+	}
105
+
106
+	return r.alloc.Has(int(offset))
107
+}
108
+
109
+// Snapshot saves the current state of the pool.
110
+func (r *Allocator) Snapshot(dst *api.RangeAllocation) error {
111
+	snapshottable, ok := r.alloc.(allocator.Snapshottable)
112
+	if !ok {
113
+		return fmt.Errorf("not a snapshottable allocator")
114
+	}
115
+	rangeString, data := snapshottable.Snapshot()
116
+	dst.Range = rangeString
117
+	dst.Data = data
118
+	return nil
119
+}
120
+
121
+// Restore restores the pool to the previously captured state. ErrMismatchedNetwork
122
+// is returned if the provided port range doesn't exactly match the previous range.
123
+func (r *Allocator) Restore(into *uid.Range, data []byte) error {
124
+	if into.String() != r.r.String() {
125
+		return ErrMismatchedRange
126
+	}
127
+	snapshottable, ok := r.alloc.(allocator.Snapshottable)
128
+	if !ok {
129
+		return fmt.Errorf("not a snapshottable allocator")
130
+	}
131
+	return snapshottable.Restore(into.String(), data)
132
+}
133
+
134
+// contains returns true and the offset if the block is in the range (and aligned), and false
135
+// and nil otherwise.
136
+func (r *Allocator) contains(block uid.Block) (bool, uint32) {
137
+	return r.r.Offset(block)
138
+}
0 139
new file mode 100644
... ...
@@ -0,0 +1,71 @@
0
+package uidallocator
1
+
2
+import (
3
+	"testing"
4
+
5
+	"github.com/GoogleCloudPlatform/kubernetes/pkg/registry/service/allocator"
6
+	"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
7
+
8
+	"github.com/openshift/origin/pkg/security/uid"
9
+)
10
+
11
+func TestAllocate(t *testing.T) {
12
+	ranger, _ := uid.NewRange(0, 9, 2)
13
+	r := New(ranger, allocator.NewContiguousAllocationInterface)
14
+	if f := r.Free(); f != 5 {
15
+		t.Errorf("unexpected free %d", f)
16
+	}
17
+	found := util.NewStringSet()
18
+	count := 0
19
+	for r.Free() > 0 {
20
+		block, err := r.AllocateNext()
21
+		if err != nil {
22
+			t.Fatalf("error @ %d: %v", count, err)
23
+		}
24
+		count++
25
+		if !ranger.Contains(block) {
26
+			t.Fatalf("allocated %s which is outside of %s", block, ranger)
27
+		}
28
+		if found.Has(block.String()) {
29
+			t.Fatalf("allocated %s twice @ %d", block, count)
30
+		}
31
+		found.Insert(block.String())
32
+	}
33
+	if _, err := r.AllocateNext(); err != ErrFull {
34
+		t.Fatal(err)
35
+	}
36
+
37
+	released := uid.Block{2, 3}
38
+	if err := r.Release(released); err != nil {
39
+		t.Fatal(err)
40
+	}
41
+	if f := r.Free(); f != 1 {
42
+		t.Errorf("unexpected free %d", f)
43
+	}
44
+	block, err := r.AllocateNext()
45
+	if err != nil {
46
+		t.Fatal(err)
47
+	}
48
+	if released != block {
49
+		t.Errorf("unexpected %s : %s", block, released)
50
+	}
51
+
52
+	if err := r.Release(released); err != nil {
53
+		t.Fatal(err)
54
+	}
55
+	if err := r.Allocate(uid.Block{11, 11}); err != ErrNotInRange {
56
+		t.Fatal(err)
57
+	}
58
+	if err := r.Allocate(uid.Block{8, 11}); err != ErrNotInRange {
59
+		t.Fatal(err)
60
+	}
61
+	if f := r.Free(); f != 1 {
62
+		t.Errorf("unexpected free %d", f)
63
+	}
64
+	if err := r.Allocate(released); err != nil {
65
+		t.Fatal(err)
66
+	}
67
+	if f := r.Free(); f != 0 {
68
+		t.Errorf("unexpected free %d", f)
69
+	}
70
+}