Browse code

generate plugin clients via template

Signed-off-by: Brian Goff <cpuguy83@gmail.com>

Brian Goff authored on 2015/06/09 00:33:06
Showing 5 changed files
1 1
new file mode 100644
... ...
@@ -0,0 +1,35 @@
0
+package foo
1
+
2
+type wobble struct {
3
+	Some      string
4
+	Val       string
5
+	Inception *wobble
6
+}
7
+
8
+type Fooer interface{}
9
+
10
+type Fooer2 interface {
11
+	Foo()
12
+}
13
+
14
+type Fooer3 interface {
15
+	Foo()
16
+	Bar(a string)
17
+	Baz(a string) (err error)
18
+	Qux(a, b string) (val string, err error)
19
+	Wobble() (w *wobble)
20
+	Wiggle() (w wobble)
21
+}
22
+
23
+type Fooer4 interface {
24
+	Foo() error
25
+}
26
+
27
+type Bar interface {
28
+	Boo(a string, b string) (s string, err error)
29
+}
30
+
31
+type Fooer5 interface {
32
+	Foo()
33
+	Bar
34
+}
0 35
new file mode 100644
... ...
@@ -0,0 +1,91 @@
0
+package main
1
+
2
+import (
3
+	"bytes"
4
+	"flag"
5
+	"fmt"
6
+	"go/format"
7
+	"io/ioutil"
8
+	"os"
9
+	"unicode"
10
+	"unicode/utf8"
11
+)
12
+
13
+type stringSet struct {
14
+	values map[string]struct{}
15
+}
16
+
17
+func (s stringSet) String() string {
18
+	return ""
19
+}
20
+
21
+func (s stringSet) Set(value string) error {
22
+	s.values[value] = struct{}{}
23
+	return nil
24
+}
25
+func (s stringSet) GetValues() map[string]struct{} {
26
+	return s.values
27
+}
28
+
29
+var (
30
+	typeName   = flag.String("type", "", "interface type to generate plugin rpc proxy for")
31
+	rpcName    = flag.String("name", *typeName, "RPC name, set if different from type")
32
+	inputFile  = flag.String("i", "", "input file path")
33
+	outputFile = flag.String("o", *inputFile+"_proxy.go", "output file path")
34
+
35
+	skipFuncs   map[string]struct{}
36
+	flSkipFuncs = stringSet{make(map[string]struct{})}
37
+
38
+	flBuildTags = stringSet{make(map[string]struct{})}
39
+)
40
+
41
+func errorOut(msg string, err error) {
42
+	if err == nil {
43
+		return
44
+	}
45
+	fmt.Fprintf(os.Stderr, "%s: %v\n", msg, err)
46
+	os.Exit(1)
47
+}
48
+
49
+func checkFlags() error {
50
+	if *outputFile == "" {
51
+		return fmt.Errorf("missing required flag `-o`")
52
+	}
53
+	if *inputFile == "" {
54
+		return fmt.Errorf("missing required flag `-i`")
55
+	}
56
+	return nil
57
+}
58
+
59
+func main() {
60
+	flag.Var(flSkipFuncs, "skip", "skip parsing for function")
61
+	flag.Var(flBuildTags, "tag", "build tags to add to generated files")
62
+	flag.Parse()
63
+	skipFuncs = flSkipFuncs.GetValues()
64
+
65
+	errorOut("error", checkFlags())
66
+
67
+	pkg, err := Parse(*inputFile, *typeName)
68
+	errorOut(fmt.Sprintf("error parsing requested type %s", *typeName), err)
69
+
70
+	var analysis = struct {
71
+		InterfaceType string
72
+		RPCName       string
73
+		BuildTags     map[string]struct{}
74
+		*parsedPkg
75
+	}{toLower(*typeName), *rpcName, flBuildTags.GetValues(), pkg}
76
+	var buf bytes.Buffer
77
+
78
+	errorOut("parser error", generatedTempl.Execute(&buf, analysis))
79
+	src, err := format.Source(buf.Bytes())
80
+	errorOut("error formating generated source", err)
81
+	errorOut("error writing file", ioutil.WriteFile(*outputFile, src, 0644))
82
+}
83
+
84
+func toLower(s string) string {
85
+	if s == "" {
86
+		return ""
87
+	}
88
+	r, n := utf8.DecodeRuneInString(s)
89
+	return string(unicode.ToLower(r)) + s[n:]
90
+}
0 91
new file mode 100644
... ...
@@ -0,0 +1,162 @@
0
+package main
1
+
2
+import (
3
+	"errors"
4
+	"fmt"
5
+	"go/ast"
6
+	"go/parser"
7
+	"go/token"
8
+	"reflect"
9
+	"strings"
10
+)
11
+
12
+var ErrBadReturn = errors.New("found return arg with no name: all args must be named")
13
+
14
+type ErrUnexpectedType struct {
15
+	expected string
16
+	actual   interface{}
17
+}
18
+
19
+func (e ErrUnexpectedType) Error() string {
20
+	return fmt.Sprintf("got wrong type expecting %s, got: %v", e.expected, reflect.TypeOf(e.actual))
21
+}
22
+
23
+type parsedPkg struct {
24
+	Name      string
25
+	Functions []function
26
+}
27
+
28
+type function struct {
29
+	Name    string
30
+	Args    []arg
31
+	Returns []arg
32
+	Doc     string
33
+}
34
+
35
+type arg struct {
36
+	Name    string
37
+	ArgType string
38
+}
39
+
40
+func (a *arg) String() string {
41
+	return strings.ToLower(a.Name) + " " + strings.ToLower(a.ArgType)
42
+}
43
+
44
+// Parses the given file for an interface definition with the given name
45
+func Parse(filePath string, objName string) (*parsedPkg, error) {
46
+	fs := token.NewFileSet()
47
+	pkg, err := parser.ParseFile(fs, filePath, nil, parser.AllErrors)
48
+	if err != nil {
49
+		return nil, err
50
+	}
51
+	p := &parsedPkg{}
52
+	p.Name = pkg.Name.Name
53
+	obj, exists := pkg.Scope.Objects[objName]
54
+	if !exists {
55
+		return nil, fmt.Errorf("could not find object %s in %s", objName, filePath)
56
+	}
57
+	if obj.Kind != ast.Typ {
58
+		return nil, fmt.Errorf("exected type, got %s", obj.Kind)
59
+	}
60
+	spec, ok := obj.Decl.(*ast.TypeSpec)
61
+	if !ok {
62
+		return nil, ErrUnexpectedType{"*ast.TypeSpec", obj.Decl}
63
+	}
64
+	iface, ok := spec.Type.(*ast.InterfaceType)
65
+	if !ok {
66
+		return nil, ErrUnexpectedType{"*ast.InterfaceType", spec.Type}
67
+	}
68
+
69
+	p.Functions, err = parseInterface(iface)
70
+	if err != nil {
71
+		return nil, err
72
+	}
73
+
74
+	return p, nil
75
+}
76
+
77
+func parseInterface(iface *ast.InterfaceType) ([]function, error) {
78
+	var functions []function
79
+	for _, field := range iface.Methods.List {
80
+		switch f := field.Type.(type) {
81
+		case *ast.FuncType:
82
+			method, err := parseFunc(field)
83
+			if err != nil {
84
+				return nil, err
85
+			}
86
+			if method == nil {
87
+				continue
88
+			}
89
+			functions = append(functions, *method)
90
+		case *ast.Ident:
91
+			spec, ok := f.Obj.Decl.(*ast.TypeSpec)
92
+			if !ok {
93
+				return nil, ErrUnexpectedType{"*ast.TypeSpec", f.Obj.Decl}
94
+			}
95
+			iface, ok := spec.Type.(*ast.InterfaceType)
96
+			if !ok {
97
+				return nil, ErrUnexpectedType{"*ast.TypeSpec", spec.Type}
98
+			}
99
+			funcs, err := parseInterface(iface)
100
+			if err != nil {
101
+				fmt.Println(err)
102
+				continue
103
+			}
104
+			functions = append(functions, funcs...)
105
+		default:
106
+			return nil, ErrUnexpectedType{"*astFuncType or *ast.Ident", f}
107
+		}
108
+	}
109
+	return functions, nil
110
+}
111
+
112
+func parseFunc(field *ast.Field) (*function, error) {
113
+	f := field.Type.(*ast.FuncType)
114
+	method := &function{Name: field.Names[0].Name}
115
+	if _, exists := skipFuncs[method.Name]; exists {
116
+		fmt.Println("skipping:", method.Name)
117
+		return nil, nil
118
+	}
119
+	if f.Params != nil {
120
+		args, err := parseArgs(f.Params.List)
121
+		if err != nil {
122
+			return nil, err
123
+		}
124
+		method.Args = args
125
+	}
126
+	if f.Results != nil {
127
+		returns, err := parseArgs(f.Results.List)
128
+		if err != nil {
129
+			return nil, fmt.Errorf("error parsing function returns for %q: %v", method.Name, err)
130
+		}
131
+		method.Returns = returns
132
+	}
133
+	return method, nil
134
+}
135
+
136
+func parseArgs(fields []*ast.Field) ([]arg, error) {
137
+	var args []arg
138
+	for _, f := range fields {
139
+		if len(f.Names) == 0 {
140
+			return nil, ErrBadReturn
141
+		}
142
+		for _, name := range f.Names {
143
+			var typeName string
144
+			switch argType := f.Type.(type) {
145
+			case *ast.Ident:
146
+				typeName = argType.Name
147
+			case *ast.StarExpr:
148
+				i, ok := argType.X.(*ast.Ident)
149
+				if !ok {
150
+					return nil, ErrUnexpectedType{"*ast.Ident", f.Type}
151
+				}
152
+				typeName = "*" + i.Name
153
+			default:
154
+				return nil, ErrUnexpectedType{"*ast.Ident or *ast.StarExpr", f.Type}
155
+			}
156
+
157
+			args = append(args, arg{name.Name, typeName})
158
+		}
159
+	}
160
+	return args, nil
161
+}
0 162
new file mode 100644
... ...
@@ -0,0 +1,168 @@
0
+package main
1
+
2
+import (
3
+	"fmt"
4
+	"path/filepath"
5
+	"runtime"
6
+	"strings"
7
+	"testing"
8
+)
9
+
10
+const testFixture = "fixtures/foo.go"
11
+
12
+func TestParseEmptyInterface(t *testing.T) {
13
+	pkg, err := Parse(testFixture, "Fooer")
14
+	if err != nil {
15
+		t.Fatal(err)
16
+	}
17
+
18
+	assertName(t, "foo", pkg.Name)
19
+	assertNum(t, 0, len(pkg.Functions))
20
+}
21
+
22
+func TestParseNonInterfaceType(t *testing.T) {
23
+	_, err := Parse(testFixture, "wobble")
24
+	if _, ok := err.(ErrUnexpectedType); !ok {
25
+		t.Fatal("expected type error when parsing non-interface type")
26
+	}
27
+}
28
+
29
+func TestParseWithOneFunction(t *testing.T) {
30
+	pkg, err := Parse(testFixture, "Fooer2")
31
+	if err != nil {
32
+		t.Fatal(err)
33
+	}
34
+
35
+	assertName(t, "foo", pkg.Name)
36
+	assertNum(t, 1, len(pkg.Functions))
37
+	assertName(t, "Foo", pkg.Functions[0].Name)
38
+	assertNum(t, 0, len(pkg.Functions[0].Args))
39
+	assertNum(t, 0, len(pkg.Functions[0].Returns))
40
+}
41
+
42
+func TestParseWithMultipleFuncs(t *testing.T) {
43
+	pkg, err := Parse(testFixture, "Fooer3")
44
+	if err != nil {
45
+		t.Fatal(err)
46
+	}
47
+
48
+	assertName(t, "foo", pkg.Name)
49
+	assertNum(t, 6, len(pkg.Functions))
50
+
51
+	f := pkg.Functions[0]
52
+	assertName(t, "Foo", f.Name)
53
+	assertNum(t, 0, len(f.Args))
54
+	assertNum(t, 0, len(f.Returns))
55
+
56
+	f = pkg.Functions[1]
57
+	assertName(t, "Bar", f.Name)
58
+	assertNum(t, 1, len(f.Args))
59
+	assertNum(t, 0, len(f.Returns))
60
+	arg := f.Args[0]
61
+	assertName(t, "a", arg.Name)
62
+	assertName(t, "string", arg.ArgType)
63
+
64
+	f = pkg.Functions[2]
65
+	assertName(t, "Baz", f.Name)
66
+	assertNum(t, 1, len(f.Args))
67
+	assertNum(t, 1, len(f.Returns))
68
+	arg = f.Args[0]
69
+	assertName(t, "a", arg.Name)
70
+	assertName(t, "string", arg.ArgType)
71
+	arg = f.Returns[0]
72
+	assertName(t, "err", arg.Name)
73
+	assertName(t, "error", arg.ArgType)
74
+
75
+	f = pkg.Functions[3]
76
+	assertName(t, "Qux", f.Name)
77
+	assertNum(t, 2, len(f.Args))
78
+	assertNum(t, 2, len(f.Returns))
79
+	arg = f.Args[0]
80
+	assertName(t, "a", f.Args[0].Name)
81
+	assertName(t, "string", f.Args[0].ArgType)
82
+	arg = f.Args[1]
83
+	assertName(t, "b", arg.Name)
84
+	assertName(t, "string", arg.ArgType)
85
+	arg = f.Returns[0]
86
+	assertName(t, "val", arg.Name)
87
+	assertName(t, "string", arg.ArgType)
88
+	arg = f.Returns[1]
89
+	assertName(t, "err", arg.Name)
90
+	assertName(t, "error", arg.ArgType)
91
+
92
+	f = pkg.Functions[4]
93
+	assertName(t, "Wobble", f.Name)
94
+	assertNum(t, 0, len(f.Args))
95
+	assertNum(t, 1, len(f.Returns))
96
+	arg = f.Returns[0]
97
+	assertName(t, "w", arg.Name)
98
+	assertName(t, "*wobble", arg.ArgType)
99
+
100
+	f = pkg.Functions[5]
101
+	assertName(t, "Wiggle", f.Name)
102
+	assertNum(t, 0, len(f.Args))
103
+	assertNum(t, 1, len(f.Returns))
104
+	arg = f.Returns[0]
105
+	assertName(t, "w", arg.Name)
106
+	assertName(t, "wobble", arg.ArgType)
107
+}
108
+
109
+func TestParseWithUnamedReturn(t *testing.T) {
110
+	_, err := Parse(testFixture, "Fooer4")
111
+	if !strings.HasSuffix(err.Error(), ErrBadReturn.Error()) {
112
+		t.Fatalf("expected ErrBadReturn, got %v", err)
113
+	}
114
+}
115
+
116
+func TestEmbeddedInterface(t *testing.T) {
117
+	pkg, err := Parse(testFixture, "Fooer5")
118
+	if err != nil {
119
+		t.Fatal(err)
120
+	}
121
+
122
+	assertName(t, "foo", pkg.Name)
123
+	assertNum(t, 2, len(pkg.Functions))
124
+
125
+	f := pkg.Functions[0]
126
+	assertName(t, "Foo", f.Name)
127
+	assertNum(t, 0, len(f.Args))
128
+	assertNum(t, 0, len(f.Returns))
129
+
130
+	f = pkg.Functions[1]
131
+	assertName(t, "Boo", f.Name)
132
+	assertNum(t, 2, len(f.Args))
133
+	assertNum(t, 2, len(f.Returns))
134
+
135
+	arg := f.Args[0]
136
+	assertName(t, "a", arg.Name)
137
+	assertName(t, "string", arg.ArgType)
138
+
139
+	arg = f.Args[1]
140
+	assertName(t, "b", arg.Name)
141
+	assertName(t, "string", arg.ArgType)
142
+
143
+	arg = f.Returns[0]
144
+	assertName(t, "s", arg.Name)
145
+	assertName(t, "string", arg.ArgType)
146
+
147
+	arg = f.Returns[1]
148
+	assertName(t, "err", arg.Name)
149
+	assertName(t, "error", arg.ArgType)
150
+}
151
+
152
+func assertName(t *testing.T, expected, actual string) {
153
+	if expected != actual {
154
+		fatalOut(t, fmt.Sprintf("expected name to be `%s`, got: %s", expected, actual))
155
+	}
156
+}
157
+
158
+func assertNum(t *testing.T, expected, actual int) {
159
+	if expected != actual {
160
+		fatalOut(t, fmt.Sprintf("expected number to be %d, got: %d", expected, actual))
161
+	}
162
+}
163
+
164
+func fatalOut(t *testing.T, msg string) {
165
+	_, file, ln, _ := runtime.Caller(2)
166
+	t.Fatalf("%s:%d: %s", filepath.Base(file), ln, msg)
167
+}
0 168
new file mode 100644
... ...
@@ -0,0 +1,97 @@
0
+package main
1
+
2
+import (
3
+	"strings"
4
+	"text/template"
5
+)
6
+
7
+func printArgs(args []arg) string {
8
+	var argStr []string
9
+	for _, arg := range args {
10
+		argStr = append(argStr, arg.String())
11
+	}
12
+	return strings.Join(argStr, ", ")
13
+}
14
+
15
+func marshalType(t string) string {
16
+	switch t {
17
+	case "error":
18
+		// convert error types to plain strings to ensure the values are encoded/decoded properly
19
+		return "string"
20
+	default:
21
+		return t
22
+	}
23
+}
24
+
25
+func isErr(t string) bool {
26
+	switch t {
27
+	case "error":
28
+		return true
29
+	default:
30
+		return false
31
+	}
32
+}
33
+
34
+// Need to use this helper due to issues with go-vet
35
+func buildTag(s string) string {
36
+	return "+build " + s
37
+}
38
+
39
+var templFuncs = template.FuncMap{
40
+	"printArgs":   printArgs,
41
+	"marshalType": marshalType,
42
+	"isErr":       isErr,
43
+	"lower":       strings.ToLower,
44
+	"title":       strings.Title,
45
+	"tag":         buildTag,
46
+}
47
+
48
+var generatedTempl = template.Must(template.New("rpc_cient").Funcs(templFuncs).Parse(`
49
+// generated code - DO NOT EDIT
50
+{{ range $k, $v := .BuildTags }}
51
+	// {{ tag $k }} {{ end }}
52
+
53
+package {{ .Name }}
54
+
55
+import "errors"
56
+
57
+type client interface{
58
+	Call(string, interface{}, interface{}) error
59
+}
60
+
61
+type {{ .InterfaceType }}Proxy struct {
62
+	client
63
+}
64
+
65
+{{ range .Functions }}
66
+	type {{ $.InterfaceType }}Proxy{{ .Name }}Request struct{
67
+		{{ range .Args }}
68
+			{{ title .Name }} {{ .ArgType }} {{ end }}
69
+	}
70
+
71
+	type {{ $.InterfaceType }}Proxy{{ .Name }}Response struct{
72
+		{{ range .Returns }}
73
+			{{ title .Name }} {{ marshalType .ArgType }} {{ end }}
74
+	}
75
+
76
+	func (pp *{{ $.InterfaceType }}Proxy) {{ .Name }}({{ printArgs .Args }}) ({{ printArgs .Returns }}) {
77
+		var(
78
+			req {{ $.InterfaceType }}Proxy{{ .Name }}Request
79
+			ret {{ $.InterfaceType }}Proxy{{ .Name }}Response
80
+		)
81
+		{{ range .Args }}
82
+			req.{{ title .Name }} = {{ lower .Name }} {{ end }}
83
+		if err = pp.Call("{{ $.RPCName }}.{{ .Name }}", req, &ret); err != nil {
84
+			return
85
+		}
86
+		{{ range $r := .Returns }}
87
+			{{ if isErr .ArgType }}
88
+				if ret.{{ title .Name }} != "" {
89
+					{{ lower .Name }} = errors.New(ret.{{ title .Name }})
90
+				} {{ end }}
91
+			{{ if isErr .ArgType | not }} {{ lower .Name }} = ret.{{ title .Name }} {{ end }} {{ end }}
92
+
93
+		return
94
+	}
95
+{{ end }}
96
+`))