result.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. package assert
  2. import (
  3. "errors"
  4. "fmt"
  5. "go/ast"
  6. "gotest.tools/v3/assert/cmp"
  7. "gotest.tools/v3/internal/format"
  8. "gotest.tools/v3/internal/source"
  9. )
  10. // RunComparison and return Comparison.Success. If the comparison fails a messages
  11. // will be printed using t.Log.
  12. func RunComparison(
  13. t LogT,
  14. argSelector argSelector,
  15. f cmp.Comparison,
  16. msgAndArgs ...interface{},
  17. ) bool {
  18. if ht, ok := t.(helperT); ok {
  19. ht.Helper()
  20. }
  21. result := f()
  22. if result.Success() {
  23. return true
  24. }
  25. if source.IsUpdate() {
  26. if updater, ok := result.(updateExpected); ok {
  27. const stackIndex = 3 // Assert/Check, assert, RunComparison
  28. err := updater.UpdatedExpected(stackIndex)
  29. switch {
  30. case err == nil:
  31. return true
  32. case errors.Is(err, source.ErrNotFound):
  33. // do nothing, fallthrough to regular failure message
  34. default:
  35. t.Log("failed to update source", err)
  36. return false
  37. }
  38. }
  39. }
  40. var message string
  41. switch typed := result.(type) {
  42. case resultWithComparisonArgs:
  43. const stackIndex = 3 // Assert/Check, assert, RunComparison
  44. args, err := source.CallExprArgs(stackIndex)
  45. if err != nil {
  46. t.Log(err.Error())
  47. }
  48. message = typed.FailureMessage(filterPrintableExpr(argSelector(args)))
  49. case resultBasic:
  50. message = typed.FailureMessage()
  51. default:
  52. message = fmt.Sprintf("comparison returned invalid Result type: %T", result)
  53. }
  54. t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
  55. return false
  56. }
  57. type resultWithComparisonArgs interface {
  58. FailureMessage(args []ast.Expr) string
  59. }
  60. type resultBasic interface {
  61. FailureMessage() string
  62. }
  63. type updateExpected interface {
  64. UpdatedExpected(stackIndex int) error
  65. }
  66. // filterPrintableExpr filters the ast.Expr slice to only include Expr that are
  67. // easy to read when printed and contain relevant information to an assertion.
  68. //
  69. // Ident and SelectorExpr are included because they print nicely and the variable
  70. // names may provide additional context to their values.
  71. // BasicLit and CompositeLit are excluded because their source is equivalent to
  72. // their value, which is already available.
  73. // Other types are ignored for now, but could be added if they are relevant.
  74. func filterPrintableExpr(args []ast.Expr) []ast.Expr {
  75. result := make([]ast.Expr, len(args))
  76. for i, arg := range args {
  77. if isShortPrintableExpr(arg) {
  78. result[i] = arg
  79. continue
  80. }
  81. if starExpr, ok := arg.(*ast.StarExpr); ok {
  82. result[i] = starExpr.X
  83. continue
  84. }
  85. }
  86. return result
  87. }
  88. func isShortPrintableExpr(expr ast.Expr) bool {
  89. switch expr.(type) {
  90. case *ast.Ident, *ast.SelectorExpr, *ast.IndexExpr, *ast.SliceExpr:
  91. return true
  92. case *ast.BinaryExpr, *ast.UnaryExpr:
  93. return true
  94. default:
  95. // CallExpr, ParenExpr, TypeAssertExpr, KeyValueExpr, StarExpr
  96. return false
  97. }
  98. }
  99. type argSelector func([]ast.Expr) []ast.Expr
  100. // ArgsAfterT selects args starting at position 1. Used when the caller has a
  101. // testing.T as the first argument, and the args to select should follow it.
  102. func ArgsAfterT(args []ast.Expr) []ast.Expr {
  103. if len(args) < 1 {
  104. return nil
  105. }
  106. return args[1:]
  107. }
  108. // ArgsFromComparisonCall selects args from the CallExpression at position 1.
  109. // Used when the caller has a testing.T as the first argument, and the args to
  110. // select are passed to the cmp.Comparison at position 1.
  111. func ArgsFromComparisonCall(args []ast.Expr) []ast.Expr {
  112. if len(args) <= 1 {
  113. return nil
  114. }
  115. if callExpr, ok := args[1].(*ast.CallExpr); ok {
  116. return callExpr.Args
  117. }
  118. return nil
  119. }
  120. // ArgsAtZeroIndex selects args from the CallExpression at position 1.
  121. // Used when the caller accepts a single cmp.Comparison argument.
  122. func ArgsAtZeroIndex(args []ast.Expr) []ast.Expr {
  123. if len(args) == 0 {
  124. return nil
  125. }
  126. if callExpr, ok := args[0].(*ast.CallExpr); ok {
  127. return callExpr.Args
  128. }
  129. return nil
  130. }