source.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. package source // import "gotest.tools/internal/source"
  2. import (
  3. "bytes"
  4. "fmt"
  5. "go/ast"
  6. "go/format"
  7. "go/parser"
  8. "go/token"
  9. "os"
  10. "runtime"
  11. "strconv"
  12. "strings"
  13. "github.com/pkg/errors"
  14. )
  15. const baseStackIndex = 1
  16. // FormattedCallExprArg returns the argument from an ast.CallExpr at the
  17. // index in the call stack. The argument is formatted using FormatNode.
  18. func FormattedCallExprArg(stackIndex int, argPos int) (string, error) {
  19. args, err := CallExprArgs(stackIndex + 1)
  20. if err != nil {
  21. return "", err
  22. }
  23. return FormatNode(args[argPos])
  24. }
  25. func getNodeAtLine(filename string, lineNum int) (ast.Node, error) {
  26. fileset := token.NewFileSet()
  27. astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors)
  28. if err != nil {
  29. return nil, errors.Wrapf(err, "failed to parse source file: %s", filename)
  30. }
  31. node := scanToLine(fileset, astFile, lineNum)
  32. if node == nil {
  33. return nil, errors.Errorf(
  34. "failed to find an expression on line %d in %s", lineNum, filename)
  35. }
  36. return node, nil
  37. }
  38. func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
  39. v := &scanToLineVisitor{lineNum: lineNum, fileset: fileset}
  40. ast.Walk(v, node)
  41. return v.matchedNode
  42. }
  43. type scanToLineVisitor struct {
  44. lineNum int
  45. matchedNode ast.Node
  46. fileset *token.FileSet
  47. }
  48. func (v *scanToLineVisitor) Visit(node ast.Node) ast.Visitor {
  49. if node == nil || v.matchedNode != nil {
  50. return nil
  51. }
  52. if v.nodePosition(node).Line == v.lineNum {
  53. v.matchedNode = node
  54. return nil
  55. }
  56. return v
  57. }
  58. // In golang 1.9 the line number changed from being the line where the statement
  59. // ended to the line where the statement began.
  60. func (v *scanToLineVisitor) nodePosition(node ast.Node) token.Position {
  61. if goVersionBefore19 {
  62. return v.fileset.Position(node.End())
  63. }
  64. return v.fileset.Position(node.Pos())
  65. }
  66. var goVersionBefore19 = isGOVersionBefore19()
  67. func isGOVersionBefore19() bool {
  68. version := runtime.Version()
  69. // not a release version
  70. if !strings.HasPrefix(version, "go") {
  71. return false
  72. }
  73. version = strings.TrimPrefix(version, "go")
  74. parts := strings.Split(version, ".")
  75. if len(parts) < 2 {
  76. return false
  77. }
  78. minor, err := strconv.ParseInt(parts[1], 10, 32)
  79. return err == nil && parts[0] == "1" && minor < 9
  80. }
  81. func getCallExprArgs(node ast.Node) ([]ast.Expr, error) {
  82. visitor := &callExprVisitor{}
  83. ast.Walk(visitor, node)
  84. if visitor.expr == nil {
  85. return nil, errors.New("failed to find call expression")
  86. }
  87. return visitor.expr.Args, nil
  88. }
  89. type callExprVisitor struct {
  90. expr *ast.CallExpr
  91. }
  92. func (v *callExprVisitor) Visit(node ast.Node) ast.Visitor {
  93. if v.expr != nil || node == nil {
  94. return nil
  95. }
  96. debug("visit (%T): %s", node, debugFormatNode{node})
  97. if callExpr, ok := node.(*ast.CallExpr); ok {
  98. v.expr = callExpr
  99. return nil
  100. }
  101. return v
  102. }
  103. // FormatNode using go/format.Node and return the result as a string
  104. func FormatNode(node ast.Node) (string, error) {
  105. buf := new(bytes.Buffer)
  106. err := format.Node(buf, token.NewFileSet(), node)
  107. return buf.String(), err
  108. }
  109. // CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at
  110. // the index in the call stack.
  111. func CallExprArgs(stackIndex int) ([]ast.Expr, error) {
  112. _, filename, lineNum, ok := runtime.Caller(baseStackIndex + stackIndex)
  113. if !ok {
  114. return nil, errors.New("failed to get call stack")
  115. }
  116. debug("call stack position: %s:%d", filename, lineNum)
  117. node, err := getNodeAtLine(filename, lineNum)
  118. if err != nil {
  119. return nil, err
  120. }
  121. debug("found node (%T): %s", node, debugFormatNode{node})
  122. return getCallExprArgs(node)
  123. }
  124. var debugEnabled = os.Getenv("GOTESTYOURSELF_DEBUG") != ""
  125. func debug(format string, args ...interface{}) {
  126. if debugEnabled {
  127. fmt.Fprintf(os.Stderr, "DEBUG: "+format+"\n", args...)
  128. }
  129. }
  130. type debugFormatNode struct {
  131. ast.Node
  132. }
  133. func (n debugFormatNode) String() string {
  134. out, err := FormatNode(n.Node)
  135. if err != nil {
  136. return fmt.Sprintf("failed to format %s: %s", n.Node, err)
  137. }
  138. return out
  139. }