compare.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. /*Package cmp provides Comparisons for Assert and Check*/
  2. package cmp // import "gotest.tools/v3/assert/cmp"
  3. import (
  4. "errors"
  5. "fmt"
  6. "reflect"
  7. "regexp"
  8. "strings"
  9. "github.com/google/go-cmp/cmp"
  10. "gotest.tools/v3/internal/format"
  11. )
  12. // Comparison is a function which compares values and returns [ResultSuccess] if
  13. // the actual value matches the expected value. If the values do not match the
  14. // [Result] will contain a message about why it failed.
  15. type Comparison func() Result
  16. // DeepEqual compares two values using [github.com/google/go-cmp/cmp]
  17. // and succeeds if the values are equal.
  18. //
  19. // The comparison can be customized using comparison Options.
  20. // Package [gotest.tools/v3/assert/opt] provides some additional
  21. // commonly used Options.
  22. func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison {
  23. return func() (result Result) {
  24. defer func() {
  25. if panicmsg, handled := handleCmpPanic(recover()); handled {
  26. result = ResultFailure(panicmsg)
  27. }
  28. }()
  29. diff := cmp.Diff(x, y, opts...)
  30. if diff == "" {
  31. return ResultSuccess
  32. }
  33. return multiLineDiffResult(diff, x, y)
  34. }
  35. }
  36. func handleCmpPanic(r interface{}) (string, bool) {
  37. if r == nil {
  38. return "", false
  39. }
  40. panicmsg, ok := r.(string)
  41. if !ok {
  42. panic(r)
  43. }
  44. switch {
  45. case strings.HasPrefix(panicmsg, "cannot handle unexported field"):
  46. return panicmsg, true
  47. }
  48. panic(r)
  49. }
  50. func toResult(success bool, msg string) Result {
  51. if success {
  52. return ResultSuccess
  53. }
  54. return ResultFailure(msg)
  55. }
  56. // RegexOrPattern may be either a [*regexp.Regexp] or a string that is a valid
  57. // regexp pattern.
  58. type RegexOrPattern interface{}
  59. // Regexp succeeds if value v matches regular expression re.
  60. //
  61. // Example:
  62. //
  63. // assert.Assert(t, cmp.Regexp("^[0-9a-f]{32}$", str))
  64. // r := regexp.MustCompile("^[0-9a-f]{32}$")
  65. // assert.Assert(t, cmp.Regexp(r, str))
  66. func Regexp(re RegexOrPattern, v string) Comparison {
  67. match := func(re *regexp.Regexp) Result {
  68. return toResult(
  69. re.MatchString(v),
  70. fmt.Sprintf("value %q does not match regexp %q", v, re.String()))
  71. }
  72. return func() Result {
  73. switch regex := re.(type) {
  74. case *regexp.Regexp:
  75. return match(regex)
  76. case string:
  77. re, err := regexp.Compile(regex)
  78. if err != nil {
  79. return ResultFailure(err.Error())
  80. }
  81. return match(re)
  82. default:
  83. return ResultFailure(fmt.Sprintf("invalid type %T for regex pattern", regex))
  84. }
  85. }
  86. }
  87. // Equal succeeds if x == y. See [gotest.tools/v3/assert.Equal] for full documentation.
  88. func Equal(x, y interface{}) Comparison {
  89. return func() Result {
  90. switch {
  91. case x == y:
  92. return ResultSuccess
  93. case isMultiLineStringCompare(x, y):
  94. diff := format.UnifiedDiff(format.DiffConfig{A: x.(string), B: y.(string)})
  95. return multiLineDiffResult(diff, x, y)
  96. }
  97. return ResultFailureTemplate(`
  98. {{- printf "%v" .Data.x}} (
  99. {{- with callArg 0 }}{{ formatNode . }} {{end -}}
  100. {{- printf "%T" .Data.x -}}
  101. ) != {{ printf "%v" .Data.y}} (
  102. {{- with callArg 1 }}{{ formatNode . }} {{end -}}
  103. {{- printf "%T" .Data.y -}}
  104. )`,
  105. map[string]interface{}{"x": x, "y": y})
  106. }
  107. }
  108. func isMultiLineStringCompare(x, y interface{}) bool {
  109. strX, ok := x.(string)
  110. if !ok {
  111. return false
  112. }
  113. strY, ok := y.(string)
  114. if !ok {
  115. return false
  116. }
  117. return strings.Contains(strX, "\n") || strings.Contains(strY, "\n")
  118. }
  119. func multiLineDiffResult(diff string, x, y interface{}) Result {
  120. return ResultFailureTemplate(`
  121. --- {{ with callArg 0 }}{{ formatNode . }}{{else}}←{{end}}
  122. +++ {{ with callArg 1 }}{{ formatNode . }}{{else}}→{{end}}
  123. {{ .Data.diff }}`,
  124. map[string]interface{}{"diff": diff, "x": x, "y": y})
  125. }
  126. // Len succeeds if the sequence has the expected length.
  127. func Len(seq interface{}, expected int) Comparison {
  128. return func() (result Result) {
  129. defer func() {
  130. if e := recover(); e != nil {
  131. result = ResultFailure(fmt.Sprintf("type %T does not have a length", seq))
  132. }
  133. }()
  134. value := reflect.ValueOf(seq)
  135. length := value.Len()
  136. if length == expected {
  137. return ResultSuccess
  138. }
  139. msg := fmt.Sprintf("expected %s (length %d) to have length %d", seq, length, expected)
  140. return ResultFailure(msg)
  141. }
  142. }
  143. // Contains succeeds if item is in collection. Collection may be a string, map,
  144. // slice, or array.
  145. //
  146. // If collection is a string, item must also be a string, and is compared using
  147. // [strings.Contains].
  148. // If collection is a Map, contains will succeed if item is a key in the map.
  149. // If collection is a slice or array, item is compared to each item in the
  150. // sequence using [reflect.DeepEqual].
  151. func Contains(collection interface{}, item interface{}) Comparison {
  152. return func() Result {
  153. colValue := reflect.ValueOf(collection)
  154. if !colValue.IsValid() {
  155. return ResultFailure("nil does not contain items")
  156. }
  157. msg := fmt.Sprintf("%v does not contain %v", collection, item)
  158. itemValue := reflect.ValueOf(item)
  159. switch colValue.Type().Kind() {
  160. case reflect.String:
  161. if itemValue.Type().Kind() != reflect.String {
  162. return ResultFailure("string may only contain strings")
  163. }
  164. return toResult(
  165. strings.Contains(colValue.String(), itemValue.String()),
  166. fmt.Sprintf("string %q does not contain %q", collection, item))
  167. case reflect.Map:
  168. if itemValue.Type() != colValue.Type().Key() {
  169. return ResultFailure(fmt.Sprintf(
  170. "%v can not contain a %v key", colValue.Type(), itemValue.Type()))
  171. }
  172. return toResult(colValue.MapIndex(itemValue).IsValid(), msg)
  173. case reflect.Slice, reflect.Array:
  174. for i := 0; i < colValue.Len(); i++ {
  175. if reflect.DeepEqual(colValue.Index(i).Interface(), item) {
  176. return ResultSuccess
  177. }
  178. }
  179. return ResultFailure(msg)
  180. default:
  181. return ResultFailure(fmt.Sprintf("type %T does not contain items", collection))
  182. }
  183. }
  184. }
  185. // Panics succeeds if f() panics.
  186. func Panics(f func()) Comparison {
  187. return func() (result Result) {
  188. defer func() {
  189. if err := recover(); err != nil {
  190. result = ResultSuccess
  191. }
  192. }()
  193. f()
  194. return ResultFailure("did not panic")
  195. }
  196. }
  197. // Error succeeds if err is a non-nil error, and the error message equals the
  198. // expected message.
  199. func Error(err error, message string) Comparison {
  200. return func() Result {
  201. switch {
  202. case err == nil:
  203. return ResultFailure("expected an error, got nil")
  204. case err.Error() != message:
  205. return ResultFailure(fmt.Sprintf(
  206. "expected error %q, got %s", message, formatErrorMessage(err)))
  207. }
  208. return ResultSuccess
  209. }
  210. }
  211. // ErrorContains succeeds if err is a non-nil error, and the error message contains
  212. // the expected substring.
  213. func ErrorContains(err error, substring string) Comparison {
  214. return func() Result {
  215. switch {
  216. case err == nil:
  217. return ResultFailure("expected an error, got nil")
  218. case !strings.Contains(err.Error(), substring):
  219. return ResultFailure(fmt.Sprintf(
  220. "expected error to contain %q, got %s", substring, formatErrorMessage(err)))
  221. }
  222. return ResultSuccess
  223. }
  224. }
  225. type causer interface {
  226. Cause() error
  227. }
  228. func formatErrorMessage(err error) string {
  229. //nolint:errorlint // unwrapping is not appropriate here
  230. if _, ok := err.(causer); ok {
  231. return fmt.Sprintf("%q\n%+v", err, err)
  232. }
  233. // This error was not wrapped with github.com/pkg/errors
  234. return fmt.Sprintf("%q", err)
  235. }
  236. // Nil succeeds if obj is a nil interface, pointer, or function.
  237. //
  238. // Use [gotest.tools/v3/assert.NilError] for comparing errors. Use Len(obj, 0) for comparing slices,
  239. // maps, and channels.
  240. func Nil(obj interface{}) Comparison {
  241. msgFunc := func(value reflect.Value) string {
  242. return fmt.Sprintf("%v (type %s) is not nil", reflect.Indirect(value), value.Type())
  243. }
  244. return isNil(obj, msgFunc)
  245. }
  246. func isNil(obj interface{}, msgFunc func(reflect.Value) string) Comparison {
  247. return func() Result {
  248. if obj == nil {
  249. return ResultSuccess
  250. }
  251. value := reflect.ValueOf(obj)
  252. kind := value.Type().Kind()
  253. if kind >= reflect.Chan && kind <= reflect.Slice {
  254. if value.IsNil() {
  255. return ResultSuccess
  256. }
  257. return ResultFailure(msgFunc(value))
  258. }
  259. return ResultFailure(fmt.Sprintf("%v (type %s) can not be nil", value, value.Type()))
  260. }
  261. }
  262. // ErrorType succeeds if err is not nil and is of the expected type.
  263. //
  264. // Expected can be one of:
  265. //
  266. // func(error) bool
  267. //
  268. // Function should return true if the error is the expected type.
  269. //
  270. // type struct{}, type &struct{}
  271. //
  272. // A struct or a pointer to a struct.
  273. // Fails if the error is not of the same type as expected.
  274. //
  275. // type &interface{}
  276. //
  277. // A pointer to an interface type.
  278. // Fails if err does not implement the interface.
  279. //
  280. // reflect.Type
  281. //
  282. // Fails if err does not implement the [reflect.Type].
  283. //
  284. // Deprecated: Use [ErrorIs]
  285. func ErrorType(err error, expected interface{}) Comparison {
  286. return func() Result {
  287. switch expectedType := expected.(type) {
  288. case func(error) bool:
  289. return cmpErrorTypeFunc(err, expectedType)
  290. case reflect.Type:
  291. if expectedType.Kind() == reflect.Interface {
  292. return cmpErrorTypeImplementsType(err, expectedType)
  293. }
  294. return cmpErrorTypeEqualType(err, expectedType)
  295. case nil:
  296. return ResultFailure("invalid type for expected: nil")
  297. }
  298. expectedType := reflect.TypeOf(expected)
  299. switch {
  300. case expectedType.Kind() == reflect.Struct, isPtrToStruct(expectedType):
  301. return cmpErrorTypeEqualType(err, expectedType)
  302. case isPtrToInterface(expectedType):
  303. return cmpErrorTypeImplementsType(err, expectedType.Elem())
  304. }
  305. return ResultFailure(fmt.Sprintf("invalid type for expected: %T", expected))
  306. }
  307. }
  308. func cmpErrorTypeFunc(err error, f func(error) bool) Result {
  309. if f(err) {
  310. return ResultSuccess
  311. }
  312. actual := "nil"
  313. if err != nil {
  314. actual = fmt.Sprintf("%s (%T)", err, err)
  315. }
  316. return ResultFailureTemplate(`error is {{ .Data.actual }}
  317. {{- with callArg 1 }}, not {{ formatNode . }}{{end -}}`,
  318. map[string]interface{}{"actual": actual})
  319. }
  320. func cmpErrorTypeEqualType(err error, expectedType reflect.Type) Result {
  321. if err == nil {
  322. return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType))
  323. }
  324. errValue := reflect.ValueOf(err)
  325. if errValue.Type() == expectedType {
  326. return ResultSuccess
  327. }
  328. return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType))
  329. }
  330. func cmpErrorTypeImplementsType(err error, expectedType reflect.Type) Result {
  331. if err == nil {
  332. return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType))
  333. }
  334. errValue := reflect.ValueOf(err)
  335. if errValue.Type().Implements(expectedType) {
  336. return ResultSuccess
  337. }
  338. return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType))
  339. }
  340. func isPtrToInterface(typ reflect.Type) bool {
  341. return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Interface
  342. }
  343. func isPtrToStruct(typ reflect.Type) bool {
  344. return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Struct
  345. }
  346. var (
  347. stdlibErrorNewType = reflect.TypeOf(errors.New(""))
  348. stdlibFmtErrorType = reflect.TypeOf(fmt.Errorf("%w", fmt.Errorf("")))
  349. )
  350. // ErrorIs succeeds if errors.Is(actual, expected) returns true. See
  351. // [errors.Is] for accepted argument values.
  352. func ErrorIs(actual error, expected error) Comparison {
  353. return func() Result {
  354. if errors.Is(actual, expected) {
  355. return ResultSuccess
  356. }
  357. // The type of stdlib errors is excluded because the type is not relevant
  358. // in those cases. The type is only important when it is a user defined
  359. // custom error type.
  360. return ResultFailureTemplate(`error is
  361. {{- if not .Data.a }} nil,{{ else }}
  362. {{- printf " \"%v\"" .Data.a }}
  363. {{- if notStdlibErrorType .Data.a }} ({{ printf "%T" .Data.a }}){{ end }},
  364. {{- end }} not {{ printf "\"%v\"" .Data.x }} (
  365. {{- with callArg 1 }}{{ formatNode . }}{{ end }}
  366. {{- if notStdlibErrorType .Data.x }}{{ printf " %T" .Data.x }}{{ end }})`,
  367. map[string]interface{}{"a": actual, "x": expected})
  368. }
  369. }