Skip to content

Commit

Permalink
Fix shorthand combination edge case in c.Traverse() code path
Browse files Browse the repository at this point in the history
  • Loading branch information
inicula committed Sep 14, 2024
1 parent 8ba575e commit b07c5cd
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 2 deletions.
29 changes: 29 additions & 0 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,35 @@ func (c *Command) Traverse(args []string) (*Command, []string, error) {
// A flag without a value, or with an `=` separated value
case isFlagArg(arg):
flags = append(flags, arg)

if !strings.HasPrefix(arg, "-") || strings.HasPrefix(arg, "--") || strings.Contains(arg, "=") || len(arg) <= 2 {
continue // Not a shorthand combination, so nothing more to do.
}

shorthandCombination := arg[1:] // Skip leading "-"
lastPos := len(shorthandCombination) - 1
for i, shorthand := range shorthandCombination {
if shortHasNoOptDefVal(string(shorthand), c.Flags()) {
continue
}

// We found a shorthand that needs a value.

if i == lastPos {
// Since we're at the end of the shorthand combination, this means that the
// value for the shorthand is given in the next argument. (e.g. '-xyzf arg',
// where -x, -y, -z are boolean flags, and -f is a flag that needs a value).
inFlag = true
} else {
// Since the shorthand combination doesn't end here, this means that the
// value for the shorthand is given in the same argument, meaning we don't
// have to consume the next one. (e.g. '-xyzfarg', where -x, -y, -z are
// boolean flags, and -f is a flag that needs a value).
}

break
}

continue
}

Expand Down
60 changes: 58 additions & 2 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2253,14 +2253,70 @@ func TestTraverseWithParentFlags(t *testing.T) {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(args) != 1 && args[0] != "--add" {
if len(args) != 1 || args[0] != "--int" {
t.Errorf("Wrong args: %v", args)
}
if c.Name() != childCmd.Name() {
t.Errorf("Expected command: %q, got: %q", childCmd.Name(), c.Name())
}
}

func TestTraverseWithShorthandCombinationInParentFlags(t *testing.T) {
rootCmd := &Command{Use: "root", TraverseChildren: true}
stringVal := rootCmd.Flags().StringP("str", "s", "", "")
boolVal := rootCmd.Flags().BoolP("bool", "b", false, "")

childCmd := &Command{Use: "child"}
childCmd.Flags().Int("int", -1, "")

rootCmd.AddCommand(childCmd)

c, args, err := rootCmd.Traverse([]string{"-bs", "ok", "child", "--int"})
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(args) != 1 || args[0] != "--int" {
t.Errorf("Wrong args: %v", args)
}
if c.Name() != childCmd.Name() {
t.Errorf("Expected command: %q, got: %q", childCmd.Name(), c.Name())
}
if *stringVal != "ok" {
t.Errorf("Expected -s to be set to: %s, got: %s", "ok", *stringVal)
}
if !*boolVal {
t.Errorf("Expected -b to be set")
}
}

func TestTraverseWithArgumentIdenticalToCommandName(t *testing.T) {
rootCmd := &Command{Use: "root", TraverseChildren: true}
stringVal := rootCmd.Flags().StringP("str", "s", "", "")
boolVal := rootCmd.Flags().BoolP("bool", "b", false, "")

childCmd := &Command{Use: "child"}
childCmd.Flags().Int("int", -1, "")

rootCmd.AddCommand(childCmd)

c, args, err := rootCmd.Traverse([]string{"-bs", "child", "child", "--int"})
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(args) != 1 || args[0] != "--int" {
t.Errorf("Wrong args: %v", args)
}
if c.Name() != childCmd.Name() {
t.Errorf("Expected command: %q, got: %q", childCmd.Name(), c.Name())
}
if *stringVal != "child" {
t.Errorf("Expected -s to be set to: %s, got: %s", "child", *stringVal)
}
if !*boolVal {
t.Errorf("Expected -b to be set")
}
}

func TestTraverseNoParentFlags(t *testing.T) {
rootCmd := &Command{Use: "root", TraverseChildren: true}
rootCmd.Flags().String("foo", "", "foo things")
Expand Down Expand Up @@ -2312,7 +2368,7 @@ func TestTraverseWithBadChildFlag(t *testing.T) {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(args) != 1 && args[0] != "--str" {
if len(args) != 1 || args[0] != "--str" {
t.Errorf("Wrong args: %v", args)
}
if c.Name() != childCmd.Name() {
Expand Down

0 comments on commit b07c5cd

Please sign in to comment.