Skip to content

Commit

Permalink
Add func args support.
Browse files Browse the repository at this point in the history
  • Loading branch information
lbenguigui authored Jun 20, 2022
1 parent f7d6bec commit d039d19
Show file tree
Hide file tree
Showing 5 changed files with 325 additions and 2 deletions.
3 changes: 3 additions & 0 deletions mocktail.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ func getTypeImports(t types.Type) []string {
case *types.Interface:
return []string{""}

case *types.Signature:
return getTupleImports(v.Params(), v.Results())

default:
panic(fmt.Sprintf("OOPS %[1]T %[1]s", t))
}
Expand Down
34 changes: 32 additions & 2 deletions syrup.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,12 @@ func (s Syrup) methodOn(writer io.Writer) error {
name := getParamName(param, i)

w.Print(name)
argNames = append(argNames, name)

if _, ok := param.Type().(*types.Signature); ok {
argNames = append(argNames, "mock.Anything")
} else {
argNames = append(argNames, name)
}

w.Print(" " + s.getTypeName(param.Type(), i == params.Len()-1))

Expand Down Expand Up @@ -312,7 +317,12 @@ func (s Syrup) methodOnRaw(writer io.Writer) error {
name := getParamName(param, i)

w.Print(name)
argNames = append(argNames, name)

if _, ok := param.Type().(*types.Signature); ok {
argNames = append(argNames, "mock.Anything")
} else {
argNames = append(argNames, name)
}

w.Print(" interface{}")

Expand Down Expand Up @@ -616,11 +626,31 @@ func (s Syrup) getTypeName(t types.Type, last bool) string {
case *types.Interface:
return v.String()

case *types.Signature:
fn := "func(" + strings.Join(s.getTupleTypes(v.Params()), ",") + ")"

if v.Results().Len() > 0 {
fn += " (" + strings.Join(s.getTupleTypes(v.Results()), ",") + ")"
}

return fn

default:
panic(fmt.Sprintf("OOPS %[1]T %[1]s", t))
}
}

func (s Syrup) getTupleTypes(t *types.Tuple) []string {
var tupleTypes []string
for i := 0; i < t.Len(); i++ {
param := t.At(i)

tupleTypes = append(tupleTypes, s.getTypeName(param.Type(), false))
}

return tupleTypes
}

func writeImports(writer io.Writer, descPkg PackageDesc) error {
base := template.New("templateImports")

Expand Down
1 change: 1 addition & 0 deletions testdata/src/a/foo.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type Coconut interface {
Voo(src *module.Version) time.Duration
Yoo(st string) interface{}
Zoo(st interface{}) string
Moo(fn func(st, stban Strawberry) Pineapple) string
}

type Water struct{}
Expand Down
Loading

0 comments on commit d039d19

Please sign in to comment.