assert.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. package assert
  2. import (
  3. "fmt"
  4. "go/ast"
  5. "go/token"
  6. "reflect"
  7. "gotest.tools/v3/assert/cmp"
  8. "gotest.tools/v3/internal/format"
  9. "gotest.tools/v3/internal/source"
  10. )
  11. // LogT is the subset of testing.T used by the assert package.
  12. type LogT interface {
  13. Log(args ...interface{})
  14. }
  15. type helperT interface {
  16. Helper()
  17. }
  18. const failureMessage = "assertion failed: "
  19. // Eval the comparison and print a failure messages if the comparison has failed.
  20. // nolint: gocyclo
  21. func Eval(
  22. t LogT,
  23. argSelector argSelector,
  24. comparison interface{},
  25. msgAndArgs ...interface{},
  26. ) bool {
  27. if ht, ok := t.(helperT); ok {
  28. ht.Helper()
  29. }
  30. var success bool
  31. switch check := comparison.(type) {
  32. case bool:
  33. if check {
  34. return true
  35. }
  36. logFailureFromBool(t, msgAndArgs...)
  37. // Undocumented legacy comparison without Result type
  38. case func() (success bool, message string):
  39. success = runCompareFunc(t, check, msgAndArgs...)
  40. case nil:
  41. return true
  42. case error:
  43. msg := failureMsgFromError(check)
  44. t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))
  45. case cmp.Comparison:
  46. success = RunComparison(t, argSelector, check, msgAndArgs...)
  47. case func() cmp.Result:
  48. success = RunComparison(t, argSelector, check, msgAndArgs...)
  49. default:
  50. t.Log(fmt.Sprintf("invalid Comparison: %v (%T)", check, check))
  51. }
  52. return success
  53. }
  54. func runCompareFunc(
  55. t LogT,
  56. f func() (success bool, message string),
  57. msgAndArgs ...interface{},
  58. ) bool {
  59. if ht, ok := t.(helperT); ok {
  60. ht.Helper()
  61. }
  62. if success, message := f(); !success {
  63. t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
  64. return false
  65. }
  66. return true
  67. }
  68. func logFailureFromBool(t LogT, msgAndArgs ...interface{}) {
  69. if ht, ok := t.(helperT); ok {
  70. ht.Helper()
  71. }
  72. const stackIndex = 3 // Assert()/Check(), assert(), logFailureFromBool()
  73. args, err := source.CallExprArgs(stackIndex)
  74. if err != nil {
  75. t.Log(err.Error())
  76. return
  77. }
  78. const comparisonArgIndex = 1 // Assert(t, comparison)
  79. if len(args) <= comparisonArgIndex {
  80. t.Log(failureMessage + "but assert failed to find the expression to print")
  81. return
  82. }
  83. msg, err := boolFailureMessage(args[comparisonArgIndex])
  84. if err != nil {
  85. t.Log(err.Error())
  86. msg = "expression is false"
  87. }
  88. t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))
  89. }
  90. func failureMsgFromError(err error) string {
  91. // Handle errors with non-nil types
  92. v := reflect.ValueOf(err)
  93. if v.Kind() == reflect.Ptr && v.IsNil() {
  94. return fmt.Sprintf("error is not nil: error has type %T", err)
  95. }
  96. return "error is not nil: " + err.Error()
  97. }
  98. func boolFailureMessage(expr ast.Expr) (string, error) {
  99. if binaryExpr, ok := expr.(*ast.BinaryExpr); ok && binaryExpr.Op == token.NEQ {
  100. x, err := source.FormatNode(binaryExpr.X)
  101. if err != nil {
  102. return "", err
  103. }
  104. y, err := source.FormatNode(binaryExpr.Y)
  105. if err != nil {
  106. return "", err
  107. }
  108. return x + " is " + y, nil
  109. }
  110. if unaryExpr, ok := expr.(*ast.UnaryExpr); ok && unaryExpr.Op == token.NOT {
  111. x, err := source.FormatNode(unaryExpr.X)
  112. if err != nil {
  113. return "", err
  114. }
  115. return x + " is true", nil
  116. }
  117. formatted, err := source.FormatNode(expr)
  118. if err != nil {
  119. return "", err
  120. }
  121. return "expression is false: " + formatted, nil
  122. }