123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- package assert
- import (
- "errors"
- "fmt"
- "go/ast"
- "gotest.tools/v3/assert/cmp"
- "gotest.tools/v3/internal/format"
- "gotest.tools/v3/internal/source"
- )
- // RunComparison and return Comparison.Success. If the comparison fails a messages
- // will be printed using t.Log.
- func RunComparison(
- t LogT,
- argSelector argSelector,
- f cmp.Comparison,
- msgAndArgs ...interface{},
- ) bool {
- if ht, ok := t.(helperT); ok {
- ht.Helper()
- }
- result := f()
- if result.Success() {
- return true
- }
- if source.IsUpdate() {
- if updater, ok := result.(updateExpected); ok {
- const stackIndex = 3 // Assert/Check, assert, RunComparison
- err := updater.UpdatedExpected(stackIndex)
- switch {
- case err == nil:
- return true
- case errors.Is(err, source.ErrNotFound):
- // do nothing, fallthrough to regular failure message
- default:
- t.Log("failed to update source", err)
- return false
- }
- }
- }
- var message string
- switch typed := result.(type) {
- case resultWithComparisonArgs:
- const stackIndex = 3 // Assert/Check, assert, RunComparison
- args, err := source.CallExprArgs(stackIndex)
- if err != nil {
- t.Log(err.Error())
- }
- message = typed.FailureMessage(filterPrintableExpr(argSelector(args)))
- case resultBasic:
- message = typed.FailureMessage()
- default:
- message = fmt.Sprintf("comparison returned invalid Result type: %T", result)
- }
- t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
- return false
- }
- type resultWithComparisonArgs interface {
- FailureMessage(args []ast.Expr) string
- }
- type resultBasic interface {
- FailureMessage() string
- }
- type updateExpected interface {
- UpdatedExpected(stackIndex int) error
- }
- // filterPrintableExpr filters the ast.Expr slice to only include Expr that are
- // easy to read when printed and contain relevant information to an assertion.
- //
- // Ident and SelectorExpr are included because they print nicely and the variable
- // names may provide additional context to their values.
- // BasicLit and CompositeLit are excluded because their source is equivalent to
- // their value, which is already available.
- // Other types are ignored for now, but could be added if they are relevant.
- func filterPrintableExpr(args []ast.Expr) []ast.Expr {
- result := make([]ast.Expr, len(args))
- for i, arg := range args {
- if isShortPrintableExpr(arg) {
- result[i] = arg
- continue
- }
- if starExpr, ok := arg.(*ast.StarExpr); ok {
- result[i] = starExpr.X
- continue
- }
- }
- return result
- }
- func isShortPrintableExpr(expr ast.Expr) bool {
- switch expr.(type) {
- case *ast.Ident, *ast.SelectorExpr, *ast.IndexExpr, *ast.SliceExpr:
- return true
- case *ast.BinaryExpr, *ast.UnaryExpr:
- return true
- default:
- // CallExpr, ParenExpr, TypeAssertExpr, KeyValueExpr, StarExpr
- return false
- }
- }
- type argSelector func([]ast.Expr) []ast.Expr
- // ArgsAfterT selects args starting at position 1. Used when the caller has a
- // testing.T as the first argument, and the args to select should follow it.
- func ArgsAfterT(args []ast.Expr) []ast.Expr {
- if len(args) < 1 {
- return nil
- }
- return args[1:]
- }
- // ArgsFromComparisonCall selects args from the CallExpression at position 1.
- // Used when the caller has a testing.T as the first argument, and the args to
- // select are passed to the cmp.Comparison at position 1.
- func ArgsFromComparisonCall(args []ast.Expr) []ast.Expr {
- if len(args) <= 1 {
- return nil
- }
- if callExpr, ok := args[1].(*ast.CallExpr); ok {
- return callExpr.Args
- }
- return nil
- }
- // ArgsAtZeroIndex selects args from the CallExpression at position 1.
- // Used when the caller accepts a single cmp.Comparison argument.
- func ArgsAtZeroIndex(args []ast.Expr) []ast.Expr {
- if len(args) == 0 {
- return nil
- }
- if callExpr, ok := args[0].(*ast.CallExpr); ok {
- return callExpr.Args
- }
- return nil
- }
|