source.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. package source // import "gotest.tools/v3/internal/source"
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "go/ast"
  7. "go/format"
  8. "go/parser"
  9. "go/token"
  10. "os"
  11. "runtime"
  12. )
  13. // FormattedCallExprArg returns the argument from an ast.CallExpr at the
  14. // index in the call stack. The argument is formatted using FormatNode.
  15. func FormattedCallExprArg(stackIndex int, argPos int) (string, error) {
  16. args, err := CallExprArgs(stackIndex + 1)
  17. if err != nil {
  18. return "", err
  19. }
  20. if argPos >= len(args) {
  21. return "", errors.New("failed to find expression")
  22. }
  23. return FormatNode(args[argPos])
  24. }
  25. // CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at
  26. // the index in the call stack.
  27. func CallExprArgs(stackIndex int) ([]ast.Expr, error) {
  28. _, filename, line, ok := runtime.Caller(stackIndex + 1)
  29. if !ok {
  30. return nil, errors.New("failed to get call stack")
  31. }
  32. debug("call stack position: %s:%d", filename, line)
  33. fileset := token.NewFileSet()
  34. astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors)
  35. if err != nil {
  36. return nil, fmt.Errorf("failed to parse source file %s: %w", filename, err)
  37. }
  38. expr, err := getCallExprArgs(fileset, astFile, line)
  39. if err != nil {
  40. return nil, fmt.Errorf("call from %s:%d: %w", filename, line, err)
  41. }
  42. return expr, nil
  43. }
  44. func getNodeAtLine(fileset *token.FileSet, astFile ast.Node, lineNum int) (ast.Node, error) {
  45. if node := scanToLine(fileset, astFile, lineNum); node != nil {
  46. return node, nil
  47. }
  48. if node := scanToDeferLine(fileset, astFile, lineNum); node != nil {
  49. node, err := guessDefer(node)
  50. if err != nil || node != nil {
  51. return node, err
  52. }
  53. }
  54. return nil, nil
  55. }
  56. func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
  57. var matchedNode ast.Node
  58. ast.Inspect(node, func(node ast.Node) bool {
  59. switch {
  60. case node == nil || matchedNode != nil:
  61. return false
  62. case fileset.Position(node.Pos()).Line == lineNum:
  63. matchedNode = node
  64. return false
  65. }
  66. return true
  67. })
  68. return matchedNode
  69. }
  70. func getCallExprArgs(fileset *token.FileSet, astFile ast.Node, line int) ([]ast.Expr, error) {
  71. node, err := getNodeAtLine(fileset, astFile, line)
  72. switch {
  73. case err != nil:
  74. return nil, err
  75. case node == nil:
  76. return nil, fmt.Errorf("failed to find an expression")
  77. }
  78. debug("found node: %s", debugFormatNode{node})
  79. visitor := &callExprVisitor{}
  80. ast.Walk(visitor, node)
  81. if visitor.expr == nil {
  82. return nil, errors.New("failed to find call expression")
  83. }
  84. debug("callExpr: %s", debugFormatNode{visitor.expr})
  85. return visitor.expr.Args, nil
  86. }
  87. type callExprVisitor struct {
  88. expr *ast.CallExpr
  89. }
  90. func (v *callExprVisitor) Visit(node ast.Node) ast.Visitor {
  91. if v.expr != nil || node == nil {
  92. return nil
  93. }
  94. debug("visit: %s", debugFormatNode{node})
  95. switch typed := node.(type) {
  96. case *ast.CallExpr:
  97. v.expr = typed
  98. return nil
  99. case *ast.DeferStmt:
  100. ast.Walk(v, typed.Call.Fun)
  101. return nil
  102. }
  103. return v
  104. }
  105. // FormatNode using go/format.Node and return the result as a string
  106. func FormatNode(node ast.Node) (string, error) {
  107. buf := new(bytes.Buffer)
  108. err := format.Node(buf, token.NewFileSet(), node)
  109. return buf.String(), err
  110. }
  111. var debugEnabled = os.Getenv("GOTESTTOOLS_DEBUG") != ""
  112. func debug(format string, args ...interface{}) {
  113. if debugEnabled {
  114. fmt.Fprintf(os.Stderr, "DEBUG: "+format+"\n", args...)
  115. }
  116. }
  117. type debugFormatNode struct {
  118. ast.Node
  119. }
  120. func (n debugFormatNode) String() string {
  121. if n.Node == nil {
  122. return "none"
  123. }
  124. out, err := FormatNode(n.Node)
  125. if err != nil {
  126. return fmt.Sprintf("failed to format %s: %s", n.Node, err)
  127. }
  128. return fmt.Sprintf("(%T) %s", n.Node, out)
  129. }