assert.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. // Package assert provides internal utilties for assertions.
  2. package assert
  3. import (
  4. "fmt"
  5. "go/ast"
  6. "go/token"
  7. "reflect"
  8. "gotest.tools/v3/assert/cmp"
  9. "gotest.tools/v3/internal/format"
  10. "gotest.tools/v3/internal/source"
  11. )
  12. // LogT is the subset of testing.T used by the assert package.
  13. type LogT interface {
  14. Log(args ...interface{})
  15. }
  16. type helperT interface {
  17. Helper()
  18. }
  19. const failureMessage = "assertion failed: "
  20. // Eval the comparison and print a failure messages if the comparison has failed.
  21. func Eval(
  22. t LogT,
  23. argSelector argSelector,
  24. comparison interface{},
  25. msgAndArgs ...interface{},
  26. ) bool {
  27. if ht, ok := t.(helperT); ok {
  28. ht.Helper()
  29. }
  30. var success bool
  31. switch check := comparison.(type) {
  32. case bool:
  33. if check {
  34. return true
  35. }
  36. logFailureFromBool(t, msgAndArgs...)
  37. // Undocumented legacy comparison without Result type
  38. case func() (success bool, message string):
  39. success = runCompareFunc(t, check, msgAndArgs...)
  40. case nil:
  41. return true
  42. case error:
  43. msg := failureMsgFromError(check)
  44. t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))
  45. case cmp.Comparison:
  46. success = RunComparison(t, argSelector, check, msgAndArgs...)
  47. case func() cmp.Result:
  48. success = RunComparison(t, argSelector, check, msgAndArgs...)
  49. default:
  50. t.Log(fmt.Sprintf("invalid Comparison: %v (%T)", check, check))
  51. }
  52. return success
  53. }
  54. func runCompareFunc(
  55. t LogT,
  56. f func() (success bool, message string),
  57. msgAndArgs ...interface{},
  58. ) bool {
  59. if ht, ok := t.(helperT); ok {
  60. ht.Helper()
  61. }
  62. if success, message := f(); !success {
  63. t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
  64. return false
  65. }
  66. return true
  67. }
  68. func logFailureFromBool(t LogT, msgAndArgs ...interface{}) {
  69. if ht, ok := t.(helperT); ok {
  70. ht.Helper()
  71. }
  72. const stackIndex = 3 // Assert()/Check(), assert(), logFailureFromBool()
  73. args, err := source.CallExprArgs(stackIndex)
  74. if err != nil {
  75. t.Log(err.Error())
  76. return
  77. }
  78. const comparisonArgIndex = 1 // Assert(t, comparison)
  79. if len(args) <= comparisonArgIndex {
  80. t.Log(failureMessage + "but assert failed to find the expression to print")
  81. return
  82. }
  83. msg, err := boolFailureMessage(args[comparisonArgIndex])
  84. if err != nil {
  85. t.Log(err.Error())
  86. msg = "expression is false"
  87. }
  88. t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))
  89. }
  90. func failureMsgFromError(err error) string {
  91. // Handle errors with non-nil types
  92. v := reflect.ValueOf(err)
  93. if v.Kind() == reflect.Ptr && v.IsNil() {
  94. return fmt.Sprintf("error is not nil: error has type %T", err)
  95. }
  96. return "error is not nil: " + err.Error()
  97. }
  98. func boolFailureMessage(expr ast.Expr) (string, error) {
  99. if binaryExpr, ok := expr.(*ast.BinaryExpr); ok {
  100. x, err := source.FormatNode(binaryExpr.X)
  101. if err != nil {
  102. return "", err
  103. }
  104. y, err := source.FormatNode(binaryExpr.Y)
  105. if err != nil {
  106. return "", err
  107. }
  108. switch binaryExpr.Op {
  109. case token.NEQ:
  110. return x + " is " + y, nil
  111. case token.EQL:
  112. return x + " is not " + y, nil
  113. case token.GTR:
  114. return x + " is <= " + y, nil
  115. case token.LSS:
  116. return x + " is >= " + y, nil
  117. case token.GEQ:
  118. return x + " is less than " + y, nil
  119. case token.LEQ:
  120. return x + " is greater than " + y, nil
  121. }
  122. }
  123. if unaryExpr, ok := expr.(*ast.UnaryExpr); ok && unaryExpr.Op == token.NOT {
  124. x, err := source.FormatNode(unaryExpr.X)
  125. if err != nil {
  126. return "", err
  127. }
  128. return x + " is true", nil
  129. }
  130. if ident, ok := expr.(*ast.Ident); ok {
  131. return ident.Name + " is false", nil
  132. }
  133. formatted, err := source.FormatNode(expr)
  134. if err != nil {
  135. return "", err
  136. }
  137. return "expression is false: " + formatted, nil
  138. }