123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263 |
- package main
- import (
- "errors"
- "fmt"
- "go/ast"
- "go/parser"
- "go/token"
- "path"
- "reflect"
- "strings"
- )
- var errBadReturn = errors.New("found return arg with no name: all args must be named")
- type errUnexpectedType struct {
- expected string
- actual interface{}
- }
- func (e errUnexpectedType) Error() string {
- return fmt.Sprintf("got wrong type expecting %s, got: %v", e.expected, reflect.TypeOf(e.actual))
- }
- // ParsedPkg holds information about a package that has been parsed,
- // its name and the list of functions.
- type ParsedPkg struct {
- Name string
- Functions []function
- Imports []importSpec
- }
- type function struct {
- Name string
- Args []arg
- Returns []arg
- Doc string
- }
- type arg struct {
- Name string
- ArgType string
- PackageSelector string
- }
- func (a *arg) String() string {
- return a.Name + " " + a.ArgType
- }
- type importSpec struct {
- Name string
- Path string
- }
- func (s *importSpec) String() string {
- var ss string
- if len(s.Name) != 0 {
- ss += s.Name
- }
- ss += s.Path
- return ss
- }
- // Parse parses the given file for an interface definition with the given name.
- func Parse(filePath string, objName string) (*ParsedPkg, error) {
- fs := token.NewFileSet()
- pkg, err := parser.ParseFile(fs, filePath, nil, parser.AllErrors)
- if err != nil {
- return nil, err
- }
- p := &ParsedPkg{}
- p.Name = pkg.Name.Name
- obj, exists := pkg.Scope.Objects[objName]
- if !exists {
- return nil, fmt.Errorf("could not find object %s in %s", objName, filePath)
- }
- if obj.Kind != ast.Typ {
- return nil, fmt.Errorf("exected type, got %s", obj.Kind)
- }
- spec, ok := obj.Decl.(*ast.TypeSpec)
- if !ok {
- return nil, errUnexpectedType{"*ast.TypeSpec", obj.Decl}
- }
- iface, ok := spec.Type.(*ast.InterfaceType)
- if !ok {
- return nil, errUnexpectedType{"*ast.InterfaceType", spec.Type}
- }
- p.Functions, err = parseInterface(iface)
- if err != nil {
- return nil, err
- }
- // figure out what imports will be needed
- imports := make(map[string]importSpec)
- for _, f := range p.Functions {
- args := append(f.Args, f.Returns...)
- for _, arg := range args {
- if len(arg.PackageSelector) == 0 {
- continue
- }
- for _, i := range pkg.Imports {
- if i.Name != nil {
- if i.Name.Name != arg.PackageSelector {
- continue
- }
- imports[i.Path.Value] = importSpec{Name: arg.PackageSelector, Path: i.Path.Value}
- break
- }
- _, name := path.Split(i.Path.Value)
- splitName := strings.Split(name, "-")
- if len(splitName) > 1 {
- name = splitName[len(splitName)-1]
- }
- // import paths have quotes already added in, so need to remove them for name comparison
- name = strings.TrimPrefix(name, `"`)
- name = strings.TrimSuffix(name, `"`)
- if name == arg.PackageSelector {
- imports[i.Path.Value] = importSpec{Path: i.Path.Value}
- break
- }
- }
- }
- }
- for _, spec := range imports {
- p.Imports = append(p.Imports, spec)
- }
- return p, nil
- }
- func parseInterface(iface *ast.InterfaceType) ([]function, error) {
- var functions []function
- for _, field := range iface.Methods.List {
- switch f := field.Type.(type) {
- case *ast.FuncType:
- method, err := parseFunc(field)
- if err != nil {
- return nil, err
- }
- if method == nil {
- continue
- }
- functions = append(functions, *method)
- case *ast.Ident:
- spec, ok := f.Obj.Decl.(*ast.TypeSpec)
- if !ok {
- return nil, errUnexpectedType{"*ast.TypeSpec", f.Obj.Decl}
- }
- iface, ok := spec.Type.(*ast.InterfaceType)
- if !ok {
- return nil, errUnexpectedType{"*ast.TypeSpec", spec.Type}
- }
- funcs, err := parseInterface(iface)
- if err != nil {
- fmt.Println(err)
- continue
- }
- functions = append(functions, funcs...)
- default:
- return nil, errUnexpectedType{"*astFuncType or *ast.Ident", f}
- }
- }
- return functions, nil
- }
- func parseFunc(field *ast.Field) (*function, error) {
- f := field.Type.(*ast.FuncType)
- method := &function{Name: field.Names[0].Name}
- if _, exists := skipFuncs[method.Name]; exists {
- fmt.Println("skipping:", method.Name)
- return nil, nil
- }
- if f.Params != nil {
- args, err := parseArgs(f.Params.List)
- if err != nil {
- return nil, err
- }
- method.Args = args
- }
- if f.Results != nil {
- returns, err := parseArgs(f.Results.List)
- if err != nil {
- return nil, fmt.Errorf("error parsing function returns for %q: %v", method.Name, err)
- }
- method.Returns = returns
- }
- return method, nil
- }
- func parseArgs(fields []*ast.Field) ([]arg, error) {
- var args []arg
- for _, f := range fields {
- if len(f.Names) == 0 {
- return nil, errBadReturn
- }
- for _, name := range f.Names {
- p, err := parseExpr(f.Type)
- if err != nil {
- return nil, err
- }
- args = append(args, arg{name.Name, p.value, p.pkg})
- }
- }
- return args, nil
- }
- type parsedExpr struct {
- value string
- pkg string
- }
- func parseExpr(e ast.Expr) (parsedExpr, error) {
- var parsed parsedExpr
- switch i := e.(type) {
- case *ast.Ident:
- parsed.value += i.Name
- case *ast.StarExpr:
- p, err := parseExpr(i.X)
- if err != nil {
- return parsed, err
- }
- parsed.value += "*"
- parsed.value += p.value
- parsed.pkg = p.pkg
- case *ast.SelectorExpr:
- p, err := parseExpr(i.X)
- if err != nil {
- return parsed, err
- }
- parsed.pkg = p.value
- parsed.value += p.value + "."
- parsed.value += i.Sel.Name
- case *ast.MapType:
- parsed.value += "map["
- p, err := parseExpr(i.Key)
- if err != nil {
- return parsed, err
- }
- parsed.value += p.value
- parsed.value += "]"
- p, err = parseExpr(i.Value)
- if err != nil {
- return parsed, err
- }
- parsed.value += p.value
- parsed.pkg = p.pkg
- case *ast.ArrayType:
- parsed.value += "[]"
- p, err := parseExpr(i.Elt)
- if err != nil {
- return parsed, err
- }
- parsed.value += p.value
- parsed.pkg = p.pkg
- default:
- return parsed, errUnexpectedType{"*ast.Ident or *ast.StarExpr", i}
- }
- return parsed, nil
- }
|