source.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. // Package source provides utilities for handling source-code.
  2. package source // import "gotest.tools/v3/internal/source"
  3. import (
  4. "bytes"
  5. "errors"
  6. "fmt"
  7. "go/ast"
  8. "go/format"
  9. "go/parser"
  10. "go/token"
  11. "os"
  12. "runtime"
  13. )
  14. // FormattedCallExprArg returns the argument from an ast.CallExpr at the
  15. // index in the call stack. The argument is formatted using FormatNode.
  16. func FormattedCallExprArg(stackIndex int, argPos int) (string, error) {
  17. args, err := CallExprArgs(stackIndex + 1)
  18. if err != nil {
  19. return "", err
  20. }
  21. if argPos >= len(args) {
  22. return "", errors.New("failed to find expression")
  23. }
  24. return FormatNode(args[argPos])
  25. }
  26. // CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at
  27. // the index in the call stack.
  28. func CallExprArgs(stackIndex int) ([]ast.Expr, error) {
  29. _, filename, line, ok := runtime.Caller(stackIndex + 1)
  30. if !ok {
  31. return nil, errors.New("failed to get call stack")
  32. }
  33. debug("call stack position: %s:%d", filename, line)
  34. fileset := token.NewFileSet()
  35. astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors)
  36. if err != nil {
  37. return nil, fmt.Errorf("failed to parse source file %s: %w", filename, err)
  38. }
  39. expr, err := getCallExprArgs(fileset, astFile, line)
  40. if err != nil {
  41. return nil, fmt.Errorf("call from %s:%d: %w", filename, line, err)
  42. }
  43. return expr, nil
  44. }
  45. func getNodeAtLine(fileset *token.FileSet, astFile ast.Node, lineNum int) (ast.Node, error) {
  46. if node := scanToLine(fileset, astFile, lineNum); node != nil {
  47. return node, nil
  48. }
  49. if node := scanToDeferLine(fileset, astFile, lineNum); node != nil {
  50. node, err := guessDefer(node)
  51. if err != nil || node != nil {
  52. return node, err
  53. }
  54. }
  55. return nil, nil
  56. }
  57. func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
  58. var matchedNode ast.Node
  59. ast.Inspect(node, func(node ast.Node) bool {
  60. switch {
  61. case node == nil || matchedNode != nil:
  62. return false
  63. case fileset.Position(node.Pos()).Line == lineNum:
  64. matchedNode = node
  65. return false
  66. }
  67. return true
  68. })
  69. return matchedNode
  70. }
  71. func getCallExprArgs(fileset *token.FileSet, astFile ast.Node, line int) ([]ast.Expr, error) {
  72. node, err := getNodeAtLine(fileset, astFile, line)
  73. switch {
  74. case err != nil:
  75. return nil, err
  76. case node == nil:
  77. return nil, fmt.Errorf("failed to find an expression")
  78. }
  79. debug("found node: %s", debugFormatNode{node})
  80. visitor := &callExprVisitor{}
  81. ast.Walk(visitor, node)
  82. if visitor.expr == nil {
  83. return nil, errors.New("failed to find call expression")
  84. }
  85. debug("callExpr: %s", debugFormatNode{visitor.expr})
  86. return visitor.expr.Args, nil
  87. }
  88. type callExprVisitor struct {
  89. expr *ast.CallExpr
  90. }
  91. func (v *callExprVisitor) Visit(node ast.Node) ast.Visitor {
  92. if v.expr != nil || node == nil {
  93. return nil
  94. }
  95. debug("visit: %s", debugFormatNode{node})
  96. switch typed := node.(type) {
  97. case *ast.CallExpr:
  98. v.expr = typed
  99. return nil
  100. case *ast.DeferStmt:
  101. ast.Walk(v, typed.Call.Fun)
  102. return nil
  103. }
  104. return v
  105. }
  106. // FormatNode using go/format.Node and return the result as a string
  107. func FormatNode(node ast.Node) (string, error) {
  108. buf := new(bytes.Buffer)
  109. err := format.Node(buf, token.NewFileSet(), node)
  110. return buf.String(), err
  111. }
  112. var debugEnabled = os.Getenv("GOTESTTOOLS_DEBUG") != ""
  113. func debug(format string, args ...interface{}) {
  114. if debugEnabled {
  115. fmt.Fprintf(os.Stderr, "DEBUG: "+format+"\n", args...)
  116. }
  117. }
  118. type debugFormatNode struct {
  119. ast.Node
  120. }
  121. func (n debugFormatNode) String() string {
  122. if n.Node == nil {
  123. return "none"
  124. }
  125. out, err := FormatNode(n.Node)
  126. if err != nil {
  127. return fmt.Sprintf("failed to format %s: %s", n.Node, err)
  128. }
  129. return fmt.Sprintf("(%T) %s", n.Node, out)
  130. }