source.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. package source // import "gotest.tools/v3/internal/source"
  2. import (
  3. "bytes"
  4. "fmt"
  5. "go/ast"
  6. "go/format"
  7. "go/parser"
  8. "go/token"
  9. "os"
  10. "runtime"
  11. "strconv"
  12. "strings"
  13. "github.com/pkg/errors"
  14. )
  15. const baseStackIndex = 1
  16. // FormattedCallExprArg returns the argument from an ast.CallExpr at the
  17. // index in the call stack. The argument is formatted using FormatNode.
  18. func FormattedCallExprArg(stackIndex int, argPos int) (string, error) {
  19. args, err := CallExprArgs(stackIndex + 1)
  20. if err != nil {
  21. return "", err
  22. }
  23. if argPos >= len(args) {
  24. return "", errors.New("failed to find expression")
  25. }
  26. return FormatNode(args[argPos])
  27. }
  28. // CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at
  29. // the index in the call stack.
  30. func CallExprArgs(stackIndex int) ([]ast.Expr, error) {
  31. _, filename, lineNum, ok := runtime.Caller(baseStackIndex + stackIndex)
  32. if !ok {
  33. return nil, errors.New("failed to get call stack")
  34. }
  35. debug("call stack position: %s:%d", filename, lineNum)
  36. node, err := getNodeAtLine(filename, lineNum)
  37. if err != nil {
  38. return nil, err
  39. }
  40. debug("found node: %s", debugFormatNode{node})
  41. return getCallExprArgs(node)
  42. }
  43. func getNodeAtLine(filename string, lineNum int) (ast.Node, error) {
  44. fileset := token.NewFileSet()
  45. astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors)
  46. if err != nil {
  47. return nil, errors.Wrapf(err, "failed to parse source file: %s", filename)
  48. }
  49. if node := scanToLine(fileset, astFile, lineNum); node != nil {
  50. return node, nil
  51. }
  52. if node := scanToDeferLine(fileset, astFile, lineNum); node != nil {
  53. node, err := guessDefer(node)
  54. if err != nil || node != nil {
  55. return node, err
  56. }
  57. }
  58. return nil, errors.Errorf(
  59. "failed to find an expression on line %d in %s", lineNum, filename)
  60. }
  61. func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
  62. var matchedNode ast.Node
  63. ast.Inspect(node, func(node ast.Node) bool {
  64. switch {
  65. case node == nil || matchedNode != nil:
  66. return false
  67. case nodePosition(fileset, node).Line == lineNum:
  68. matchedNode = node
  69. return false
  70. }
  71. return true
  72. })
  73. return matchedNode
  74. }
  75. // In golang 1.9 the line number changed from being the line where the statement
  76. // ended to the line where the statement began.
  77. func nodePosition(fileset *token.FileSet, node ast.Node) token.Position {
  78. if goVersionBefore19 {
  79. return fileset.Position(node.End())
  80. }
  81. return fileset.Position(node.Pos())
  82. }
  83. // GoVersionLessThan returns true if runtime.Version() is semantically less than
  84. // version major.minor. Returns false if a release version can not be parsed from
  85. // runtime.Version().
  86. func GoVersionLessThan(major, minor int64) bool {
  87. version := runtime.Version()
  88. // not a release version
  89. if !strings.HasPrefix(version, "go") {
  90. return false
  91. }
  92. version = strings.TrimPrefix(version, "go")
  93. parts := strings.Split(version, ".")
  94. if len(parts) < 2 {
  95. return false
  96. }
  97. rMajor, err := strconv.ParseInt(parts[0], 10, 32)
  98. if err != nil {
  99. return false
  100. }
  101. if rMajor != major {
  102. return rMajor < major
  103. }
  104. rMinor, err := strconv.ParseInt(parts[1], 10, 32)
  105. if err != nil {
  106. return false
  107. }
  108. return rMinor < minor
  109. }
  110. var goVersionBefore19 = GoVersionLessThan(1, 9)
  111. func getCallExprArgs(node ast.Node) ([]ast.Expr, error) {
  112. visitor := &callExprVisitor{}
  113. ast.Walk(visitor, node)
  114. if visitor.expr == nil {
  115. return nil, errors.New("failed to find call expression")
  116. }
  117. debug("callExpr: %s", debugFormatNode{visitor.expr})
  118. return visitor.expr.Args, nil
  119. }
  120. type callExprVisitor struct {
  121. expr *ast.CallExpr
  122. }
  123. func (v *callExprVisitor) Visit(node ast.Node) ast.Visitor {
  124. if v.expr != nil || node == nil {
  125. return nil
  126. }
  127. debug("visit: %s", debugFormatNode{node})
  128. switch typed := node.(type) {
  129. case *ast.CallExpr:
  130. v.expr = typed
  131. return nil
  132. case *ast.DeferStmt:
  133. ast.Walk(v, typed.Call.Fun)
  134. return nil
  135. }
  136. return v
  137. }
  138. // FormatNode using go/format.Node and return the result as a string
  139. func FormatNode(node ast.Node) (string, error) {
  140. buf := new(bytes.Buffer)
  141. err := format.Node(buf, token.NewFileSet(), node)
  142. return buf.String(), err
  143. }
  144. var debugEnabled = os.Getenv("GOTESTTOOLS_DEBUG") != ""
  145. func debug(format string, args ...interface{}) {
  146. if debugEnabled {
  147. fmt.Fprintf(os.Stderr, "DEBUG: "+format+"\n", args...)
  148. }
  149. }
  150. type debugFormatNode struct {
  151. ast.Node
  152. }
  153. func (n debugFormatNode) String() string {
  154. out, err := FormatNode(n.Node)
  155. if err != nil {
  156. return fmt.Sprintf("failed to format %s: %s", n.Node, err)
  157. }
  158. return fmt.Sprintf("(%T) %s", n.Node, out)
  159. }