Browse code

daemon/libnetwork/bitmap: add OnesCount method

Signed-off-by: Cory Snider <csnider@mirantis.com>

Cory Snider authored on 2025/08/14 05:14:23
Showing 3 changed files
1 1
new file mode 100644
... ...
@@ -0,0 +1 @@
0
+/testdata/rapid/**
... ...
@@ -6,6 +6,7 @@ import (
6 6
 	"encoding/json"
7 7
 	"errors"
8 8
 	"fmt"
9
+	"math/bits"
9 10
 )
10 11
 
11 12
 // block sequence constants
... ...
@@ -176,6 +177,52 @@ func (s *sequence) fromByteArray(data []byte) error {
176 176
 	return nil
177 177
 }
178 178
 
179
+// OnesCount calculates the number of selected bits in the range [start, end].
180
+func (h *Bitmap) OnesCount(start, end uint64) (uint64, error) {
181
+	if end < start || end >= h.bits {
182
+		return 0, fmt.Errorf("invalid bit range [%d, %d]", start, end)
183
+	}
184
+
185
+	// Account for the starting ordinal being partway into a block: count
186
+	// the sequence element's block once with a bitmask applied, then count
187
+	// the remaining repeats of the sequence element without the mask.
188
+	current, _, precBlocks, _ := findSequence(h.head, start/8)
189
+	var (
190
+		blocksToCount = end/blockLen - start/blockLen + 1
191
+		curblocks     = min(current.count-precBlocks, blocksToCount)
192
+		mask          = blockMAX >> (start % blockLen)
193
+		runlen        = uint64(1)
194
+		count         = uint64(0)
195
+	)
196
+	for blocksToCount > 0 {
197
+		if blocksToCount == 1 {
198
+			// We're counting the last block.
199
+			// (Which could be the same as the first block.)
200
+			// Mask off the bits beyond the end ordinal.
201
+			mask &= blockMAX << (blockLen - end%blockLen - 1)
202
+		}
203
+		count += uint64(bits.OnesCount32(current.block&mask)) * runlen
204
+		mask = blockMAX
205
+		blocksToCount -= runlen
206
+		curblocks -= runlen
207
+		if curblocks == 0 {
208
+			current = current.next
209
+			if current == nil {
210
+				break
211
+			}
212
+			curblocks = min(current.count, blocksToCount)
213
+		}
214
+		// If the block containing the end ordinal is a repeat of the
215
+		// current block, split the counting across two loop iterations.
216
+		// The final repeat of the block needs to be counted with a
217
+		// bitmask applied so we do not count bits beyond the end
218
+		// ordinal.
219
+		runlen = max(1, min(curblocks, blocksToCount-1))
220
+	}
221
+
222
+	return count, nil
223
+}
224
+
179 225
 // SetAnyInRange sets the first unset bit in the range [start, end] and returns
180 226
 // the ordinal of the set bit.
181 227
 //
... ...
@@ -5,6 +5,10 @@ import (
5 5
 	"math/rand"
6 6
 	"testing"
7 7
 	"time"
8
+
9
+	"gotest.tools/v3/assert"
10
+	is "gotest.tools/v3/assert/cmp"
11
+	"pgregory.net/rapid"
8 12
 )
9 13
 
10 14
 func TestSequenceGetAvailableBit(t *testing.T) {
... ...
@@ -1231,3 +1235,80 @@ func TestMarshalJSON(t *testing.T) {
1231 1231
 		})
1232 1232
 	}
1233 1233
 }
1234
+
1235
+func TestOnesCount(t *testing.T) {
1236
+	rapid.Check(t, func(t *rapid.T) {
1237
+		bm := New(rapid.Uint64Range(10, 1000).Draw(t, "capacity"))
1238
+		ordinals := make([]uint64, bm.Bits())
1239
+		for i := range ordinals {
1240
+			ordinals[i] = uint64(i)
1241
+		}
1242
+		nBitsToSet := rapid.IntRange(0, int(bm.Bits()-1)).Draw(t, "nBitsToSet")
1243
+		selectedOrdinals := rapid.Permutation(ordinals).Draw(t, "selectedOrdinals")[:nBitsToSet]
1244
+		for _, i := range selectedOrdinals {
1245
+			assert.NilError(t, bm.Set(i))
1246
+		}
1247
+		t.Logf("%v", bm)
1248
+
1249
+		got, err := bm.OnesCount(0, bm.Bits()-1)
1250
+		assert.Check(t, err, "OnesCount of all ordinals should succeed")
1251
+		assert.Check(t, is.Equal(got, bm.Bits()-bm.Unselected()), "OnesCount of all ordinals should equal Bits-Unselected")
1252
+
1253
+		idxgen := rapid.Uint64Range(0, bm.Bits()-1)
1254
+		for range 1000 {
1255
+			start, end := idxgen.Draw(t, "start"), idxgen.Draw(t, "end")
1256
+			if start > end {
1257
+				start, end = end, start
1258
+			}
1259
+			var expected uint64
1260
+			for _, i := range selectedOrdinals {
1261
+				if i >= start && i <= end {
1262
+					expected++
1263
+				}
1264
+			}
1265
+			got, err := bm.OnesCount(start, end)
1266
+			assert.NilError(t, err)
1267
+			assert.Check(t, is.Equal(got, expected))
1268
+		}
1269
+	})
1270
+
1271
+	bm := New(8*blockLen + 1)
1272
+	for _, r := range []struct{ from, to uint64 }{
1273
+		// 0b0111111... repeated 3x
1274
+		{1, blockLen - 1},
1275
+		{blockLen + 1, 2*blockLen - 1},
1276
+		{2*blockLen + 1, 3*blockLen - 1},
1277
+
1278
+		// 0b0001111...
1279
+		{3*blockLen + 3, 3*blockLen + 6},
1280
+
1281
+		// 0b111111....., 0b1000000....
1282
+		{5 * blockLen, 6 * blockLen},
1283
+
1284
+		{8 * blockLen, 8 * blockLen},
1285
+	} {
1286
+		for i := r.from; i <= r.to; i++ {
1287
+			assert.NilError(t, bm.Set(i))
1288
+		}
1289
+	}
1290
+	t.Logf("%v", bm)
1291
+
1292
+	expectOnesCount := func(from, to uint64, expected uint64) {
1293
+		t.Helper()
1294
+		got, err := bm.OnesCount(from, to)
1295
+		assert.NilError(t, err)
1296
+		assert.Check(t, is.Equal(got, expected))
1297
+	}
1298
+
1299
+	expectOnesCount(0, bm.Bits()-1, bm.Bits()-bm.Unselected())
1300
+	expectOnesCount(0, 0, 0)
1301
+	expectOnesCount(0, 1, 1)
1302
+	expectOnesCount(1, 1, 1)
1303
+	expectOnesCount(3, 12, 12-3+1)
1304
+	expectOnesCount(blockLen-1, 2*blockLen+1, blockLen+1)
1305
+	expectOnesCount(5, 3*blockLen-3, blockLen-1-4+blockLen-1+blockLen-1-2)
1306
+	expectOnesCount(3*blockLen, 3*blockLen+4, 2)
1307
+	expectOnesCount(3*blockLen, 5*blockLen-1, 4)
1308
+	expectOnesCount(6*blockLen, 8*blockLen-1, 1)
1309
+	expectOnesCount(8*blockLen, 8*blockLen, 1)
1310
+}