Browse Source

generate plugin clients via template

Signed-off-by: Brian Goff <cpuguy83@gmail.com>
Brian Goff 10 years ago
parent
commit
4c81c9dddc

+ 35 - 0
pkg/plugins/pluginrpc-gen/fixtures/foo.go

@@ -0,0 +1,35 @@
+package foo
+
+type wobble struct {
+	Some      string
+	Val       string
+	Inception *wobble
+}
+
+type Fooer interface{}
+
+type Fooer2 interface {
+	Foo()
+}
+
+type Fooer3 interface {
+	Foo()
+	Bar(a string)
+	Baz(a string) (err error)
+	Qux(a, b string) (val string, err error)
+	Wobble() (w *wobble)
+	Wiggle() (w wobble)
+}
+
+type Fooer4 interface {
+	Foo() error
+}
+
+type Bar interface {
+	Boo(a string, b string) (s string, err error)
+}
+
+type Fooer5 interface {
+	Foo()
+	Bar
+}

+ 91 - 0
pkg/plugins/pluginrpc-gen/main.go

@@ -0,0 +1,91 @@
+package main
+
+import (
+	"bytes"
+	"flag"
+	"fmt"
+	"go/format"
+	"io/ioutil"
+	"os"
+	"unicode"
+	"unicode/utf8"
+)
+
+type stringSet struct {
+	values map[string]struct{}
+}
+
+func (s stringSet) String() string {
+	return ""
+}
+
+func (s stringSet) Set(value string) error {
+	s.values[value] = struct{}{}
+	return nil
+}
+func (s stringSet) GetValues() map[string]struct{} {
+	return s.values
+}
+
+var (
+	typeName   = flag.String("type", "", "interface type to generate plugin rpc proxy for")
+	rpcName    = flag.String("name", *typeName, "RPC name, set if different from type")
+	inputFile  = flag.String("i", "", "input file path")
+	outputFile = flag.String("o", *inputFile+"_proxy.go", "output file path")
+
+	skipFuncs   map[string]struct{}
+	flSkipFuncs = stringSet{make(map[string]struct{})}
+
+	flBuildTags = stringSet{make(map[string]struct{})}
+)
+
+func errorOut(msg string, err error) {
+	if err == nil {
+		return
+	}
+	fmt.Fprintf(os.Stderr, "%s: %v\n", msg, err)
+	os.Exit(1)
+}
+
+func checkFlags() error {
+	if *outputFile == "" {
+		return fmt.Errorf("missing required flag `-o`")
+	}
+	if *inputFile == "" {
+		return fmt.Errorf("missing required flag `-i`")
+	}
+	return nil
+}
+
+func main() {
+	flag.Var(flSkipFuncs, "skip", "skip parsing for function")
+	flag.Var(flBuildTags, "tag", "build tags to add to generated files")
+	flag.Parse()
+	skipFuncs = flSkipFuncs.GetValues()
+
+	errorOut("error", checkFlags())
+
+	pkg, err := Parse(*inputFile, *typeName)
+	errorOut(fmt.Sprintf("error parsing requested type %s", *typeName), err)
+
+	var analysis = struct {
+		InterfaceType string
+		RPCName       string
+		BuildTags     map[string]struct{}
+		*parsedPkg
+	}{toLower(*typeName), *rpcName, flBuildTags.GetValues(), pkg}
+	var buf bytes.Buffer
+
+	errorOut("parser error", generatedTempl.Execute(&buf, analysis))
+	src, err := format.Source(buf.Bytes())
+	errorOut("error formating generated source", err)
+	errorOut("error writing file", ioutil.WriteFile(*outputFile, src, 0644))
+}
+
+func toLower(s string) string {
+	if s == "" {
+		return ""
+	}
+	r, n := utf8.DecodeRuneInString(s)
+	return string(unicode.ToLower(r)) + s[n:]
+}

+ 162 - 0
pkg/plugins/pluginrpc-gen/parser.go

@@ -0,0 +1,162 @@
+package main
+
+import (
+	"errors"
+	"fmt"
+	"go/ast"
+	"go/parser"
+	"go/token"
+	"reflect"
+	"strings"
+)
+
+var ErrBadReturn = errors.New("found return arg with no name: all args must be named")
+
+type ErrUnexpectedType struct {
+	expected string
+	actual   interface{}
+}
+
+func (e ErrUnexpectedType) Error() string {
+	return fmt.Sprintf("got wrong type expecting %s, got: %v", e.expected, reflect.TypeOf(e.actual))
+}
+
+type parsedPkg struct {
+	Name      string
+	Functions []function
+}
+
+type function struct {
+	Name    string
+	Args    []arg
+	Returns []arg
+	Doc     string
+}
+
+type arg struct {
+	Name    string
+	ArgType string
+}
+
+func (a *arg) String() string {
+	return strings.ToLower(a.Name) + " " + strings.ToLower(a.ArgType)
+}
+
+// Parses the given file for an interface definition with the given name
+func Parse(filePath string, objName string) (*parsedPkg, error) {
+	fs := token.NewFileSet()
+	pkg, err := parser.ParseFile(fs, filePath, nil, parser.AllErrors)
+	if err != nil {
+		return nil, err
+	}
+	p := &parsedPkg{}
+	p.Name = pkg.Name.Name
+	obj, exists := pkg.Scope.Objects[objName]
+	if !exists {
+		return nil, fmt.Errorf("could not find object %s in %s", objName, filePath)
+	}
+	if obj.Kind != ast.Typ {
+		return nil, fmt.Errorf("exected type, got %s", obj.Kind)
+	}
+	spec, ok := obj.Decl.(*ast.TypeSpec)
+	if !ok {
+		return nil, ErrUnexpectedType{"*ast.TypeSpec", obj.Decl}
+	}
+	iface, ok := spec.Type.(*ast.InterfaceType)
+	if !ok {
+		return nil, ErrUnexpectedType{"*ast.InterfaceType", spec.Type}
+	}
+
+	p.Functions, err = parseInterface(iface)
+	if err != nil {
+		return nil, err
+	}
+
+	return p, nil
+}
+
+func parseInterface(iface *ast.InterfaceType) ([]function, error) {
+	var functions []function
+	for _, field := range iface.Methods.List {
+		switch f := field.Type.(type) {
+		case *ast.FuncType:
+			method, err := parseFunc(field)
+			if err != nil {
+				return nil, err
+			}
+			if method == nil {
+				continue
+			}
+			functions = append(functions, *method)
+		case *ast.Ident:
+			spec, ok := f.Obj.Decl.(*ast.TypeSpec)
+			if !ok {
+				return nil, ErrUnexpectedType{"*ast.TypeSpec", f.Obj.Decl}
+			}
+			iface, ok := spec.Type.(*ast.InterfaceType)
+			if !ok {
+				return nil, ErrUnexpectedType{"*ast.TypeSpec", spec.Type}
+			}
+			funcs, err := parseInterface(iface)
+			if err != nil {
+				fmt.Println(err)
+				continue
+			}
+			functions = append(functions, funcs...)
+		default:
+			return nil, ErrUnexpectedType{"*astFuncType or *ast.Ident", f}
+		}
+	}
+	return functions, nil
+}
+
+func parseFunc(field *ast.Field) (*function, error) {
+	f := field.Type.(*ast.FuncType)
+	method := &function{Name: field.Names[0].Name}
+	if _, exists := skipFuncs[method.Name]; exists {
+		fmt.Println("skipping:", method.Name)
+		return nil, nil
+	}
+	if f.Params != nil {
+		args, err := parseArgs(f.Params.List)
+		if err != nil {
+			return nil, err
+		}
+		method.Args = args
+	}
+	if f.Results != nil {
+		returns, err := parseArgs(f.Results.List)
+		if err != nil {
+			return nil, fmt.Errorf("error parsing function returns for %q: %v", method.Name, err)
+		}
+		method.Returns = returns
+	}
+	return method, nil
+}
+
+func parseArgs(fields []*ast.Field) ([]arg, error) {
+	var args []arg
+	for _, f := range fields {
+		if len(f.Names) == 0 {
+			return nil, ErrBadReturn
+		}
+		for _, name := range f.Names {
+			var typeName string
+			switch argType := f.Type.(type) {
+			case *ast.Ident:
+				typeName = argType.Name
+			case *ast.StarExpr:
+				i, ok := argType.X.(*ast.Ident)
+				if !ok {
+					return nil, ErrUnexpectedType{"*ast.Ident", f.Type}
+				}
+				typeName = "*" + i.Name
+			default:
+				return nil, ErrUnexpectedType{"*ast.Ident or *ast.StarExpr", f.Type}
+			}
+
+			args = append(args, arg{name.Name, typeName})
+		}
+	}
+	return args, nil
+}

+ 168 - 0
pkg/plugins/pluginrpc-gen/parser_test.go

@@ -0,0 +1,168 @@
+package main
+
+import (
+	"fmt"
+	"path/filepath"
+	"runtime"
+	"strings"
+	"testing"
+)
+
+const testFixture = "fixtures/foo.go"
+
+func TestParseEmptyInterface(t *testing.T) {
+	pkg, err := Parse(testFixture, "Fooer")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	assertName(t, "foo", pkg.Name)
+	assertNum(t, 0, len(pkg.Functions))
+}
+
+func TestParseNonInterfaceType(t *testing.T) {
+	_, err := Parse(testFixture, "wobble")
+	if _, ok := err.(ErrUnexpectedType); !ok {
+		t.Fatal("expected type error when parsing non-interface type")
+	}
+}
+
+func TestParseWithOneFunction(t *testing.T) {
+	pkg, err := Parse(testFixture, "Fooer2")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	assertName(t, "foo", pkg.Name)
+	assertNum(t, 1, len(pkg.Functions))
+	assertName(t, "Foo", pkg.Functions[0].Name)
+	assertNum(t, 0, len(pkg.Functions[0].Args))
+	assertNum(t, 0, len(pkg.Functions[0].Returns))
+}
+
+func TestParseWithMultipleFuncs(t *testing.T) {
+	pkg, err := Parse(testFixture, "Fooer3")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	assertName(t, "foo", pkg.Name)
+	assertNum(t, 6, len(pkg.Functions))
+
+	f := pkg.Functions[0]
+	assertName(t, "Foo", f.Name)
+	assertNum(t, 0, len(f.Args))
+	assertNum(t, 0, len(f.Returns))
+
+	f = pkg.Functions[1]
+	assertName(t, "Bar", f.Name)
+	assertNum(t, 1, len(f.Args))
+	assertNum(t, 0, len(f.Returns))
+	arg := f.Args[0]
+	assertName(t, "a", arg.Name)
+	assertName(t, "string", arg.ArgType)
+
+	f = pkg.Functions[2]
+	assertName(t, "Baz", f.Name)
+	assertNum(t, 1, len(f.Args))
+	assertNum(t, 1, len(f.Returns))
+	arg = f.Args[0]
+	assertName(t, "a", arg.Name)
+	assertName(t, "string", arg.ArgType)
+	arg = f.Returns[0]
+	assertName(t, "err", arg.Name)
+	assertName(t, "error", arg.ArgType)
+
+	f = pkg.Functions[3]
+	assertName(t, "Qux", f.Name)
+	assertNum(t, 2, len(f.Args))
+	assertNum(t, 2, len(f.Returns))
+	arg = f.Args[0]
+	assertName(t, "a", f.Args[0].Name)
+	assertName(t, "string", f.Args[0].ArgType)
+	arg = f.Args[1]
+	assertName(t, "b", arg.Name)
+	assertName(t, "string", arg.ArgType)
+	arg = f.Returns[0]
+	assertName(t, "val", arg.Name)
+	assertName(t, "string", arg.ArgType)
+	arg = f.Returns[1]
+	assertName(t, "err", arg.Name)
+	assertName(t, "error", arg.ArgType)
+
+	f = pkg.Functions[4]
+	assertName(t, "Wobble", f.Name)
+	assertNum(t, 0, len(f.Args))
+	assertNum(t, 1, len(f.Returns))
+	arg = f.Returns[0]
+	assertName(t, "w", arg.Name)
+	assertName(t, "*wobble", arg.ArgType)
+
+	f = pkg.Functions[5]
+	assertName(t, "Wiggle", f.Name)
+	assertNum(t, 0, len(f.Args))
+	assertNum(t, 1, len(f.Returns))
+	arg = f.Returns[0]
+	assertName(t, "w", arg.Name)
+	assertName(t, "wobble", arg.ArgType)
+}
+
+func TestParseWithUnamedReturn(t *testing.T) {
+	_, err := Parse(testFixture, "Fooer4")
+	if !strings.HasSuffix(err.Error(), ErrBadReturn.Error()) {
+		t.Fatalf("expected ErrBadReturn, got %v", err)
+	}
+}
+
+func TestEmbeddedInterface(t *testing.T) {
+	pkg, err := Parse(testFixture, "Fooer5")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	assertName(t, "foo", pkg.Name)
+	assertNum(t, 2, len(pkg.Functions))
+
+	f := pkg.Functions[0]
+	assertName(t, "Foo", f.Name)
+	assertNum(t, 0, len(f.Args))
+	assertNum(t, 0, len(f.Returns))
+
+	f = pkg.Functions[1]
+	assertName(t, "Boo", f.Name)
+	assertNum(t, 2, len(f.Args))
+	assertNum(t, 2, len(f.Returns))
+
+	arg := f.Args[0]
+	assertName(t, "a", arg.Name)
+	assertName(t, "string", arg.ArgType)
+
+	arg = f.Args[1]
+	assertName(t, "b", arg.Name)
+	assertName(t, "string", arg.ArgType)
+
+	arg = f.Returns[0]
+	assertName(t, "s", arg.Name)
+	assertName(t, "string", arg.ArgType)
+
+	arg = f.Returns[1]
+	assertName(t, "err", arg.Name)
+	assertName(t, "error", arg.ArgType)
+}
+
+func assertName(t *testing.T, expected, actual string) {
+	if expected != actual {
+		fatalOut(t, fmt.Sprintf("expected name to be `%s`, got: %s", expected, actual))
+	}
+}
+
+func assertNum(t *testing.T, expected, actual int) {
+	if expected != actual {
+		fatalOut(t, fmt.Sprintf("expected number to be %d, got: %d", expected, actual))
+	}
+}
+
+func fatalOut(t *testing.T, msg string) {
+	_, file, ln, _ := runtime.Caller(2)
+	t.Fatalf("%s:%d: %s", filepath.Base(file), ln, msg)
+}

+ 97 - 0
pkg/plugins/pluginrpc-gen/template.go

@@ -0,0 +1,97 @@
+package main
+
+import (
+	"strings"
+	"text/template"
+)
+
+func printArgs(args []arg) string {
+	var argStr []string
+	for _, arg := range args {
+		argStr = append(argStr, arg.String())
+	}
+	return strings.Join(argStr, ", ")
+}
+
+func marshalType(t string) string {
+	switch t {
+	case "error":
+		// convert error types to plain strings to ensure the values are encoded/decoded properly
+		return "string"
+	default:
+		return t
+	}
+}
+
+func isErr(t string) bool {
+	switch t {
+	case "error":
+		return true
+	default:
+		return false
+	}
+}
+
+// Need to use this helper due to issues with go-vet
+func buildTag(s string) string {
+	return "+build " + s
+}
+
+var templFuncs = template.FuncMap{
+	"printArgs":   printArgs,
+	"marshalType": marshalType,
+	"isErr":       isErr,
+	"lower":       strings.ToLower,
+	"title":       strings.Title,
+	"tag":         buildTag,
+}
+
+var generatedTempl = template.Must(template.New("rpc_cient").Funcs(templFuncs).Parse(`
+// generated code - DO NOT EDIT
+{{ range $k, $v := .BuildTags }}
+	// {{ tag $k }} {{ end }}
+
+package {{ .Name }}
+
+import "errors"
+
+type client interface{
+	Call(string, interface{}, interface{}) error
+}
+
+type {{ .InterfaceType }}Proxy struct {
+	client
+}
+
+{{ range .Functions }}
+	type {{ $.InterfaceType }}Proxy{{ .Name }}Request struct{
+		{{ range .Args }}
+			{{ title .Name }} {{ .ArgType }} {{ end }}
+	}
+
+	type {{ $.InterfaceType }}Proxy{{ .Name }}Response struct{
+		{{ range .Returns }}
+			{{ title .Name }} {{ marshalType .ArgType }} {{ end }}
+	}
+
+	func (pp *{{ $.InterfaceType }}Proxy) {{ .Name }}({{ printArgs .Args }}) ({{ printArgs .Returns }}) {
+		var(
+			req {{ $.InterfaceType }}Proxy{{ .Name }}Request
+			ret {{ $.InterfaceType }}Proxy{{ .Name }}Response
+		)
+		{{ range .Args }}
+			req.{{ title .Name }} = {{ lower .Name }} {{ end }}
+		if err = pp.Call("{{ $.RPCName }}.{{ .Name }}", req, &ret); err != nil {
+			return
+		}
+		{{ range $r := .Returns }}
+			{{ if isErr .ArgType }}
+				if ret.{{ title .Name }} != "" {
+					{{ lower .Name }} = errors.New(ret.{{ title .Name }})
+				} {{ end }}
+			{{ if isErr .ArgType | not }} {{ lower .Name }} = ret.{{ title .Name }} {{ end }} {{ end }}
+
+		return
+	}
+{{ end }}
+`))