Browse code

fs.Visit() returns nil flag

Signed-off-by: Sven Dowideit <SvenDowideit@docker.com>

Sven Dowideit authored on 2014/08/29 13:55:49
Showing 2 changed files
... ...
@@ -317,8 +317,13 @@ func (p flagSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
317 317
 // sortFlags returns the flags as a slice in lexicographical sorted order.
318 318
 func sortFlags(flags map[string]*Flag) []*Flag {
319 319
 	var list flagSlice
320
-	for _, f := range flags {
320
+
321
+	// The sorted list is based on the first name, when flag map might use the other names.
322
+	nameMap := make(map[string]string)
323
+
324
+	for n, f := range flags {
321 325
 		fName := strings.TrimPrefix(f.Names[0], "#")
326
+		nameMap[fName] = n
322 327
 		if len(f.Names) == 1 {
323 328
 			list = append(list, fName)
324 329
 			continue
... ...
@@ -338,7 +343,7 @@ func sortFlags(flags map[string]*Flag) []*Flag {
338 338
 	sort.Sort(list)
339 339
 	result := make([]*Flag, len(list))
340 340
 	for i, name := range list {
341
-		result[i] = flags[name]
341
+		result[i] = flags[nameMap[name]]
342 342
 	}
343 343
 	return result
344 344
 }
... ...
@@ -473,7 +478,7 @@ var Usage = func() {
473 473
 }
474 474
 
475 475
 // FlagCount returns the number of flags that have been defined.
476
-func (f *FlagSet) FlagCount() int { return len(f.formal) }
476
+func (f *FlagSet) FlagCount() int { return len(sortFlags(f.formal)) }
477 477
 
478 478
 // FlagCountUndeprecated returns the number of undeprecated flags that have been defined.
479 479
 func (f *FlagSet) FlagCountUndeprecated() int {
... ...
@@ -440,7 +440,7 @@ func TestFlagCounts(t *testing.T) {
440 440
 	fs.BoolVar(&flag, []string{"flag3"}, false, "regular flag")
441 441
 	fs.BoolVar(&flag, []string{"g", "#flag4", "-flag4"}, false, "regular flag")
442 442
 
443
-	if fs.FlagCount() != 10 {
443
+	if fs.FlagCount() != 6 {
444 444
 		t.Fatal("FlagCount wrong. ", fs.FlagCount())
445 445
 	}
446 446
 	if fs.FlagCountUndeprecated() != 4 {
... ...
@@ -457,3 +457,50 @@ func TestFlagCounts(t *testing.T) {
457 457
 		t.Fatal("NFlag wrong. ", fs.NFlag())
458 458
 	}
459 459
 }
460
+
461
+// Show up bug in sortFlags
462
+func TestSortFlags(t *testing.T) {
463
+	fs := NewFlagSet("help TestSortFlags", ContinueOnError)
464
+
465
+	var err error
466
+
467
+	var b bool
468
+	fs.BoolVar(&b, []string{"b", "-banana"}, false, "usage")
469
+
470
+	err = fs.Parse([]string{"--banana=true"})
471
+	if err != nil {
472
+		t.Fatal("expected no error; got ", err)
473
+	}
474
+
475
+	count := 0
476
+
477
+	fs.VisitAll(func(flag *Flag) {
478
+		count++
479
+		if flag == nil {
480
+			t.Fatal("VisitAll should not return a nil flag")
481
+		}
482
+	})
483
+	flagcount := fs.FlagCount()
484
+	if flagcount != count {
485
+		t.Fatalf("FlagCount (%d) != number (%d) of elements visited", flagcount, count)
486
+	}
487
+	// Make sure its idempotent
488
+	if flagcount != fs.FlagCount() {
489
+		t.Fatalf("FlagCount (%d) != fs.FlagCount() (%d) of elements visited", flagcount, fs.FlagCount())
490
+	}
491
+
492
+	count = 0
493
+	fs.Visit(func(flag *Flag) {
494
+		count++
495
+		if flag == nil {
496
+			t.Fatal("Visit should not return a nil flag")
497
+		}
498
+	})
499
+	nflag := fs.NFlag()
500
+	if nflag != count {
501
+		t.Fatalf("NFlag (%d) != number (%d) of elements visited", nflag, count)
502
+	}
503
+	if nflag != fs.NFlag() {
504
+		t.Fatalf("NFlag (%d) != fs.NFlag() (%d) of elements visited", nflag, fs.NFlag())
505
+	}
506
+}